Advertisement
Guest User

Untitled

a guest
Jan 19th, 2019
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.45 KB | None | 0 0
  1. #KL-divergence annealingのメモ(先にこちらを実行してから上のコードを実行する事)
  2. from keras import backend as K
  3.  
  4. #hp_lambdaはkl_lossの係数として用意する。
  5. hp_lambda = K.variable(0) # default values
  6.  
  7. from keras import callbacks
  8.  
  9. class AneelingCallback(callbacks.Callback): #callbacks.Callbackはhttps://keras.io/ja/callbacks/#callbackを参照したい。
  10. '''Aneeling theano shared variable.
  11. # Arguments
  12. schedule(関数): a function that takes an epoch index as input
  13. (integer, indexed from 0) and returns a new
  14. learning rate as output (float).
  15. '''
  16. def __init__(self, schedule, variable):
  17. super(AneelingCallback, self).__init__()
  18. self.schedule = schedule
  19. self.variable = variable #hp_lambdaにあたるもの
  20.  
  21. def on_epoch_begin(self, epoch, logs={}):
  22. assert hasattr(self.model.optimizer, 'lr'),
  23. 'Optimizer must have a "lr" attribute.'
  24. value = self.schedule(epoch)
  25. assert type(value) == float, 'The output of the "schedule" function should be float.'
  26. K.set_value(self.variable, value) #上のvalueで得た値をvariableにセットする。
  27.      print(K.eval(self.variable)) #後で消す行。kl_lossの係数をepochごとに知りたい...。
  28.  
  29. def schedule(epoch):
  30. return 0.5 * epoch
  31.  
  32. aneeling_callback = AneelingCallback(schedule, hp_lambda)
  33.  
  34. from __future__ import absolute_import
  35. from __future__ import division
  36. from __future__ import print_function
  37.  
  38. from keras.layers import Lambda
  39. from keras.losses import mse, binary_crossentropy
  40. from keras.utils import plot_model
  41.  
  42. import matplotlib.pyplot as plt
  43. import argparse
  44. import os
  45.  
  46.  
  47. # reparameterization trick
  48. # instead of sampling from Q(z|X), sample eps = N(0,I)
  49. # z = z_mean + sqrt(var)*eps
  50. def sampling(args):
  51. """Reparameterization trick by sampling fr an isotropic unit Gaussian.
  52. # Arguments:
  53. args (tensor): mean and log of variance of Q(z|X)
  54. # Returns:
  55. z (tensor): sampled latent vector
  56. """
  57.  
  58. z_mean, z_log_var = args
  59. batch = K.shape(z_mean)[0]
  60. dim = K.int_shape(z_mean)[1]
  61. # by default, random_normal has mean=0 and std=1.0
  62. epsilon = K.random_normal(shape=(batch, dim))
  63. return z_mean + K.exp(0.5 * z_log_var) * epsilon
  64. #関数定義終わり
  65.  
  66.  
  67.  
  68.  
  69. from keras.datasets import mnist
  70. import numpy as np
  71. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  72.  
  73. image_size = x_train.shape[1]
  74. original_dim = image_size * image_size
  75. x_train = np.reshape(x_train, [-1, original_dim])
  76. x_test = np.reshape(x_test, [-1, original_dim])
  77. x_train = x_train.astype('float32') / 255
  78. x_test = x_test.astype('float32') / 255
  79.  
  80. from keras.models import Model
  81. from keras.layers import Input, Dense
  82. # network parameters
  83. input_shape = (original_dim, )
  84.  
  85. intermediate_dim = 256
  86. batch_size = 8
  87. latent_dim = 128
  88. epochs = 1 #とりあえず
  89.  
  90. # VAE model = encoder + decoder
  91. # build encoder model
  92. inputs = Input(shape=input_shape, name='encoder_input')
  93. x = Dense(intermediate_dim, activation='relu')(inputs)
  94. z_mean = Dense(latent_dim, name='z_mean')(x)
  95. z_log_var = Dense(latent_dim, name='z_log_var')(x)
  96.  
  97. z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
  98.  
  99. # instantiate encoder model
  100. encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
  101.  
  102. # build decoder model
  103. latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
  104. x = Dense(intermediate_dim, activation='relu')(latent_inputs)
  105. outputs = Dense(original_dim, activation='sigmoid')(x)
  106.  
  107. # instantiate decoder model
  108. decoder = Model(latent_inputs, outputs, name='decoder')
  109.  
  110. # instantiate VAE model
  111. #encoder(inputs)[2]はzのこと
  112. outputs = decoder(encoder(inputs)[2])
  113. vae = Model(inputs, outputs, name='vae_mlp')
  114.  
  115. # VAE loss = mse_loss or xent_loss + kl_loss
  116. reconstruction_loss = binary_crossentropy(inputs,
  117. outputs)
  118. reconstruction_loss *= original_dim
  119. kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
  120. kl_loss = K.sum(kl_loss, axis=-1)
  121. kl_loss *= -aneeling_callback.variable #hp_lambda...?
  122. # VAE loss = mse_loss or xent_loss + kl_loss
  123. vae_loss = K.mean(reconstruction_loss + kl_loss)
  124. vae.add_loss(vae_loss)
  125. vae.compile(optimizer='adam')
  126.  
  127. print(vae.summary())
  128.  
  129. # train the autoencoder
  130. vae.fit(x_train,
  131. epochs=epochs,
  132. batch_size=batch_size,
  133. callbacks=[aneeling_callback],
  134. validation_data=(x_test, None))
  135.  
  136. line 25
  137. K.eval(self.variable) #後で消す行
  138. ^
  139. SyntaxError: invalid character in identifier
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement