Advertisement
Guest User

Untitled

a guest
Mar 27th, 2017
49
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.36 KB | None | 0 0
  1. def fit(self, train_data, eval_data=None, eval_metric='acc',
  2. epoch_end_callback=None, batch_end_callback=None, kvstore='local',
  3. optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
  4. eval_end_callback=None,
  5. eval_batch_end_callback=None, initializer=Uniform(0.01),
  6. arg_params=None, aux_params=None, allow_missing=False,
  7. force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
  8. validation_metric=None, monitor=None):
  9. """Train the module parameters.
  10.  
  11. Parameters
  12. ----------
  13. train_data : DataIter
  14. eval_data : DataIter
  15. If not `None`, will be used as validation set and evaluate the performance
  16. after each epoch.
  17. eval_metric : str or EvalMetric
  18. Default `'accuracy'`. The performance measure used to display during training.
  19. Other possible predefined metrics are:
  20. 'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'
  21. epoch_end_callback : function or list of function
  22. Each callback will be called with the current `epoch`, `symbol`, `arg_params`
  23. and `aux_params`.
  24. batch_end_callback : function or list of function
  25. Each callback will be called with a `BatchEndParam`.
  26. kvstore : str or KVStore
  27. Default `'local'`.
  28. optimizer : str or Optimizer
  29. Default `'sgd'`
  30. optimizer_params : dict
  31. Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor.
  32. The default value is not a `dict`, just to avoid pylint warning on dangerous
  33. default values.
  34. eval_end_callback : function or list of function
  35. These will be called at the end of each full evaluation, with the metrics over
  36. the entire evaluation set.
  37. eval_batch_end_callback : function or list of function
  38. These will be called at the end of each minibatch during evaluation
  39. initializer : Initializer
  40. Will be called to initialize the module parameters if not already initialized.
  41. arg_params : dict
  42. Default `None`, if not `None`, should be existing parameters from a trained
  43. model or loaded from a checkpoint (previously saved model). In this case,
  44. the value here will be used to initialize the module parameters, unless they
  45. are already initialized by the user via a call to `init_params` or `fit`.
  46. `arg_params` has higher priority to `initializer`.
  47. aux_params : dict
  48. Default `None`. Similar to `arg_params`, except for auxiliary states.
  49. allow_missing : bool
  50. Default `False`. Indicate whether we allow missing parameters when `arg_params`
  51. and `aux_params` are not `None`. If this is `True`, then the missing parameters
  52. will be initialized via the `initializer`.
  53. force_rebind : bool
  54. Default `False`. Whether to force rebinding the executors if already binded.
  55. force_init : bool
  56. Default `False`. Indicate whether we should force initialization even if the
  57. parameters are already initialized.
  58. begin_epoch : int
  59. Default `0`. Indicate the starting epoch. Usually, if we are resuming from a
  60. checkpoint saved at a previous training phase at epoch N, then we should specify
  61. this value as N+1.
  62. num_epoch : int
  63. Number of epochs to run training.
  64.  
  65. Examples
  66. --------
  67. An example of using fit for training::
  68. >>> #Assume training dataIter and validation dataIter are ready
  69. >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter,
  70. optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
  71. num_epoch=10)
  72. """
  73. assert num_epoch is not None, 'please specify number of epochs'
  74.  
  75. self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
  76. for_training=True, force_rebind=force_rebind)
  77. if monitor is not None:
  78. self.install_monitor(monitor)
  79. self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
  80. allow_missing=allow_missing, force_init=force_init)
  81. self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
  82. optimizer_params=optimizer_params)
  83.  
  84. if validation_metric is None:
  85. validation_metric = eval_metric
  86. if not isinstance(eval_metric, metric.EvalMetric):
  87. eval_metric = metric.create(eval_metric)
  88.  
  89. #KP for tracking accuracy and early stop
  90. epoch_train_eval_metrics = {}
  91. ################################################################################
  92. # training loop
  93. ################################################################################
  94. for epoch in range(begin_epoch, num_epoch):
  95. tic = time.time()
  96. eval_metric.reset()
  97. nbatch = 0
  98. data_iter = iter(train_data)
  99. end_of_batch = False
  100. next_data_batch = data_iter.next()
  101. while not end_of_batch:
  102. data_batch = next_data_batch
  103. if monitor is not None:
  104. monitor.tic()
  105. self.forward_backward(data_batch)
  106. self.update()
  107. try:
  108. # pre fetch next batch
  109. next_data_batch = data_iter.next()
  110. self.prepare(next_data_batch)
  111. except StopIteration:
  112. end_of_batch = True
  113.  
  114. self.update_metric(eval_metric, data_batch.label)
  115.  
  116. if monitor is not None:
  117. monitor.toc_print()
  118.  
  119. if batch_end_callback is not None:
  120. batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
  121. eval_metric=eval_metric,
  122. locals=locals())
  123. for callback in _as_list(batch_end_callback):
  124. callback(batch_end_params)
  125. nbatch += 1
  126.  
  127. # one epoch of training is finished
  128. for name, val in eval_metric.get_name_value():
  129. self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
  130. self.build_metrics( epoch_train_eval_metrics,epoch, name, val)
  131.  
  132. toc = time.time()
  133. self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
  134.  
  135. # sync aux params across devices
  136. arg_params, aux_params = self.get_params()
  137. self.set_params(arg_params, aux_params)
  138.  
  139.  
  140. #----------------------------------------
  141. # evaluation on validation set
  142. if eval_data:
  143. res = self.score(eval_data, validation_metric,
  144. score_end_callback=eval_end_callback,
  145. batch_end_callback=eval_batch_end_callback, epoch=epoch)
  146. #TODO: pull this into default
  147. for name, val in res:
  148. self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
  149. self.build_metrics(epoch_train_eval_metrics, epoch, name, val)
  150.  
  151. #KP moved these from above the eval_data call to have both train and eval data available
  152. if epoch_end_callback is not None:
  153. for callback in _as_list(epoch_end_callback):
  154. #KP if callback returns false then early stop
  155. if not callback(epoch, self.symbol, arg_params, aux_params,epoch_train_eval_metrics=epoch_train_eval_metrics):
  156. return
  157.  
  158. # end of 1 epoch, reset the data-iter for another epoch
  159. train_data.reset()
  160.  
  161. def build_metrics(self, epoch_train_eval_metrics,epoch, name, val):
  162. '''
  163. keep track of accuracys,
  164. :param epoch_train_eval_metrics:
  165. :param epoch:
  166. :param name:
  167. :param val:
  168. :return:
  169. '''
  170. if not epoch_train_eval_metrics.has_key(name): # KP add to dict
  171. epoch_train_eval_metrics[name] = {}
  172. if not epoch_train_eval_metrics[name].has_key(epoch): # KP add to dict
  173. epoch_train_eval_metrics[name][epoch] = []
  174. epoch_train_eval_metrics[name][epoch].append(val)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement