Guest User

training script

a guest
Apr 5th, 2017
156
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.63 KB | None | 0 0
  1. from __future__ import division
  2. import os
  3. caffe_root = '../caffe-portraitseg/'
  4. import sys
  5. sys.path.insert(0,caffe_root + 'python')
  6. import caffe
  7. import numpy as np
  8.  
  9. # make a bilinear interpolation kernel
  10. # credit @longjon
  11. def upsample_filt(size):
  12. factor = (size + 1) // 2
  13. if size % 2 == 1:
  14. center = factor - 1
  15. else:
  16. center = factor - 0.5
  17. og = np.ogrid[:size, :size]
  18. return (1 - abs(og[0] - center) / factor) * \
  19. (1 - abs(og[1] - center) / factor)
  20.  
  21. # set parameters s.t. deconvolutional layers compute bilinear interpolation
  22. # N.B. this is for deconvolution without groups
  23. def interp_surgery(net, layers):
  24. for l in layers:
  25. m, k, h, w = net.params[l][0].data.shape
  26. if m != k:
  27. print('input + output channels need to be the same')
  28. raise
  29. if h != w:
  30. print('filters need to be square')
  31. raise
  32. filt = upsample_filt(h)
  33. net.params[l][0].data[range(m), range(k), :, :] = filt
  34.  
  35. # init
  36. caffe.set_mode_gpu()
  37. caffe.set_device(0)
  38.  
  39. MODEL_FILE = './FCN8s_models/fcn-8s-pascal-deploy.prototxt'
  40. PRETRAINED = './FCN8s_models/fcn-8s-pascal.caffemodel'
  41.  
  42. net = caffe.Net(MODEL_FILE, PRETRAINED, caffe.TEST)
  43.  
  44. solverpath = './model_files/solver_portraitFCN.prototxt'
  45. solver = caffe.SGDSolver(solverpath)
  46.  
  47. # do net surgery to set the deconvolution weights for bilinear interpolation
  48. interp_layers = [k for k in solver.net.params.keys() if 'up' in k]
  49. interp_surgery(solver.net, interp_layers)
  50. # copy base weights for fine-tuning
  51. #solver.net.copy_from(base_weights)
  52. solver.net.params['conv1_1'][0].data[:,0:3:1,:,:] = net.params['conv1_1'][0].data[:,:,:,:]
  53.  
  54. layerkeys = ['conv1_2', 'conv2_1', 'conv2_2','conv3_1','conv3_2','conv3_3','conv4_1','conv4_2','conv4_3','conv5_1','conv5_2','conv5_3', 'conv1_2','fc6','fc7']
  55. for key in layerkeys:
  56. solver.net.params[key][0].data[...] = net.params[key][0].data[...]
  57.  
  58.  
  59. # also copy other weights from the net
  60. solver.net.params['score-fr'][0].data[:,:,:,:] = net.params['score-fr'][0].data[0:15:15,:,:,:]
  61.  
  62. #score2
  63. solver.net.params['score2'][0].data[:,:,:,:] = net.params['score2'][0].data[0:15:15,0:15:15,:,:]
  64. #score-pool4
  65. solver.net.params['score-pool4'][0].data[:,:,:,:] = net.params['score-pool4'][0].data[0:15:15,:,:,:]
  66. #score4
  67. solver.net.params['score4'][0].data[:,:,:,:] = net.params['score4'][0].data[0:15:15,0:15:15,:,:]
  68. #score-pool3
  69. solver.net.params['score-pool3'][0].data[:,:,:,:] = net.params['score-pool3'][0].data[0:15:15,:,:,:]
  70. #upsample
  71. solver.net.params['upsample'][0].data[:,:,:,:] = net.params['upsample'][0].data[0:15:15,0:15:15,:,:]
  72.  
  73.  
  74. solver.step(80000)
Add Comment
Please, Sign In to add comment