Guest User

Untitled

a guest
May 22nd, 2018
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 19.32 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import tensorflow as tf
  4. import tensorflow.contrib.slim as slim
  5. from sklearn.utils import shuffle
  6. from scipy import ndimage
  7.  
  8. def augment_img(_imgVec,_imgSize):
  9. # Reshape to image
  10. imgVecAug = np.copy(_imgVec)
  11. n = _imgVec.shape[0]
  12. imgs = np.reshape(_imgVec,[n]+_imgSize)
  13. for i in range(n):
  14. cImg = imgs[i,:,:,:] # Current img
  15. # Rotate
  16. angle = np.random.randint(-20,20,1)
  17. cImg = ndimage.rotate(cImg,angle,reshape=False
  18. ,mode='reflect',prefilter=True,order=1)
  19. # Flip
  20. if np.random.rand()>0.5: cImg = np.fliplr(cImg)
  21. # Shift
  22. shift = np.random.randint(-3,3,3);shift[2]=0
  23. cImg = ndimage.shift(cImg,shift,mode='reflect')
  24. # Append
  25. imgVecAug[i,:] = np.reshape(cImg,[1,-1])
  26. imgVecAug = np.clip(imgVecAug,a_min=0.0,a_max=1.0)
  27. return imgVecAug
  28.  
  29.  
  30. class cnn_cls_class(object):
  31. def __init__(self,_name='basic_cnn',_xdim=[28,28,1],_ydim=10,_hdims=[64,64],_filterSizes=[3,3],_max_pools=[2,2]
  32. ,_feat_dim=128,_actv=tf.nn.relu,_bn=slim.batch_norm
  33. ,_l2_reg_coef=1e-5
  34. ,_momentum = 0.5
  35. ,_USE_INPUT_BN=False,_USE_RESNET=False,_USE_GAP=False,_USE_SGD=False
  36. ,_USE_MIXUP=False
  37. ,_GPU_ID=0,_VERBOSE=True):
  38. self.name = _name
  39. self.xdim = _xdim
  40. self.ydim = _ydim
  41. self.hdims = _hdims
  42. self.filterSizes = _filterSizes
  43. self.max_pools = _max_pools
  44. self.feat_dim = _feat_dim
  45. self.actv = _actv
  46. self.bn = _bn
  47. self.l2_reg_coef = _l2_reg_coef
  48. self.momentum = _momentum
  49. self.USE_INPUT_BN = _USE_INPUT_BN
  50. self.USE_RESNET = _USE_RESNET
  51. self.USE_GAP = _USE_GAP
  52. self.USE_SGD = _USE_SGD
  53. self.USE_MIXUP = _USE_MIXUP
  54. self.GPU_ID = (int)(_GPU_ID)
  55. self.VERBOSE = _VERBOSE
  56. with tf.device('/device:GPU:%d'%(self.GPU_ID)):
  57. # Build model
  58. self.build_model()
  59. # Build graph
  60. self.build_graph()
  61. # Check parameters
  62. self.check_params()
  63.  
  64. def build_model(self):
  65. # Set placeholders
  66. _xdim = self.xdim[0]*self.xdim[1]*self.xdim[2] # Total dimension
  67. self.x = tf.placeholder(dtype=tf.float32,shape=[None,_xdim]) # Input [N x xdim]
  68. self.t = tf.placeholder(dtype=tf.float32,shape=[None,self.ydim]) # Output [N x D]
  69. self.kp = tf.placeholder(dtype=tf.float32,shape=[]) # []
  70. self.is_training = tf.placeholder(dtype=tf.bool,shape=[]) # []
  71. self.lr = tf.placeholder(dtype=tf.float32,shape=[]) # []
  72. self.bn_init = {'beta': tf.constant_initializer(0.),
  73. 'gamma': tf.random_normal_initializer(1., 0.01)}
  74. batch_norm_params = {'is_training':self.is_training,'decay':0.9,'updates_collections': None}
  75.  
  76. with tf.variable_scope(self.name,reuse=False) as scope:
  77.  
  78. # List of features
  79. self.layers = []
  80. self.layers.append(self.x)
  81.  
  82. # Reshape input
  83. _net = tf.reshape(self.x,[-1]+self.xdim)
  84. self.layers.append(_net)
  85.  
  86. # Input normalization
  87. if self.USE_INPUT_BN:
  88. _net = slim.batch_norm(_net,param_initializers=self.bn_init,is_training=self.is_training,updates_collections=None)
  89.  
  90. # Convolution layers
  91. for hidx,hdim in enumerate(self.hdims):
  92. fs = self.filterSizes[hidx]
  93. if self.USE_RESNET: # Use residual connection
  94. cChannelSize = _net.get_shape()[3] # Current channel size
  95. if cChannelSize == hdim:
  96. _identity = _net
  97. else: # Expand dimension if required
  98. _identity = slim.conv2d(_net,hdim,[1,1],padding='SAME',activation_fn=None
  99. , weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
  100. , normalizer_fn = self.bn
  101. , normalizer_params = batch_norm_params
  102. , scope='identity_%d'%(hidx))
  103. # First conv
  104. _net = slim.conv2d(_net,hdim,[fs,fs],padding='SAME'
  105. , activation_fn = None
  106. , weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
  107. , normalizer_fn = self.bn
  108. , normalizer_params = batch_norm_params
  109. , scope='res_a_%d'%(hidx))
  110. # Relu
  111. _net = self.actv(_net)
  112. self.layers.append(_net) # Append to list
  113. # Second conv
  114. _net = slim.conv2d(_net,hdim,[fs,fs],padding='SAME'
  115. , activation_fn = None
  116. , weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
  117. , normalizer_fn = self.bn
  118. , normalizer_params = batch_norm_params
  119. , scope='res_b_%d'%(hidx))
  120. # Skip connection
  121. _net = _net + _identity
  122. # Relu
  123. _net = self.actv(_net)
  124. self.layers.append(_net) # Append to list
  125. else:
  126. _net = slim.conv2d(_net,hdim,[fs,fs],padding='SAME'
  127. , activation_fn = self.actv
  128. , weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
  129. , normalizer_fn = self.bn
  130. , normalizer_params = batch_norm_params
  131. , scope='conv_%d'%(hidx))
  132. self.layers.append(_net) # Append to list
  133. # Max pooling (if required)
  134. max_pool = self.max_pools[hidx]
  135. if max_pool > 1:
  136. _net = slim.max_pool2d(_net,[max_pool,max_pool],scope='pool_%d'%(hidx))
  137. self.layers.append(_net) # Append to list
  138.  
  139. # Global average pooling
  140. if self.USE_GAP:
  141. _net = tf.reduce_mean(_net,[1,2])
  142. self.layers.append(_net) # Append to list
  143. # Feature
  144. self.feat = _net # [N x Q]
  145. else:
  146. # Flatten and output
  147. _net = slim.flatten(_net, scope='flatten')
  148. self.layers.append(_net) # Append to list
  149. # Dense
  150. _net = slim.fully_connected(_net,self.feat_dim,scope='fc')
  151. self.layers.append(_net) # Append to list
  152. # Feature
  153. self.feat = _net # [N x Q]
  154.  
  155. # Dropout at the last layer
  156. _net = slim.dropout(_net, keep_prob=self.kp,is_training=self.is_training,scope='dropout')
  157. _out = slim.fully_connected(_net,self.ydim,activation_fn=None,normalizer_fn=None, scope='out')# [N x D]
  158. self.layers.append(_out) # Append to list
  159. self.out = _out
  160.  
  161. # Build graph
  162. def build_graph(self):
  163. # Cross-entropy loss
  164. self._loss_ce = tf.nn.softmax_cross_entropy_with_logits(labels=self.t,logits=self.out) # [N]
  165. self.loss_ce = tf.reduce_mean(self._loss_ce) # []
  166. # Weight decay regularizer
  167. _g_vars = tf.global_variables()
  168. _c_vars = [var for var in _g_vars if '%s/'%(self.name) in var.name]
  169. self.l2_reg = self.l2_reg_coef*tf.reduce_sum(tf.stack([tf.nn.l2_loss(v) for v in _c_vars])) # []
  170. # Total loss
  171. self.loss_total = self.loss_ce + self.l2_reg
  172. if self.USE_SGD:
  173. # self.optm = tf.train.GradientDescentOptimizer(learning_rate=self.lr).minimize(self.loss_total)
  174. self.optm = tf.train.MomentumOptimizer(learning_rate=self.lr,momentum=self.momentum).minimize(self.loss_total)
  175. else:
  176. self.optm = tf.train.AdamOptimizer(learning_rate=self.lr
  177. ,beta1=0.9,beta2=0.999,epsilon=1e-6).minimize(self.loss_total)
  178. # Accuracy
  179. _corr = tf.equal(tf.argmax(self.out, 1), tf.argmax(self.t, 1))
  180. self.accr = tf.reduce_mean(tf.cast(_corr,tf.float32))
  181.  
  182. # Check parameters
  183. def check_params(self):
  184. _g_vars = tf.global_variables()
  185. self.g_vars = [var for var in _g_vars if '%s/'%(self.name) in var.name]
  186. if self.VERBOSE:
  187. print ("==== Global Variables ====")
  188. for i in range(len(self.g_vars)):
  189. w_name = self.g_vars[i].name
  190. w_shape = self.g_vars[i].get_shape().as_list()
  191. if self.VERBOSE:
  192. print (" [%02d] Name:[%s] Shape:[%s]" % (i,w_name,w_shape))
  193. # Print layers
  194. if self.VERBOSE:
  195. print ("====== Layers ======")
  196. nLayers = len(self.layers)
  197. for i in range(nLayers):
  198. print ("[%02d/%d] %s %s"%(i,nLayers,self.layers[i].name,self.layers[i].shape))
  199.  
  200. # Saver
  201. def save(self,_sess,_savename=None):
  202. if _savename==None:
  203. _savename='../net/net_%s.npz'%(self.name)
  204. # Get global variables
  205. self.g_wnames,self.g_wvals,self.g_wshapes = [],[],[]
  206. for i in range(len(self.g_vars)):
  207. curr_wname = self.g_vars[i].name
  208. curr_wvar = [v for v in tf.global_variables() if v.name==curr_wname][0]
  209. curr_wval = _sess.run(curr_wvar)
  210.  
  211. curr_wval_sqz = curr_wval
  212. # curr_wval_sqz = curr_wval.squeeze() # ???
  213. curr_wval_sqz = np.asanyarray(curr_wval_sqz,order=(1,-1))
  214.  
  215. self.g_wnames.append(curr_wname)
  216. self.g_wvals.append(curr_wval_sqz)
  217. self.g_wshapes.append(curr_wval.shape)
  218. # Save
  219. np.savez(_savename,g_wnames=self.g_wnames,g_wvals=self.g_wvals,g_wshapes=self.g_wshapes)
  220. if self.VERBOSE:
  221. print ("[%s] Saved. Size is [%.4f]MB" %
  222. (_savename,os.path.getsize(_savename)/1000./1000.))
  223.  
  224. # Restore
  225. def restore(self,_sess,_loadname=None):
  226. if _loadname==None:
  227. _loadname='../net/net_%s.npz'%(self.name)
  228. l = np.load(_loadname)
  229. g_wnames = l['g_wnames']
  230. g_wvals = l['g_wvals']
  231. g_wshapes = l['g_wshapes']
  232. for widx,wname in enumerate(g_wnames):
  233. curr_wvar = [v for v in tf.global_variables() if v.name==wname][0]
  234. _sess.run(tf.assign(curr_wvar,g_wvals[widx].reshape(g_wshapes[widx])))
  235. if self.VERBOSE:
  236. print ("Weight restored from [%s] Size is [%.4f]MB" %
  237. (_loadname,os.path.getsize(_loadname)/1000./1000.))
  238.  
  239. # Train
  240. def train(self,_sess,_trainimg,_trainlabel,_testimg,_testlabel,_valimg,_vallabel
  241. ,_maxEpoch=10,_batchSize=256,_lr=1e-3,_kp=0.9
  242. ,_LR_SCHEDULE=False,_PRINT_EVERY=10,_SAVE_BEST=True,_DO_AUGMENTATION=False,_VERBOSE_TRAIN=True):
  243. tf.set_random_seed(0)
  244. nTrain,nVal,nTest = _trainimg.shape[0],_valimg.shape[0],_testimg.shape[0]
  245. txtName = ('../res/res_%s.txt'%(self.name))
  246. f = open(txtName,'w') # Open txt file
  247. print_n_txt(_f=f,_chars='Text name: '+txtName)
  248. print_period=max(1,_maxEpoch//_PRINT_EVERY)
  249. maxIter,maxValAccr,maxTestAccr = max(nTrain//_batchSize,1),0.0,0.0
  250. for epoch in range(_maxEpoch+1): # For every epoch
  251. _trainimg,_trainlabel = shuffle(_trainimg,_trainlabel)
  252. for iter in range(maxIter): # For every iteration in one epoch
  253. start,end = iter*_batchSize,(iter+1)*_batchSize
  254. # Learning rate scheduling
  255. if _LR_SCHEDULE:
  256. if epoch < 0.5*_maxEpoch:
  257. _lr_use = _lr
  258. elif epoch < 0.75*_maxEpoch:
  259. _lr_use = _lr/2.0
  260. else:
  261. _lr_use = _lr/10.0
  262. else:
  263. _lr_use = _lr
  264. if _DO_AUGMENTATION:
  265. trainImgBatch = augment_img(_trainimg[start:end,:],self.xdim)
  266. else:
  267. trainImgBatch = _trainimg[start:end,:]
  268. if self.USE_MIXUP:
  269. xBatch = trainImgBatch
  270. tBatch = _trainlabel[start:end,:]
  271. xBatch,tBatch = mixup(xBatch,tBatch,32)
  272. else:
  273. xBatch = trainImgBatch
  274. tBatch = _trainlabel[start:end,:]
  275. feeds = {self.x:xBatch,self.t:tBatch
  276. ,self.kp:_kp,self.lr:_lr_use,self.is_training:True}
  277. _sess.run(self.optm,feed_dict=feeds)
  278. # Print training losses, training accuracy, validation accuracy, and test accuracy
  279. if (epoch%print_period)==0 or (epoch==(_maxEpoch)):
  280. batchSize4print = 512
  281. # Compute train loss and accuracy
  282. maxIter4print = max(nTrain//batchSize4print,1)
  283. trainLoss,trainAccr,nTemp = 0,0,0
  284. for iter in range(maxIter4print):
  285. start,end = iter*batchSize4print,(iter+1)*batchSize4print
  286. feeds_train = {self.x:_trainimg[start:end,:],self.t:_trainlabel[start:end,:]
  287. ,self.kp:1.0,self.is_training:False}
  288. _trainLoss,_trainAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_train)
  289. _nTemp = end-start; nTemp+=_nTemp
  290. trainLoss+=(_nTemp*_trainLoss); trainAccr+=(_nTemp*_trainAccr)
  291. trainLoss/=nTemp;trainAccr/=nTemp
  292. # Compute validation loss and accuracy
  293. maxIter4print = max(nVal//batchSize4print,1)
  294. valLoss,valAccr,nTemp = 0,0,0
  295. for iter in range(maxIter4print):
  296. start,end = iter*batchSize4print,(iter+1)*batchSize4print
  297. feeds_val = {self.x:_valimg[start:end,:],self.t:_vallabel[start:end,:]
  298. ,self.kp:1.0,self.is_training:False}
  299. _valLoss,_valAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_val)
  300. _nTemp = end-start; nTemp+=_nTemp
  301. valLoss+=(_nTemp*_valLoss); valAccr+=(_nTemp*_valAccr)
  302. valLoss/=nTemp;valAccr/=nTemp
  303. # Compute test loss and accuracy
  304. maxIter4print = max(nTest//batchSize4print,1)
  305. testLoss,testAccr,nTemp = 0,0,0
  306. for iter in range(maxIter4print):
  307. start,end = iter*batchSize4print,(iter+1)*batchSize4print
  308. feeds_test = {self.x:_testimg[start:end,:],self.t:_testlabel[start:end,:]
  309. ,self.kp:1.0,self.is_training:False}
  310. _testLoss,_testAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_test)
  311. _nTemp = end-start; nTemp+=_nTemp
  312. testLoss+=(_nTemp*_testLoss); testAccr+=(_nTemp*_testAccr)
  313. testLoss/=nTemp;testAccr/=nTemp
  314. # Compute max val accr
  315. if valAccr > maxValAccr:
  316. maxValAccr = valAccr
  317. maxTestAccr = testAccr
  318. if _SAVE_BEST: self.save(_sess)
  319. strTemp = (("[%02d/%d] [Loss] train:%.3f val:%.3f test:%.3f"
  320. +" [Accr] train:%.1f%% val:%.1f%% test:%.1f%% maxVal:%.1f%% maxTest:%.1f%%")
  321. %(epoch,_maxEpoch,trainLoss,valLoss,testLoss
  322. ,trainAccr*100,valAccr*100,testAccr*100,maxValAccr*100,maxTestAccr*100))
  323. print_n_txt(_f=f,_chars=strTemp,_DO_PRINT=_VERBOSE_TRAIN)
  324. # Done
  325. print ("Training finished.")
  326.  
  327. # Test
  328. def test(self,_sess,_trainimg,_trainlabel,_testimg,_testlabel,_valimg,_vallabel):
  329. nTrain,nVal,nTest = _trainimg.shape[0],_valimg.shape[0],_testimg.shape[0]
  330. # Check accuracies (train, val, and test)
  331. batchSize4print = 512
  332. # Compute train loss and accuracy
  333. maxIter4print = max(nTrain//batchSize4print,1)
  334. trainLoss,trainAccr,nTemp = 0,0,0
  335. for iter in range(maxIter4print):
  336. start,end = iter*batchSize4print,(iter+1)*batchSize4print
  337. feeds_train = {self.x:_trainimg[start:end,:],self.t:_trainlabel[start:end,:]
  338. ,self.kp:1.0,self.is_training:False}
  339. _trainLoss,_trainAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_train)
  340. _nTemp = end-start; nTemp+=_nTemp
  341. trainLoss+=(_nTemp*_trainLoss); trainAccr+=(_nTemp*_trainAccr)
  342. trainLoss/=nTemp;trainAccr/=nTemp
  343. # Compute validation loss and accuracy
  344. maxIter4print = max(nVal//batchSize4print,1)
  345. valLoss,valAccr,nTemp = 0,0,0
  346. for iter in range(maxIter4print):
  347. start,end = iter*batchSize4print,(iter+1)*batchSize4print
  348. feeds_val = {self.x:_valimg[start:end,:],self.t:_vallabel[start:end,:]
  349. ,self.kp:1.0,self.is_training:False}
  350. _valLoss,_valAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_val)
  351. _nTemp = end-start; nTemp+=_nTemp
  352. valLoss+=(_nTemp*_valLoss); valAccr+=(_nTemp*_valAccr)
  353. valLoss/=nTemp;valAccr/=nTemp
  354. # Compute test loss and accuracy
  355. maxIter4print = max(nTest//batchSize4print,1)
  356. testLoss,testAccr,nTemp = 0,0,0
  357. for iter in range(maxIter4print):
  358. start,end = iter*batchSize4print,(iter+1)*batchSize4print
  359. feeds_test = {self.x:_testimg[start:end,:],self.t:_testlabel[start:end,:]
  360. ,self.kp:1.0,self.is_training:False}
  361. _testLoss,_testAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_test)
  362. _nTemp = end-start; nTemp+=_nTemp
  363. testLoss+=(_nTemp*_testLoss); testAccr+=(_nTemp*_testAccr)
  364. testLoss/=nTemp;testAccr/=nTemp
  365. strTemp = (("[%s] [Loss] train:%.3f val:%.3f test:%.3f"
  366. +" [Accr] train:%.3f%% val:%.3f%% test:%.3f%%")
  367. %(self.name,trainLoss,valLoss,testLoss,trainAccr*100,valAccr*100,testAccr*100))
  368. print(strTemp)
  369.  
  370. if __name__ == "__main__":
  371. xdim,ydim,hdims,filterSizes,max_pools,feat_dim \
  372. = [32,32,3],10,[64,64,64,64,128,128,128],[3,3,3,3,3,3,3],[1,1,1,2,1,1,2],256
  373. actv,bn,VERBOSE = tf.nn.relu,slim.batch_norm,False
  374. maxEpoch,batchSize,lr_base = 200,128,1e-1
  375. USE_RESNET,USE_GAP,USE_SGD = True,True,True
  376. tf.reset_default_graph(); tf.set_random_seed(0)
  377. CNN = cnn_cls_class(_name=('cifar10_%s_err%.0f_cnn'%(errType,outlierRatio*100))
  378. ,_xdim=xdim,_ydim=ydim,_hdims=hdims,_filterSizes=filterSizes
  379. ,_max_pools=max_pools,_feat_dim=feat_dim
  380. ,_actv=actv,_bn=bn,_l2_reg_coef=1e-6
  381. ,_USE_RESNET=USE_RESNET,_USE_GAP=USE_GAP,_USE_SGD=USE_SGD,_VERBOSE=VERBOSE)
  382. sess = gpusession(); sess.run(tf.global_variables_initializer())
  383. CNN.train(_sess=sess,_trainimg=trainimg,_trainlabel=trainlabel
  384. ,_testimg=testimg,_testlabel=testlabel,_valimg=valimg,_vallabel=vallabel
  385. ,_maxEpoch=maxEpoch,_batchSize=batchSize,_lr=lr_base
  386. ,_LR_SCHEDULE=True,_PRINT_EVERY=100,_SAVE_BEST=True,_DO_AUGMENTATION=True)
  387. sess.close()
  388. print ("Class defined.")
Add Comment
Please, Sign In to add comment