Advertisement
Guest User

Untitled

a guest
Mar 28th, 2017
54
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.72 KB | None | 0 0
  1. import numpy as np
  2. import pandas as pd
  3. import time
  4. import scipy
  5. from IPython.display import display, HTML
  6.  
  7.  
  8. letters = ['e','t','a','i','n','o','s','h','r','d']
  9. feature_params = np.fromfile("./model/feature-params.txt", sep=" ")
  10. feature_params = feature_params.reshape((len(letters), 321))
  11. wf = feature_params
  12.  
  13. transition_params = np.fromfile("./model/transition-params.txt", sep=" ")
  14. transition_params = transition_params.reshape((len(letters), len(letters)))
  15. wt = transition_params
  16.  
  17. train_words = np.genfromtxt("./data/train_words.txt", dtype='str')
  18. test_words = np.genfromtxt("./data/test_words.txt", dtype='str')
  19.  
  20. def train_img_feat(i):
  21. r = np.fromfile("./data/train_img"+str(i+1)+".txt", sep=" ")
  22. r = r.reshape(len(train_words[i]), 321)
  23. return r
  24.  
  25. def test_img_feat(i):
  26. r = np.fromfile("./data/test_img"+str(i+1)+".txt", sep=" ")
  27. r = r.reshape(len(test_words[i]), 321)
  28. return r
  29.  
  30. def psi_y_given_x(y, x):
  31. return np.dot(wf[letters.index(y)], x.T)
  32.  
  33. def psi_y1_y2(y1, y2):
  34. return wt[letters.index(y1), letters.index(y2)]
  35.  
  36. def messages(x, wf_d, wt_d):
  37. m = dict()
  38. n=len(letters)
  39. m[len(x)+1,len(x)] = np.zeros(n)
  40. m[0,1] = np.zeros(n)
  41.  
  42. for t in range(len(x), 1, -1):
  43. s = t-1
  44. m[t, s] = np.zeros(n)
  45.  
  46. for xs in range(n):
  47. lse = []
  48. for xt in range(n):
  49. lse.append( np.dot(wf_d[xt], x[t-1].T) + wt_d[xs][xt] + m[t+1,t][xt] )
  50. m [t,s][xs] = scipy.misc.logsumexp(lse)
  51.  
  52.  
  53. for t in range(1, len(x) , 1):
  54. s = t+1
  55. m[t,s] = np.zeros(n)
  56. for xs in range(n):
  57. lse = []
  58. for xt in range(n):
  59. lse.append( np.dot(wf_d[xt], x[t-1].T) + wt_d[xs][xt] + m[t-1,t][xt] )
  60. m [t,s][xs] = scipy.misc.logsumexp(lse)
  61.  
  62. return m
  63.  
  64. def get_Z(t, x, wf_d, wt_d):
  65. m = messages(x, wf_d, wt_d)
  66. sum = 0
  67. lse = []
  68. for xt in range(len(letters)):
  69. lse.append(np.dot(wf_d[xt], x[t-1].T) + m[t-1,t][xt] + m[t+1,t][xt])
  70. return scipy.misc.logsumexp(lse)
  71.  
  72. def marginals(x,wf_d,wt_d):
  73. m = messages(x,wf_d,wt_d)
  74. marg = dict()
  75. Z = get_Z(1, x, wf_d, wt_d)
  76. for pos in range(1,len(x)+1):
  77. marg[pos] = np.zeros(len(letters))
  78. for k in range(len(letters)):
  79. marg[pos][k] = np.exp(np.dot(wf_d[k], x[pos-1].T) + m[pos-1,pos][k] + m[pos+1,pos][k] - Z)
  80. pairwise_marg=dict()
  81.  
  82. for i in range(1,len(x)):
  83. j = i+1
  84. pairwise_marg[i,j] = dict()
  85. for k1 in range(len(letters)):
  86. pairwise_marg[i,j][k1] = np.zeros(len(letters))
  87. for k2 in range(len(letters)):
  88. pairwise_marg[i,j][k1][k2] = np.exp( np.dot(wf_d[k1], x[i-1].T) + np.dot(wf_d[k2], x[j-1].T) + m[i-1,i][k1] + m[j+1,j][k2] + wt_d[k1][k2] - Z )
  89. return marg, pairwise_marg
  90.  
  91. def der_feat(d,count):
  92. reshaped = d.reshape((321+len(letters), len(letters)))
  93. wf_d = reshaped[0:321][:].T
  94. wt_d = reshaped[321:][:]
  95.  
  96.  
  97. def der(d,count):
  98. reshaped = d.reshape((321+len(letters), len(letters)))
  99. wf_d = reshaped[0:321][:].T
  100. wt_d = reshaped[321:][:]
  101.  
  102. dt = np.zeros((len(letters),len(letters)))
  103. df = np.zeros((len(letters), 321))
  104. '''
  105. for c in range(len(letters)):
  106. for c_ in range(len(letters)):
  107. sum = 0
  108. for i in range(count):
  109. x = train_img_feat(i)
  110. mg, pmg = marginals(x,wf_d,wt_d)
  111. for j in range(len(x)-1):
  112. if (train_words[i][j] == letters[c] and train_words[i][j+1] == letters[c_]):
  113. sum += 1
  114. sum -= pmg[j+1,j+2][c][c_]
  115. dt[c][c_] = -sum/count
  116. '''
  117. for i in range(count):
  118. x = train_img_feat(i)
  119. mg, pmg = marginals(x,wf_d,wt_d)
  120. for c in range(len(letters)):
  121. for c_ in range(len(letters)):
  122. for j in range(len(x)-1):
  123. if (train_words[i][j] == letters[c] and train_words[i][j+1] == letters[c_]):
  124. dt[c][c_] += 1
  125. dt[c][c_] -= pmg[j+1,j+2][c][c_]
  126. for f in range(321):
  127. for j in range(len(x)):
  128. if (train_words[i][j] == letters[c]):
  129. df[c][f] += x[j][f]
  130. df[c][f] -= mg[j+1][c] * x[j][f]
  131. #return df/-count, dt/-count
  132. return np.ravel(np.vstack((df.T/-count,dt/-count)))
  133. '''
  134.  
  135.  
  136. for c in range(len(letters)):
  137. for f in range(321):
  138. sum = 0
  139. for i in range(count):
  140. x = train_img_feat(i)
  141. mg, pmg = marginals(x,wf_d,wt_d)
  142. for j in range(len(x)):
  143. if (train_words[i][j] == letters[c]):
  144. sum += x[j][f]
  145. sum -= mg[j+1][c] * x[j][f]
  146. df[c][f] = -sum/count
  147. return df
  148. return dt
  149. '''
  150.  
  151.  
  152. def avg_log_likelihood(d,count):
  153. reshaped = d.reshape((321+len(letters), len(letters)))
  154. wf_d = reshaped[0:321][:].T
  155. wt_d = reshaped[321:][:]
  156. sum = 0
  157. for i in range(count):
  158. x = train_img_feat(i)
  159. y = train_words[i]
  160. Z = get_Z(1,x, wf_d, wt_d)
  161. psi_1 = 0
  162. for j in range(len(y)):
  163. psi_1 += np.dot(wf_d[letters.index(y[j])], x[j].T)
  164. psi_2 = 0
  165. for j in range(len(y)-1):
  166. psi_2 += wt_d[letters.index(y[j])][letters.index(y[j+1])]
  167. sum += ( (psi_1 + psi_2 - Z))
  168. return -sum/count
  169.  
  170. d=np.ravel(np.vstack((wf.T,wt)))
  171. #t1=time.time()
  172. #print(der(d,50))
  173. #t2=time.time()
  174.  
  175. #print(t2-t1)
  176. #display(pd.DataFrame(der_feat(d,50)))
  177. #print(der_trans(d,50))
  178. #print(der_trans(d,"e","e",50))
  179. #print(avg_log_likelihood(d,50))
  180. #for i in range(50):
  181. # x = train_img_feat(i)
  182. df, dt = der_trans(d, 50)
  183. display(pd.DataFrame(df))
  184. display(pd.DataFrame(dt))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement