Guest User

Untitled

a guest
May 31st, 2016
42
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.27 KB | None | 0 0
  1. def prepare_data(self, input_data):
  2.         """
  3.        Prepare data for model training
  4.        """
  5.        
  6.         n_img = len(input_data)
  7.  
  8.         if not os.path.exists(self.data_dir):
  9.             os.makedirs(self.data_dir)
  10.  
  11.         n_tree = self.options["n_tree"]
  12.         n_pos = self.options["n_pos"]
  13.         n_neg = self.options["n_neg"]
  14.         fraction = self.options["fraction"]
  15.         p_size = self.options["p_size"]
  16.         g_size = self.options["g_size"]
  17.         shrink = self.options["shrink"]
  18.         #radius of image patches and ground truth pathces?
  19.         p_rad, g_rad = p_size / 2, g_size / 2
  20.         #get the features dimensions, based on color channels and gradients
  21.         n_ftr_dim = N.sum(self.get_ftr_dim())
  22.         #only take a portion of the features
  23.         n_smp_ftr_dim = int(n_ftr_dim * fraction)
  24.         rand = self.rand
  25.  
  26.         for i in xrange(n_tree):
  27.             data_file = self.data_prefix + str(i + 1) + ".h5"
  28.             data_path = os.path.join(self.data_dir, data_file)
  29.             if os.path.exists(data_path):
  30.                 print "Found Data %d '%s', reusing..." % ((i + 1), data_file)
  31.                 continue
  32.  
  33.             #allocate memory to store the features
  34.             ftrs = N.zeros((n_pos + n_neg, n_smp_ftr_dim), dtype=N.float32)
  35.             #allocate memory to store the ground truth regions?
  36.             lbls = N.zeros((n_pos + n_neg, g_size, g_size), dtype=N.int32)
  37.             #generate random features index from the whole features
  38.             sids = rand.permutation(n_ftr_dim)[:n_smp_ftr_dim]
  39.             total = 0
  40.  
  41.             #input data has three info, img is the data of image
  42.             #bnds is the data of boundary, segs is the data of segmentation
  43.             for j, (img, bnds, segs) in enumerate(input_data):
  44.                 mask = N.zeros(bnds[0].shape, dtype=bnds[0].dtype) 
  45.                 #set mask value to 1 by shrink step at x,y direction               
  46.                 mask[::shrink, ::shrink] = 1
  47.                 #set first p_rad rows and last p_rad rows as zeros
  48.                 mask[:p_rad] = mask[-p_rad:] = 0
  49.                 #set first p_rad cols and last p_rad cols as zeros
  50.                 mask[:, :p_rad] = mask[:, -p_rad:] = 0
  51.  
  52.                 #number of positive per ground truth
  53.                 n_pos_per_gt = int(ceil(float(n_pos) / n_img / len(bnds)))
  54.                 #number of negative per ground truth
  55.                 n_neg_per_gt = int(ceil(float(n_neg) / n_img / len(bnds)))
  56.  
  57.                 for k, boundary in enumerate(bnds):
  58.                     #do euclidean transform, boundary == 0 is the locations of non boundary pixels?
  59.                     #looks like armadillo do not have distance transform algorithm, do mlpack provided one?
  60.                     #I can implement a simple one if you like
  61.                     dis = distance_transform_edt(boundary == 0)
  62.  
  63.                     #if the distance less than g_rad, it is positive location
  64.                     #because the euclidean distance is close to the boundary pixels
  65.                     pos_loc = ((dis < g_rad) * mask).nonzero()
  66.                     pos_loc = zip(pos_loc[0].tolist(), pos_loc[1].tolist())
  67.                     #shuffle pos_loc
  68.                     pos_loc = [pos_loc[item] for item in
  69.                                rand.permutation(len(pos_loc))[:n_pos_per_gt]]
  70.                    
  71.                     #get negative location
  72.                     neg_loc = ((dis >= g_rad) * mask).nonzero()
  73.                     neg_loc = zip(neg_loc[0].tolist(), neg_loc[1].tolist())
  74.                     neg_loc = [neg_loc[item] for item in
  75.                                rand.permutation(len(neg_loc))[:n_neg_per_gt]]
  76.  
  77.                     #add positive location and negative location together?why
  78.                     loc = pos_loc + neg_loc
  79.                     #randomize the location
  80.                     n_loc = min(len(loc), ftrs.shape[0] - total)
  81.                     loc = [loc[item] for item in rand.permutation(len(loc))[:n_loc]]
  82.                     if n_loc == 0:
  83.                         continue
  84.  
  85.                     #get the features from the img and locations
  86.                     ftr = N.concatenate(self.get_features(img, loc), axis=1)
  87.                     assert ftr.shape[1] == n_ftr_dim
  88.                     ftr = ftr[:, sids]
  89.                                        
  90.                     lbl = N.zeros((ftr.shape[0], g_size, g_size), dtype=N.int8)
  91.                     for m, (x, y) in enumerate(loc):
  92.                         #get the ground truth segmentation of boundary k, (x,y) is the center location
  93.                         sub = segs[k][x - g_rad: x + g_rad, y - g_rad: y + g_rad]
  94.                         #store the unique, invertable index, why?                      
  95.                         sub = N.unique(sub, return_inverse=True)[1]
  96.                         lbl[m] = sub.reshape((g_size, g_size))
  97.  
  98.                     #ftrs store the features
  99.                     ftrs[total: total + n_loc] = ftr
  100.                     #lbls store the invertable index
  101.                     lbls[total: total + n_loc] = lbl
  102.                     total += n_loc
  103.  
  104.                 sys.stdout.write("Processing Data %d: %d/%d\r" % (i + 1, j + 1, n_img))
  105.                 sys.stdout.flush()
  106.             print
  107.  
  108.             #write the results
  109.             with tables.open_file(data_path, "w", filters=self.comp_filt) as dfile:
  110.                 dfile.create_carray("/", "ftrs", obj=ftrs[:total])
  111.                 dfile.create_carray("/", "lbls", obj=lbls[:total])
  112.                 dfile.create_carray("/", "sids", obj=sids.astype(N.int32))
  113.             print "Saving %d samples to '%s'..." % (total, data_file)
Add Comment
Please, Sign In to add comment