Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def prepare_data(self, input_data):
- """
- Prepare data for model training
- """
- n_img = len(input_data)
- if not os.path.exists(self.data_dir):
- os.makedirs(self.data_dir)
- n_tree = self.options["n_tree"]
- n_pos = self.options["n_pos"]
- n_neg = self.options["n_neg"]
- fraction = self.options["fraction"]
- p_size = self.options["p_size"]
- g_size = self.options["g_size"]
- shrink = self.options["shrink"]
- #radius of image patches and ground truth pathces?
- p_rad, g_rad = p_size / 2, g_size / 2
- #get the features dimensions, based on color channels and gradients
- n_ftr_dim = N.sum(self.get_ftr_dim())
- #only take a portion of the features
- n_smp_ftr_dim = int(n_ftr_dim * fraction)
- rand = self.rand
- for i in xrange(n_tree):
- data_file = self.data_prefix + str(i + 1) + ".h5"
- data_path = os.path.join(self.data_dir, data_file)
- if os.path.exists(data_path):
- print "Found Data %d '%s', reusing..." % ((i + 1), data_file)
- continue
- #allocate memory to store the features
- ftrs = N.zeros((n_pos + n_neg, n_smp_ftr_dim), dtype=N.float32)
- #allocate memory to store the ground truth regions?
- lbls = N.zeros((n_pos + n_neg, g_size, g_size), dtype=N.int32)
- #generate random features index from the whole features
- sids = rand.permutation(n_ftr_dim)[:n_smp_ftr_dim]
- total = 0
- #input data has three info, img is the data of image
- #bnds is the data of boundary, segs is the data of segmentation
- for j, (img, bnds, segs) in enumerate(input_data):
- mask = N.zeros(bnds[0].shape, dtype=bnds[0].dtype)
- #set mask value to 1 by shrink step at x,y direction
- mask[::shrink, ::shrink] = 1
- #set first p_rad rows and last p_rad rows as zeros
- mask[:p_rad] = mask[-p_rad:] = 0
- #set first p_rad cols and last p_rad cols as zeros
- mask[:, :p_rad] = mask[:, -p_rad:] = 0
- #number of positive per ground truth
- n_pos_per_gt = int(ceil(float(n_pos) / n_img / len(bnds)))
- #number of negative per ground truth
- n_neg_per_gt = int(ceil(float(n_neg) / n_img / len(bnds)))
- for k, boundary in enumerate(bnds):
- #do euclidean transform, boundary == 0 is the locations of non boundary pixels?
- #looks like armadillo do not have distance transform algorithm, do mlpack provided one?
- #I can implement a simple one if you like
- dis = distance_transform_edt(boundary == 0)
- #if the distance less than g_rad, it is positive location
- #because the euclidean distance is close to the boundary pixels
- pos_loc = ((dis < g_rad) * mask).nonzero()
- pos_loc = zip(pos_loc[0].tolist(), pos_loc[1].tolist())
- #shuffle pos_loc
- pos_loc = [pos_loc[item] for item in
- rand.permutation(len(pos_loc))[:n_pos_per_gt]]
- #get negative location
- neg_loc = ((dis >= g_rad) * mask).nonzero()
- neg_loc = zip(neg_loc[0].tolist(), neg_loc[1].tolist())
- neg_loc = [neg_loc[item] for item in
- rand.permutation(len(neg_loc))[:n_neg_per_gt]]
- #add positive location and negative location together?why
- loc = pos_loc + neg_loc
- #randomize the location
- n_loc = min(len(loc), ftrs.shape[0] - total)
- loc = [loc[item] for item in rand.permutation(len(loc))[:n_loc]]
- if n_loc == 0:
- continue
- #get the features from the img and locations
- ftr = N.concatenate(self.get_features(img, loc), axis=1)
- assert ftr.shape[1] == n_ftr_dim
- ftr = ftr[:, sids]
- lbl = N.zeros((ftr.shape[0], g_size, g_size), dtype=N.int8)
- for m, (x, y) in enumerate(loc):
- #get the ground truth segmentation of boundary k, (x,y) is the center location
- sub = segs[k][x - g_rad: x + g_rad, y - g_rad: y + g_rad]
- #store the unique, invertable index, why?
- sub = N.unique(sub, return_inverse=True)[1]
- lbl[m] = sub.reshape((g_size, g_size))
- #ftrs store the features
- ftrs[total: total + n_loc] = ftr
- #lbls store the invertable index
- lbls[total: total + n_loc] = lbl
- total += n_loc
- sys.stdout.write("Processing Data %d: %d/%d\r" % (i + 1, j + 1, n_img))
- sys.stdout.flush()
- print
- #write the results
- with tables.open_file(data_path, "w", filters=self.comp_filt) as dfile:
- dfile.create_carray("/", "ftrs", obj=ftrs[:total])
- dfile.create_carray("/", "lbls", obj=lbls[:total])
- dfile.create_carray("/", "sids", obj=sids.astype(N.int32))
- print "Saving %d samples to '%s'..." % (total, data_file)
Add Comment
Please, Sign In to add comment