Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def fit(self, train_data, eval_data=None, eval_metric='acc',
- epoch_end_callback=None, batch_end_callback=None, kvstore='local',
- optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
- eval_end_callback=None,
- eval_batch_end_callback=None, initializer=Uniform(0.01),
- arg_params=None, aux_params=None, allow_missing=False,
- force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
- validation_metric=None, monitor=None):
- """Train the module parameters.
- Parameters
- ----------
- train_data : DataIter
- eval_data : DataIter
- If not `None`, will be used as validation set and evaluate the performance
- after each epoch.
- eval_metric : str or EvalMetric
- Default `'accuracy'`. The performance measure used to display during training.
- Other possible predefined metrics are:
- 'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'
- epoch_end_callback : function or list of function
- Each callback will be called with the current `epoch`, `symbol`, `arg_params`
- and `aux_params`.
- batch_end_callback : function or list of function
- Each callback will be called with a `BatchEndParam`.
- kvstore : str or KVStore
- Default `'local'`.
- optimizer : str or Optimizer
- Default `'sgd'`
- optimizer_params : dict
- Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor.
- The default value is not a `dict`, just to avoid pylint warning on dangerous
- default values.
- eval_end_callback : function or list of function
- These will be called at the end of each full evaluation, with the metrics over
- the entire evaluation set.
- eval_batch_end_callback : function or list of function
- These will be called at the end of each minibatch during evaluation
- initializer : Initializer
- Will be called to initialize the module parameters if not already initialized.
- arg_params : dict
- Default `None`, if not `None`, should be existing parameters from a trained
- model or loaded from a checkpoint (previously saved model). In this case,
- the value here will be used to initialize the module parameters, unless they
- are already initialized by the user via a call to `init_params` or `fit`.
- `arg_params` has higher priority to `initializer`.
- aux_params : dict
- Default `None`. Similar to `arg_params`, except for auxiliary states.
- allow_missing : bool
- Default `False`. Indicate whether we allow missing parameters when `arg_params`
- and `aux_params` are not `None`. If this is `True`, then the missing parameters
- will be initialized via the `initializer`.
- force_rebind : bool
- Default `False`. Whether to force rebinding the executors if already binded.
- force_init : bool
- Default `False`. Indicate whether we should force initialization even if the
- parameters are already initialized.
- begin_epoch : int
- Default `0`. Indicate the starting epoch. Usually, if we are resuming from a
- checkpoint saved at a previous training phase at epoch N, then we should specify
- this value as N+1.
- num_epoch : int
- Number of epochs to run training.
- Examples
- --------
- An example of using fit for training::
- >>> #Assume training dataIter and validation dataIter are ready
- >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter,
- optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
- num_epoch=10)
- """
- assert num_epoch is not None, 'please specify number of epochs'
- self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
- for_training=True, force_rebind=force_rebind)
- if monitor is not None:
- self.install_monitor(monitor)
- self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
- allow_missing=allow_missing, force_init=force_init)
- self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
- optimizer_params=optimizer_params)
- if validation_metric is None:
- validation_metric = eval_metric
- if not isinstance(eval_metric, metric.EvalMetric):
- eval_metric = metric.create(eval_metric)
- #KP for tracking accuracy and early stop
- epoch_train_eval_metrics = {}
- ################################################################################
- # training loop
- ################################################################################
- for epoch in range(begin_epoch, num_epoch):
- tic = time.time()
- eval_metric.reset()
- nbatch = 0
- data_iter = iter(train_data)
- end_of_batch = False
- next_data_batch = data_iter.next()
- while not end_of_batch:
- data_batch = next_data_batch
- if monitor is not None:
- monitor.tic()
- self.forward_backward(data_batch)
- self.update()
- try:
- # pre fetch next batch
- next_data_batch = data_iter.next()
- self.prepare(next_data_batch)
- except StopIteration:
- end_of_batch = True
- self.update_metric(eval_metric, data_batch.label)
- if monitor is not None:
- monitor.toc_print()
- if batch_end_callback is not None:
- batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
- eval_metric=eval_metric,
- locals=locals())
- for callback in _as_list(batch_end_callback):
- callback(batch_end_params)
- nbatch += 1
- # one epoch of training is finished
- for name, val in eval_metric.get_name_value():
- self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
- self.build_metrics( epoch_train_eval_metrics,epoch, name, val)
- toc = time.time()
- self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
- # sync aux params across devices
- arg_params, aux_params = self.get_params()
- self.set_params(arg_params, aux_params)
- #----------------------------------------
- # evaluation on validation set
- if eval_data:
- res = self.score(eval_data, validation_metric,
- score_end_callback=eval_end_callback,
- batch_end_callback=eval_batch_end_callback, epoch=epoch)
- #TODO: pull this into default
- for name, val in res:
- self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
- self.build_metrics(epoch_train_eval_metrics, epoch, name, val)
- #KP moved these from above the eval_data call to have both train and eval data available
- if epoch_end_callback is not None:
- for callback in _as_list(epoch_end_callback):
- #KP if callback returns false then early stop
- if not callback(epoch, self.symbol, arg_params, aux_params,epoch_train_eval_metrics=epoch_train_eval_metrics):
- return
- # end of 1 epoch, reset the data-iter for another epoch
- train_data.reset()
- def build_metrics(self, epoch_train_eval_metrics,epoch, name, val):
- '''
- keep track of accuracys,
- :param epoch_train_eval_metrics:
- :param epoch:
- :param name:
- :param val:
- :return:
- '''
- if not epoch_train_eval_metrics.has_key(name): # KP add to dict
- epoch_train_eval_metrics[name] = {}
- if not epoch_train_eval_metrics[name].has_key(epoch): # KP add to dict
- epoch_train_eval_metrics[name][epoch] = []
- epoch_train_eval_metrics[name][epoch].append(val)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement