Advertisement
Guest User

Untitled

a guest
Feb 13th, 2016
54
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.28 KB | None | 0 0
  1. import numpy as np
  2. from scipy import linalg,sparse,random
  3.  
  4. class RESCAL:
  5. def __init__(self,r,lamb_A,lamb_R):
  6. self.r = r
  7. self.lamb_A = lamb_A
  8. self.lamb_R = lamb_R
  9.  
  10. def fit(self,X,niter=30):
  11. m = len(X)
  12. n,_ = X[0].shape
  13. self.A = random.randn(n,self.r)
  14. self.R = [random.randn(self.r,self.r) for i in range(m)]
  15. t = 0
  16. while True:
  17. """ update A """
  18. AA = self.A.T.dot(self.A)
  19. F = sum([X[k].dot(self.A).dot(self.R[k].T) + X[k].T.dot(self.A).dot(self.R[k]) for k in range(m)])
  20. S = sum([self.R[k].dot(AA).dot(self.R[k].T) + self.R[k].T.dot(AA).dot(self.R[k]) for k in range(m)])
  21. S += m * self.lamb_A * np.identity(self.r)
  22. self.A = F.dot(linalg.inv(S))
  23.  
  24. """ update R """
  25. Q,A_bar = linalg.qr(self.A,mode='economic')
  26. Z = sparse.kron(A_bar,A_bar)
  27. for k in range(m):
  28. vec_Xk = Q.T.dot(X[k].dot(Q)).reshape(self.r**2,1)
  29. self.R[k] = linalg.inv(Z.T.dot(Z) + self.lamb_R*np.identity(self.r**2)).dot(Z.T.dot(vec_Xk)).reshape(self.r,self.r)
  30.  
  31. t += 1
  32. if t >= niter: break
  33.  
  34. if __name__ == '__main__':
  35.  
  36. # Example graph from ICML'11 paper
  37. m = 2 # number of edge types
  38. X = [sparse.lil_matrix((5,5)) for i in range(m)]
  39. X[0][0,1] = 1 # vicePresidentOf
  40. X[0][2,3] = 1 # vicePresidentOf
  41. X[1][0,4] = 1 # party
  42. X[1][1,4] = 1 # party
  43. X[1][2,4] = 1 # party
  44. nodenames = {0:'Lyndon', 1:'John', 2:'AI', 3:'Bill', 4:'Party X'}
  45. edgetypes = {0:'vicePresidentOf', 1:'party'}
  46.  
  47. # Parameters
  48. r = 3 # number of latent component
  49. lamb_A = 0.00001 # regularization
  50. lamb_R = 0.00001 # regularization
  51.  
  52. rescal = RESCAL(r,lamb_A,lamb_R)
  53. rescal.fit(X)
  54.  
  55. # TEST (link prediction)
  56. X_bar = [rescal.A.dot(rescal.R[i]).dot(rescal.A.T) for i in range(m)] # reconstract X
  57. estimated_facts = []
  58. for k in range(m):
  59. indices = np.where(X_bar[k]>0.1) # triples with high likelihood
  60. for i in range(indices[0].shape[0]):
  61. if X[k][indices[0][i],indices[1][i]] == 0: # only for non-existence triples
  62. estimated_facts.append((nodenames[indices[0][i]], edgetypes[k], nodenames[indices[1][i]]))
  63. for f in estimated_facts:
  64. print f
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement