Advertisement
Guest User

Untitled

a guest
Sep 20th, 2017
60
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.60 KB | None | 0 0
  1. coding: utf-8
  2.  
  3. # In[1]:
  4.  
  5.  
  6. import logging
  7. logging.getLogger().setLevel(logging.INFO)
  8.  
  9. import numpy as np
  10. from tensorflow.examples.tutorials.mnist import input_data
  11. import mxnet as mx
  12.  
  13.  
  14. # In[2]:
  15.  
  16.  
  17. N_GPUS = 8
  18. mb_size = 128*N_GPUS
  19.  
  20.  
  21. # In[3]:
  22.  
  23.  
  24. mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
  25.  
  26.  
  27. # In[4]:
  28.  
  29.  
  30. def to4d(img):
  31. """
  32. reshape to 4D arrays
  33. """
  34. return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
  35.  
  36. def get_mnist_iter():
  37. """
  38. create data iterator with NDArrayIter
  39. """
  40. (train_lbl, train_img) = (mnist.train.labels, mnist.train.images)
  41. (val_lbl, val_img) = (mnist.test.labels, mnist.test.images)
  42. train = mx.io.NDArrayIter(
  43. to4d(train_img), train_lbl, mb_size, shuffle=True)
  44. val = mx.io.NDArrayIter(
  45. to4d(val_img), val_lbl, mb_size)
  46. return (train, val)
  47.  
  48.  
  49. # In[5]:
  50.  
  51.  
  52. ctx = [mx.gpu(i) for i in range(N_GPUS)]
  53.  
  54.  
  55. # In[6]:
  56.  
  57.  
  58. def fc_simple():
  59. data = mx.sym.Variable('data')
  60. fc1 = mx.sym.FullyConnected(data, name='fc1', num_hidden=128)
  61. act1 = mx.sym.Activation(fc1, name='relu1', act_type="relu")
  62. fc2 = mx.sym.FullyConnected(act1, name='fc2', num_hidden=10)
  63. fc_do = mx.sym.Dropout(fc2, p=0.5)
  64. out = mx.sym.SoftmaxOutput(fc_do, name = 'softmax')
  65. return out
  66.  
  67.  
  68. # In[7]:
  69.  
  70.  
  71. def conv():
  72. data = mx.sym.Variable('data')
  73. conv1 = mx.sym.Convolution(
  74. data,
  75. name='conv1',
  76. kernel=(5,5),
  77. stride=(1,1),
  78. num_filter=128,
  79. )
  80. act1 = mx.sym.LeakyReLU(conv1, name='lrelu1')
  81.  
  82. conv2 = mx.sym.Convolution(
  83. mx.sym.BatchNorm(act1),
  84. name='conv2',
  85. kernel=(5,5),
  86. stride=(2,2),
  87. num_filter=256,
  88. )
  89. act2 = mx.sym.LeakyReLU(conv2, name='lrelu2')
  90.  
  91. conv3 = mx.sym.Convolution(
  92. mx.sym.BatchNorm(act2),
  93. name='conv3',
  94. kernel=(5,5),
  95. stride=(2,2),
  96. num_filter=512,
  97. )
  98. act3 = mx.sym.LeakyReLU(conv3, name='lrelu3')
  99.  
  100. fc = mx.sym.FullyConnected(act3, name='fc', num_hidden=10)
  101.  
  102. fc_do = mx.sym.Dropout(fc, p=0.5)
  103.  
  104. out = mx.sym.SoftmaxOutput(fc_do, name = 'softmax')
  105.  
  106. return out
  107.  
  108.  
  109. # In[10]:
  110.  
  111.  
  112. out = fc_simple()
  113. mx.viz.plot_network(out)
  114.  
  115.  
  116. # In[11]:
  117.  
  118.  
  119. train, val = get_mnist_iter()
  120.  
  121.  
  122. # In[12]:
  123.  
  124.  
  125. mod = mx.module.Module(out, context=ctx)
  126.  
  127.  
  128. # In[13]:
  129.  
  130.  
  131. mod.bind(
  132. data_shapes=train.provide_data,
  133. label_shapes=train.provide_label
  134. ) # create memory by given input shapes
  135.  
  136.  
  137. # In[14]:
  138.  
  139.  
  140. # initialize parameters by uniform random numbers
  141. mod.init_params(initializer=mx.initializer.Xavier())
  142. mod.init_optimizer(optimizer='adam')
  143.  
  144.  
  145. # In[ ]:
  146.  
  147.  
  148. mod.fit(
  149. train,
  150. num_epoch=10,
  151. eval_metric='acc',
  152. )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement