Advertisement
Guest User

Untitled

a guest
Feb 22nd, 2018
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 11.71 KB | None | 0 0
  1. package adapters
  2.  
  3. // Esse é um exemplo de classificador com multiplas classes utilizando Naive Bayes
  4. // FONTE DO DATASET: https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html
  5. // ===========================================================================================================
  6. // Aproveitei para fazer alguns testes de como salvar o modelo já treinado:
  7. // PARA SALVAR:
  8. // modelo.save("modelo")
  9. // PARA CARREGAR (via console Scala REPL (Read-Evaluate-Print Loop)):
  10. // import org.apache.spark.ml.PipelineModel
  11. // val modelo_treinado = PipelineModel.load("modelo")
  12. // PARA USAR:
  13. // val data2 = spark.read.format("libsvm").option("header", "false").option("inferSchema", "true").load("news20.full")
  14. // modelo_treinado.transform(data2)
  15. //
  16. // Exemplo: Carregando o modelo Naive Bayes já salvo da etapa anterior
  17.  
  18.  
  19. import org.apache.spark.sql.SparkSession
  20. import org.apache.spark.ml.Pipeline
  21. import org.apache.spark.ml.classification.NaiveBayes
  22. import org.apache.spark.ml.classification.LogisticRegression
  23. import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
  24. import org.apache.log4j._
  25. import org.apache.spark.sql._
  26.  
  27. import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer, StopWordsRemover}
  28. import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
  29. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
  30. import org.apache.spark.mllib.evaluation.MulticlassMetrics
  31. import org.apache.spark.sql.functions._
  32.  
  33.  
  34. object SparkML4_Testing {
  35.  
  36. def main(args: Array[String]): Unit = {
  37. Logger.getLogger("org").setLevel(Level.OFF)
  38. Logger.getLogger("akka").setLevel(Level.OFF)
  39. Logger.getLogger("org.apache.spark.SparkContext").setLevel(Level.OFF)
  40.  
  41. // Spark Session
  42. val spark = (SparkSession.builder()
  43. .master("local[*]")
  44. .getOrCreate())
  45. spark.sparkContext.setLogLevel("OFF")
  46.  
  47. // Carrega os dados de treinamento e teste
  48. val entrada = "data/IMDB/trainning1.csv"
  49. val test = "data/IMDB/test1.csv"
  50. /*val trainData = (spark.read.format("csv") //desabilitado para os testes nos métodos IFEEL
  51. .option("header","false")
  52. .option("delimiter","\t")
  53. .option("inferSchema","true")
  54. .load(entrada).toDF("sentence","label"))*/
  55.  
  56. val testData = (spark.read.format("csv")
  57. .option("header","false")
  58. .option("delimiter","\t")
  59. .option("inferSchema","true")
  60. .load(test).toDF("sentence","label"))
  61.  
  62. /*
  63. // CONVERTENDO TEXTO PARA O FORMATO LIBSVM ========================================
  64. // Define tokenizador
  65. val tokenizer = new Tokenizer()
  66. .setInputCol("sentence")
  67. .setOutputCol("words")
  68.  
  69. // Define removedor de stopwords
  70. val remover = new StopWordsRemover()
  71. .setInputCol("words")
  72. .setOutputCol("filtered")
  73.  
  74. // Cria e configura TF-IDF
  75. val numFeatures = 5000
  76. //val minDocFreq = 5 //NB
  77. val minDocFreq = 1 //RF
  78.  
  79. // TF
  80. val hashingTF = new HashingTF()
  81. .setInputCol("filtered")
  82. .setOutputCol("tf")
  83. .setNumFeatures(numFeatures)
  84.  
  85. // IDF
  86. val idf = new IDF()
  87. .setInputCol("tf")
  88. .setOutputCol("features")
  89. .setMinDocFreq(minDocFreq)
  90.  
  91. // Carrega instancia do classificador NaiveBayes
  92. //val modelo = new NaiveBayes().setSmoothing(0.2)
  93. //val modelo = new LogisticRegression()
  94. val modelo = new RandomForestClassifier().setNumTrees(5).setMaxBins(27).setMaxDepth(15)
  95.  
  96. val pipeline = new Pipeline()
  97. .setStages(Array(tokenizer, remover, hashingTF, idf, modelo))
  98.  
  99. val trainned_model = pipeline.fit(trainData)
  100.  
  101. //executa a classificação de sentimentos
  102. val predicoes = trainned_model.transform(testData)
  103.  
  104. //predicoes.show()
  105. //predicoes.printSchema()
  106. */
  107.  
  108. //TESTES COM IFEEL ============================================================
  109. import spark.implicits._
  110. var resultado: Option[DataFrame] = None
  111.  
  112. /*import adapters.Afinn
  113. val afinn = spark.udf.register("Metodo", (input: String) => { Afinn.as(input) })
  114. resultado = Some(testData.withColumn("analise", afinn(testData.col("sentence"))))*/
  115.  
  116. /*import adapters.Emolex
  117. val emolex = spark.udf.register("Metodo", (input: String) => { Emolex.as(input) })
  118. resultado = Some(testData.withColumn("analise", emolex(testData.col("sentence"))))*/
  119.  
  120. /*import adapters.Emoticons
  121. val emoticons = spark.udf.register("Metodo", (input: String) => { Emoticons.as(input) })
  122. resultado = Some(testData.withColumn("analise", emoticons(testData.col("sentence"))))*/
  123.  
  124. /*import adapters.EmoticonDS
  125. val emoticonDS = spark.udf.register("Metodo", (input: String) => { EmoticonDS.as(input) })
  126. resultado = Some(testData.withColumn("analise", emoticonDS(testData.col("sentence"))))*/
  127.  
  128. /*import adapters.HappinessIndex
  129. val happiness = spark.udf.register("Metodo", (input: String) => { HappinessIndex.as(input) })
  130. resultado = Some(testData.withColumn("analise", happiness(testData.col("sentence"))))*/
  131.  
  132. /*import adapters.MPQA
  133. val mpqa = spark.udf.register("Metodo", (input: String) => { MPQA.as(input) })
  134. resultado = Some(testData.withColumn("analise", mpqa(testData.col("sentence"))))*/
  135.  
  136. /*import adapters.NRC
  137. val nrc = spark.udf.register("Metodo", (input: String) => { NRC.as(input) })
  138. resultado = Some(testData.withColumn("analise", nrc(testData.col("sentence"))))*/
  139.  
  140. /*import adapters.Opinion
  141. val opinion = spark.udf.register("Metodo", (input: String) => { Opinion.as(input) })
  142. resultado = Some(testData.withColumn("analise", opinion(testData.col("sentence"))))*/
  143.  
  144. /*import adapters.PanasT
  145. val panasT = spark.udf.register("Metodo", (input: String) => { PanasT.as(input) })
  146. resultado = Some(testData.withColumn("analise", panasT(testData.col("sentence"))))*/
  147.  
  148. /*import adapters.Sann
  149. val sann = spark.udf.register("Metodo", (input: String) => { Sann.as(input) })
  150. resultado = Some(testData.withColumn("analise", sann(testData.col("sentence"))))*/
  151.  
  152. /*import adapters.Sasa
  153. val sasa = spark.udf.register("Metodo", (input: String) => { Sasa.as(input) })
  154. resultado = Some(testData.withColumn("analise", sasa(testData.col("sentence"))))*/
  155.  
  156. /*import adapters.SenticNet
  157. val senticNet = spark.udf.register("Metodo", (input: String) => { SenticNet.as(input) })
  158. resultado = Some(testData.withColumn("analise", senticNet(testData.col("sentence"))))*/
  159.  
  160. /*import adapters.Sentiment140
  161. val sentiment140 = spark.udf.register("Metodo", (input: String) => { Sentiment140.as(input) })
  162. resultado = Some(testData.withColumn("analise", sentiment140(testData.col("sentence"))))*/
  163.  
  164. /*import adapters.SentiStrength
  165. val sentiStrength = spark.udf.register("Metodo", (input: String) => { SentiStrength.as(input) })
  166. resultado = Some(testData.withColumn("analise", sentiStrength(testData.col("sentence"))))*/
  167.  
  168. import adapters.SentiWordNet
  169. val sentiWordNet = spark.udf.register("Metodo", (input: String) => { SentiWordNet.as(input) })
  170. resultado = Some(testData.withColumn("analise", sentiWordNet(testData.col("sentence"))))
  171.  
  172. val predictions = resultado.get.select( when($"analise" === -1, 0.0)
  173. .when($"analise" === 0, 1.0)
  174. .when($"analise" === 1, 2.0)
  175. .as("predicao"))
  176. .rdd.map(_.getDouble(0))
  177. val labels = resultado.get.select("label")
  178. .rdd.map(_.getDouble(0))
  179.  
  180. //MÉTRICS ===============================================================
  181.  
  182. //val predictions = resultado.get.select("prediction").rdd.map(_.getDouble(0)) //Spark implementations
  183. //val labels = resultado.get.select("label").rdd.map(_.getDouble(0)) //Spark implementations
  184.  
  185. val predictionAndLabels = predictions.zip(labels)
  186. val pl = predictionAndLabels.collect()
  187.  
  188. val metrics = new MulticlassMetrics(predictionAndLabels)
  189.  
  190. // Confusion matrix
  191. //println("Confusion matrix:")
  192. println(metrics.confusionMatrix)
  193.  
  194. // Overall Statistics
  195. val accuracy = metrics.accuracy
  196. //println("Summary Statistics")
  197. //println(s"Accuracy = $accuracy")
  198. println(accuracy)
  199.  
  200. // Precision by label
  201. val labels2 = metrics.labels
  202. labels2.foreach { l =>
  203. //println(s"Precision($l) = " + metrics.precision(l))
  204. println(metrics.precision(l))
  205. }
  206.  
  207. // Recall by label
  208. labels2.foreach { l =>
  209. //println(s"Recall($l) = " + metrics.recall(l))
  210. println(metrics.recall(l))
  211. }
  212.  
  213. // False positive rate by label
  214. labels2.foreach { l =>
  215. //println(s"FPR($l) = " + metrics.falsePositiveRate(l))
  216. println(metrics.falsePositiveRate(l))
  217. }
  218.  
  219. // F-measure by label
  220. labels2.foreach { l =>
  221. //println(s"F1-Score($l) = " + metrics.fMeasure(l))
  222. println(metrics.fMeasure(l))
  223. }
  224.  
  225. // Weighted stats
  226. //println(s"Weighted precision = ${metrics.weightedPrecision}")
  227. //println(s"Weighted recall = ${metrics.weightedRecall}")
  228. //println(s"Weighted F1 score = ${metrics.weightedFMeasure}")
  229. //println(s"Weighted false positive rate = ${metrics.weightedFalsePositiveRate}")
  230. println(metrics.weightedPrecision)
  231. println(metrics.weightedRecall)
  232. println(metrics.weightedFMeasure)
  233. println(metrics.weightedFalsePositiveRate)
  234.  
  235. //possible metrics: f1-score, precision, recall, weightedPrecision and weightedRecall
  236. /*val evaluator = new MulticlassClassificationEvaluator()
  237. .setLabelCol("label")
  238. .setPredictionCol("prediction")
  239.  
  240. val accuracy = evaluator.setMetricName("accuracy").evaluate(predicoes)
  241. val weightedPrecision = evaluator.setMetricName("weightedPrecision").evaluate(predicoes)
  242. val weightedRecall = evaluator.setMetricName("weightedRecall").evaluate(predicoes)
  243. val f1 = evaluator.setMetricName("f1").evaluate(predicoes)
  244.  
  245. println()
  246. println("Test accuracy = " + accuracy)
  247. println("Test weightedPrecision = " + weightedPrecision)
  248. println("Test weightedRecall = " + weightedRecall)
  249. println("Test f1_score = " + f1)*/
  250.  
  251. /*import spark.implicits._
  252. val toDouble = udf[Double, String]( _.toDouble)
  253.  
  254. val predictionAndLabels = predicoes.withColumn("prediction", predicoes("prediction"))
  255. .withColumn("label", predicoes("label"))
  256. .rdd.map(r => (r.getDouble(0), r.getDouble(1)))*/
  257.  
  258. /*val metrics = new MulticlassMetrics(predictionAndLabels)
  259. println("Confusion Matrix: " + metrics.confusionMatrix)*/
  260.  
  261. spark.stop()
  262. }
  263. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement