Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from lib.nasnet.nasnet import build_nasnet_mobile
- def model_fn(features, labels, mode, params):
- ...
- # build model (based on tf.slim)
- net_out, cells_out = build_nasnet_mobile(
- features, 2, is_training=mode == tf.estimator.ModeKeys.TRAIN)
- predictions = ...
- if mode == tf.estimator.ModeKeys.PREDICT:
- return tf.estimator.EstimatorSpec(mode=mode,
- predictions=predictions)
- loss = ...
- optimizer = tf.train.AdamOptimizer()
- update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- with tf.control_dependencies(update_ops):
- train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
- return tf.estimator.EstimatorSpec(loss=loss,
- train_op=train_op,
- mode=mode)
- def main():
- ...
- session_config = tf.ConfigProto()
- session_config.gpu_options.allow_growth = True
- session_config.allow_soft_placement = True
- config = tf.estimator.RunConfig(session_config=session_config)
- estimator = tf.estimator.Estimator(model_fn=model_fn,
- model_dir=model_dir,
- config=config)
Add Comment
Please, Sign In to add comment