Advertisement
Guest User

Untitled

a guest
Apr 25th, 2019
1,312
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.63 KB | None | 0 0
  1. from keras.models import Model
  2. from keras.layers import Conv2D, Input, UpSampling2D, Lambda, Layer
  3. from keras.optimizers import *
  4. from keras import backend as K
  5. from keras.applications import VGG19
  6. from ops import *
  7.  
  8. class UNet():
  9. def __init__(self):
  10. self.imgshape = (None, None, 3)
  11. self.alpha = tf.placeholder_with_default(1., shape=[], name='alpha')
  12. self.encoder = self.build_encoder()
  13. self.encoder.trainable = False
  14. self.model = self.build_model()
  15. print(self.model.summary())
  16.  
  17. def build_model(self):
  18. cinput = Input(self.imgshape, name='content_input')
  19. sinput = Input(self.imgshape, name='style_input')
  20. content_encoded = self.encoder(cinput)
  21. style_encoded = self.encoder(sinput)
  22. intermediate = Lambda(lambda x: AdaIN(x))([content_encoded, style_encoded, self.alpha])
  23. decoder = self.build_decoder()
  24. output = decoder(intermediate)
  25. return Model([cinput, sinput], output)
  26.  
  27. def build_encoder(self):
  28. vgg19_model = VGG19(include_top=False, weights='imagenet')
  29. content_layer = vgg19_model.get_layer('block4_conv1').output
  30. return Model(inputs=vgg19_model.input, outputs=content_layer, name='encoder_model')
  31.  
  32. def build_decoder(self):
  33. layers = [ # HxW / InC->OutC
  34. Conv2DReflect(256, 3, padding='valid', activation='relu'), # 32x32 / 512->256
  35. UpSampling2D(), # 32x32 -> 64x64
  36. Conv2DReflect(256, 3, padding='valid', activation='relu'), # 64x64 / 256->256
  37. Conv2DReflect(256, 3, padding='valid', activation='relu'), # 64x64 / 256->256
  38. Conv2DReflect(256, 3, padding='valid', activation='relu'), # 64x64 / 256->256
  39. Conv2DReflect(128, 3, padding='valid', activation='relu'), # 64x64 / 256->128
  40. UpSampling2D(), # 64x64 -> 128x128
  41. Conv2DReflect(128, 3, padding='valid', activation='relu'), # 128x128 / 128->128
  42. Conv2DReflect(64, 3, padding='valid', activation='relu'), # 128x128 / 128->64
  43. UpSampling2D(), # 128x128 -> 256x256
  44. Conv2DReflect(64, 3, padding='valid', activation='relu'), # 256x256 / 64->64
  45. Conv2DReflect(3, 3, padding='valid', activation=None) # 256x256 / 64->3
  46. ]
  47. input = Input((None,None,512))
  48. x = input
  49. with tf.variable_scope('decoder_vars'):
  50. for layer in layers:
  51. x = layer(x)
  52. return Model(input, x, name='decoder_model')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement