Advertisement
Guest User

Untitled

a guest
Oct 7th, 2015
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.65 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3.  
  4. from __future__ import absolute_import
  5.  
  6. import grammar
  7. from keras.models import Sequential
  8. from keras.layers.core import TimeDistributedDense, Masking
  9. from keras.layers.recurrent import LSTM
  10. import numpy as np
  11. import pad
  12. '''
  13. Train a LSTM on UCF11 dataset (preprocessed wirh pretrained CNN)
  14. '''
  15. np.random.seed(1337)
  16.  
  17. print 'Loading data...'
  18. nb_sample = 2048
  19. (X, y) = grammar.gen_ERber(nb_sample)
  20. split = int(len(X)*0.05)
  21. X_train, y_train = X[split:], y[split:]
  22. X_test, y_test = X[:split], y[:split]
  23.  
  24. batch_size = 64
  25. maxlen = max(len(x) for x in X) # padding the sequence to limited length
  26.  
  27. print 'padding sequence (samples x times x inputdim)'
  28. X_train = pad.pad_sequences(X_train, maxlen=maxlen)
  29. X_test = pad.pad_sequences(X_test, maxlen=maxlen)
  30. y_train = pad.pad_sequences(y_train, maxlen=maxlen)
  31. y_test = pad.pad_sequences(y_test, maxlen=maxlen)
  32.  
  33. print 'X_train shape:', X_train.shape
  34. print 'X_test shape:', X_test.shape
  35.  
  36. print 'Build model...'
  37. model = Sequential()
  38. model.add(Masking())
  39. model.add(LSTM(7, 6, return_sequences=True))
  40. model.add(TimeDistributedDense(6, 7, activation='sigmoid'))
  41.  
  42. # complie model...
  43. print 'compling model...'
  44. model.compile(loss='mse', optimizer='adam')
  45.  
  46. #model.load_weights('ereber-lstm-weights')
  47.  
  48. print X_train.shape,y_train.shape
  49. # fit model...
  50. print 'Training...'
  51. model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=100, validation_data=(X_test, y_test), show_accuracy=True)
  52. score, acc = model.evaluate(X_test, y_test, batch_size=batch_size, show_accuracy=True)
  53. print 'test\t score,\t accuracy:', (score, acc)
  54.  
  55. print 'saved model as ereber-lstm-weights'
  56. model.save_weights('ereber-lstm-weights',overwrite=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement