Advertisement
Guest User

Untitled

a guest
Aug 29th, 2016
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.10 KB | None | 0 0
  1. import numpy as np
  2. import sys
  3. import scipy.io as sio
  4.  
  5. def sample(hprev, xt, n):
  6. for index in xt:
  7. x = np.zeros((vocab_size, 1))
  8. x[index] = 1
  9. h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, hprev) + bh)
  10. y = np.dot(Why, h) + by
  11. p = np.exp(y) / np.sum(np.exp(y))
  12. ix = np.random.choice(range(vocab_size), p=p.ravel())
  13. hprev = h
  14.  
  15. generated_seq = []
  16. generated_seq.append(ix)
  17. x = np.zeros((vocab_size, 1))
  18. x[ix] = 1
  19. h = hprev
  20. for t in range(n):
  21. h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, h) + bh)
  22. y = np.dot(Why, h) + by
  23. p = np.exp(y) / np.sum(np.exp(y))
  24. ix = np.random.choice(range(vocab_size), p=p.ravel())
  25. x = np.zeros((vocab_size, 1))
  26. x[ix] = 1
  27. generated_seq.append(ix)
  28. return generated_seq
  29.  
  30. def generate_h(xt):
  31. hidden_size = Whh.shape[0]
  32. hprev = np.zeros((hidden_size,1))
  33. for index in xt:
  34. x = np.zeros((vocab_size, 1))
  35. x[index] = 1
  36. h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, hprev) + bh)
  37. y = np.dot(Why, h) + by
  38. p = np.exp(y) / np.sum(np.exp(y))
  39. ix = np.random.choice(range(vocab_size), p=p.ravel())
  40. hprev = h
  41. return hprev
  42.  
  43. if __name__ == '__main__':
  44.  
  45. dataset = 'lstm_from_wiki.txt'
  46. data = open('data/%s' % dataset, 'r').read()
  47. chars = list(set(data))
  48. print '%d unique characters in data.' % (len(chars))
  49. vocab_size, data_size = len(chars), len(data)
  50. char_to_ix = { ch:i for i,ch in enumerate(chars) }
  51. ix_to_char = { i:ch for i,ch in enumerate(chars) }
  52.  
  53. # nb. of sequence to generate from model
  54. seq_length_sample = 2048
  55.  
  56. model = sio.loadmat(sys.argv[1])
  57. p, _, Wxh, Whh, Why, bh, by = model['p'], model['hprev'], model['Wxh'], model['Whh'], model['Why'], model['bh'], model['by']
  58. inputs = [char_to_ix[ch] for ch in data[1:p]]
  59. hidden_size = Whh.shape[0]
  60.  
  61. while (True):
  62. question = raw_input("Inital seq.: ")
  63. print "gen-seq.: "
  64. inputs = [char_to_ix[ch] for ch in question]
  65. # initial state
  66. hprev = np.zeros((hidden_size,1))
  67. sample_ix = sample(hprev, inputs, seq_length_sample)
  68. txt = ''.join(ix_to_char[ix] for ix in inputs) + ''.join(ix_to_char[ix] for ix in sample_ix)
  69. print '----\n %s \n----' % txt
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement