Advertisement
Guest User

Untitled

a guest
Jan 24th, 2017
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.44 KB | None | 0 0
  1. import sys, getopt
  2.  
  3. import tensorflow as tf
  4.  
  5. usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \
  6. '--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run=True'
  7.  
  8.  
  9. def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
  10. checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
  11. with tf.Session() as sess:
  12. for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
  13. # Load the variable
  14. var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
  15.  
  16. # Set the new name
  17. new_name = var_name
  18. if None not in [replace_from, replace_to]:
  19. new_name = new_name.replace(replace_from, replace_to)
  20. if add_prefix:
  21. new_name = add_prefix + new_name
  22.  
  23. if new_name == var_name:
  24. continue
  25.  
  26. if dry_run:
  27. print('%s would be renamed to %s.' % (var_name, new_name))
  28. else:
  29. print('Renaming %s to %s.' % (var_name, new_name))
  30. # Rename the variable
  31. var = tf.Variable(var, name=new_name)
  32.  
  33. if not dry_run:
  34. # Save the variables
  35. saver = tf.train.Saver()
  36. sess.run(tf.global_variables_initializer())
  37. saver.save(sess, checkpoint.model_checkpoint_path)
  38.  
  39.  
  40. def main(argv):
  41. checkpoint_dir = None
  42. replace_from = None
  43. replace_to = None
  44. add_prefix = None
  45. dry_run = False
  46.  
  47. try:
  48. opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=',
  49. 'replace_to=', 'add_prefix=', 'dry_run'])
  50. except getopt.GetoptError:
  51. print(usage_str)
  52. sys.exit(2)
  53. for opt, arg in opts:
  54. if opt in ('-h', '--help'):
  55. print(usage_str)
  56. sys.exit()
  57. elif opt == '--checkpoint_dir':
  58. checkpoint_dir = arg
  59. elif opt == '--replace_from':
  60. replace_from = arg
  61. elif opt == '--replace_to':
  62. replace_to = arg
  63. elif opt == '--add_prefix':
  64. add_prefix = arg
  65. elif opt == '--dry_run':
  66. dry_run = True
  67.  
  68. if not checkpoint_dir:
  69. print('Please specify a checkpoint_dir. Usage:')
  70. print(usage_str)
  71. sys.exit(2)
  72.  
  73. rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)
  74.  
  75.  
  76. if __name__ == '__main__':
  77. main(sys.argv[1:])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement