Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- from scipy import linalg,sparse,random
- class RESCAL:
- def __init__(self,r,lamb_A,lamb_R):
- self.r = r
- self.lamb_A = lamb_A
- self.lamb_R = lamb_R
- def fit(self,X,niter=30):
- m = len(X)
- n,_ = X[0].shape
- self.A = random.randn(n,self.r)
- self.R = [random.randn(self.r,self.r) for i in range(m)]
- t = 0
- while True:
- """ update A """
- AA = self.A.T.dot(self.A)
- F = sum([X[k].dot(self.A).dot(self.R[k].T) + X[k].T.dot(self.A).dot(self.R[k]) for k in range(m)])
- S = sum([self.R[k].dot(AA).dot(self.R[k].T) + self.R[k].T.dot(AA).dot(self.R[k]) for k in range(m)])
- S += m * self.lamb_A * np.identity(self.r)
- self.A = F.dot(linalg.inv(S))
- """ update R """
- Q,A_bar = linalg.qr(self.A,mode='economic')
- Z = sparse.kron(A_bar,A_bar)
- for k in range(m):
- vec_Xk = Q.T.dot(X[k].dot(Q)).reshape(self.r**2,1)
- self.R[k] = linalg.inv(Z.T.dot(Z) + self.lamb_R*np.identity(self.r**2)).dot(Z.T.dot(vec_Xk)).reshape(self.r,self.r)
- t += 1
- if t >= niter: break
- if __name__ == '__main__':
- # Example graph from ICML'11 paper
- m = 2 # number of edge types
- X = [sparse.lil_matrix((5,5)) for i in range(m)]
- X[0][0,1] = 1 # vicePresidentOf
- X[0][2,3] = 1 # vicePresidentOf
- X[1][0,4] = 1 # party
- X[1][1,4] = 1 # party
- X[1][2,4] = 1 # party
- nodenames = {0:'Lyndon', 1:'John', 2:'AI', 3:'Bill', 4:'Party X'}
- edgetypes = {0:'vicePresidentOf', 1:'party'}
- # Parameters
- r = 3 # number of latent component
- lamb_A = 0.00001 # regularization
- lamb_R = 0.00001 # regularization
- rescal = RESCAL(r,lamb_A,lamb_R)
- rescal.fit(X)
- # TEST (link prediction)
- X_bar = [rescal.A.dot(rescal.R[i]).dot(rescal.A.T) for i in range(m)] # reconstract X
- estimated_facts = []
- for k in range(m):
- indices = np.where(X_bar[k]>0.1) # triples with high likelihood
- for i in range(indices[0].shape[0]):
- if X[k][indices[0][i],indices[1][i]] == 0: # only for non-existence triples
- estimated_facts.append((nodenames[indices[0][i]], edgetypes[k], nodenames[indices[1][i]]))
- for f in estimated_facts:
- print f
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement