Advertisement
Guest User

Untitled

a guest
Oct 23rd, 2019
118
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.71 KB | None | 0 0
  1. !pip install mxnet
  2. !pip install opencv-python
  3. !pip install matplotlib
  4.  
  5. import mxnet as mx
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from collections import namedtuple
  9.  
  10. #test that MXnet works on hosted runtime (I said this to make me sound smart)
  11.  
  12. test_images = []
  13. for i in range(1,10):
  14. x_data = np.random.randint(0,255,[3,256,256])
  15. test_images.append(x_data)
  16.  
  17. test_labels = np.ones(10)
  18.  
  19. '''
  20. #Define iterator class
  21. class RandIter(mx.io.DataIter):
  22.  
  23. def __init__ (self,batch_size,ndim):
  24. self.batch_size = batch_size
  25. self.ndim = ndim
  26. self.provide_data = [('data'),(batch_size,ndim,256,256)]
  27. self.provide_label = []
  28.  
  29. def iter_next(self):
  30. return True
  31.  
  32. def getdata(self):
  33. return [mx.random.normal(0,1,shape=(self.batch_size,self.ndim,256,256))]
  34.  
  35. x = RandIter(1,3)
  36. '''
  37.  
  38. Batch = namedtuple('Batch', ['data'])
  39.  
  40. test_batch = Batch(test)
  41.  
  42. test = mx.io.NDArrayIter(test_images,label=test_labels,batch_size=10)
  43.  
  44. test_batch = Batch(test)
  45. #rand_iter = RandIter(batch_size=10,ndim=3)
  46.  
  47. #Define Network Architecture
  48. data = mx.sym.Variable('data')
  49. g1 = mx.sym.Convolution(data, name = 'CONV1',kernel=(1,1),num_filter=1000)
  50. g2 = mx.sym.Convolution(g1,name = 'Conv2', kernel=(1,1),num_filter=3)
  51.  
  52. conv_model = mx.mod.Module(symbol=g2,data_names=('data',),label_names=None,context=mx.cpu())
  53.  
  54. conv_model.bind(for_training=True,data_shapes=[('data',(1,3,256,256))])
  55. conv_model.init_params(initializer=mx.init.Xavier(magnitude=2.))
  56. conv_model.init_optimizer(
  57. optimizer='adam',
  58. optimizer_params= {'learning_rate': 0.0002,'beta1':0.5})
  59.  
  60. conv_model.forward(test_batch)
  61. z = conv_model.get_outputs()
  62.  
  63. x = z[0].asnumpy()
  64. x = x.reshape(3,256,256)
  65.  
  66. #plt.imshow(x)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement