Guest User

Untitled

a guest
Jul 21st, 2018
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.28 KB | None | 0 0
  1. library(Rcpp)
  2. library(inline)
  3. plug <- Rcpp:::Rcpp.plugin.maker(include.before = "#include <daal.h> ", libs = paste("-L$DAALROOT/lib/ -ldaal_core -ldaal_thread ","-ltbb -lpthread -lm", sep=""))
  4. registerPlugin("daalNB", plug)
  5.  
  6. readCSV <- '
  7. using namespace daal;
  8. using namespace daal::data_management;
  9. // Inputs:
  10. // file - file name
  11. // ncols - number of columns in file
  12. std::string fname = Rcpp::as<std::string>(file);
  13. int k = Rcpp::as<int>(ncols);
  14. // Data source
  15. FileDataSource<CSVFeatureManager> dataSource(fname, DataSource::notAllocateNumericTable, DataSource::doDictionaryFromContext);
  16. // DAAL NumericTables for data and labels
  17. ervices::SharedPtr<NumericTable> data(
  18. new HomogenNumericTable<double>(k-1, 0, NumericTable::notAllocate));
  19. services::SharedPtr<NumericTable> labels(
  20. new HomogenNumericTable<int>(1, 0, NumericTable::notAllocate));
  21. services::SharedPtr<NumericTable> merged(new MergedNumericTable(data, labels));
  22. // Load data
  23. dataSource.loadDataBlock(merged.get());
  24. // Serialize NumericTables
  25. InputDataArchive dataArch, labelsArch;
  26. data->serialize(dataArch);
  27. labels->serialize(labelsArch);
  28. Rcpp::RawVector dataBytes(dataArch.getSizeOfArchive());
  29. dataArch.copyArchiveToArray(&dataBytes[0], dataArch.getSizeOfArchive());
  30. Rcpp::RawVector labelsBytes(labelsArch.getSizeOfArchive());
  31. abelsArch.copyArchiveToArray(&labelsBytes[0], labelsArch.getSizeOfArchive());
  32. // Return a list of RawVectors
  33. return Rcpp::List::create(
  34. ["data"] = dataBytes,
  35. ["labels"] = labelsBytes);'
  36.  
  37. train <- '
  38. using namespace daal;
  39. using namespace daal::algorithms;
  40. using namespace daal::algorithms::multinomial_naive_bayes;
  41. using namespace daal::data_management;
  42. // Inputs:
  43. // X - training dataset
  44. // y - training data groundtruth
  45. // nclasses - number of classes
  46. Rcpp::RawVector Xr(X);
  47. Rcpp::RawVector yr(y);
  48. int nClasses = Rcpp::as<int>(nclasses);
  49. // Deserialize data and labels
  50. OutputDataArchive dataArch(&Xr[0], Xr.length());
  51. services::SharedPtr<NumericTable> ntData(new HomogenNumericTable<double>());
  52. ntData->deserialize(dataArch);
  53. OutputDataArchive labelsArch(&yr[0], yr.length());
  54. services::SharedPtr<NumericTable> ntLabels(new HomogenNumericTable<int>());
  55. ntLabels->deserialize(labelsArch);
  56. // Train a model
  57. training::Batch<> algorithm(nClasses);
  58. algorithm.input.set(classifier::training::data, ntData);
  59. algorithm.input.set(classifier::training::labels, ntLabels);
  60. algorithm.compute();
  61. // Get result
  62. services::SharedPtr<training::Result> result = algorithm.getResult();
  63. InputDataArchive archive;
  64. result->get(classifier::training::model)->serialize(archive);
  65. Rcpp::RawVector out(archive.getSizeOfArchive());
  66. archive.copyArchiveToArray(&out[0], archive.getSizeOfArchive());
  67. return out;'
  68.  
  69.  
  70. # Naive Bayes: predict
  71. predict <- '
  72. using namespace daal;
  73. using namespace daal::algorithms;
  74. using namespace daal::algorithms::multinomial_naive_bayes;
  75. using namespace daal::data_management;
  76. // Inputs:
  77. // model - a trained model
  78. // X - input data
  79. // nclasses - number of classes
  80. Rcpp::RawVector modelBytes(model);
  81. Rcpp::RawVector dataBytes(X);
  82. int nClasses = Rcpp::as<int>(nclasses);
  83. // Retrieve model
  84. OutputDataArchive modelArch(&modelBytes[0], modelBytes.length());
  85. services::SharedPtr<multinomial_naive_bayes::Model> nb(
  86. new multinomial_naive_bayes::Model());
  87. nb->deserialize(modelArch);
  88. // Deserialize data
  89. OutputDataArchive dataArch(&dataBytes[0], dataBytes.length());
  90. services::SharedPtr<NumericTable> ntData(new HomogenNumericTable<double>());
  91. ntData->deserialize(dataArch);
  92. // Predict for new data
  93. prediction::Batch<> algorithm(nClasses);
  94. algorithm.input.set(classifier::prediction::data, ntData);
  95. algorithm.input.set(classifier::prediction::model, nb);
  96. algorithm.compute();
  97. // Return newlabels
  98. services::SharedPtr<NumericTable> predictionResult =
  99. algorithm.getResult()->get(classifier::prediction::prediction);
  100. BlockDescriptor<int> block;
  101. int n = predictionResult->getNumberOfRows();
  102. predictionResult->getBlockOfRows(0, n, readOnly, block);
  103. int* newlabels = block.getBlockPtr();
  104. IntegerVector predictedLabels(n);
  105. std::copy(newlabels, newlabels+n, predictedLabels.begin());
  106. return predictedLabels;'
  107.  
  108. loadData <- cxxfunction(signature(file="character", ncols="integer"), readCSV, plugin="daalNB")
  109. nbTrain <- cxxfunction(signature(X="raw", y="raw", nclasses="integer"), train, plugin="daalNB")
  110. nbPredict <- cxxfunction(signature(model="raw", X="raw", nclasses="integer"), predict, plugin="daalNB")
Add Comment
Please, Sign In to add comment