Advertisement
Guest User

Untitled

a guest
Jan 16th, 2020
411
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.97 KB | None | 0 0
  1. import numpy as np
  2. import mxnet as mx
  3. import random
  4. import time
  5.  
  6. from multiprocessing import cpu_count
  7. from mxnet import autograd as ag
  8. from mxnet import nd
  9. from mxnet.metric import Accuracy
  10. from mxnet.gluon import Block, Trainer
  11. from mxnet.gluon.data import DataLoader
  12. from mxnet.gluon.data.vision import MNIST
  13. from mxnet.gluon.loss import SoftmaxCrossEntropyLoss
  14. from mxnet.gluon.nn import Conv2D, Dense, Dropout, Flatten, MaxPool2D, HybridBlock
  15. from mxnet.gluon.utils import split_and_load
  16.  
  17.  
  18. BATCH_SIZE_PER_REPLICA = 512
  19. BATCH_SIZE = BATCH_SIZE_PER_REPLICA * 1
  20. NUM_CLASSES = 10
  21. EPOCHS = 10
  22. GPU_COUNT = 2
  23.  
  24.  
  25. class Model(HybridBlock):
  26.     def __init__(self, **kwargs):
  27.         super(Model, self).__init__(**kwargs)
  28.         with self.name_scope():
  29.             self.conv1 = Conv2D(32, (3, 3))
  30.             self.conv2 = Conv2D(64, (3, 3))
  31.             self.pool = MaxPool2D(pool_size=(2, 2))
  32.             self.dropout1 = Dropout(0.25)
  33.             self.flatten = Flatten()
  34.             self.dense1 = Dense(128)
  35.             self.dropout2 = Dropout(0.5)
  36.             self.dense2 = Dense(NUM_CLASSES)
  37.  
  38.     def hybrid_forward(self, F, x):
  39.         x = F.relu(self.conv1(x))
  40.         x = F.relu(self.conv2(x))
  41.         x = self.pool(x)
  42.         x = self.dropout1(x)
  43.         x = self.flatten(x)
  44.         x = F.relu(self.dense1(x))
  45.         x = self.dropout2(x)
  46.         x = self.dense2(x)
  47.         return x
  48.  
  49.  
  50. def transform(data, label):
  51.     return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
  52.  
  53.  
  54. def data_loader(train, batch_size, num_workers):
  55.     dataset = MNIST(train=train, transform=transform)
  56.     return DataLoader(dataset, batch_size, shuffle=train, num_workers=num_workers)
  57.  
  58.  
  59. mx.random.seed(42)
  60. random.seed(42)
  61.  
  62. train_data = data_loader(train=True, batch_size=BATCH_SIZE, num_workers=cpu_count())
  63. test_data = data_loader(train=False, batch_size=BATCH_SIZE, num_workers=cpu_count())
  64.  
  65. model = Model()
  66. model.hybridize(static_alloc=True, static_shape=True)
  67.  
  68. ctx = [mx.gpu(i) for i in range(GPU_COUNT)]
  69.  
  70. # optimizer
  71. opt_params={'learning_rate':0.001, 'beta1':0.9, 'beta2':0.999, 'epsilon':1e-08}
  72. opt = mx.optimizer.create('adam', **opt_params)
  73. # Initialize parameters randomly
  74. model.initialize(force_reinit=True, ctx=ctx)
  75. # fetch and broadcast parameters
  76. params = model.collect_params()
  77. # trainer
  78. trainer = Trainer(params=params,
  79.                   optimizer=opt,
  80.                   kvstore='device')
  81. loss_fn = SoftmaxCrossEntropyLoss()
  82. metric = Accuracy()
  83.  
  84. start = time.perf_counter()
  85. for epoch in range(EPOCHS):
  86.     tick = time.time()
  87.     for i, (data, label) in enumerate(train_data):
  88.         if i == 0:
  89.             tick_0 = time.time()
  90.         data = split_and_load(data, ctx_list=ctx, batch_axis=0)
  91.         label = split_and_load(label, ctx_list=ctx, batch_axis=0)
  92.         output = []
  93.         losses = []
  94.         with ag.record():
  95.             for x, y in zip(data, label):
  96.                 z = model(x)
  97.                 # computes softmax cross entropy loss
  98.                 l = loss_fn(z, y)
  99.                 output.append(z)
  100.                 losses.append(l)
  101.         # backpropagate the error for one iteration.
  102.         for l in losses:
  103.             l.backward()
  104.         # Update network weights
  105.         trainer.step(BATCH_SIZE)
  106.         # Update metric
  107.         metric.update(label, output)
  108.     str1 = 'Epoch [{}], Accuracy {:.4f}'.format(epoch, metric.get()[1])
  109.     str2 = '~Samples/Sec {:.4f}'.format(BATCH_SIZE*(i+1)/(time.time()-tick_0))
  110.     print('%s  %s' % (str1, str2))
  111.     metric.reset()
  112.  
  113. elapsed = time.perf_counter() - start
  114. print('elapsed: {:0.3f}'.format(elapsed))
  115.  
  116. # use Accuracy as the evaluation metric
  117. metric = Accuracy()
  118. for data, label in test_data:
  119.     data = split_and_load(data, ctx_list=ctx, batch_axis=0)
  120.     label = split_and_load(label, ctx_list=ctx, batch_axis=0)
  121.     outputs = []
  122.     for x in data:
  123.         outputs.append(model(x))
  124.     metric.update(label, outputs)
  125. print('validation %s=%f' % metric.get())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement