Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- coding: utf-8
- # In[1]:
- import logging
- logging.getLogger().setLevel(logging.INFO)
- import numpy as np
- from tensorflow.examples.tutorials.mnist import input_data
- import mxnet as mx
- # In[2]:
- N_GPUS = 8
- mb_size = 128*N_GPUS
- # In[3]:
- mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
- # In[4]:
- def to4d(img):
- """
- reshape to 4D arrays
- """
- return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
- def get_mnist_iter():
- """
- create data iterator with NDArrayIter
- """
- (train_lbl, train_img) = (mnist.train.labels, mnist.train.images)
- (val_lbl, val_img) = (mnist.test.labels, mnist.test.images)
- train = mx.io.NDArrayIter(
- to4d(train_img), train_lbl, mb_size, shuffle=True)
- val = mx.io.NDArrayIter(
- to4d(val_img), val_lbl, mb_size)
- return (train, val)
- # In[5]:
- ctx = [mx.gpu(i) for i in range(N_GPUS)]
- # In[6]:
- def fc_simple():
- data = mx.sym.Variable('data')
- fc1 = mx.sym.FullyConnected(data, name='fc1', num_hidden=128)
- act1 = mx.sym.Activation(fc1, name='relu1', act_type="relu")
- fc2 = mx.sym.FullyConnected(act1, name='fc2', num_hidden=10)
- fc_do = mx.sym.Dropout(fc2, p=0.5)
- out = mx.sym.SoftmaxOutput(fc_do, name = 'softmax')
- return out
- # In[7]:
- def conv():
- data = mx.sym.Variable('data')
- conv1 = mx.sym.Convolution(
- data,
- name='conv1',
- kernel=(5,5),
- stride=(1,1),
- num_filter=128,
- )
- act1 = mx.sym.LeakyReLU(conv1, name='lrelu1')
- conv2 = mx.sym.Convolution(
- mx.sym.BatchNorm(act1),
- name='conv2',
- kernel=(5,5),
- stride=(2,2),
- num_filter=256,
- )
- act2 = mx.sym.LeakyReLU(conv2, name='lrelu2')
- conv3 = mx.sym.Convolution(
- mx.sym.BatchNorm(act2),
- name='conv3',
- kernel=(5,5),
- stride=(2,2),
- num_filter=512,
- )
- act3 = mx.sym.LeakyReLU(conv3, name='lrelu3')
- fc = mx.sym.FullyConnected(act3, name='fc', num_hidden=10)
- fc_do = mx.sym.Dropout(fc, p=0.5)
- out = mx.sym.SoftmaxOutput(fc_do, name = 'softmax')
- return out
- # In[10]:
- out = fc_simple()
- mx.viz.plot_network(out)
- # In[11]:
- train, val = get_mnist_iter()
- # In[12]:
- mod = mx.module.Module(out, context=ctx)
- # In[13]:
- mod.bind(
- data_shapes=train.provide_data,
- label_shapes=train.provide_label
- ) # create memory by given input shapes
- # In[14]:
- # initialize parameters by uniform random numbers
- mod.init_params(initializer=mx.initializer.Xavier())
- mod.init_optimizer(optimizer='adam')
- # In[ ]:
- mod.fit(
- train,
- num_epoch=10,
- eval_metric='acc',
- )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement