Advertisement
Guest User

Untitled

a guest
Aug 27th, 2024
56
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 13.48 KB | None | 0 0
  1. from __future__ import annotations
  2.  
  3. import pickle
  4. from datetime import datetime
  5. from os import makedirs
  6. from os.path import join
  7. from typing import Optional, Tuple, Dict, Any
  8.  
  9. import numpy as np
  10. import tensorflow as tf
  11. from tensorflow.keras import Model
  12. from tensorflow.keras import backend as K
  13. from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Reshape, Conv2DTranspose, Lambda, LSTM
  14. from tensorflow.keras.optimizers import Adam
  15. from tensorflow.python.keras import losses, metrics
  16. from tensorflow.python.keras.callbacks import TensorBoard, History, EarlyStopping
  17. from tensorflow.python.keras.engine.base_layer import Layer
  18. from tensorflow.python.keras.losses import LossFunctionWrapper
  19. from tensorflow.python.keras.utils.vis_utils import plot_model
  20.  
  21. from config import cfg
  22.  
  23. tf.compat.v1.enable_eager_execution()
  24.  
  25. # NOTE: Can be upgraded to introduce Kapre integration.
  26.  
  27. class VAE:
  28. """
  29. VAE represents a Deep Convolutional variational autoencoder architecture
  30. with mirrored encoder and decoder components.
  31. """
  32.  
  33. def __init__(self,
  34. input_shape: Tuple[int, int, int] = cfg.SPECTROGRAM_SHAPE,
  35. latent_space_dim: int = cfg.LATENT_SPACE_DIM,
  36. is_variational: bool = cfg.IS_VARIATIONAL,
  37. reconstruction_loss_weight: float = cfg.RECONSTRUCTION_LOSS_WEIGHT,
  38. num_hidden_conv_layers: int = cfg.NUM_HIDDEN_CONV_LAYERS,
  39. filter_size_conv_layers: int = cfg.FILTER_SIZE_CONV_LAYERS,
  40. conv_params_middle: Optional[Dict[str, Any]] = None,
  41. conv_params_end: Optional[Dict[str, Any]] = None,
  42. use_mock_encoder: bool = cfg.USE_MOCK_ENCODER,
  43. use_mock_decoder: bool = cfg.USE_MOCK_DECODER,
  44. clip_prediction: bool = cfg.CLIP_PREDICTION,
  45. loss_squared: bool = cfg.LOSS_SQUARED,
  46. loss_mse_linear_weight: Optional[Tuple[float, float]] = cfg.LOSS_MSE_LINEAR_WEIGHT,
  47. use_lstm: bool = cfg.USE_LSTM,
  48. ):
  49.  
  50. if conv_params_middle is None:
  51. conv_params_middle = cfg.CONV_PARAMS_MIDDLE
  52. if conv_params_end is None:
  53. conv_params_end = cfg.CONV_PARAMS_END
  54. assert len(input_shape) == 3
  55.  
  56. # Input parameters
  57. self.input_shape: Tuple[int, int, int] = input_shape
  58. self.latent_space_dim: int = latent_space_dim
  59. self._is_variational: bool = is_variational
  60. self._reconstruction_loss_weight: float = reconstruction_loss_weight if self._is_variational else 0.
  61. self._num_hidden_conv_layers: int = num_hidden_conv_layers
  62. self._filter_size_conv_layers: int = filter_size_conv_layers
  63. self._conv_params_middle: Dict[str, Any] = conv_params_middle
  64. self._conv_params_end: Dict[str, Any] = conv_params_end
  65. self._use_mock_encoder: bool = use_mock_encoder
  66. self._use_mock_decoder: bool = use_mock_decoder
  67. self._clip_prediction: bool = clip_prediction
  68. self._loss_squared: bool = loss_squared
  69. self._loss_mse_linear_weight: Optional[Tuple[float, float]] = loss_mse_linear_weight
  70. self._loss_mse_weight_factor: np.ndarray = \
  71. np.linspace(self._loss_mse_linear_weight[0], self._loss_mse_linear_weight[1], self.input_shape[0])[
  72. None, ..., None, None] if self._loss_mse_linear_weight is not None else np.ones((1, *self.input_shape))
  73. self._use_lstm: bool = use_lstm
  74.  
  75. # Build model
  76. self.encoder: Optional[Model] = None
  77. self.decoder: Optional[Model] = None
  78. self.model: Optional[Model] = None
  79. self._encoder_input_layer: Input = None
  80. self._loss_mse: LossFunctionWrapper = losses.MeanSquaredError()
  81. self._loss_kl: LossFunctionWrapper = losses.KLDivergence()
  82. self._build()
  83.  
  84. # Set TensorBoard
  85. self._log_dir: str = join(cfg.LOGS_DIR, f"model_{datetime.now().strftime('%y%m%d')}")
  86. self._tensorboard_callback: TensorBoard = tf.keras.callbacks.TensorBoard(log_dir=self._log_dir,
  87. histogram_freq=1)
  88.  
  89. def summary(self) -> None:
  90.  
  91. self.model.summary(line_length=100)
  92. print()
  93. self.encoder.summary(line_length=100)
  94. print()
  95. self.decoder.summary(line_length=100)
  96. print()
  97.  
  98. def compile(self, learning_rate: float) -> None:
  99. if cfg.VERBOSE:
  100. print(f"Compiling VAE with {learning_rate=}")
  101. self.model.compile(optimizer=Adam(learning_rate=learning_rate),
  102. loss=self._calculate_combined_loss,
  103. metrics=[metrics.MeanSquaredError(),
  104. metrics.KLDivergence()])
  105.  
  106. def _calculate_combined_loss(self, y_target: np.ndarray, y_predicted: np.ndarray) -> float:
  107.  
  108. if self._loss_squared:
  109. y_target = tf.math.square(y_target)
  110. y_predicted = tf.math.square(y_predicted)
  111. y_target = y_target * self._loss_mse_weight_factor
  112. y_predicted = y_predicted * self._loss_mse_weight_factor
  113.  
  114. return self._loss_mse(y_target, y_predicted) + \
  115. self._loss_kl(y_target, y_predicted) * self._reconstruction_loss_weight
  116.  
  117. def train(self, x_train: np.ndarray, batch_size: int, num_epochs: int, should_callback: bool = True) -> History:
  118. early_stopping = EarlyStopping(monitor='val_loss',
  119. patience=10,
  120. restore_best_weights=True)
  121. callbacks = [self._tensorboard_callback, early_stopping] if should_callback else [early_stopping]
  122.  
  123. history = self.model.fit(x=x_train,
  124. y=x_train,
  125. batch_size=batch_size,
  126. epochs=num_epochs,
  127. shuffle=True,
  128. callbacks=callbacks,
  129. verbose=cfg.VERBOSE,
  130. validation_split=0.2 # Using 20% of the data for validation
  131. )
  132. return history
  133.  
  134. def reconstruct(self, images: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  135. assert images.shape[1:] == self.input_shape
  136. latent_representations = self.encoder.predict(images)
  137. reconstructed_images = self.decoder.predict(latent_representations)
  138. if self._clip_prediction:
  139. reconstructed_images = np.clip(reconstructed_images, a_min=cfg.NORM_RANGE[0], a_max=cfg.NORM_RANGE[1])
  140. return reconstructed_images, latent_representations
  141.  
  142. def save(self) -> None:
  143.  
  144. print(f"tensorboard --logdir {self._tensorboard_callback.log_dir}")
  145. print(f"Model saved to {self._log_dir}")
  146.  
  147. makedirs(self._log_dir, exist_ok=True)
  148. # Save weights
  149. self.model.save_weights(join(self._log_dir, "weights.weights.h5"))
  150. # Save parameters
  151. self._save_parameters()
  152.  
  153. def _save_parameters(self):
  154. parameters = {
  155. 'input_shape': self.input_shape,
  156. 'latent_space_dim': self.latent_space_dim,
  157. 'is_variational': self._is_variational,
  158. 'reconstruction_loss_weight': self._reconstruction_loss_weight,
  159. 'num_hidden_conv_layers': self._num_hidden_conv_layers,
  160. 'filter_size_conv_layers': self._filter_size_conv_layers,
  161. 'conv_params_middle': self._conv_params_middle,
  162. 'conv_params_end': self._conv_params_end,
  163. 'use_mock_encoder': self._use_mock_encoder,
  164. 'use_mock_decoder': self._use_mock_decoder,
  165. 'clip_prediction': self._clip_prediction,
  166. 'loss_squared': self._loss_squared,
  167. 'loss_mse_linear_weight': self._loss_mse_linear_weight,
  168. 'use_lstm': self._use_lstm,
  169. }
  170.  
  171. with open(join(self._log_dir, "parameters.pkl"), "wb") as f:
  172. pickle.dump(parameters, f)
  173.  
  174. @classmethod
  175. def load(cls, save_folder: str) -> VAE:
  176. save_folder = join(cfg.LOGS_DIR, save_folder)
  177.  
  178. with open(join(save_folder, "parameters.pkl"), "rb") as f:
  179. parameters = pickle.load(f)
  180. vae = VAE(**parameters)
  181.  
  182. weights_path = join(save_folder, "weights.weights.h5")
  183. vae.model.load_weights(weights_path)
  184. print(f"Loaded VAE from {save_folder}")
  185. return vae
  186.  
  187. def _build(self) -> None:
  188. self._build_encoder()
  189. self._build_decoder()
  190. self._build_autoencoder()
  191.  
  192. def _build_autoencoder(self) -> None:
  193. decoder_output = self.decoder(self.encoder(self._encoder_input_layer))
  194. self.model = Model(self._encoder_input_layer, decoder_output, name="vae")
  195.  
  196. def _build_encoder(self) -> None:
  197. self._encoder_input_layer = Input(shape=self.input_shape, name='encoder_input')
  198.  
  199. # Convolution layers
  200. x = Conv2D(
  201. filters=self._filter_size_conv_layers,
  202. **self._conv_params_middle,
  203. name="encoder_conv2d_1")(self._encoder_input_layer)
  204. for i in range(self._num_hidden_conv_layers - 1):
  205. x = Conv2D(
  206. filters=self._filter_size_conv_layers,
  207. **self._conv_params_middle,
  208. name=f"encoder_conv2d_{i + 2}")(x)
  209.  
  210. self._smallest_convolution_shape = x.shape[1:] # Calculate value for the decoder to use
  211.  
  212. if self._use_mock_encoder:
  213. x = Flatten(name="mock_encoder_flatten")(self._encoder_input_layer)
  214. bottleneck = Dense(self.latent_space_dim, trainable=False, name="mock_encoder_dense")(x)
  215. self.encoder = Model(self._encoder_input_layer, bottleneck, name="mock_encoder")
  216.  
  217. else:
  218. if self._use_lstm:
  219. if x.shape[1] > 1:
  220. x = Conv2D(
  221. filters=self._filter_size_conv_layers,
  222. kernel_size=(x.shape[1], 1),
  223. strides=(x.shape[1], 1),
  224. padding=self._conv_params_middle['padding'],
  225. activation=self._conv_params_middle['activation'],
  226. kernel_initializer=self._conv_params_middle['kernel_initializer'],
  227. name=f"encoder_conv2d_last")(x)
  228. x = Reshape(x.shape[2:], name="encoder_reshape_lstm")(x)
  229. x = LSTM(128, return_sequences=False, name='encoder_lstm')(x)
  230. x = Flatten(name="encoder_flatten_1")(x)
  231. bottleneck = self._add_bottleneck(x)
  232. self.encoder = Model(self._encoder_input_layer, bottleneck, name="encoder")
  233.  
  234. def _add_bottleneck(self, x: Layer) -> Layer:
  235. """Flatten data and add bottleneck with Gaussian sampling (Dense layer)."""
  236. if self._is_variational:
  237. # Implement a VARIATIONAL autoencoder
  238. self.mu = Dense(self.latent_space_dim, name="encoder_mu")(x)
  239. self.log_variance = Dense(self.latent_space_dim, name="encoder_log_variance")(x)
  240.  
  241. def _sample_point_from_normal_distribution(args):
  242. mu, log_variance = args
  243. epsilon = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.)
  244. sampled_point = mu + K.exp(log_variance / 2) * epsilon
  245. return sampled_point
  246.  
  247. x = Lambda(_sample_point_from_normal_distribution, name="encoder_output")([self.mu, self.log_variance])
  248.  
  249. else:
  250. # Implement a NON-VARIATIONAL autoencoder
  251. x = Dense(self.latent_space_dim, name="encoder_output")(x)
  252.  
  253. return x
  254.  
  255. def _build_decoder(self) -> None:
  256. decoder_input = Input(shape=(self.latent_space_dim,), name='decoder_input')
  257. if self._use_mock_decoder:
  258. x = Dense(np.product(self.input_shape), trainable=False, name="mock_decoder_dense")(decoder_input)
  259. x = Reshape(self.input_shape, name="mock_decoder_reshape")(x)
  260. self.decoder = Model(decoder_input, x, name="mock_decoder")
  261. else:
  262. x = Dense(128)(decoder_input)
  263. x = Dense(np.product(self._smallest_convolution_shape), name="decoder_dense_1")(x)
  264. x = Reshape(self._smallest_convolution_shape, name="decoder_reshape_1")(x)
  265.  
  266. # Convolution layers
  267.  
  268. for i in range(self._num_hidden_conv_layers-1):
  269. x = Conv2DTranspose(
  270. filters=self._filter_size_conv_layers,
  271. **self._conv_params_middle,
  272. name=f"decoder_conv2d_t_{i + 1}")(x)
  273. x = Conv2DTranspose(
  274. filters=self.input_shape[-1],
  275. **self._conv_params_middle,
  276. name=f"decoder_conv2d_t_{self._num_hidden_conv_layers}")(x)
  277.  
  278. x = Conv2D(
  279. filters=self.input_shape[-1],
  280. **self._conv_params_end,
  281. name="decoder_output")(x)
  282. self.decoder = Model(decoder_input, x, name="decoder")
  283.  
  284. def _plot_architecture(self):
  285. plot_model(self.model, to_file='ae_model.png', **cfg.ARCH_PLOT_ARGS)
  286. plot_model(self.model, to_file='ae_nested_model.png', expand_nested=False, **cfg.ARCH_PLOT_ARGS)
  287. plot_model(self.encoder, to_file='encoder_model.png', **cfg.ARCH_PLOT_ARGS)
  288. plot_model(self.decoder, to_file='decoder_model.png', **cfg.ARCH_PLOT_ARGS)
  289.  
  290.  
  291. if __name__ == '__main__':
  292. autoencoder = VAE(
  293. input_shape=cfg.SPECTROGRAM_SHAPE, # (Frequency, Time, Channels),
  294. )
  295. autoencoder.summary()
  296.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement