SHARE
TWEET

Untitled

a guest Jun 19th, 2019 62 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import sys, getopt
  2. import os
  3. import pdb
  4.  
  5. import tensorflow as tf
  6.  
  7. usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \
  8.             '--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run ' \
  9.             '--output_dir=dir/to/output'
  10.  
  11.  
  12. def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run, output_dir):
  13.     checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
  14.     with tf.Session() as sess:
  15.         for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
  16.             # Load the variable
  17.             var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
  18.  
  19.             # Set the new name
  20.             new_name = var_name
  21.             if None not in [replace_from, replace_to]:
  22.                 new_name = new_name.replace(replace_from, replace_to)
  23.             if add_prefix:
  24.                 new_name = add_prefix + new_name
  25.             if dry_run:
  26.                 if var_name != new_name:
  27.                     print('%s would be renamed to %s.' % (var_name, new_name))
  28.                 else:
  29.                     print('%s would not change.' % (var_name))
  30.             else:
  31.                 if var_name != new_name:
  32.                     print('Renaming %s to %s.' % (var_name, new_name))
  33.                     # Rename the variable
  34.                 var = tf.Variable(var, name=new_name)
  35.  
  36.         if not dry_run:
  37.             # Save the variables
  38.             saver = tf.train.Saver()
  39.             sess.run(tf.global_variables_initializer())
  40.             out_file = os.path.join(output_dir, os.path.split(checkpoint.model_checkpoint_path)[-1])
  41.             saver.save(sess, out_file)
  42.             print("Model save to %s", out_file)
  43.  
  44.  
  45. def main(argv):
  46.     checkpoint_dir = None
  47.     replace_from = None
  48.     replace_to = None
  49.     add_prefix = None
  50.     dry_run = False
  51.  
  52.     try:
  53.         opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=',
  54.                                                'replace_to=', 'add_prefix=', 'dry_run', 'output_dir='])
  55.     except getopt.GetoptError:
  56.         print(usage_str)
  57.         sys.exit(2)
  58.     for opt, arg in opts:
  59.         if opt in ('-h', '--help'):
  60.             print(usage_str)
  61.             sys.exit()
  62.         elif opt == '--checkpoint_dir':
  63.             checkpoint_dir = arg
  64.         elif opt == '--replace_from':
  65.             replace_from = arg
  66.         elif opt == '--replace_to':
  67.             replace_to = arg
  68.         elif opt == '--add_prefix':
  69.             add_prefix = arg
  70.         elif opt == '--dry_run':
  71.             dry_run = True
  72.         elif opt == '--output_dir':
  73.             output_dir = arg
  74.  
  75.     if not checkpoint_dir:
  76.         print('Please specify a checkpoint_dir. Usage:')
  77.         print(usage_str)
  78.         sys.exit(2)
  79.  
  80.     rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run, output_dir)
  81.  
  82.  
  83. if __name__ == '__main__':
  84.     main(sys.argv[1:])
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top