SHARE
TWEET

Untitled

a guest Sep 16th, 2015 223 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import sys
  2. import random
  3. import numpy
  4. from pyspark import SparkContext, SparkConf
  5.  
  6. def file2triplets(infile):
  7.     # Returns a triplet of (movieid, userid, rating) for each line in the files
  8.     lines = infile[1].split("\n")
  9.     return [(int(lines[0].strip(':')), int(line.split(",")[0]), int(line.split(",")[1])) for line in lines[1:-1]]
  10.  
  11. def hashfunc(idx, numworkers, seed):
  12.     return hash(str(idx) + str(seed)) % numworkers
  13.  
  14. def updateWH((Vblock, Wblock, Hblock), num_updates, beta_val, lambda_val, Ni, Nj):
  15.     Wdict = dict(Wblock)
  16.     Hdict = dict(Hblock)
  17.     it=0
  18.     for (movieid, userid, rating) in Vblock:
  19.         # Compute the number of updates
  20.         it += 1
  21.         eps_val = pow(100+num_updates+it, -beta_val)
  22.         #print "M: " + str(movieid) + "U: " + str(userid) + "R: " + str(rating)
  23.         Wi = Wdict[movieid]
  24.         Hj = Hdict[userid]
  25.         WiHj = numpy.dot(Wi,Hj)
  26.         # L_NZSL loss gradient coefficient
  27.         LNZSL_coeff = -2*(rating - WiHj)
  28.         # L_2 loss gradient coefficients
  29.         L2_coeff1 = 2*lambda_val/Ni[movieid]
  30.         L2_coeff2 = 2*lambda_val/Nj[userid]
  31.         Wdict[movieid] = Wi - eps_val*(LNZSL_coeff*Hj + L2_coeff1*Wi)      
  32.         Hdict[userid]  = Hj - eps_val*(LNZSL_coeff*Wi + L2_coeff2*Hj)
  33.     return (Wdict.items(), Hdict.items())
  34.  
  35. def lossNZSL(Ventry, W, H):
  36.     return pow(Ventry[2] - numpy.dot(W[Ventry[0]],H[Ventry[1]]),2)
  37.  
  38. # Read command line arguments
  39. num_factors = int(sys.argv[1])
  40. num_workers = int(sys.argv[2])
  41. num_iterations = int(sys.argv[3])
  42. beta_val = float(sys.argv[4])
  43. lambda_val = float(sys.argv[5])
  44. inputV_path = sys.argv[6]
  45. outputW_path = sys.argv[7]
  46. outputH_path = sys.argv[8]
  47.  
  48. conf = SparkConf().setAppName("dsgd_mf").setMaster("local")
  49. sc = SparkContext(conf=conf)
  50.  
  51. # V = sc.wholeTextFiles(inputV_path)
  52. # triples = V.flatMap(file2triplets)
  53.  
  54. triples = sc.textFile(inputV_path).map(lambda a: [int(x) for x in a.split(",")])
  55.  
  56. triples.persist()
  57. num_movies = triples.map(lambda trip : trip[0]).reduce(max)
  58. num_users = triples.map(lambda trip : trip[1]).reduce(max)
  59. Ni = triples.keyBy(lambda trip: trip[0]).countByKey()
  60. Nj = triples.keyBy(lambda trip: trip[1]).countByKey()
  61.  
  62. # W*H = V
  63. # W is a list of (movieid, factors) kv pairs
  64. # H is a list of (userid, factors) kv pairs
  65. #   where factors is a list of floats of length num_factors
  66. W = sc.parallelize(range(num_movies+1)).map(lambda a : (a, numpy.random.rand(num_factors))).persist()
  67. H = sc.parallelize(range(num_users+1)).map(lambda a : (a, numpy.random.rand(num_factors))).persist()
  68.  
  69. num_updates = 0
  70. loss_all = []
  71. f = open('log','w')
  72.  
  73. for it in range(num_iterations):
  74.     seed = random.randrange(100000)
  75.     # Get the diagonal blocks of V
  76.     filtered = triples.filter(lambda trip : hashfunc(trip[0],num_workers,seed) == hashfunc(trip[1],num_workers,seed)).persist()
  77.     Vblocks = filtered.keyBy(lambda trip : hashfunc(trip[0], num_workers, seed))
  78.     cur_num_updates = filtered.count()
  79.     filtered.unpersist()    
  80.     # Partition W and H and group them with V by they block number
  81.     Wblocks = W.keyBy(lambda pair: hashfunc(pair[0], num_workers, seed))
  82.     Hblocks = H.keyBy(lambda pair: hashfunc(pair[0], num_workers, seed))
  83.     grouped = Vblocks.groupWith(Wblocks, Hblocks).coalesce(num_workers)
  84.     # Perform the updates to W and H in parallel
  85.     updatedWH = grouped.map(lambda a: updateWH(a[1], num_updates, beta_val, lambda_val, Ni, Nj)).persist()
  86.     W = updatedWH.flatMap(lambda a: a[0]).persist()
  87.     H = updatedWH.flatMap(lambda a: a[1]).persist()
  88.     Wdict = dict(W.collect())
  89.     Hdict = dict(H.collect())
  90.     loss = triples.map(lambda a: lossNZSL(a, Wdict, Hdict)).reduce(lambda a,b: a+b)
  91.     print "L_NZSL: " + str(loss)
  92.     f.write("Iteration: " + str(it) + "\t L_NZSL: " + str(loss) + "\n")
  93.     loss_all.append(loss)
  94.     num_updates += cur_num_updates
  95.  
  96. f.close()
  97.  
  98. Wpy = numpy.vstack(W.sortByKey().map(lambda a : a[1]).collect())
  99. numpy.savetxt(outputW_path, Wpy, delimiter=',')
  100.  
  101. Hpy = numpy.vstack(H.sortByKey().map(lambda a : a[1]).collect())
  102. numpy.savetxt(outputH_path, Hpy, delimiter=',')
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top