Advertisement
Guest User

Untitled

a guest
Jan 23rd, 2017
114
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.68 KB | None | 0 0
  1. import numpy as np
  2. from keras.models import Sequential
  3. from keras.layers.core import Dense, Dropout, Activation
  4. from keras.optimizers import SGD, Adam, RMSprop
  5. from keras.utils import np_utils
  6. from time import time
  7. from pickle import dump
  8. trainfile = "train.csv"
  9. t = time()
  10. data = np.loadtxt(open(trainfile,'r'),delimiter=',',skiprows=1) #http://softwarerecs.stackexchange.com/questions/7463/fastest-python-library-to-read-a-csv-file
  11. y,X = data[:,0:1],data[:,1:]/255
  12. print("Loading :",time()-t," sec")
  13. #check shape
  14. print(X,y)
  15.  
  16. outname = "modell"
  17. dropout_rate = 0.4
  18. batch_size = 128
  19. nb_classes = 10
  20. nb_epoch = 20
  21. #split train|valid
  22. #if (False):
  23. # split_len = 2*len(X)//3
  24. # X_train,X_test = X[:split_len],X[split_len:]
  25. # y_train,y_test = y[:split_len].astype('int32'),y[split_len:].astype('int32')
  26. # y_train,y_test = np_utils.to_categorical(y_train,nb_classes),np_utils.to_categorical(y_test,nb_classes)
  27. # print(y_train,y_test)
  28. #else:
  29. X_train = X
  30. y_train = np_utils.to_categorical(y.astype('int32'),nb_classes)
  31.  
  32.  
  33. #build nn
  34. model = Sequential()
  35. model.add(Dense(512, input_shape=(784,)))
  36. model.add(Activation('relu'))
  37. model.add(Dropout(dropout_rate))
  38. model.add(Dense(512))
  39. model.add(Activation('relu'))
  40. model.add(Dropout(dropout_rate))
  41. model.add(Dense(10))
  42. model.add(Activation('softmax'))
  43. model.compile(loss='categorical_crossentropy',
  44. optimizer=RMSprop(),
  45. metrics=['accuracy'])
  46. #model.summary()
  47. history = model.fit(X_train, y_train,
  48. batch_size=batch_size, nb_epoch=nb_epoch,
  49. verbose=1)
  50. with open(outname+".json", "w") as json_file:
  51. json_file.write(model.to_json())
  52. model.save_weights(outname+".h5")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement