SHARE
TWEET

Untitled

a guest Oct 23rd, 2019 83 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. !pip install mxnet
  2. !pip install opencv-python
  3. !pip install matplotlib
  4.  
  5. import mxnet as mx
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from collections import namedtuple
  9.  
  10. #test that MXnet works on hosted runtime (I said this to make me sound smart)
  11.  
  12. test_images = []
  13. for i in range(1,10):
  14.     x_data = np.random.randint(0,255,[3,256,256])
  15.     test_images.append(x_data)
  16.        
  17. test_labels = np.ones(10)
  18.  
  19. '''
  20. #Define iterator class
  21. class RandIter(mx.io.DataIter):
  22.    
  23.     def __init__ (self,batch_size,ndim):
  24.         self.batch_size = batch_size
  25.         self.ndim = ndim
  26.         self.provide_data = [('data'),(batch_size,ndim,256,256)]
  27.         self.provide_label = []
  28.    
  29.     def iter_next(self):
  30.         return True
  31.    
  32.     def getdata(self):
  33.         return [mx.random.normal(0,1,shape=(self.batch_size,self.ndim,256,256))]
  34.    
  35. x = RandIter(1,3)
  36. '''
  37.  
  38. Batch = namedtuple('Batch', ['data'])
  39.  
  40. test_batch = Batch(test)
  41.  
  42. test = mx.io.NDArrayIter(test_images,label=test_labels,batch_size=10)
  43.  
  44. test_batch = Batch(test)
  45. #rand_iter = RandIter(batch_size=10,ndim=3)
  46.  
  47. #Define Network Architecture
  48. data = mx.sym.Variable('data')
  49. g1 = mx.sym.Convolution(data, name = 'CONV1',kernel=(1,1),num_filter=1000)
  50. g2 = mx.sym.Convolution(g1,name = 'Conv2', kernel=(1,1),num_filter=3)
  51.  
  52. conv_model = mx.mod.Module(symbol=g2,data_names=('data',),label_names=None,context=mx.cpu())
  53.  
  54. conv_model.bind(for_training=True,data_shapes=[('data',(1,3,256,256))])
  55. conv_model.init_params(initializer=mx.init.Xavier(magnitude=2.))
  56. conv_model.init_optimizer(
  57.             optimizer='adam',
  58.             optimizer_params= {'learning_rate': 0.0002,'beta1':0.5})
  59.  
  60. conv_model.forward(test_batch)
  61. z = conv_model.get_outputs()
  62.  
  63. x = z[0].asnumpy()
  64. x = x.reshape(3,256,256)
  65.  
  66. #plt.imshow(x)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top