Guest User

Untitled

a guest
Jun 28th, 2017
147
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # Modules
  2. import numpy
  3. import pandas
  4. from keras.models import Sequential
  5. from keras.layers import Dense
  6. from keras.utils import np_utils
  7. from keras.wrappers.scikit_learn import KerasClassifier
  8. from sklearn.model_selection import cross_val_score
  9. from sklearn.model_selection import KFold
  10. from sklearn.preprocessing import LabelEncoder
  11. from keras import backend as K
  12. import os
  13.  
  14.  
  15. def set_keras_backend(backend):
  16.     if K.backend() != backend:
  17.         os.environ['KERAS_BACKEND'] = backend
  18.         reload(K)
  19.         assert K.backend() == backend
  20.  
  21.  
  22. set_keras_backend("theano")
  23. # seed
  24. seed = 7
  25. numpy.random.seed(seed)
  26.  
  27. # load dataset
  28. dataFrame = pandas.read_csv("iris.csv", header=None)
  29. dataset = dataFrame.values
  30.  
  31. X = dataset[:, 0:4].astype(float)
  32. Y = dataset[:, 4]
  33.  
  34. # encode class values
  35. encoder = LabelEncoder()
  36. encoder.fit(Y)
  37. encoded_Y = encoder.transform(Y)
  38.  
  39. dummy_Y = np_utils.to_categorical(encoded_Y)
  40.  
  41.  
  42. # baseline model
  43. def baseline_model():
  44.     # create model
  45.     model = Sequential()
  46.     model.add(Dense(8, input_dim=4, kernel_initializer='normal', activation='softplus'))
  47.     model.add(Dense(3, kernel_initializer='normal', activation='softmax'))
  48.     # compile model
  49.     model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
  50.     return model
  51.  
  52.  
  53. estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=200, batch_size=5, verbose=0)
  54. kfold = KFold(n_splits=10, shuffle=True, random_state=seed)
  55.  
  56. results = cross_val_score(estimator, X, dummy_Y, cv=kfold)
  57.  
  58. print("Accuracy: %.2f%% (%.2f%%)" % (results.mean() * 100, results.std() * 100))
RAW Paste Data