Advertisement
Guest User

Untitled

a guest
Jul 21st, 2017
54
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.24 KB | None | 0 0
  1. import org.apache.spark.sql.types.FloatType
  2. import spark.implicits._
  3. import org.apache.spark.rdd.RDD
  4. import org.apache.spark.sql.Row
  5. import org.apache.spark.ml.regression.GeneralizedLinearRegression
  6. import org.apache.spark.ml.feature.VectorAssembler
  7. import org.apache.spark.ml.linalg.Vectors
  8. import org.apache.spark.mllib.stat.Statistics
  9. import org.apache.spark.ml.Pipeline
  10. import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
  11. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
  12. import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
  13. org.apache.spark.sql.types.{StructField, IntegerType, StringType, StructType}
  14.  
  15. var df_match = spark.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("dota-2-matches/match.csv")
  16. var df_players = spark.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("dota-2-matches/players.csv")
  17. var df_players_rating = spark.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("dota-2-matches/player_ratings.csv")
  18. var df_chat = spark.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("dota-2-matches/chat.csv")
  19. var df_bad_words = spark.read.format("com.databricks.spark.csv").option("header", "false").load("dota-2-matches/bad_words.csv")
  20.  
  21. df_chat = df_chat.na.drop("any", Seq("key"))
  22. df_bad_words = df_bad_words.dropDuplicates()
  23. df_match = df_match.drop("tower_status_radiant", "tower_status_dire", "barracks_status_dire", "barracks_status_radiant")
  24.  
  25. df_players = df_players.join(df_players_rating, Seq("account_id"), "inner")
  26.  
  27. var df_relevant_matches = df_players.groupBy("match_id").count().filter(row => row.getLong(1) >= 6).withColumnRenamed("count", "player_count")
  28.  
  29. df_relevant_matches.count()
  30.  
  31. df_match = df_match.join(df_relevant_matches, Seq("match_id"), "inner")
  32.  
  33. df_match = df_match.join(df_players.groupBy("match_id").max("gold_per_min", "xp_per_min"), Seq("match_id"), "inner")
  34.  
  35. df_match = df_match.join(df_players.groupBy("match_id").avg("trueskill_mu"), Seq("match_id"), "inner").withColumnRenamed("avg(trueskill_mu)", "trueskill_mu_avg")
  36.  
  37. df_match.describe("trueskill_mu_avg").show()
  38.  
  39. var bad_words = df_bad_words.collect().map(row => row.getString(0))
  40.  
  41. var assembler = new VectorAssembler().setInputCols(Array("max(gold_per_min)", "max(xp_per_min)")).setOutputCol("features")
  42.  
  43. var df_ml = assembler.transform(df_match)
  44.  
  45. var splits = df_ml.randomSplit(Array(0.5, 0.5))
  46.  
  47. var glr = new GeneralizedLinearRegression().setFamily("gaussian").setLink("identity").setLabelCol("trueskill_mu_avg").setMaxIter(10).setRegParam(0.3)
  48.  
  49. var model = glr.fit(splits(0))
  50.  
  51. var summary = model.evaluate(splits(1))
  52.  
  53.  
  54.  
  55. val labelIndexer = new StringIndexer().setInputCol("radiant_win").setOutputCol("indexedLabel").fit(df_match)
  56.  
  57. var rf_assembler = new VectorAssembler().setInputCols(Array("negative_votes", "positive_votes")).setOutputCol("features").transform(df_match)
  58.  
  59. val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(rf_assembler)
  60. val Array(trainingData, testData) = rf_assembler.randomSplit(Array(0.7, 0.3))
  61.  
  62. // Train a RandomForest model.
  63. val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10)
  64.  
  65. // Convert indexed labels back to original labels.
  66. val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
  67.  
  68.  
  69. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
  70.  
  71. // Train model. This also runs the indexers.
  72. val model = pipeline.fit(trainingData)
  73.  
  74. // Make predictions.
  75. val predictions = model.transform(testData)
  76.  
  77. // Select example rows to display.
  78. predictions.select("predictedLabel", "label", "features").show(5)
  79.  
  80. // Select (prediction, true label) and compute test error.
  81. val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
  82. val accuracy = evaluator.evaluate(predictions)
  83. println("Test Error = " + (1.0 - accuracy))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement