Guest User

Untitled

a guest
Feb 14th, 2016
49
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.03 KB | None | 0 0
  1. from __future__ import print_function
  2.  
  3. from pyspark import SparkContext, SparkConf
  4. from pyspark.mllib.linalg import DenseVector, VectorUDT
  5. from pyspark.sql import SQLContext
  6.  
  7. from pyspark.ml.classification import MultilayerPerceptronClassifier
  8. from pyspark.ml.evaluation import MulticlassClassificationEvaluator
  9. from pyspark.sql.types import StructType, StructField, StringType, DoubleType, ArrayType
  10.  
  11.  
  12. def data_frame_from_file(sqlContext, file_name, fraction):
  13. lines = sc.textFile(file_name).sample(False, fraction)
  14. parts = lines.map(lambda l: map(lambda s: int(s), l.split(",")))
  15. samples = parts.map(lambda p: (
  16. float(p[0]),
  17. DenseVector(map(lambda el: el / 255.0, p[1:]))
  18. ))
  19.  
  20. fields = [
  21. StructField("label", DoubleType(), True),
  22. StructField("features", VectorUDT(), True)
  23. ]
  24. schema = StructType(fields)
  25.  
  26. data = sqlContext.createDataFrame(samples, schema)
  27. return data
  28.  
  29.  
  30. if __name__ == "__main__":
  31. conf = SparkConf(True)
  32. conf.set("spark.executor.memory", "8g")
  33.  
  34. sc = SparkContext(
  35. master="spark://169.254.147.148:7077",
  36. appName="multilayer_perceptron_classification_example",
  37. conf=conf
  38. )
  39.  
  40. sqlContext = SQLContext(sc)
  41.  
  42. # train = data_frame_from_file(sqlContext, "mnist_train.csv", 0.01)
  43. # test = data_frame_from_file(sqlContext, "mnist_test.csv", 0.01)
  44.  
  45. train = data_frame_from_file(sqlContext, "mnist_train.csv", 1)
  46. test = data_frame_from_file(sqlContext, "mnist_test.csv", 1)
  47.  
  48. # layers = [28*28, 14*14, 5*5, 10]
  49. layers = [28*28, 1024, 10]
  50.  
  51. # create the trainer and set its parameters
  52. trainer = MultilayerPerceptronClassifier(maxIter=100, layers=layers, blockSize=128, seed=1234)
  53. # train the model
  54. model = trainer.fit(train)
  55. # compute precision on the test set
  56. result = model.transform(test)
  57. predictionAndLabels = result.select("prediction", "label")
  58. evaluator = MulticlassClassificationEvaluator(metricName="precision")
  59. print("Precision: " + str(evaluator.evaluate(predictionAndLabels)))
  60.  
  61. sc.stop()
Add Comment
Please, Sign In to add comment