Advertisement
Guest User

Untitled

a guest
May 19th, 2019
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.20 KB | None | 0 0
  1. import functools
  2. import json
  3. import os
  4. import tensorflow as tf
  5.  
  6. from object_detection.builders import dataset_builder
  7. from object_detection.builders import graph_rewriter_builder
  8. from object_detection.builders import model_builder
  9. from object_detection.legacy import trainer
  10. from object_detection.utils import config_util
  11.  
  12. tf.logging.set_verbosity(tf.logging.INFO)
  13.  
  14. flags = tf.app.flags
  15. flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
  16. flags.DEFINE_integer('task', 0, 'task id')
  17. flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.')
  18. flags.DEFINE_boolean('clone_on_cpu', False,
  19.                      'Force clones to be deployed on CPU.  Note that even if '
  20.                      'set to False (allowing ops to run on gpu), some ops may '
  21.                      'still be run on the CPU if they have no GPU kernel.')
  22. flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer '
  23.                      'replicas.')
  24. flags.DEFINE_integer('ps_tasks', 0,
  25.                      'Number of parameter server tasks. If None, does not use '
  26.                      'a parameter server.')
  27. flags.DEFINE_string('train_dir', '',
  28.                     'Directory to save the checkpoints and training summaries.')
  29.  
  30. flags.DEFINE_string('pipeline_config_path', '',
  31.                     'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
  32.                     'file. If provided, other configs are ignored')
  33.  
  34. flags.DEFINE_string('train_config_path', '',
  35.                     'Path to a train_pb2.TrainConfig config file.')
  36. flags.DEFINE_string('input_config_path', '',
  37.                     'Path to an input_reader_pb2.InputReader config file.')
  38. flags.DEFINE_string('model_config_path', '',
  39.                     'Path to a model_pb2.DetectionModel config file.')
  40.  
  41. FLAGS = flags.FLAGS
  42.  
  43.  
  44. @tf.contrib.framework.deprecated(None, 'Use object_detection/model_main.py.')
  45. def main(_):
  46.   assert FLAGS.train_dir, '`train_dir` is missing.'
  47.   if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
  48.   if FLAGS.pipeline_config_path:
  49.     configs = config_util.get_configs_from_pipeline_file(
  50.         FLAGS.pipeline_config_path)
  51.     if FLAGS.task == 0:
  52.       tf.gfile.Copy(FLAGS.pipeline_config_path,
  53.                     os.path.join(FLAGS.train_dir, 'pipeline.config'),
  54.                     overwrite=True)
  55.   else:
  56.     configs = config_util.get_configs_from_multiple_files(
  57.         model_config_path=FLAGS.model_config_path,
  58.         train_config_path=FLAGS.train_config_path,
  59.         train_input_config_path=FLAGS.input_config_path)
  60.     if FLAGS.task == 0:
  61.       for name, config in [('model.config', FLAGS.model_config_path),
  62.                            ('train.config', FLAGS.train_config_path),
  63.                            ('input.config', FLAGS.input_config_path)]:
  64.         tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
  65.                       overwrite=True)
  66.  
  67.   model_config = configs['model']
  68.   train_config = configs['train_config']
  69.   input_config = configs['train_input_config']
  70.  
  71.   model_fn = functools.partial(
  72.       model_builder.build,
  73.       model_config=model_config,
  74.       is_training=True)
  75.  
  76.   def get_next(config):
  77.     return dataset_builder.make_initializable_iterator(
  78.         dataset_builder.build(config)).get_next()
  79.  
  80.   create_input_dict_fn = functools.partial(get_next, input_config)
  81.  
  82.   env = json.loads(os.environ.get('TF_CONFIG', '{}'))
  83.   cluster_data = env.get('cluster', None)
  84.   cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
  85.   task_data = env.get('task', None) or {'type': 'master', 'index': 0}
  86.   task_info = type('TaskSpec', (object,), task_data)
  87.  
  88.   # Parameters for a single worker.
  89.   ps_tasks = 0
  90.   worker_replicas = 1
  91.   worker_job_name = 'lonely_worker'
  92.   task = 0
  93.   is_chief = True
  94.   master = ''
  95.  
  96.   if cluster_data and 'worker' in cluster_data:
  97.     # Number of total worker replicas include "worker"s and the "master".
  98.     worker_replicas = len(cluster_data['worker']) + 1
  99.   if cluster_data and 'ps' in cluster_data:
  100.     ps_tasks = len(cluster_data['ps'])
  101.  
  102.   if worker_replicas > 1 and ps_tasks < 1:
  103.     raise ValueError('At least 1 ps task is needed for distributed training.')
  104.  
  105.   if worker_replicas >= 1 and ps_tasks > 0:
  106.     # Set up distributed training.
  107.     server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',
  108.                              job_name=task_info.type,
  109.                              task_index=task_info.index)
  110.     if task_info.type == 'ps':
  111.       server.join()
  112.       return
  113.  
  114.     worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
  115.     task = task_info.index
  116.     is_chief = (task_info.type == 'master')
  117.     master = server.target
  118.  
  119.   graph_rewriter_fn = None
  120.   if 'graph_rewriter_config' in configs:
  121.     graph_rewriter_fn = graph_rewriter_builder.build(
  122.         configs['graph_rewriter_config'], is_training=True)
  123.  
  124.   trainer.train(
  125.       create_input_dict_fn,
  126.       model_fn,
  127.       train_config,
  128.       master,
  129.       task,
  130.       FLAGS.num_clones,
  131.       worker_replicas,
  132.       FLAGS.clone_on_cpu,
  133.       ps_tasks,
  134.       worker_job_name,
  135.       is_chief,
  136.       FLAGS.train_dir,
  137.       graph_hook_fn=graph_rewriter_fn)
  138.  
  139.  
  140. if __name__ == '__main__':
  141.   tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement