Guest User

Untitled

a guest
Feb 24th, 2018
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.66 KB | None | 0 0
  1. import keras
  2. from keras.models import Sequential
  3. from keras.layers import Dense, LSTM, Dropout, SimpleRNN
  4. from keras import regularizers
  5. import numpy as np
  6. import argparse
  7.  
  8.  
  9. def main():
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--input_size', type=int, required=True)
  12. parser.add_argument('--model_config', type=str, required=True,
  13. choices=['lstm', 'ffn', 'rnn'])
  14. parser.add_argument('--epochs', type=int, required=True)
  15. parser.add_argument('--reg', type=float, default=0)
  16. parser.add_argument('--hidden_size', type=int, required=True)
  17. parser.add_argument('--train_examples', type=int, default=100000)
  18. parser.add_argument('--dropout', type=int, default=0)
  19. parser.add_argument('--lr', type=float, required=True)
  20. parser.add_argument('--batch_size', type=int, required=True)
  21. args = parser.parse_args()
  22.  
  23. train_fts = np.random.randint(0, 2, (args.train_examples, args.input_size))
  24. train_lbl = (train_fts.sum(axis=1) % 2).astype(int)
  25.  
  26. test_fts = np.random.randint(0, 2, (args.train_examples // 10, args.input_size))
  27. test_lbl = (test_fts.sum(axis=1) % 2).astype(int)
  28.  
  29. set_train_fts = set()
  30. for x in train_fts:
  31. set_train_fts.add(tuple(x))
  32. count_test_present = sum(tuple(x) in set_train_fts for x in test_fts)
  33. print("Fraction of test instances present in training set: {:.3f}".format(count_test_present / len(test_fts)))
  34.  
  35. model = Sequential()
  36. if args.model_config == 'ffn':
  37. model.add(Dense(args.hidden_size, activation='relu', input_dim=args.input_size))
  38. # kernel_regularizer=regularizers.l2(args.reg),
  39. # bias_regularizer=regularizers.l2(args.reg)))
  40. if args.dropout > 0:
  41. model.add(Dropout(args.dropout))
  42. model.add(Dense(1, activation='sigmoid'))
  43. elif args.model_config == 'rnn':
  44. train_fts = np.expand_dims(train_fts, 2)
  45. test_fts = np.expand_dims(test_fts, 2)
  46. model.add(SimpleRNN(args.hidden_size, input_shape=(None, 1)))
  47. model.add(Dense(1, activation='sigmoid'))
  48. else:
  49. train_fts = np.expand_dims(train_fts, 2)
  50. test_fts = np.expand_dims(test_fts, 2)
  51. model.add(LSTM(args.hidden_size, input_shape=(None, 1)))
  52. model.add(Dense(1, activation='sigmoid'))
  53.  
  54. print("Model type:", args.model_config)
  55.  
  56. optimizer = keras.optimizers.Adam(lr=args.lr)
  57.  
  58. model.compile(loss='binary_crossentropy',
  59. optimizer=optimizer,
  60. metrics=['accuracy'])
  61.  
  62. model.fit(train_fts, train_lbl, epochs=args.epochs, batch_size=128,
  63. validation_data=(test_fts, test_lbl))
  64.  
  65.  
  66. if __name__ == "__main__":
  67. main()
Add Comment
Please, Sign In to add comment