Advertisement
Guest User

Untitled

a guest
Jan 20th, 2017
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.69 KB | None | 0 0
  1. #!/usr/bin/env python
  2. from __future__ import print_function
  3.  
  4. import numpy as np
  5. import PIL.Image
  6. import argparse
  7. import tensorflow as tf
  8. import os
  9. import sys
  10. import shutil
  11. import gc
  12. BLACK = -2
  13. WHITE = 2
  14. IMAGE_PREFIX = 'img_'
  15.  
  16. def in_circle(x, y):
  17. dx = x - 358
  18. dy = y - 358
  19. r2 = dx * dx + dy * dy
  20. return r2 < 2
  21.  
  22.  
  23. def save_image(img, fn):
  24. a = np.uint8(np.clip(img, 0, 1) * 255)
  25. PIL.Image.fromarray(a).save(open(fn, 'wb'))
  26.  
  27.  
  28. def T(graph, layer):
  29. '''Helper for getting layer output tensor'''
  30. return graph.get_tensor_by_name("import/%s:0"%layer)
  31.  
  32.  
  33. def update_image(g, img1):
  34. g1 = np.mean(g, 2)
  35. todo = sorted(((g1[x][y], x, y) for x in range(600) for y in range(600)), key=lambda t: abs(t[0]))
  36. flips = 0
  37. while flips < 8999 and todo:
  38. score, x, y = todo.pop()
  39. if score > 0 and img1[x][y][0] == BLACK:
  40. img1[x][y] = [WHITE, WHITE, WHITE]
  41. flips += 1
  42. elif score < 0 and img1[x][y][0] == WHITE:
  43. img1[x][y] = [BLACK, BLACK, BLACK]
  44. flips += 1
  45. return flips
  46.  
  47.  
  48. def parse_channel(channel_def):
  49. for chunk in channel_def.split(','):
  50. chunk = chunk.strip()
  51. if '-' in chunk:
  52. begin, end = (int(x) for x in chunk.split('-'))
  53. for i in range(begin, end + 1):
  54. yield i
  55. else:
  56. yield int(chunk)
  57.  
  58. def create_drawing(model_fn, layer, channel, drawings):
  59. graph = tf.Graph()
  60. sess = tf.InteractiveSession(graph=graph)
  61. with tf.gfile.FastGFile(model_fn, 'rb') as f:
  62. graph_def = tf.GraphDef()
  63. graph_def.ParseFromString(f.read())
  64. t_input = tf.placeholder(np.float32, name='input') # define the input tensor
  65. imagenet_mean = 117.0
  66. t_preprocessed = tf.expand_dims(t_input - imagenet_mean, 0)
  67. tf.import_graph_def(graph_def, {'input': t_preprocessed})
  68.  
  69. t_obj = T(graph, layer)[:,channel]
  70. t_score = tf.reduce_mean(t_obj)
  71. t_grad = tf.gradients(t_score, t_input)[0]
  72.  
  73. checkpoint_path = os.path.join(drawings, 'checkpoints')
  74. if os.path.isdir(checkpoint_path):
  75. shutil.rmtree(checkpoint_path)
  76. os.mkdir(checkpoint_path)
  77. image = np.array([[([BLACK, BLACK, BLACK] if in_circle(x, y) else [WHITE, WHITE, WHITE]) for x in range(600)] for y in range(600)])
  78.  
  79.  
  80. score = 0
  81. for index in range(0, 1001):
  82. if index > 0:
  83. g, score = sess.run([t_grad, t_score], {t_input:image})
  84. update_image(g, image)
  85. gc.collect()
  86. save_image(image, 'dog.jpg')
  87.  
  88. if index % 25 == 0:
  89. print(index, score)
  90. g, score = sess.run([t_grad, t_score], {t_input:image})
  91. update_image(g, image)
  92. save_image(image, 'dog.jpg')
  93. gc.collect()
  94. def path_for_index(index, path):
  95. return os.path.join(path, '%s%d.png' % (IMAGE_PREFIX, index))
  96.  
  97.  
  98. if __name__ == '__main__':
  99. parser = argparse.ArgumentParser(description='Doodle images')
  100. parser.add_argument('--model_path', type=str, default='tensorflow_inception_graph.pb',
  101. help='Where the unpacked model dump is')
  102. parser.add_argument('--layer', type=str, default='output2',
  103. help='Which layer to pick the feature from')
  104. parser.add_argument('--channel', type=str, default='130',
  105. help='Which channel contains the feature')
  106. parser.add_argument('--drawings', type=str, default='drawings',
  107. help='Where to store the drawings')
  108. args = parser.parse_args()
  109.  
  110. if not os.path.isfile(args.model_path):
  111. print('%s is not a file. You can download inception from:' % args.model_path)
  112. print('https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip')
  113. sys.exit(1)
  114.  
  115. for channel in parse_channel(args.channel):
  116. create_drawing(args.model_path, args.layer, channel, args.drawings)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement