Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from keras.models import Model
- from keras.layers import Conv2D, Input, UpSampling2D, Lambda, Layer
- from keras.optimizers import *
- from keras import backend as K
- from keras.applications import VGG19
- from ops import *
- class UNet():
- def __init__(self):
- self.imgshape = (None, None, 3)
- self.alpha = tf.placeholder_with_default(1., shape=[], name='alpha')
- self.encoder = self.build_encoder()
- self.encoder.trainable = False
- self.model = self.build_model()
- print(self.model.summary())
- def build_model(self):
- cinput = Input(self.imgshape, name='content_input')
- sinput = Input(self.imgshape, name='style_input')
- content_encoded = self.encoder(cinput)
- style_encoded = self.encoder(sinput)
- intermediate = Lambda(lambda x: AdaIN(x))([content_encoded, style_encoded, self.alpha])
- decoder = self.build_decoder()
- output = decoder(intermediate)
- return Model([cinput, sinput], output)
- def build_encoder(self):
- vgg19_model = VGG19(include_top=False, weights='imagenet')
- content_layer = vgg19_model.get_layer('block4_conv1').output
- return Model(inputs=vgg19_model.input, outputs=content_layer, name='encoder_model')
- def build_decoder(self):
- layers = [ # HxW / InC->OutC
- Conv2DReflect(256, 3, padding='valid', activation='relu'), # 32x32 / 512->256
- UpSampling2D(), # 32x32 -> 64x64
- Conv2DReflect(256, 3, padding='valid', activation='relu'), # 64x64 / 256->256
- Conv2DReflect(256, 3, padding='valid', activation='relu'), # 64x64 / 256->256
- Conv2DReflect(256, 3, padding='valid', activation='relu'), # 64x64 / 256->256
- Conv2DReflect(128, 3, padding='valid', activation='relu'), # 64x64 / 256->128
- UpSampling2D(), # 64x64 -> 128x128
- Conv2DReflect(128, 3, padding='valid', activation='relu'), # 128x128 / 128->128
- Conv2DReflect(64, 3, padding='valid', activation='relu'), # 128x128 / 128->64
- UpSampling2D(), # 128x128 -> 256x256
- Conv2DReflect(64, 3, padding='valid', activation='relu'), # 256x256 / 64->64
- Conv2DReflect(3, 3, padding='valid', activation=None) # 256x256 / 64->3
- ]
- input = Input((None,None,512))
- x = input
- with tf.variable_scope('decoder_vars'):
- for layer in layers:
- x = layer(x)
- return Model(input, x, name='decoder_model')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement