Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import commentjson
- import os
- import sys
- import collections
- import numpy as np
- import scipy as scp
- import scipy.misc
- import tensorflow as tf
- sys.path.insert(1,'incl')
- try:
- # Check whether setup was done correctly
- import tensorvision.utils as tv_utils
- import tensorvision.core as core
- import tensorvision.train
- import tensorflow_fcn.utils
- except ImportError:
- # You forgot to initialize submodules
- logging.error("Could not import the submodules.")
- logging.error("Please execute:"
- "'git submodule update --init --recursive'")
- exit(1)
- # load the network from the working directory
- # for some reason I can't get this to work without
- # using tensorvision - TODO figure this out so we can
- # use any checkpoint from the run directory
- logdir = 'CHANGE_THIS'
- image_pl = tf.placeholder(tf.float32)
- hypes = tv_utils.load_hypes_from_logdir(logdir)
- modules = tv_utils.load_modules_from_logdir(logdir)
- image_pl = tf.placeholder(tf.float32)
- image = tf.expand_dims(image_pl, 0)
- pred = core.build_inference_graph(hypes, modules,
- image=image)
- # load the weights
- sess = tf.Session()
- saver = tf.train.Saver()
- core.load_weights(logdir, sess, saver)
- # freeze the graph
- frozen_graph_def = tf.graph_util.convert_variables_to_constants(
- sess, sess.graph.as_graph_def(), ['Validation/decoder/Softmax'])
- # remove training only nodes
- frozen_graph_def = tf.graph_util.remove_training_nodes(frozen_graph_def)
- # save the model
- with tf.gfile.GFile('test_frozen_model.pb', 'wb') as f:
- f.write(frozen_graph_def.SerializeToString())
Add Comment
Please, Sign In to add comment