Advertisement
Guest User

Untitled

a guest
Jun 19th, 2019
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.84 KB | None | 0 0
  1. import sys, getopt
  2. import os
  3. import pdb
  4.  
  5. import tensorflow as tf
  6.  
  7. usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \
  8. '--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run ' \
  9. '--output_dir=dir/to/output'
  10.  
  11.  
  12. def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run, output_dir):
  13. checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
  14. with tf.Session() as sess:
  15. for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
  16. # Load the variable
  17. var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
  18.  
  19. # Set the new name
  20. new_name = var_name
  21. if None not in [replace_from, replace_to]:
  22. new_name = new_name.replace(replace_from, replace_to)
  23. if add_prefix:
  24. new_name = add_prefix + new_name
  25. if dry_run:
  26. if var_name != new_name:
  27. print('%s would be renamed to %s.' % (var_name, new_name))
  28. else:
  29. print('%s would not change.' % (var_name))
  30. else:
  31. if var_name != new_name:
  32. print('Renaming %s to %s.' % (var_name, new_name))
  33. # Rename the variable
  34. var = tf.Variable(var, name=new_name)
  35.  
  36. if not dry_run:
  37. # Save the variables
  38. saver = tf.train.Saver()
  39. sess.run(tf.global_variables_initializer())
  40. out_file = os.path.join(output_dir, os.path.split(checkpoint.model_checkpoint_path)[-1])
  41. saver.save(sess, out_file)
  42. print("Model save to %s", out_file)
  43.  
  44.  
  45. def main(argv):
  46. checkpoint_dir = None
  47. replace_from = None
  48. replace_to = None
  49. add_prefix = None
  50. dry_run = False
  51.  
  52. try:
  53. opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=',
  54. 'replace_to=', 'add_prefix=', 'dry_run', 'output_dir='])
  55. except getopt.GetoptError:
  56. print(usage_str)
  57. sys.exit(2)
  58. for opt, arg in opts:
  59. if opt in ('-h', '--help'):
  60. print(usage_str)
  61. sys.exit()
  62. elif opt == '--checkpoint_dir':
  63. checkpoint_dir = arg
  64. elif opt == '--replace_from':
  65. replace_from = arg
  66. elif opt == '--replace_to':
  67. replace_to = arg
  68. elif opt == '--add_prefix':
  69. add_prefix = arg
  70. elif opt == '--dry_run':
  71. dry_run = True
  72. elif opt == '--output_dir':
  73. output_dir = arg
  74.  
  75. if not checkpoint_dir:
  76. print('Please specify a checkpoint_dir. Usage:')
  77. print(usage_str)
  78. sys.exit(2)
  79.  
  80. rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run, output_dir)
  81.  
  82.  
  83. if __name__ == '__main__':
  84. main(sys.argv[1:])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement