• API
• FAQ
• Tools
• Archive
SHARE
TWEET

# Untitled

a guest Sep 16th, 2015 223 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import sys
2. import random
3. import numpy
4. from pyspark import SparkContext, SparkConf
5.
6. def file2triplets(infile):
7.     # Returns a triplet of (movieid, userid, rating) for each line in the files
8.     lines = infile[1].split("\n")
9.     return [(int(lines[0].strip(':')), int(line.split(",")[0]), int(line.split(",")[1])) for line in lines[1:-1]]
10.
11. def hashfunc(idx, numworkers, seed):
12.     return hash(str(idx) + str(seed)) % numworkers
13.
14. def updateWH((Vblock, Wblock, Hblock), num_updates, beta_val, lambda_val, Ni, Nj):
15.     Wdict = dict(Wblock)
16.     Hdict = dict(Hblock)
17.     it=0
18.     for (movieid, userid, rating) in Vblock:
19.         # Compute the number of updates
20.         it += 1
22.         #print "M: " + str(movieid) + "U: " + str(userid) + "R: " + str(rating)
23.         Wi = Wdict[movieid]
24.         Hj = Hdict[userid]
25.         WiHj = numpy.dot(Wi,Hj)
26.         # L_NZSL loss gradient coefficient
27.         LNZSL_coeff = -2*(rating - WiHj)
28.         # L_2 loss gradient coefficients
29.         L2_coeff1 = 2*lambda_val/Ni[movieid]
30.         L2_coeff2 = 2*lambda_val/Nj[userid]
31.         Wdict[movieid] = Wi - eps_val*(LNZSL_coeff*Hj + L2_coeff1*Wi)
32.         Hdict[userid]  = Hj - eps_val*(LNZSL_coeff*Wi + L2_coeff2*Hj)
33.     return (Wdict.items(), Hdict.items())
34.
35. def lossNZSL(Ventry, W, H):
36.     return pow(Ventry[2] - numpy.dot(W[Ventry[0]],H[Ventry[1]]),2)
37.
38. # Read command line arguments
39. num_factors = int(sys.argv[1])
40. num_workers = int(sys.argv[2])
41. num_iterations = int(sys.argv[3])
42. beta_val = float(sys.argv[4])
43. lambda_val = float(sys.argv[5])
44. inputV_path = sys.argv[6]
45. outputW_path = sys.argv[7]
46. outputH_path = sys.argv[8]
47.
48. conf = SparkConf().setAppName("dsgd_mf").setMaster("local")
49. sc = SparkContext(conf=conf)
50.
51. # V = sc.wholeTextFiles(inputV_path)
52. # triples = V.flatMap(file2triplets)
53.
54. triples = sc.textFile(inputV_path).map(lambda a: [int(x) for x in a.split(",")])
55.
56. triples.persist()
57. num_movies = triples.map(lambda trip : trip[0]).reduce(max)
58. num_users = triples.map(lambda trip : trip[1]).reduce(max)
59. Ni = triples.keyBy(lambda trip: trip[0]).countByKey()
60. Nj = triples.keyBy(lambda trip: trip[1]).countByKey()
61.
62. # W*H = V
63. # W is a list of (movieid, factors) kv pairs
64. # H is a list of (userid, factors) kv pairs
65. #   where factors is a list of floats of length num_factors
66. W = sc.parallelize(range(num_movies+1)).map(lambda a : (a, numpy.random.rand(num_factors))).persist()
67. H = sc.parallelize(range(num_users+1)).map(lambda a : (a, numpy.random.rand(num_factors))).persist()
68.
70. loss_all = []
71. f = open('log','w')
72.
73. for it in range(num_iterations):
74.     seed = random.randrange(100000)
75.     # Get the diagonal blocks of V
76.     filtered = triples.filter(lambda trip : hashfunc(trip[0],num_workers,seed) == hashfunc(trip[1],num_workers,seed)).persist()
77.     Vblocks = filtered.keyBy(lambda trip : hashfunc(trip[0], num_workers, seed))
79.     filtered.unpersist()
80.     # Partition W and H and group them with V by they block number
81.     Wblocks = W.keyBy(lambda pair: hashfunc(pair[0], num_workers, seed))
82.     Hblocks = H.keyBy(lambda pair: hashfunc(pair[0], num_workers, seed))
83.     grouped = Vblocks.groupWith(Wblocks, Hblocks).coalesce(num_workers)
84.     # Perform the updates to W and H in parallel
85.     updatedWH = grouped.map(lambda a: updateWH(a[1], num_updates, beta_val, lambda_val, Ni, Nj)).persist()
86.     W = updatedWH.flatMap(lambda a: a[0]).persist()
87.     H = updatedWH.flatMap(lambda a: a[1]).persist()
88.     Wdict = dict(W.collect())
89.     Hdict = dict(H.collect())
90.     loss = triples.map(lambda a: lossNZSL(a, Wdict, Hdict)).reduce(lambda a,b: a+b)
91.     print "L_NZSL: " + str(loss)
92.     f.write("Iteration: " + str(it) + "\t L_NZSL: " + str(loss) + "\n")
93.     loss_all.append(loss)