Advertisement
Guest User

Untitled

a guest
Jul 25th, 2017
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.22 KB | None | 0 0
  1. import tensorflow as tf
  2. import tensorlayer as tl
  3.  
  4. sess = tf.InteractiveSession()
  5.  
  6. # prepare data
  7. X_train, y_train, X_val, y_val, X_test, y_test = \
  8. tl.files.load_mnist_dataset(shape=(-1,784))
  9.  
  10. # define placeholder
  11. x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
  12. y_ = tf.placeholder(tf.int64, shape=[None, ], name='y_')
  13.  
  14. # define the network
  15. network = tl.layers.InputLayer(x, name='input_layer')
  16. network = tl.layers.DropoutLayer(network, keep=0.8, name='drop1')
  17. network = tl.layers.DenseLayer(network, n_units=800,
  18. act = tf.nn.relu, name='relu1')
  19. network = tl.layers.DropoutLayer(network, keep=0.5, name='drop2')
  20. network = tl.layers.DenseLayer(network, n_units=800,
  21. act = tf.nn.relu, name='relu2')
  22. network = tl.layers.DropoutLayer(network, keep=0.5, name='drop3')
  23. # the softmax is implemented internally in tl.cost.cross_entropy(y, y_, 'cost') to
  24. # speed up computation, so we use identity here.
  25. # see tf.nn.sparse_softmax_cross_entropy_with_logits()
  26. network = tl.layers.DenseLayer(network, n_units=10,
  27. act = tf.identity,
  28. name='output_layer')
  29. # define cost function and metric.
  30. y = network.outputs
  31. cost = tl.cost.cross_entropy(y, y_, 'cost')
  32. correct_prediction = tf.equal(tf.argmax(y, 1), y_)
  33. acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  34. y_op = tf.argmax(tf.nn.softmax(y), 1)
  35.  
  36. # define the optimizer
  37. train_params = network.all_params
  38. train_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999,
  39. epsilon=1e-08, use_locking=False).minimize(cost, var_list=train_params)
  40.  
  41. # initialize all variables in the session
  42. tl.layers.initialize_global_variables(sess)
  43.  
  44. # print network information
  45. network.print_params()
  46. network.print_layers()
  47.  
  48. # train the network
  49. tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,
  50. acc=acc, batch_size=500, n_epoch=500, print_freq=5,
  51. X_val=X_val, y_val=y_val, eval_train=False)
  52.  
  53. # evaluation
  54. tl.utils.test(sess, network, acc, X_test, y_test, x, y_, batch_size=None, cost=cost)
  55.  
  56. # save the network to .npz file
  57. tl.files.save_npz(network.all_params , name='model.npz')
  58. sess.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement