Advertisement
Guest User

Untitled

a guest
Aug 27th, 2024
47
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.61 KB | None | 0 0
  1. from dataclasses import dataclass, field
  2. from os import getcwd, pardir
  3. from os.path import join
  4. from typing import Dict, Any, Tuple, Optional
  5.  
  6. import numpy as np
  7. from openpyxl.compat.singleton import Singleton
  8.  
  9.  
  10. @dataclass
  11. class Config(metaclass=Singleton):
  12. ##############
  13. # PREPROCESS #
  14. ##############
  15. N_FFT: int = 254
  16. HOP_LENGTH: int = 2 ** 7
  17. DURATION: float = 2.0
  18. SAMPLE_RATE: int = 44_100
  19. EPS: float = 1e-10
  20. SPECTROGRAMS_SAVE_DIR: str = "../datasets/drum_ai/spectrograms_db/"
  21. FILES_DIR: str = "../datasets/drum_ai/audio/"
  22. TOP_DB: float = 80.
  23. NORM_RANGE: Tuple[float, float] = (0., 1)
  24. DB_RANGE: Tuple[float, float] = (-TOP_DB, 0)
  25. SPECTROGRAM_SHAPE: Tuple[int, int, int] = (128, 690, 1)
  26. EMPTY_SPEC: np.ndarray = np.zeros(SPECTROGRAM_SHAPE)
  27.  
  28. ###############
  29. # AUTOENCODER #
  30. ###############
  31. IS_VARIATIONAL: bool = True
  32. RECONSTRUCTION_LOSS_WEIGHT: float = 1e-7
  33. NUM_HIDDEN_CONV_LAYERS: int = 6
  34. FILTER_SIZE_CONV_LAYERS: int = 4
  35. LATENT_SPACE_DIM: int = 8
  36. CONV_PARAMS_MIDDLE: Dict[str, Any] = field(default_factory=lambda: {
  37. 'kernel_size': (6, 3),
  38. 'strides': (2, 1),
  39. 'padding': 'same',
  40. 'activation': 'elu',
  41. 'kernel_initializer': 'glorot_normal',
  42. })
  43. CONV_PARAMS_END: Dict[str, Any] = field(default_factory=lambda: {
  44. 'kernel_size': 6,
  45. 'strides': 1,
  46. 'padding': 'same',
  47. 'activation': 'softplus',
  48. 'kernel_initializer': 'glorot_normal',
  49. })
  50. USE_MOCK_ENCODER: bool = False
  51. USE_MOCK_DECODER: bool = False
  52. ARCH_PLOT_ARGS: Dict[str, bool] = field(default_factory=lambda: {
  53. 'show_layer_activations': True,
  54. 'show_shapes': True,
  55. })
  56. LOGS_DIR: str = join(getcwd(), pardir, "logs")
  57. CLIP_PREDICTION: bool = True
  58. LOSS_SQUARED: bool = False
  59. LOSS_MSE_LINEAR_WEIGHT: Optional[Tuple[float, float]] = (2., 1.)
  60. USE_LSTM: bool = True
  61.  
  62. #########
  63. # TRAIN #
  64. #########
  65. VERBOSE: bool = True
  66. LEARNING_RATE: float = 1e-3
  67. BATCH_SIZE: int = 2 ** 6
  68. PRETRAIN_EPOCHS: int = 0
  69. EPOCHS: int = 2000
  70. SPECTROGRAMS_PATH: str = "../datasets/drum_ai/spectrograms_db/"
  71. PRETRAIN_DATA: np.ndarray = \
  72. np.tile(np.linspace(NORM_RANGE[1], NORM_RANGE[0], SPECTROGRAM_SHAPE[0])[..., None], SPECTROGRAM_SHAPE[1]).T[None, ..., None]
  73.  
  74. ############
  75. # GENERATE #
  76. ############
  77. SAVE_DIR_ORIGINAL: str = "../output/samples/original/"
  78. SAVE_DIR_GENERATED: str = "../output/samples/generated/"
  79. NORMALIZATION_FACTOR: float = 2 ** 16
  80.  
  81.  
  82. cfg = Config()
  83.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement