Guest User

Untitled

a guest
Mar 21st, 2018
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.22 KB | None | 0 0
  1. from lib.nasnet.nasnet import build_nasnet_mobile
  2.  
  3. def model_fn(features, labels, mode, params):
  4. ...
  5. # build model (based on tf.slim)
  6. net_out, cells_out = build_nasnet_mobile(
  7. features, 2, is_training=mode == tf.estimator.ModeKeys.TRAIN)
  8.  
  9. predictions = ...
  10. if mode == tf.estimator.ModeKeys.PREDICT:
  11. return tf.estimator.EstimatorSpec(mode=mode,
  12. predictions=predictions)
  13.  
  14. loss = ...
  15.  
  16. optimizer = tf.train.AdamOptimizer()
  17. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  18. with tf.control_dependencies(update_ops):
  19. train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
  20.  
  21. return tf.estimator.EstimatorSpec(loss=loss,
  22. train_op=train_op,
  23. mode=mode)
  24.  
  25.  
  26. def main():
  27. ...
  28. session_config = tf.ConfigProto()
  29. session_config.gpu_options.allow_growth = True
  30. session_config.allow_soft_placement = True
  31.  
  32. config = tf.estimator.RunConfig(session_config=session_config)
  33. estimator = tf.estimator.Estimator(model_fn=model_fn,
  34. model_dir=model_dir,
  35. config=config)
Add Comment
Please, Sign In to add comment