Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- import sys
- import itertools
- from math import sqrt
- from operator import add
- from os.path import join, isfile, dirname
- from pyspark import SparkConf, SparkContext
- from pyspark.mllib.recommendation import ALS
- def parseRating(line):
- """
- Parses a rating record in MovieLens format userId::movieId::rating::timestamp .
- """
- fields = line.strip().split("::")
- return long(fields[3]) % 10, (int(fields[0]), int(fields[1]), float(fields[2]))
- def parseMovie(line):
- """
- Parses a movie record in MovieLens format movieId::movieTitle .
- """
- fields = line.strip().split("::")
- return int(fields[0]), fields[1]
- def loadRatings(ratingsFile):
- """
- Load ratings from file.
- """
- if not isfile(ratingsFile):
- print "File %s does not exist." % ratingsFile
- sys.exit(1)
- f = open(ratingsFile, 'r')
- ratings = filter(lambda r: r[2] > 0, [parseRating(line)[1] for line in f])
- f.close()
- if not ratings:
- print "No ratings provided."
- sys.exit(1)
- else:
- return ratings
- def computeRmse(model, data, n):
- """
- Compute RMSE (Root Mean Squared Error).
- """
- predictions = model.predictAll(data.map(lambda x: (x[0], x[1])))
- predictionsAndRatings = predictions.map(lambda x: ((x[0], x[1]), x[2])) \
- .join(data.map(lambda x: ((x[0], x[1]), x[2]))) \
- .values()
- return sqrt(predictionsAndRatings.map(lambda x: (x[0] - x[1]) ** 2).reduce(add) / float(n))
- if __name__ == "__main__":
- if (len(sys.argv) != 3):
- print "Usage: /path/to/spark/bin/spark-submit --driver-memory 2g " + \
- "MovieLensALS.py movieLensDataDir personalRatingsFile"
- sys.exit(1)
- # set up environment
- conf = SparkConf() \
- .setAppName("MovieLensALS") \
- .set("spark.executor.memory", "2g")
- sc = SparkContext(conf=conf)
- # load personal ratings
- myRatings = loadRatings(sys.argv[2])
- myRatingsRDD = sc.parallelize(myRatings, 1)
- # load ratings and movie titles
- movieLensHomeDir = sys.argv[1]
- # ratings is an RDD of (last digit of timestamp, (userId, movieId, rating))
- ratings = sc.textFile(join(movieLensHomeDir, "ratings.dat")).map(parseRating)
- # movies is an RDD of (movieId, movieTitle)
- movies = dict(sc.textFile(join(movieLensHomeDir, "movies.dat")).map(parseMovie).collect())
- # your code here
- # clean up
- sc.stop()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement