Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import org.apache.spark.sql.types.FloatType
- import spark.implicits._
- import org.apache.spark.rdd.RDD
- import org.apache.spark.sql.Row
- import org.apache.spark.ml.regression.GeneralizedLinearRegression
- import org.apache.spark.ml.feature.VectorAssembler
- import org.apache.spark.ml.linalg.Vectors
- import org.apache.spark.mllib.stat.Statistics
- import org.apache.spark.ml.Pipeline
- import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
- import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
- import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
- org.apache.spark.sql.types.{StructField, IntegerType, StringType, StructType}
- var df_match = spark.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("dota-2-matches/match.csv")
- var df_players = spark.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("dota-2-matches/players.csv")
- var df_players_rating = spark.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("dota-2-matches/player_ratings.csv")
- var df_chat = spark.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("dota-2-matches/chat.csv")
- var df_bad_words = spark.read.format("com.databricks.spark.csv").option("header", "false").load("dota-2-matches/bad_words.csv")
- df_chat = df_chat.na.drop("any", Seq("key"))
- df_bad_words = df_bad_words.dropDuplicates()
- df_match = df_match.drop("tower_status_radiant", "tower_status_dire", "barracks_status_dire", "barracks_status_radiant")
- df_players = df_players.join(df_players_rating, Seq("account_id"), "inner")
- var df_relevant_matches = df_players.groupBy("match_id").count().filter(row => row.getLong(1) >= 6).withColumnRenamed("count", "player_count")
- df_relevant_matches.count()
- df_match = df_match.join(df_relevant_matches, Seq("match_id"), "inner")
- df_match = df_match.join(df_players.groupBy("match_id").max("gold_per_min", "xp_per_min"), Seq("match_id"), "inner")
- df_match = df_match.join(df_players.groupBy("match_id").avg("trueskill_mu"), Seq("match_id"), "inner").withColumnRenamed("avg(trueskill_mu)", "trueskill_mu_avg")
- df_match.describe("trueskill_mu_avg").show()
- var bad_words = df_bad_words.collect().map(row => row.getString(0))
- var assembler = new VectorAssembler().setInputCols(Array("max(gold_per_min)", "max(xp_per_min)")).setOutputCol("features")
- var df_ml = assembler.transform(df_match)
- var splits = df_ml.randomSplit(Array(0.5, 0.5))
- var glr = new GeneralizedLinearRegression().setFamily("gaussian").setLink("identity").setLabelCol("trueskill_mu_avg").setMaxIter(10).setRegParam(0.3)
- var model = glr.fit(splits(0))
- var summary = model.evaluate(splits(1))
- val labelIndexer = new StringIndexer().setInputCol("radiant_win").setOutputCol("indexedLabel").fit(df_match)
- var rf_assembler = new VectorAssembler().setInputCols(Array("negative_votes", "positive_votes")).setOutputCol("features").transform(df_match)
- val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(rf_assembler)
- val Array(trainingData, testData) = rf_assembler.randomSplit(Array(0.7, 0.3))
- // Train a RandomForest model.
- val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10)
- // Convert indexed labels back to original labels.
- val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
- val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
- // Train model. This also runs the indexers.
- val model = pipeline.fit(trainingData)
- // Make predictions.
- val predictions = model.transform(testData)
- // Select example rows to display.
- predictions.select("predictedLabel", "label", "features").show(5)
- // Select (prediction, true label) and compute test error.
- val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
- val accuracy = evaluator.evaluate(predictions)
- println("Test Error = " + (1.0 - accuracy))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement