Advertisement
Guest User

Untitled

a guest
Feb 19th, 2019
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 34.01 KB | None | 0 0
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. #from train.dataprocessor import DataProcessor
  5. from tensorflow.data import Iterator
  6. import os.path
  7.  
  8.  
  9.  
  10. # Path for tf.summary.FileWriter and to store model checkpoints
  11. train_filewriter_path = os.path.join(FLAGS.project_saves_path,"TCNN_train_tensorboard")
  12. val_filewriter_path = os.path.join(FLAGS.project_saves_path, "TCNN_val_tensorboard")
  13. checkpoint_path = os.path.join(FLAGS.project_saves_path, "TCNN_checkpoints")
  14.  
  15.  
  16. # Create parent path if it doesn't exist
  17. if not os.path.isdir(checkpoint_path):
  18. os.mkdir(checkpoint_path)
  19. # Create parent path if it doesn't exist
  20. if not os.path.isdir(train_filewriter_path):
  21. os.mkdir(train_filewriter_path)
  22. if not os.path.isdir(val_filewriter_path):
  23. os.mkdir(val_filewriter_path)
  24.  
  25. TOWER_NAME = "tower"
  26.  
  27. # Input parameters
  28. IMAGENET_MEAN = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32)
  29. HEIGHT = 480
  30. WIDTH = 640
  31. CHANNELS = 3 # RGB
  32.  
  33. BATCH_SIZE = 4
  34. DEPTH = 5
  35. SEQ_LEN = 10
  36. NUM_GPUS = 1
  37.  
  38. # Learning parameters
  39. LEARN_RATE = 1e-4
  40. NUM_EPOCHS = 2
  41.  
  42. # Network params
  43. DROPOUT_RATE = 0.5
  44. KEEP_PROB_TRAIN = 0.25
  45. OUTPUT_DIMS = 1 # only steering angles
  46.  
  47. # The parameters of the LSTM that keeps the model state.
  48. RNN_NUM_UNITS = 32
  49. RNN_NUM_PROJ = 32
  50.  
  51.  
  52. def parse_img(img_path):
  53. img_path = tf.read_file(img_path)
  54. img_decoded = tf.image.decode_png(img_path, channels=3)
  55. img_resized = tf.image.resize_images(img_decoded, [227, 227])
  56.  
  57. img_centered = tf.subtract(img_resized, IMAGENET_MEAN) # not needed??????????????????
  58.  
  59. # RGB -> BGR
  60. img_bgr = img_centered[:, :, ::-1]
  61.  
  62. return img_bgr
  63.  
  64.  
  65. def get_optimizer(loss, learn_rate):
  66. optimizer = tf.train.AdamOptimizer(learning_rate=learn_rate)
  67. gradvars = optimizer.compute_gradients(loss)
  68.  
  69. grads, vars = zip(*gradvars)
  70.  
  71. #print([x.name for x in vars])
  72. # grad clipping
  73. grads, _ = tf.clip_by_global_norm(grads, 15.0)
  74.  
  75. return optimizer.apply_gradients(zip(grads, vars))
  76.  
  77.  
  78. def add_depth(prev_img_batch, curr_img_batch):
  79. # curr_img_batch on input - batch size * seq length = 40,227,227,3
  80. # curr_img_batch [BATCH_SIZE = 4, DEPTH = 5 + SEQ_LEN = 10, 227, 227, 3]
  81. # label_batch [BATCH_SIZE= 4, SEQ_LEN = 10 , angle] (40,)
  82.  
  83. cur_batch_images = []
  84. # [40, 227, 227, 3]
  85. for i in range(BATCH_SIZE):
  86. seq_start_id = i * SEQ_LEN
  87.  
  88. # add context frames for sequence
  89. if (seq_start_id - DEPTH < 0):
  90. # not enough previous frames
  91. added_context = curr_img_batch[0:seq_start_id,:,:,:]
  92. #print("added_context", added_context.shape)
  93.  
  94.  
  95. # no prev batch
  96. if (prev_img_batch == None):
  97. first_img = curr_img_batch[0,:,:,:]
  98. first_img = tf.reshape(first_img, [1,227,227,CHANNELS])
  99. for i in range(DEPTH - seq_start_id):
  100. cur_batch_images.append(first_img)
  101. else:
  102. cur_batch_images.append(prev_img_batch[seq_start_id - DEPTH:,:,:,:])
  103. else:
  104. added_context = curr_img_batch[seq_start_id - DEPTH : seq_start_id,:,:,:]
  105.  
  106. # add sequence
  107. if (added_context.shape[0] > 0):
  108. cur_batch_images.append(added_context)
  109. cur_batch_images.append(curr_img_batch[seq_start_id : seq_start_id + SEQ_LEN,:,:,:])
  110.  
  111.  
  112. curr_img_batch = tf.concat(cur_batch_images, 0)
  113.  
  114. return curr_img_batch
  115.  
  116. ################################ Temporal Convolution part ########################################
  117.  
  118. # expects data of shape = [BATCH_SIZE, DEPTH + SEQ_LEN, HEIGT, WIDTH, CHANNELS]
  119. def temporal_cnn(images, keep_prob):
  120. with tf.variable_scope("TCNN", reuse=None) as scope:
  121. #conv 1
  122. net = tf.contrib.layers.conv3d(images,
  123. num_outputs=64,
  124. kernel_size=[3,12,12],
  125. stride=[1,1,6],
  126. padding='VALID',
  127. rate=1,
  128. activation_fn=tf.nn.relu,
  129. normalizer_fn=None,
  130. normalizer_params=None,
  131. weights_initializer=tf.contrib.layers.xavier_initializer(),
  132. weights_regularizer=None,
  133. biases_initializer=tf.zeros_initializer(),
  134. biases_regularizer=None)
  135. net = tf.nn.dropout(x=net, keep_prob=keep_prob)
  136. reshaped_data = tf.reshape(net[:, -SEQ_LEN:, :, :, :],
  137. [BATCH_SIZE, SEQ_LEN, -1])
  138. aux1 = tf.contrib.layers.fully_connected(reshaped_data,
  139. 128, activation_fn=None)
  140.  
  141. #conv 2
  142. net = tf.contrib.layers.conv3d(net,
  143. num_outputs=64,
  144. kernel_size=[2,5,5],
  145. stride=[1,2,2],
  146. padding='VALID')
  147. net = tf.nn.dropout(x=net, keep_prob=keep_prob)
  148. aux2 = tf.contrib.layers.fully_connected(reshaped_data,
  149. 128, activation_fn=None)
  150.  
  151. #conv 3
  152. net = tf.contrib.layers.conv3d(net,
  153. num_outputs=64,
  154. kernel_size=[2,5,5],
  155. stride=[1,1,1],
  156. padding='VALID')
  157. net = tf.nn.dropout(x=net, keep_prob=keep_prob)
  158. aux3 = tf.contrib.layers.fully_connected(reshaped_data,
  159. 128, activation_fn=None)
  160.  
  161. #conv 4
  162. net = tf.contrib.layers.conv3d(net,
  163. num_outputs=64,
  164. kernel_size=[2,5,5],
  165. stride=[1,1,1],
  166. padding='VALID')
  167. net = tf.nn.dropout(x=net, keep_prob=keep_prob)
  168. aux4 = tf.contrib.layers.fully_connected(reshaped_data,
  169. 128, activation_fn=None)
  170.  
  171. # fc 1
  172. net = tf.reshape(net, [BATCH_SIZE, SEQ_LEN, -1])
  173. net = tf.contrib.layers.fully_connected(net,
  174. 1024, activation_fn=tf.nn.relu)
  175. net = tf.nn.dropout(x=net, keep_prob=keep_prob)
  176.  
  177. # fc 2
  178. net = tf.contrib.layers.fully_connected(net,
  179. 512, activation_fn=tf.nn.relu)
  180. net = tf.nn.dropout(x=net, keep_prob=keep_prob)
  181.  
  182. # fc 3
  183. net = tf.contrib.layers.fully_connected(net,
  184. 256, activation_fn=tf.nn.relu)
  185. net = tf.nn.dropout(x=net, keep_prob=keep_prob)
  186.  
  187. # fc 4
  188. net = tf.contrib.layers.fully_connected(net,
  189. 128, activation_fn=None)
  190.  
  191.  
  192. # define layer normalization function
  193. layer_norm = lambda x: tf.contrib.layers.layer_norm(inputs=x,
  194. center=True,
  195. scale=True,
  196. reuse=tf.AUTO_REUSE,
  197. scope=scope,
  198. activation_fn=None,
  199. trainable=True)
  200.  
  201. # aux[1-4] are residual connections (shortcuts)
  202. return layer_norm(tf.nn.elu(net + aux1 + aux2 + aux3 + aux4))
  203.  
  204. #######################################################################################################
  205. def get_rnn_initial_state(complex_state_tuple_sizes):
  206. # flatten
  207. flat_sizes = tf.contrib.framework.nest.flatten(complex_state_tuple_sizes)
  208. initial_state_flat = [tf.tile(
  209. multiples=[BATCH_SIZE, 1],
  210. input=tf.get_variable("controller_initial_state_%d" % index, initializer=tf.zeros_initializer, shape=([1, size]), dtype=tf.float32))
  211. for index,size in enumerate(flat_sizes)]
  212.  
  213. # pack the flat copy into the original tuple structure
  214. initial_state = tf.contrib.framework.nest.pack_sequence_as(
  215. structure=complex_state_tuple_sizes,
  216. flat_sequence=initial_state_flat)
  217. return initial_state
  218.  
  219. def deep_copy_initial_state(complex_state_tuple):
  220. # flatten
  221. flat_state = tf.contrib.framework.nest.flatten(complex_state_tuple)
  222.  
  223. # copy each each element
  224. flat_copy = [tf.identity(s) for s in flat_state]
  225.  
  226. # pack the flat copy into the original tuple structure
  227. deep_copy = tf.contrib.framework.nest.pack_sequence_as(
  228. structure=complex_state_tuple,
  229. flat_sequence=flat_copy)
  230. return deep_copy
  231.  
  232.  
  233. class DualRNNCell(tf.nn.rnn_cell.RNNCell):
  234.  
  235. def __init__(self, OUTPUT_DIMS, use_ground_truth, internal_cell):
  236. self._OUTPUT_DIMS = OUTPUT_DIMS # predctions
  237. self._use_ground_truth = use_ground_truth # boolean
  238. self._internal_cell = internal_cell # may be LSTM or GRU or anything
  239.  
  240. @property
  241. def state_size(self):
  242. # previous output and bottleneck state
  243. return self._OUTPUT_DIMS, self._internal_cell.state_size
  244.  
  245. @property
  246. def output_size(self):
  247. return self._OUTPUT_DIMS
  248.  
  249. def __call__(self, data, prev_state, scope=None):
  250. (visual_feats, current_ground_truth) = data
  251. #print("visual_feats",visual_feats.shape)#4,128
  252. #print("current_ground_truth",current_ground_truth.shape)#4,1
  253.  
  254. prev_output, prev_state_internal = prev_state
  255. #print("prev_output",prev_output.shape)#4,1
  256. #print("prev_state_internal[0]",prev_state_internal[0].shape)#4,32
  257. #print("prev_state_internal[1]",prev_state_internal[1].shape)#4,32
  258.  
  259. # 4,1 and 4,128 -> 4,129
  260. context = tf.concat([prev_output, visual_feats], axis=1)
  261.  
  262. # call internal cell
  263. new_output_internal, new_state_internal = self._internal_cell(context, prev_state_internal)
  264.  
  265.  
  266. # FC
  267. new_output = tf.contrib.layers.fully_connected(
  268. inputs=tf.concat([new_output_internal, prev_output, visual_feats], axis=1),
  269. num_outputs=self._OUTPUT_DIMS,
  270. activation_fn=None,
  271. scope="OutputProjection")
  272.  
  273. return new_output, (current_ground_truth if self._use_ground_truth else new_output, new_state_internal)
  274.  
  275. ###############################################################################################
  276.  
  277. def model_losses(output_with_gt, output_with_pred, targets, aux_cost_weight):
  278. # mean of the squared error
  279. mse_gt = tf.reduce_mean(tf.squared_difference(output_with_gt, targets))
  280. mse_pred = tf.reduce_mean(tf.squared_difference(output_with_pred, targets))
  281. mse_pred_steering = tf.reduce_mean(tf.squared_difference(output_with_pred[:, :, 0], targets[:, :, 0]))
  282.  
  283. combined_loss = mse_pred_steering + aux_cost_weight * (mse_gt + mse_pred)
  284. tf.add_to_collection('losses', combined_loss)
  285.  
  286. # additional stats
  287. tf.add_to_collection('rmse_collection', tf.sqrt(mse_gt))
  288. tf.add_to_collection('rmse_collection', tf.sqrt(mse_pred))
  289. tf.add_to_collection('rmse_collection', tf.sqrt(mse_pred_steering))
  290.  
  291. # The total loss is defined as the combined_loss plus all of the weight decay terms (L2 loss) - NONE YET.
  292. return tf.add_n(tf.get_collection('losses'), name='total_tower_loss')
  293.  
  294.  
  295.  
  296. def inference(images, labels, keep_prob):
  297. ####################### TCNN part ##########################
  298. visual_data = temporal_cnn(images=images, keep_prob=keep_prob)
  299. visual_data = tf.reshape(visual_data, [BATCH_SIZE, SEQ_LEN, -1])
  300. visual_data = tf.nn.dropout(x=visual_data, keep_prob=keep_prob)
  301.  
  302.  
  303. ######################## LSTM part #########################
  304.  
  305. # inputs for the LSTM part
  306. data_with_gt = (visual_data, labels)
  307.  
  308. #no predictions yet => zeros
  309. zero_pred = tf.zeros(shape=(BATCH_SIZE, SEQ_LEN, OUTPUT_DIMS),dtype=tf.float32)
  310. data_with_pred = (visual_data, zero_pred)
  311.  
  312. # the internal LSTM cell for our custom dual cell
  313. rnn_internal_cell = tf.nn.rnn_cell.LSTMCell(num_units=RNN_NUM_UNITS,
  314. use_peepholes=False,
  315. cell_clip=None,
  316. initializer=None,
  317. num_proj=RNN_NUM_PROJ,
  318. proj_clip=None,
  319. num_unit_shards=None,
  320. num_proj_shards=None,
  321. forget_bias=1.0,
  322. state_is_tuple=True,
  323. activation=None)
  324.  
  325. # cell with ground truth
  326. rnn_cell_with_gt = DualRNNCell(OUTPUT_DIMS=OUTPUT_DIMS,
  327. use_ground_truth=True,
  328. internal_cell=rnn_internal_cell)
  329.  
  330. # cell with predictions
  331. rnn_cell_with_pred = DualRNNCell(OUTPUT_DIMS=OUTPUT_DIMS,
  332. use_ground_truth=False,
  333. internal_cell=rnn_internal_cell)
  334.  
  335.  
  336.  
  337. rnn_initial_state = get_rnn_initial_state(rnn_cell_with_pred.state_size)
  338. # initial states for the 2 customs RNN cells
  339. initial_state_gt = deep_copy_initial_state(rnn_initial_state)
  340. initial_state_pred = deep_copy_initial_state(rnn_initial_state)
  341.  
  342. # predict using our custom cells
  343. with tf.variable_scope("predictor"):
  344. output_with_gt, final_state_gt = tf.nn.dynamic_rnn(cell=rnn_cell_with_gt,
  345. inputs=data_with_gt,
  346. sequence_length=[SEQ_LEN] * BATCH_SIZE,
  347. initial_state=initial_state_gt,
  348. dtype=tf.float32,
  349. swap_memory=True,
  350. time_major=False)
  351.  
  352.  
  353. with tf.variable_scope("predictor", reuse=True):
  354. output_with_pred, final_state_pred = tf.nn.dynamic_rnn(cell=rnn_cell_with_pred,
  355. inputs=data_with_pred,
  356. sequence_length=[SEQ_LEN] * BATCH_SIZE,
  357. initial_state=initial_state_pred,
  358. dtype=tf.float32,
  359. swap_memory=True,
  360. time_major=False)
  361.  
  362. tf.add_to_collection('cell_final_states', (final_state_gt, final_state_pred))
  363.  
  364. return output_with_gt, output_with_pred
  365.  
  366.  
  367.  
  368. def tower_loss(scope, img_batch, label_batch, keep_prob, aux_cost_weight):
  369. # Calculate the total loss on a single tower running the whole model.
  370. # scope: unique prefix string identifying the tower, e.g. 'tower_0'
  371. # returns total loss for a batch of data
  372.  
  373. # Build infer Graph.
  374. output_with_gt, output_with_pred = inference(img_batch, label_batch, keep_prob)
  375.  
  376. # Build the portion of the Graph calculating the losses. Note that we will
  377. # assemble the total_loss across all towers using a custom function below.
  378. _ = model_losses(output_with_gt, output_with_pred, label_batch, aux_cost_weight)
  379.  
  380. # Assemble all of the losses for the current tower only.
  381. losses = tf.get_collection('losses', scope)
  382. rmse_collection = tf.get_collection('rmse_collection', scope)
  383. # Calculate the total loss for the current tower.
  384. total_tower_loss = tf.add_n(losses, name='total_tower_loss')
  385.  
  386. # Compute the moving average of all individual losses and the total loss.
  387. ema = tf.train.ExponentialMovingAverage(0.9, name='avg')
  388.  
  389.  
  390. # need scope: TOWER_NAME_1/mean_sq_error/avg/
  391. #with tf.variable_scope(scope):
  392. with tf.variable_scope("ema", reuse=tf.AUTO_REUSE):
  393. ema_op = ema.apply(losses + rmse_collection + [total_tower_loss]) ##IDKKKKKKKKKKKKKKKKKKKKKKK
  394.  
  395.  
  396.  
  397. # Attach a scalar summary to all individual losses and the total loss; do the
  398. # same for the averaged version of the losses.
  399. for l in losses + rmse_collection + [total_tower_loss]:
  400. # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
  401. # session. This helps the clarity of presentation on tensorboard.
  402. #loss_name = l.op.name.replace('%s_[0-9]*/' % TOWER_NAME, '')
  403. loss_name = l.op.name
  404. # Name each loss as '(raw)' and name the moving average version of the loss
  405. # as the original loss name.
  406. tf.summary.scalar(loss_name + ' (raw)', l)
  407. tf.summary.scalar(loss_name, ema.average(l))
  408.  
  409. with tf.control_dependencies([ema_op]):
  410. total_tower_loss = tf.identity(total_tower_loss)
  411.  
  412. return total_tower_loss
  413.  
  414.  
  415. def average_gradients(tower_grads):
  416. """Calculate the average gradient for each shared variable across all towers.
  417. Note that this function provides a synchronization point across all towers.
  418. Args:
  419. tower_grads: List of lists of (gradient, variable) tuples. The outer list
  420. is over individual gradients. The inner list is over the gradient
  421. calculation for each tower.
  422. Returns:
  423. List of pairs of (gradient, variable) where the gradient has been averaged
  424. across all towers.
  425. """
  426. average_grads = []
  427. for grad_and_vars in zip(*tower_grads):
  428. # Note that each grad_and_vars looks like the following:
  429. # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
  430. grads = []
  431. for g, _ in grad_and_vars:
  432. # Add 0 dimension to the gradients to represent the tower.
  433. expanded_g = tf.expand_dims(g, 0)
  434.  
  435. # Append on a 'tower' dimension which we will average over below.
  436. grads.append(expanded_g)
  437.  
  438.  
  439.  
  440. # Average over the 'tower' dimension.
  441. grad = tf.concat(grads, axis=0)
  442. grad = tf.reduce_mean(grad, 0)
  443.  
  444. # Keep in mind that the Variables are redundant because they are shared
  445. # across towers. So .. we will just return the first tower's pointer to
  446. # the Variable.
  447. v = grad_and_vars[0][1]
  448.  
  449. grad_and_var = (grad, v)
  450. average_grads.append(grad_and_var)
  451.  
  452. return average_grads
  453.  
  454.  
  455. def train():
  456. with tf.device('/cpu:0'):
  457. ####################### Data input pipeline ###############################
  458. epoch_is_done = False
  459. # Images in dataset: 25,172
  460. # interpolated 1 -> 4401 - val ~ 12%
  461. # interpolated 2 -> 15,796
  462. # interpolated 4 -> 1974
  463. # interpolated 5 -> 4235 - test ~ 12%
  464. # interpolated 6 -> 7402
  465. # total -> 33,808
  466.  
  467. train_iter_init_ops = []
  468. next_batch_ops = []
  469. # define datasets
  470. train_dataset_1 = tf.data.experimental.make_csv_dataset([os.path.join(FLAGS.data_path, "interpolated_2_4_6_ONE.csv")],
  471. batch_size=BATCH_SIZE*SEQ_LEN,
  472. select_columns=[5,6],
  473. label_name='angle',
  474. shuffle=False,
  475. column_defaults=[tf.string, tf.float32])
  476.  
  477. # create an reinitializable iterator given the dataset structure
  478. iterator_1 = Iterator.from_structure(train_dataset_1.output_types, train_dataset_1.output_shapes)
  479. # Ops for initializing the two different iterators
  480. train_iter_init_op_1 = iterator_1.make_initializer(train_dataset_1)
  481. train_iter_init_ops.append(train_iter_init_op_1)
  482. next_batch_op_1 = iterator_1.get_next()
  483. next_batch_ops.append(next_batch_op_1)
  484.  
  485. train_dataset_2 = tf.data.experimental.make_csv_dataset([os.path.join(FLAGS.data_path, "interpolated_2_4_6_TWO.csv")],
  486. batch_size=BATCH_SIZE*SEQ_LEN,
  487. select_columns=[5,6],
  488. label_name='angle',
  489. shuffle=False,
  490. column_defaults=[tf.string, tf.float32])
  491.  
  492. # create an reinitializable iterator given the dataset structure
  493. iterator_2 = Iterator.from_structure(train_dataset_2.output_types, train_dataset_2.output_shapes)
  494. # Ops for initializing the two different iterators
  495. train_iter_init_op_2 = iterator_2.make_initializer(train_dataset_2)
  496. train_iter_init_ops.append(train_iter_init_op_2)
  497. next_batch_op_2 = iterator_2.get_next()
  498. next_batch_ops.append(next_batch_op_2)
  499.  
  500. train_dataset_3 = tf.data.experimental.make_csv_dataset([os.path.join(FLAGS.data_path, "interpolated_2_4_6_THREE.csv")],
  501. batch_size=BATCH_SIZE*SEQ_LEN,
  502. select_columns=[5,6],
  503. label_name='angle',
  504. shuffle=False,
  505. column_defaults=[tf.string, tf.float32])
  506.  
  507. # create an reinitializable iterator given the dataset structure
  508. iterator_3 = Iterator.from_structure(train_dataset_3.output_types, train_dataset_3.output_shapes)
  509. # Ops for initializing the two different iterators
  510. train_iter_init_op_3 = iterator_3.make_initializer(train_dataset_3)
  511. train_iter_init_ops.append(train_iter_init_op_3)
  512. next_batch_op_3 = iterator_3.get_next()
  513. next_batch_ops.append(next_batch_op_3)
  514.  
  515.  
  516. train_dataset_4 = tf.data.experimental.make_csv_dataset([os.path.join(FLAGS.data_path, "interpolated_2_4_6_FOUR.csv")],
  517. batch_size=BATCH_SIZE*SEQ_LEN,
  518. select_columns=[5,6],
  519. label_name='angle',
  520. shuffle=False,
  521. column_defaults=[tf.string, tf.float32])
  522.  
  523.  
  524. # create an reinitializable iterator given the dataset structure
  525. iterator_4 = Iterator.from_structure(train_dataset_4.output_types, train_dataset_4.output_shapes)
  526. # Ops for initializing the two different iterators
  527. train_iter_init_op_4 = iterator_4.make_initializer(train_dataset_4)
  528. train_iter_init_ops.append(train_iter_init_op_4)
  529. next_batch_op_4 = iterator_4.get_next()
  530. next_batch_ops.append(next_batch_op_4)
  531.  
  532.  
  533. #############################################################################
  534.  
  535. # learning placeholders
  536. keep_prob = tf.placeholder_with_default(input=1.0, shape=())
  537. aux_cost_weight = tf.placeholder_with_default(input=0.1, shape=())
  538. learn_rate = tf.placeholder_with_default(input=1e-4, shape=())
  539.  
  540. optimizer = tf.train.AdamOptimizer(learning_rate=learn_rate)
  541.  
  542. global_step = tf.get_variable(
  543. 'global_step', [],
  544. initializer=tf.constant_initializer(0), trainable=False)
  545.  
  546. # Calculate the gradients for each model tower.
  547. tower_grads = []
  548. with tf.variable_scope(tf.get_variable_scope()) as scope:
  549. for i in range(NUM_GPUS):
  550. with tf.device('/gpu:%d' % i):
  551. with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope:
  552. # Calculate the loss for one tower of the model. This function
  553. # constructs the entire model but shares the variables across
  554. # all towers.
  555.  
  556. img_batch = tf.placeholder(tf.float32, shape=(BATCH_SIZE, DEPTH + SEQ_LEN, 227, 227, 3), name="img_batch_placeholder_%d" % i)
  557. label_batch = tf.placeholder(tf.float32, shape=(BATCH_SIZE, SEQ_LEN, OUTPUT_DIMS), name="label_batch_placeholder_%d" % i)
  558.  
  559.  
  560. loss = tower_loss(scope, img_batch, label_batch, keep_prob, aux_cost_weight)
  561.  
  562.  
  563. # Reuse variables for the next tower.
  564. tf.get_variable_scope().reuse_variables()
  565.  
  566. # Retain the summaries from the final tower.
  567. summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
  568.  
  569. # Retain the Batch Normalization updates operations only from the
  570. # final tower. Ideally, we should grab the updates from all towers
  571. # but these stats accumulate extremely fast so we can ignore the
  572. # other stats from the other towers without significant detriment.
  573. #batchnorm_updates = tf.get_collection(ops.GraphKeys.UPDATE_OPS, scope) #ot BDD proveri kvo stava
  574.  
  575.  
  576. # Calculate the gradients for the batch of data on this tower.
  577. gradvars = optimizer.compute_gradients(loss, tf.trainable_variables())
  578. # get values with existing gradients only
  579. vars_with_grads = [v for (g, v) in gradvars if g is not None]
  580. # recompute their gradients only
  581. gradvars = optimizer.compute_gradients(loss, vars_with_grads)
  582.  
  583. # grad clipping
  584. grads, vars = zip(*gradvars)
  585. grads, _ = tf.clip_by_global_norm(grads, 15.0)
  586.  
  587. # Keep track of the gradients across all towers.
  588. tower_grads.append(zip(grads, vars))
  589.  
  590.  
  591.  
  592. # We must calculate the mean of each gradient. Note that this is the
  593. # synchronization point across all towers.
  594. gradvars = average_gradients(tower_grads)
  595.  
  596. # Add schedule here !!!!!!
  597. # Add a summary to track the learning rate.
  598. summaries.append(tf.summary.scalar('learning_rate', learn_rate))
  599.  
  600. # Add histograms for gradients.
  601. #for grad, var in gradvars:
  602. # summaries.append(tf.summary.histogram(var.op.name + '/gradients', grad))
  603. for grad_var_summ in [tf.summary.histogram(var.op.name + '/gradients', grad) for (grad, var) in gradvars]:
  604. summaries.append(grad_var_summ)
  605.  
  606. # Apply the gradients to adjust the shared variables.
  607. apply_gradient_op = optimizer.apply_gradients(gradvars, global_step=global_step)
  608.  
  609. # Add histograms for trainable variables.
  610. for var in tf.trainable_variables():
  611. summaries.append(tf.summary.histogram(var.op.name, var))
  612.  
  613.  
  614. # Track the moving averages of all trainable variables.
  615. variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step)
  616. variables_averages_op = variable_averages.apply(tf.trainable_variables())
  617.  
  618. #batchnorm_updates_op = tf.group(*batchnorm_updates)
  619. # Group all updates to into a single train op.
  620. train_op = tf.group(apply_gradient_op, variables_averages_op)#, batchnorm_updates_op)
  621.  
  622. # final_states for LSTM cells
  623. final_cell_states = tf.get_collection('cell_final_states')
  624.  
  625.  
  626. # Create a saver.
  627. saver = tf.train.Saver()
  628.  
  629.  
  630. # Build the summary operation from the last tower summaries.
  631. summary_op = tf.summary.merge(summaries)
  632.  
  633. train_writer = tf.summary.FileWriter(train_filewriter_path)
  634.  
  635. # Build an initialization operation to run below.
  636. init_op = tf.global_variables_initializer()
  637.  
  638. ##################### Session ##################################
  639.  
  640.  
  641. # Start running operations on the Graph. allow_soft_placement must be set to
  642. # True to build towers on GPU, as some of the ops do not have GPU
  643. # implementations.
  644.  
  645. config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth = True),
  646. allow_soft_placement=True,
  647. log_device_placement=False)
  648. sess = tf.Session(config=config)
  649. sess.run(init_op)
  650.  
  651.  
  652. step_start = int(sess.run(global_step))
  653.  
  654. # Load previous model version
  655. model_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
  656. if model_checkpoint:
  657. print("Restoring from", model_checkpoint)
  658. saver.restore(sess=sess, save_path=model_checkpoint)
  659.  
  660.  
  661. print("Training...")
  662. prev_img_batches = []
  663. curr_final_cell_states = []
  664. # Loop over number of epochs
  665. for epoch in range(1, NUM_EPOCHS + 1):
  666. print("Epoch number: %d" %epoch)
  667.  
  668. # initialize all iterators
  669. for i in range(NUM_GPUS):
  670. sess.run(train_iter_init_ops[i])
  671. prev_img_batches.append(None)
  672. # curr_final_state_gt, curr_final_state_pred
  673. curr_final_cell_states.append((None, None))
  674.  
  675. # go through all batches in sets
  676. while True:
  677. try:
  678. feed_dict = {}
  679. feed_dict[learn_rate] = LEARN_RATE
  680. feed_dict[keep_prob] = KEEP_PROB_TRAIN
  681.  
  682. for i in range(NUM_GPUS):
  683. curr_final_state_gt, curr_final_state_pred = curr_final_cell_states[i]
  684. if curr_final_state_gt is not None:
  685. # first part of the RNN state tuple - Tensor of shape = BATCH_SIZE
  686. final_state_gt_result = tf.get_default_graph().get_tensor_by_name('tower_%d/predictor/rnn/while/Exit_3:0' % i)
  687. # the internal states for the 2 LSTM cells
  688. final_state_gt_internal_1 = tf.get_default_graph().get_tensor_by_name('tower_%d/predictor/rnn/while/Exit_4:0' % i)
  689. final_state_gt_internal_2 = tf.get_default_graph().get_tensor_by_name('tower_%d/predictor/rnn/while/Exit_5:0' % i)
  690.  
  691. feed_dict[final_state_gt_result] = curr_final_state_gt[0]
  692. feed_dict[final_state_gt_internal_1] = curr_final_state_gt[1][0]
  693. feed_dict[final_state_gt_internal_2] = curr_final_state_gt[1][1]
  694.  
  695.  
  696. if curr_final_state_pred is not None:
  697. # first part of the RNN state tuple - Tensor of shape = BATCH_SIZE
  698. initial_state_pred_result = tf.get_default_graph().get_tensor_by_name('tower_%d/Identity_3:0' % i)
  699. # the internal states for the 2 LSTM cells
  700. initial_state_pred_internal_1 = tf.get_default_graph().get_tensor_by_name('tower_%d/Identity_4:0' % i)
  701. initial_state_pred_internal_2 = tf.get_default_graph().get_tensor_by_name('tower_%d/Identity_5:0' % i)
  702.  
  703. feed_dict[initial_state_pred_result] = curr_final_state_pred[0]
  704. feed_dict[initial_state_pred_internal_1] = curr_final_state_pred[1][0]
  705. feed_dict[initial_state_pred_internal_2] = curr_final_state_pred[1][1]
  706.  
  707.  
  708. img_batch_placeholder = tf.get_default_graph().get_tensor_by_name("tower_%d/img_batch_placeholder_%d:0" % (i,i))
  709. label_batch_placeholder = tf.get_default_graph().get_tensor_by_name("tower_%d/label_batch_placeholder_%d:0" % (i,i))
  710.  
  711. img_batch_dict, label_batch = sess.run(next_batch_ops[i])
  712.  
  713. # organize data
  714. img_batch_paths_encoded = sess.run(tf.convert_to_tensor(img_batch_dict["filename"], dtype=tf.string))
  715. img_batch_paths = []
  716. for p in img_batch_paths_encoded:
  717. img_batch_paths.append(os.path.join(FLAGS.data_path, p.decode('UTF-8')).rstrip()) ##CHANGE .replace("/", "\")
  718.  
  719. img_batch_list = []
  720. for p in img_batch_paths:
  721. img_batch_list.append(parse_img(p))
  722.  
  723. img_batch = tf.stack(img_batch_list)
  724.  
  725. # add DEPTH preceding frames to every sequence
  726. img_batch = add_depth(prev_img_batches[i], curr_img_batch=img_batch)
  727. prev_img_batches[i] = img_batch
  728.  
  729. img_batch = tf.reshape(img_batch, [BATCH_SIZE, DEPTH + SEQ_LEN, 227, 227, CHANNELS])
  730. label_batch = tf.reshape(label_batch, [BATCH_SIZE, SEQ_LEN, OUTPUT_DIMS])
  731.  
  732.  
  733. # feed data
  734. feed_dict[img_batch_placeholder] = sess.run(img_batch)
  735. feed_dict[label_batch_placeholder] = sess.run(label_batch)
  736.  
  737.  
  738.  
  739. # perform training
  740. print("Running graph...")
  741. _, loss_value, summary, curr_step, curr_final_cell_states = sess.run([train_op, loss, summary_op, global_step, final_cell_states],
  742. feed_dict=feed_dict)
  743. print('global_step %d, loss = %.2f' % (curr_step, loss_value))
  744.  
  745. except tf.errors.OutOfRangeError:
  746. break
  747.  
  748.  
  749. break
  750. # save model after epoch
  751. train_writer.add_summary(summary, global_step)
  752. saver.save(sess, checkpoint_path, global_step=global_step)
  753. #if epoch > 14 or epoch == 5 or epoch == 10:
  754. # saver.save(sess, checkpoint_path, global_step=global_step)
  755.  
  756.  
  757. train()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement