Advertisement
Guest User

Untitled

a guest
Jan 24th, 2017
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.65 KB | None | 0 0
  1. from keras.models import Sequential
  2. from keras.layers import Dense, Activation, Dropout
  3. from keras.layers.recurrent import GRU, LSTM
  4. from keras.preprocessing import sequence
  5. import numpy as np
  6. from BuySessionData import *
  7.  
  8.  
  9. def generator(batch_size, test_type='train'):
  10.     data = BuySessionData('/home/eric/data/formatted/buy-sessions-train.dat',
  11.                           '/home/eric/data/formatted/buy-sessions-test.dat')
  12.     while True:
  13.         x, y = data.next_train_batch(batch_size) if test_type == 'train'\
  14.             else data.next_test_batch(batch_size, test_type)
  15.         x = sequence.pad_sequences(x)
  16.         yield (x, y)
  17.  
  18. # Parameters
  19. HIDDEN_SIZE = 64
  20. BATCH_SIZE = 50
  21. LAYERS = 2
  22. FEATURES = 4 + 339
  23. CLASSES = 2
  24. EPOCH_SAMPLES = 100000
  25. EPOCHS = 1
  26. TEST_SAMPLES = 25000
  27.  
  28. model = Sequential()
  29. model.add(LSTM(HIDDEN_SIZE, input_shape=(None, FEATURES), return_sequences=True))
  30. model.add(Activation('relu'))
  31.  
  32. model.add(LSTM(HIDDEN_SIZE, input_shape=(None, FEATURES)))
  33. model.add(Activation('relu'))
  34.  
  35. model.add(Dense(CLASSES))
  36. model.add(Activation('softmax'))
  37.  
  38. model.compile(loss='binary_crossentropy',
  39.               optimizer='rmsprop',
  40.               metrics=['binary_accuracy'])
  41.  
  42. model.fit_generator(generator(BATCH_SIZE),
  43.                     samples_per_epoch=EPOCH_SAMPLES,
  44.                     nb_epoch=EPOCHS,
  45.                     class_weight={0: 5.5, 1: 94.5})
  46.  
  47. score, acc = model.evaluate_generator(generator(BATCH_SIZE, 'buys'), val_samples=TEST_SAMPLES)
  48. print('Test accuracy buys:', acc)
  49.  
  50. score, acc = model.evaluate_generator(generator(BATCH_SIZE, 'non_buys'), val_samples=TEST_SAMPLES)
  51. print('Test accuracy non buys:', acc)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement