Advertisement
Guest User

Untitled

a guest
Apr 18th, 2025
25
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 35.59 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torch.utils.data import Dataset, DataLoader
  5. import torchaudio
  6. import librosa
  7. import librosa.display
  8. import numpy as np
  9. import os
  10. import random
  11. import matplotlib.pyplot as plt
  12. from pathlib import Path # Use pathlib for better path handling
  13. import warnings # To suppress warnings
  14. import io       # For stderr capture
  15. import contextlib # For stderr capture
  16. import soundfile as sf # For saving audio
  17. import traceback # For detailed error printing
  18.  
  19. # --- 1. Configuration ---
  20. DATASET_PATH = Path('/kaggle/input/fma-small/fma_small/fma_small') # <--- VERIFY THIS PATH!
  21. TRAINING_OUTPUT_DIR = Path('./training_output_unet_norm') # <--- Changed output dir name
  22. TRAINING_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) # Ensure it exists
  23.  
  24. SAMPLE_RATE = 22050
  25. N_FFT = 512
  26. HOP_LENGTH = 128
  27. SPECTROGRAM_HEIGHT = 128 # Target height after resizing
  28. SPECTROGRAM_WIDTH = 128  # Target width after resizing
  29. MASK_SIZE = (32, 32)     # Size of the square mask in the resized spectrogram
  30. LOAD_DURATION = 5.0      # How many seconds of audio to load initially
  31.  
  32. BATCH_SIZE = 4
  33. EPOCHS = 100
  34. LEARNING_RATE_GENERATOR = 0.0002
  35. LEARNING_RATE_DISCRIMINATOR = 0.0001
  36. BETA_1 = 0.5
  37. CHECKPOINT_SAVE_INTERVAL = 10
  38. GRIFFIN_LIM_ITERS = 32
  39. LAMBDA_L1 = 100
  40.  
  41. # --- Normalization Parameters ---
  42. MIN_DB = -80.0
  43. MAX_DB = 0.0
  44.  
  45. # --- Device Setup ---
  46. if torch.cuda.is_available():
  47.     if torch.cuda.device_count() >= 2:
  48.         DEVICE = torch.device("cuda:0")
  49.         print(f"Using primary device: {DEVICE}")
  50.         print(f"Utilizing {torch.cuda.device_count()} GPUs with DataParallel.")
  51.         use_dataparallel = True
  52.     else:
  53.         DEVICE = torch.device("cuda:0")
  54.         print(f"Using single GPU: {DEVICE}")
  55.         use_dataparallel = False
  56. else:
  57.     DEVICE = torch.device("cpu")
  58.     print(f"Using CPU: {DEVICE}")
  59.     use_dataparallel = False
  60.  
  61. # --- 2. Dataset Class (Returns Mask) ---
  62. class AudioInpaintingDataset(Dataset):
  63.     def __init__(self, dataset_dir, sr, n_fft, hop_length, spec_height, spec_width, mask_size, load_duration, min_db=MIN_DB, max_db=MAX_DB):
  64.         self.dataset_dir = Path(dataset_dir)
  65.         self.sr = sr
  66.         self.n_fft = n_fft
  67.         self.hop_length = hop_length
  68.         self.spec_height = spec_height # Target height for network input
  69.         self.spec_width = spec_width   # Target width for network input
  70.         self.mask_h, self.mask_w = mask_size # Store mask dimensions
  71.         self.load_duration = load_duration
  72.         self.min_db = min_db
  73.         self.max_db = max_db
  74.         self.audio_files = self._find_and_validate_files()
  75.  
  76.         if not self.audio_files:
  77.             raise FileNotFoundError(f"No valid MP3 files found >= {self.load_duration}s in {self.dataset_dir}")
  78.         print(f"Dataset initialized with {len(self.audio_files)} valid audio files.")
  79.  
  80.     def _find_and_validate_files(self):
  81.         # (Validation code remains the same as before)
  82.         print(f"Searching for MP3 files in: {self.dataset_dir}")
  83.         candidate_files = list(self.dataset_dir.rglob('*.mp3'))
  84.         print(f"Found {len(candidate_files)} potential MP3 files. Starting pre-check for >= {self.load_duration}s duration...")
  85.         valid_audio_files = []
  86.         skipped_short_count = 0; skipped_error_count = 0
  87.         max_verbose_skips = 5; short_msgs_printed = 0; error_msgs_printed = 0
  88.         stderr_capture = io.StringIO()
  89.         with contextlib.redirect_stderr(stderr_capture), warnings.catch_warnings():
  90.             warnings.simplefilter("ignore", category=UserWarning); warnings.simplefilter("ignore", category=FutureWarning)
  91.             for i, file_path in enumerate(candidate_files):
  92.                 if (i + 1) % 500 == 0: print(f"  Pre-checking file {i+1}/{len(candidate_files)}...")
  93.                 try:
  94.                     actual_duration = librosa.get_duration(path=file_path)
  95.                     if actual_duration >= self.load_duration: valid_audio_files.append(file_path)
  96.                     else:
  97.                         skipped_short_count += 1
  98.                         if short_msgs_printed < max_verbose_skips: print(f"  Skipping short: {file_path} ({actual_duration:.2f}s)"); short_msgs_printed += 1
  99.                         elif short_msgs_printed == max_verbose_skips: print("  (Further short file skip messages suppressed)"); short_msgs_printed += 1
  100.                 except Exception as e:
  101.                     skipped_error_count += 1
  102.                     if error_msgs_printed < max_verbose_skips: print(f"  Skipping error: {file_path} ({type(e).__name__})"); error_msgs_printed += 1
  103.                     elif error_msgs_printed == max_verbose_skips: print("  (Further error messages suppressed)"); error_msgs_printed += 1
  104.                     continue
  105.         print("-" * 30 + f"\nPre-check complete. Found {len(valid_audio_files)} valid files. Skipped {skipped_short_count} (short), {skipped_error_count} (error).\n" + "-" * 30)
  106.         captured_stderr = stderr_capture.getvalue()
  107.         if captured_stderr: print("\n--- Captured Stderr during pre-check ---\n" + captured_stderr.strip() + "\n--------------------------------------\n")
  108.         return valid_audio_files
  109.  
  110.  
  111.     def __len__(self):
  112.         return len(self.audio_files)
  113.  
  114.     def normalize_spectrogram(self, spec_db):
  115.         scaled = (spec_db - self.min_db) / (self.max_db - self.min_db) * 2.0 - 1.0
  116.         return torch.clamp(scaled, -1.0, 1.0)
  117.  
  118.     def create_mask(self, spectrogram_shape):
  119.         """Creates the mask AND returns its coordinates."""
  120.         mask = torch.ones(spectrogram_shape)
  121.         h, w = spectrogram_shape
  122.         mask_h_actual = min(h, self.mask_h)
  123.         mask_w_actual = min(w, self.mask_w)
  124.  
  125.         if h <= 0 or w <= 0 or mask_h_actual <= 0 or mask_w_actual <= 0:
  126.             print(f"Warning: Cannot apply mask to shape {spectrogram_shape}. Returning full mask.")
  127.             return mask, (0, 0, 0, 0) # Return mask and zero coordinates
  128.  
  129.         if mask_h_actual < self.mask_h or mask_w_actual < self.mask_w:
  130.              print(f"Warning: Spectrogram shape {spectrogram_shape} too small for mask {(self.mask_h, self.mask_w)}. Applying smaller mask {(mask_h_actual, mask_w_actual)}.")
  131.  
  132.         start_row = random.randint(0, h - mask_h_actual)
  133.         start_col = random.randint(0, w - mask_w_actual)
  134.         mask[start_row : start_row + mask_h_actual, start_col : start_col + mask_w_actual] = 0.0
  135.         coords = (start_row, start_col, mask_h_actual, mask_w_actual) # (y, x, height, width)
  136.         return mask, coords
  137.  
  138.     def __getitem__(self, idx):
  139.         audio_path = self.audio_files[idx]
  140.         stderr_capture = io.StringIO()
  141.         with contextlib.redirect_stderr(stderr_capture), warnings.catch_warnings():
  142.             warnings.simplefilter("ignore", category=UserWarning); warnings.simplefilter("ignore", category=FutureWarning)
  143.             try:
  144.                 y, sr_loaded = librosa.load(audio_path, sr=self.sr, duration=self.load_duration)
  145.                 if sr_loaded != self.sr: warnings.warn(f"SR mismatch: {sr_loaded} != {self.sr}")
  146.                 if y is None or len(y) == 0: raise ValueError("Loaded audio empty")
  147.  
  148.                 target_samples = int(self.load_duration * self.sr)
  149.                 if len(y) < target_samples: y = np.pad(y, (0, target_samples - len(y)), mode='constant')
  150.                 elif len(y) > target_samples: y = y[:target_samples]
  151.  
  152.                 spectrogram = librosa.stft(y, n_fft=self.n_fft, hop_length=self.hop_length, center=True)
  153.                 spectrogram_db = librosa.amplitude_to_db(np.abs(spectrogram), ref=np.max)
  154.                 original_shape = spectrogram_db.shape # (F_orig, T_orig)
  155.  
  156.                 spectrogram_tensor = torch.tensor(spectrogram_db).unsqueeze(0) # [1, F_orig, T_orig]
  157.                 spectrogram_resized = torch.nn.functional.interpolate(
  158.                     spectrogram_tensor.unsqueeze(0), size=(self.spec_height, self.spec_width),
  159.                     mode='bilinear', align_corners=False
  160.                 ).squeeze(0) # [1, H, W]
  161.  
  162.                 target_spectrogram_normalized = self.normalize_spectrogram(spectrogram_resized) # [1, H, W]
  163.  
  164.                 # Create mask for RESIZED dimensions
  165.                 mask_hw, mask_coords = self.create_mask(target_spectrogram_normalized.shape[1:]) # Shape is (H, W)
  166.                 mask_tensor = mask_hw.unsqueeze(0) # [1, H, W]
  167.                 masked_spectrogram = target_spectrogram_normalized * mask_tensor # Apply mask [1, H, W]
  168.  
  169.                 # Return necessary items including the mask tensor and path for full reconstruction
  170.                 return masked_spectrogram, target_spectrogram_normalized, mask_tensor, original_shape, audio_path
  171.  
  172.             except Exception as e:
  173.                 print(f"WARNING: Error in __getitem__ for {audio_path}: {e}. Skipping.")
  174.                 # print(f"Traceback: {traceback.format_exc()}") # Uncomment for debugging
  175.                 item_stderr = stderr_capture.getvalue()
  176.                 if item_stderr: print(f"--- Stderr for {audio_path} ---\n{item_stderr.strip()}\n---")
  177.                 return None, None, None, None, None # Match return structure
  178.  
  179. # --- Custom Collate Function (Handles Mask Tensor and Path) ---
  180. def collate_fn_skip_none(batch):
  181.     batch = [item for item in batch if all(i is not None for i in item)] # Check all items in tuple
  182.     if not batch: return None
  183.     masked_specs = [item[0] for item in batch]
  184.     target_specs = [item[1] for item in batch]
  185.     masks = [item[2] for item in batch] # Collect mask tensors
  186.     original_shapes = [item[3] for item in batch]
  187.     audio_paths = [item[4] for item in batch] # Collect paths
  188.  
  189.     masked_specs_collated = torch.utils.data.dataloader.default_collate(masked_specs)
  190.     target_specs_collated = torch.utils.data.dataloader.default_collate(target_specs)
  191.     masks_collated = torch.utils.data.dataloader.default_collate(masks) # Collate masks
  192.  
  193.     # Return collated tensors and lists for shapes/paths
  194.     return masked_specs_collated, target_specs_collated, masks_collated, original_shapes, audio_paths
  195.  
  196.  
  197. # --- 3. Data Loader ---
  198. try:
  199.     train_dataset = AudioInpaintingDataset(
  200.         DATASET_PATH, SAMPLE_RATE, N_FFT, HOP_LENGTH,
  201.         SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH, MASK_SIZE, LOAD_DURATION,
  202.         min_db=MIN_DB, max_db=MAX_DB
  203.     )
  204.     if len(train_dataset) == 0: raise RuntimeError("Dataset is empty after validation.")
  205.     train_dataloader = DataLoader(
  206.         train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, # Reduced workers potentially
  207.         pin_memory=True if DEVICE.type == 'cuda' else False, drop_last=True,
  208.         collate_fn=collate_fn_skip_none # Use the modified collate function
  209.     )
  210.     print(f"DataLoader initialized with {len(train_dataloader)} batches.")
  211. except (FileNotFoundError, RuntimeError) as e: print(f"Error initializing dataset/dataloader: {e}"); exit()
  212. except Exception as e: print(f"Unexpected error during dataloader init: {e}"); traceback.print_exc(); exit()
  213.  
  214.  
  215. # --- 4. Generator Model (U-Net - Remains the Same) ---
  216. class UNetDown(nn.Module):
  217.     def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
  218.         super(UNetDown, self).__init__()
  219.         layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
  220.         if normalize: layers.append(nn.InstanceNorm2d(out_channels))
  221.         layers.append(nn.LeakyReLU(0.2, inplace=True))
  222.         if dropout > 0.0: layers.append(nn.Dropout(dropout))
  223.         self.model = nn.Sequential(*layers)
  224.     def forward(self, x): return self.model(x)
  225.  
  226. class UNetUp(nn.Module):
  227.     def __init__(self, in_channels, out_channels, dropout=0.0):
  228.         super(UNetUp, self).__init__()
  229.         layers = [
  230.             nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
  231.             nn.InstanceNorm2d(out_channels),
  232.             nn.ReLU(inplace=True)
  233.         ]
  234.         if dropout > 0.0: layers.append(nn.Dropout(dropout))
  235.         self.model = nn.Sequential(*layers)
  236.     def forward(self, x, skip_input):
  237.         x = self.model(x)
  238.         x = torch.cat((x, skip_input), 1)
  239.         return x
  240.  
  241. class UNetGenerator(nn.Module):
  242.     def __init__(self, in_channels=1, out_channels=1):
  243.         super(UNetGenerator, self).__init__()
  244.         self.down1=UNetDown(in_channels,64,normalize=False); self.down2=UNetDown(64,128)
  245.         self.down3=UNetDown(128,256); self.down4=UNetDown(256,512,dropout=0.2)
  246.         self.down5=UNetDown(512,512,dropout=0.2); self.down6=UNetDown(512,512,dropout=0.0)
  247.         self.down7=UNetDown(512,512,normalize=False,dropout=0.5)
  248.         self.up1=UNetUp(512,512,dropout=0.5); self.up2=UNetUp(1024,512,dropout=0.2)
  249.         self.up3=UNetUp(1024,512,dropout=0.2); self.up4=UNetUp(1024,256,dropout=0.0)
  250.         self.up5=UNetUp(512,128,dropout=0.0); self.up6=UNetUp(256,64,dropout=0.0)
  251.         self.final_up=nn.Sequential(
  252.             nn.ConvTranspose2d(128,out_channels,kernel_size=4,stride=2,padding=1), nn.Tanh())
  253.     def forward(self, x):
  254.         d1=self.down1(x);d2=self.down2(d1);d3=self.down3(d2);d4=self.down4(d3)
  255.         d5=self.down5(d4);d6=self.down6(d5);d7=self.down7(d6)
  256.         u1=self.up1(d7,d6);u2=self.up2(u1,d5);u3=self.up3(u2,d4)
  257.         u4=self.up4(u3,d3);u5=self.up5(u4,d2);u6=self.up6(u5,d1)
  258.         return self.final_up(u6)
  259.  
  260. # --- 5. Discriminator Model (PatchGAN - Remains the Same) ---
  261. class Discriminator(nn.Module):
  262.     def __init__(self, in_channels=1): # Takes 1 channel input
  263.         super(Discriminator, self).__init__()
  264.         def block(i,o,n=True): layers=[nn.Conv2d(i,o,4,2,1)]; layers.append(nn.InstanceNorm2d(o) if n else nn.Identity()); layers.append(nn.LeakyReLU(0.2,inplace=True)); return layers
  265.         self.model = nn.Sequential( *block(in_channels,64,False), *block(64,128), *block(128,256), *block(256,512), nn.Conv2d(512,1,4,1,1) )
  266.     def forward(self, x): return self.model(x)
  267.  
  268. # --- 6. Loss Functions & Optimizers (Remains the Same) ---
  269. adversarial_loss = nn.BCEWithLogitsLoss().to(DEVICE)
  270. content_loss_l1 = nn.L1Loss().to(DEVICE)
  271. lambda_l1 = LAMBDA_L1
  272.  
  273. def calculate_discriminator_loss(real_output, fake_output):
  274.     real_loss = adversarial_loss(real_output, torch.ones_like(real_output, device=DEVICE))
  275.     fake_loss = adversarial_loss(fake_output, torch.zeros_like(fake_output, device=DEVICE))
  276.     return (real_loss + fake_loss) / 2
  277.  
  278. def calculate_generator_loss(fake_output_disc, generated_spectrogram, target_spectrogram, mask_inv):
  279.     # mask_inv selects the *masked* region (where mask is 0, inv is 1)
  280.     adv_loss = adversarial_loss(fake_output_disc, torch.ones_like(fake_output_disc, device=DEVICE))
  281.     content_loss_val = content_loss_l1(generated_spectrogram * mask_inv, target_spectrogram * mask_inv)
  282.     total_loss = adv_loss + lambda_l1 * content_loss_val
  283.     return total_loss, adv_loss, content_loss_val
  284.  
  285. # --- Initialize Models & Optimizers ---
  286. generator = UNetGenerator(in_channels=1, out_channels=1).to(DEVICE)
  287. discriminator = Discriminator(in_channels=1).to(DEVICE) # Takes 1 channel
  288.  
  289. if use_dataparallel:
  290.     generator = nn.DataParallel(generator); discriminator = nn.DataParallel(discriminator)
  291.     print("Models wrapped in nn.DataParallel.")
  292.  
  293. optimizer_generator = optim.Adam(generator.parameters(), lr=LEARNING_RATE_GENERATOR, betas=(BETA_1, 0.999))
  294. optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE_DISCRIMINATOR, betas=(BETA_1, 0.999))
  295.  
  296. # --- 7. Helper Functions for Saving (Adjusted for dB input reconstruction) ---
  297.  
  298. def save_spectrogram_plot(spectrogram_data, title, filepath):
  299.     plt.figure(figsize=(6, 6)); plt.tight_layout()
  300.     if isinstance(spectrogram_data, torch.Tensor): spectrogram_data = spectrogram_data.detach().squeeze().cpu().numpy()
  301.     img = librosa.display.specshow(spectrogram_data, sr=SAMPLE_RATE, hop_length=HOP_LENGTH, x_axis='time', y_axis='linear', cmap='magma')
  302.     plt.colorbar(img, format='%+.1f dB' if 'Original' in title else '%+.2f') # Format based on type
  303.     plt.title(title)
  304.     plt.savefig(filepath); plt.close()
  305.  
  306. def denormalize_spectrogram(spec_norm, min_db=MIN_DB, max_db=MAX_DB):
  307.     scaled_01 = (spec_norm + 1.0) / 2.0
  308.     spec_db = scaled_01 * (max_db - min_db) + min_db
  309.     return spec_db
  310.  
  311. def reconstruct_audio_from_spectrogram(
  312.         spectrogram_norm=None,        # Input: Normalized spectrogram [-1, 1]
  313.         input_spec_db=None,           # Input: OR dB spectrogram directly
  314.         original_shape=None,          # Optional: Target shape (F_orig, T_orig) for resizing norm spec
  315.         n_iter=GRIFFIN_LIM_ITERS):
  316.     """
  317.    Reconstructs audio from either a NORMALIZED spectrogram (optionally resizing it)
  318.    OR directly from a provided dB spectrogram.
  319.    """
  320.     try:
  321.         if input_spec_db is not None:
  322.             # Use provided dB spectrogram directly
  323.             if isinstance(input_spec_db, torch.Tensor):
  324.                 spectrogram_db_final = input_spec_db.detach().squeeze().cpu().numpy()
  325.             else:
  326.                 spectrogram_db_final = np.squeeze(input_spec_db)
  327.             print(f"Reconstructing directly from provided dB spectrogram, shape: {spectrogram_db_final.shape}")
  328.  
  329.         elif spectrogram_norm is not None:
  330.             # Process normalized spectrogram
  331.             if isinstance(spectrogram_norm, torch.Tensor):
  332.                 spectrogram_norm_np = spectrogram_norm.detach().squeeze().cpu().numpy()
  333.             else:
  334.                 spectrogram_norm_np = np.squeeze(spectrogram_norm)
  335.  
  336.             # --- De-normalize first ---
  337.             spectrogram_db = denormalize_spectrogram(spectrogram_norm_np, MIN_DB, MAX_DB)
  338.  
  339.             # --- Optional: Resize back to original dimensions ---
  340.             if original_shape is not None:
  341.                 current_shape = spectrogram_db.shape
  342.                 target_F, target_T = original_shape
  343.                 if current_shape != original_shape:
  344.                     print(f"Resizing de-normalized spectrogram from {current_shape} to {original_shape}...")
  345.                     spec_db_tensor = torch.tensor(spectrogram_db).unsqueeze(0).unsqueeze(0)
  346.                     spec_resized_tensor = torch.nn.functional.interpolate(
  347.                         spec_db_tensor, size=(target_F, target_T), mode='bilinear', align_corners=False)
  348.                     spectrogram_db_final = spec_resized_tensor.squeeze().numpy()
  349.                 else: spectrogram_db_final = spectrogram_db # No resize needed
  350.             else: spectrogram_db_final = spectrogram_db # Use as is
  351.  
  352.             print(f"Reconstructing from de-normalized/resized spec, final shape: {spectrogram_db_final.shape}")
  353.  
  354.         else:
  355.             raise ValueError("Must provide either spectrogram_norm or input_spec_db")
  356.  
  357.         # --- Convert dB back to linear amplitude ---
  358.         spectrogram_amp = librosa.db_to_amplitude(spectrogram_db_final, ref=1.0) # ref=1.0 as we scaled dB absolutely
  359.  
  360.         # --- Estimate phase using Griffin-Lim ---
  361.         num_frames = spectrogram_amp.shape[1]
  362.         expected_length = int((num_frames) * HOP_LENGTH) # More accurate length estimate for center=True STFT
  363.  
  364.         estimated_audio = librosa.griffinlim(spectrogram_amp,
  365.                                              n_iter=n_iter,
  366.                                              hop_length=HOP_LENGTH,
  367.                                              n_fft=N_FFT,
  368.                                              length=expected_length) # Give length hint
  369.  
  370.         print(f"Reconstructed audio length: {len(estimated_audio)} samples ({len(estimated_audio)/SAMPLE_RATE:.2f}s)")
  371.         return estimated_audio
  372.  
  373.     except Exception as e:
  374.         print(f"Error during audio reconstruction: {e}")
  375.         print(f"Traceback: {traceback.format_exc()}")
  376.         # Print shapes for debugging if available
  377.         if 'spectrogram_norm_np' in locals(): print(f"Input norm spec shape: {spectrogram_norm_np.shape}")
  378.         if 'spectrogram_db' in locals(): print(f"De-norm spec shape: {spectrogram_db.shape}")
  379.         if 'spectrogram_db_final' in locals(): print(f"Final spec shape for GL: {spectrogram_db_final.shape}")
  380.         if original_shape: print(f"Target original shape: {original_shape}")
  381.         return None
  382.  
  383. def generate_and_save_samples(epoch, generator_model, test_inputs, output_dir):
  384.     """
  385.    Generates plots and saves three audio versions:
  386.    1. Short audio from direct generator output (HxW spec).
  387.    2. Long audio from resizing generator output back to original shape (F_orig x T_orig spec).
  388.    3. FULL inpainted audio by combining original + generated (F_orig x T_orig spec).
  389.    """
  390.     generator_model.eval()
  391.     # test_inputs: (masked_spec[1,1,H,W], target_spec[1,1,H,W], mask[1,1,H,W], [orig_shape], [audio_path])
  392.     test_masked, test_target_normalized, test_mask_resized, test_original_shapes, test_audio_paths = test_inputs
  393.  
  394.     # Use the first item in the batch for visualization/saving
  395.     test_original_shape = test_original_shapes[0] if test_original_shapes else None
  396.     test_audio_path = test_audio_paths[0] if test_audio_paths else None
  397.  
  398.     if not test_original_shape or not test_audio_path:
  399.         print(f"Warning: Missing original shape or audio path for test sample. Cannot generate full inpainted audio.")
  400.         # We can still generate short/resized versions if masked/generated are available
  401.         if test_masked is None: # Check if we even have input
  402.              print("Error: test_masked is None. Cannot generate any samples.")
  403.              generator_model.train()
  404.              return
  405.  
  406.     test_masked_dev = test_masked.to(DEVICE)
  407.  
  408.     with torch.no_grad():
  409.         generated_spectrogram_normalized = generator_model(test_masked_dev) # Output is normalized [1, 1, H, W]
  410.  
  411.     # --- Move results to CPU ---
  412.     test_masked_cpu = test_masked[0, 0].cpu()           # [H, W] normalized
  413.     generated_norm_cpu = generated_spectrogram_normalized[0, 0].cpu() # [H, W] normalized
  414.     test_target_norm_cpu = test_target_normalized[0, 0].cpu() # [H, W] normalized
  415.     test_mask_resized_cpu = test_mask_resized[0, 0].cpu()   # [H, W] mask (0 or 1)
  416.  
  417.     print(f"\n--- Generating Samples for Epoch {epoch} ---")
  418.     print(f"Network Output (Generated Norm HxW): {generated_norm_cpu.shape}")
  419.     print(f"Test Mask (Resized HxW): {test_mask_resized_cpu.shape}")
  420.     if test_original_shape: print(f"Original Spec Shape (F_orig x T_orig): {test_original_shape}")
  421.     if test_audio_path: print(f"Original Audio Path: {test_audio_path}")
  422.  
  423.     # --- Save Spectrogram Plots ---
  424.     save_spectrogram_plot(test_masked_cpu, f"Masked Input (Norm) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_1_masked_input_norm.png")
  425.     save_spectrogram_plot(generated_norm_cpu, f"Generated Output (Norm) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_2_generated_output_norm.png")
  426.     save_spectrogram_plot(test_target_norm_cpu, f"Target (Norm, Resized) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_3_target_resized_norm.png")
  427.     save_spectrogram_plot(test_mask_resized_cpu, f"Test Mask (Resized HxW) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_4_mask_resized.png")
  428.  
  429.  
  430.     # --- 1. Reconstruct and Save SHORT Audio (from HxW norm spec) ---
  431.     print("\n1. Reconstructing SHORT audio (from network HxW output)...")
  432.     audio_short = reconstruct_audio_from_spectrogram(
  433.         spectrogram_norm=generated_norm_cpu,
  434.         original_shape=None # Don't resize
  435.     )
  436.     if audio_short is not None:
  437.         try: sf.write(output_dir / f"epoch_{epoch:03d}_audio_short.wav", audio_short, SAMPLE_RATE)
  438.         except Exception as e: print(f"Error saving SHORT audio: {e}")
  439.     else: print("Skipping SHORT audio saving due to reconstruction error.")
  440.  
  441.     # --- 2. Reconstruct and Save LONG RESIZED Audio (Resize HxW norm spec to F_orig x T_orig) ---
  442.     if test_original_shape:
  443.         print("\n2. Reconstructing LONG RESIZED audio (resizing HxW output to original shape)...")
  444.         audio_long_resized = reconstruct_audio_from_spectrogram(
  445.             spectrogram_norm=generated_norm_cpu,
  446.             original_shape=test_original_shape # Resize back
  447.         )
  448.         if audio_long_resized is not None:
  449.             try: sf.write(output_dir / f"epoch_{epoch:03d}_audio_long_resized.wav", audio_long_resized, SAMPLE_RATE)
  450.             except Exception as e: print(f"Error saving LONG RESIZED audio: {e}")
  451.         else: print("Skipping LONG RESIZED audio saving due to reconstruction error.")
  452.     else: print("Skipping LONG RESIZED audio reconstruction: original shape missing.")
  453.  
  454.  
  455.     # --- 3. Reconstruct and Save FULL INPAINTED Audio ---
  456.     if test_original_shape and test_audio_path:
  457.         print("\n3. Reconstructing FULL INPAINTED audio (combining original and generated)...")
  458.         try:
  459.             # --- a) Reload original audio and compute its full dB spectrogram ---
  460.             print(f"   Reloading original audio from: {test_audio_path}")
  461.             y_orig, _ = librosa.load(test_audio_path, sr=SAMPLE_RATE, duration=LOAD_DURATION)
  462.             target_samples = int(LOAD_DURATION * SAMPLE_RATE)
  463.             if len(y_orig) < target_samples: y_orig = np.pad(y_orig, (0, target_samples - len(y_orig)), mode='constant')
  464.             elif len(y_orig) > target_samples: y_orig = y_orig[:target_samples]
  465.  
  466.             spec_orig = librosa.stft(y_orig, n_fft=N_FFT, hop_length=HOP_LENGTH, center=True)
  467.             spec_orig_db = librosa.amplitude_to_db(np.abs(spec_orig), ref=np.max)
  468.             spec_orig_db_tensor = torch.tensor(spec_orig_db) # Keep as tensor [F_orig, T_orig]
  469.             save_spectrogram_plot(spec_orig_db_tensor, f"Original Full Spec (dB) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_5_original_full_spec_db.png")
  470.  
  471.  
  472.             # --- b) De-normalize and resize generated HxW spectrogram to original shape ---
  473.             print("   De-normalizing and resizing generated HxW spectrogram...")
  474.             generated_db_cpu = denormalize_spectrogram(generated_norm_cpu) # [H, W] dB
  475.             generated_db_tensor = generated_db_cpu.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
  476.             generated_db_resized_tensor = torch.nn.functional.interpolate(
  477.                 generated_db_tensor, size=test_original_shape, mode='bilinear', align_corners=False
  478.             ).squeeze() # [F_orig, T_orig] dB
  479.             save_spectrogram_plot(generated_db_resized_tensor, f"Generated Resized Spec (dB) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_6_generated_resized_spec_db.png")
  480.  
  481.  
  482.             # --- c) Resize the HxW mask to original shape (use nearest neighbor) ---
  483.             print("   Resizing HxW mask to original shape...")
  484.             mask_resized_tensor = test_mask_resized_cpu.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
  485.             mask_orig_shape_tensor = torch.nn.functional.interpolate(
  486.                 mask_resized_tensor, size=test_original_shape, mode='nearest' # Use nearest for mask
  487.             ).squeeze() # [F_orig, T_orig] (0s and 1s)
  488.             save_spectrogram_plot(mask_orig_shape_tensor, f"Mask Resized to Orig Shape - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_7_mask_orig_shape.png")
  489.  
  490.  
  491.             # --- d) Combine original and generated using the resized mask ---
  492.             # output = original * mask + generated * (1 - mask)
  493.             # We want generated where mask_orig is 0, original where mask_orig is 1
  494.             print("   Combining original and generated spectrograms...")
  495.             combined_spec_db_tensor = (spec_orig_db_tensor * mask_orig_shape_tensor) + \
  496.                                       (generated_db_resized_tensor * (1.0 - mask_orig_shape_tensor))
  497.             save_spectrogram_plot(combined_spec_db_tensor, f"Combined Full Spec (dB) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_8_combined_spec_db.png")
  498.  
  499.  
  500.             # --- e) Reconstruct audio from the combined dB spectrogram ---
  501.             print("   Reconstructing audio from combined spectrogram...")
  502.             audio_full_inpainted = reconstruct_audio_from_spectrogram(
  503.                 input_spec_db=combined_spec_db_tensor # Pass dB spec directly
  504.             )
  505.  
  506.             if audio_full_inpainted is not None:
  507.                 try: sf.write(output_dir / f"epoch_{epoch:03d}_audio_full_inpainted.wav", audio_full_inpainted, SAMPLE_RATE)
  508.                 except Exception as e: print(f"Error saving FULL INPAINTED audio: {e}")
  509.             else: print("Skipping FULL INPAINTED audio saving due to reconstruction error.")
  510.  
  511.         except Exception as e:
  512.             print(f"Error during FULL INPAINTED audio creation: {e}")
  513.             print(f"Traceback: {traceback.format_exc()}")
  514.  
  515.     else: print("Skipping FULL INPAINTED audio reconstruction: original shape or path missing.")
  516.  
  517.     print("-" * 30)
  518.     generator_model.train() # Set back to training mode
  519.  
  520. # --- 8. Training Loop ---
  521. def train(dataloader, epochs, save_interval):
  522.     history = {'gen_loss_adv': [], 'gen_loss_l1': [], 'disc_loss': []}
  523.     print("Starting Training Loop...")
  524.  
  525.     for epoch in range(epochs):
  526.         gen_loss_adv_epoch = 0.0; gen_loss_l1_epoch = 0.0; disc_loss_epoch = 0.0
  527.         items_processed = 0
  528.         generator.train(); discriminator.train()
  529.  
  530.         for i, batch_data in enumerate(dataloader):
  531.             if batch_data is None: continue
  532.  
  533.             # Unpack batch data (masked, target, mask_tensor, shapes, paths)
  534.             masked_batch, target_batch, mask_batch, _, _ = batch_data # Use mask_batch now
  535.             masked_batch = masked_batch.to(DEVICE)
  536.             target_batch = target_batch.to(DEVICE)
  537.             mask_batch = mask_batch.to(DEVICE) # [B, 1, H, W], values are 0 or 1
  538.  
  539.             if masked_batch.nelement() == 0: continue # Skip empty batch
  540.  
  541.             batch_size = masked_batch.size(0)
  542.             items_processed += batch_size
  543.  
  544.             # Inverse mask for L1 loss calculation (where mask is 0 -> inv is 1)
  545.             mask_batch_inv = 1.0 - mask_batch
  546.  
  547.             # --- Train Discriminator ---
  548.             optimizer_discriminator.zero_grad(set_to_none=True)
  549.             with torch.no_grad(): generated_specs_d = generator(masked_batch)
  550.             real_output = discriminator(target_batch)       # D sees normalized target
  551.             fake_output = discriminator(generated_specs_d)  # D sees normalized generated
  552.             disc_loss = calculate_discriminator_loss(real_output, fake_output)
  553.             disc_loss.backward()
  554.             optimizer_discriminator.step()
  555.  
  556.             # --- Train Generator ---
  557.             optimizer_generator.zero_grad(set_to_none=True)
  558.             generated_specs_g = generator(masked_batch)   # [B, 1, H, W] normalized
  559.             fake_output_gen = discriminator(generated_specs_g) # D sees normalized generated
  560.             # Calculate loss (Adv + L1 in the masked region)
  561.             gen_loss, adv_loss, l1_loss = calculate_generator_loss(
  562.                 fake_output_gen, generated_specs_g, target_batch, mask_batch_inv
  563.             )
  564.             gen_loss.backward()
  565.             optimizer_generator.step()
  566.  
  567.             # Accumulate losses
  568.             gen_loss_adv_epoch += adv_loss.item() * batch_size
  569.             gen_loss_l1_epoch += l1_loss.item() * batch_size
  570.             disc_loss_epoch += disc_loss.item() * batch_size
  571.  
  572.             if (i + 1) % 50 == 0: print(f"[{epoch+1}/{epochs}][{i+1}/{len(dataloader)}] D L: {disc_loss.item():.4f} | G L: {gen_loss.item():.4f} (Adv: {adv_loss.item():.4f}, L1: {l1_loss.item():.4f})")
  573.  
  574.         # --- End of Epoch ---
  575.         if items_processed > 0:
  576.             avg_d_loss = disc_loss_epoch / items_processed
  577.             avg_g_adv = gen_loss_adv_epoch / items_processed
  578.             avg_g_l1 = gen_loss_l1_epoch / items_processed
  579.             history['disc_loss'].append(avg_d_loss); history['gen_loss_adv'].append(avg_g_adv); history['gen_loss_l1'].append(avg_g_l1)
  580.             print(f"Epoch {epoch+1}/{epochs} Avg Losses -> D: {avg_d_loss:.4f}, G_Adv: {avg_g_adv:.4f}, G_L1: {avg_g_l1:.4f}")
  581.         else: print(f"Epoch {epoch+1}/{epochs} finished - 0 items processed.")
  582.  
  583.         # --- Save Checkpoints and Samples ---
  584.         if (epoch + 1) % save_interval == 0 or (epoch + 1) == epochs:
  585.             epoch_num_padded = f"{epoch+1:03d}"
  586.             print(f"Saving outputs for epoch {epoch+1} to {TRAINING_OUTPUT_DIR}")
  587.             g_path = TRAINING_OUTPUT_DIR / f'generator_epoch_{epoch_num_padded}.pth'
  588.             d_path = TRAINING_OUTPUT_DIR / f'discriminator_epoch_{epoch_num_padded}.pth'
  589.             model_to_save_g = generator.module if use_dataparallel else generator
  590.             model_to_save_d = discriminator.module if use_dataparallel else discriminator
  591.             torch.save(model_to_save_g.state_dict(), g_path)
  592.             torch.save(model_to_save_d.state_dict(), d_path)
  593.  
  594.             if test_input_tuple is not None:
  595.                 generate_and_save_samples(epoch + 1, generator, test_input_tuple, TRAINING_OUTPUT_DIR)
  596.             else: print("Skipping sample generation: fixed test input not available.")
  597.  
  598.     # --- End of Training Plot ---
  599.     plt.figure(figsize=(12, 6))
  600.     plt.plot(history['disc_loss'], label='Discriminator Loss'); plt.plot(history['gen_loss_adv'], label='Gen Adversarial Loss')
  601.     plt.plot(history['gen_loss_l1'], label=f'Gen L1 Loss (Unscaled)'); plt.xlabel('Epoch'); plt.ylabel('Loss')
  602.     plt.legend(); plt.title('Training Losses'); plt.grid(True)
  603.     plt.savefig(TRAINING_OUTPUT_DIR / 'training_losses.png'); plt.show()
  604.  
  605. # --- 9. Prepare Fixed Test Input (Once) ---
  606. print("Preparing fixed test input for visualization...")
  607. test_input_tuple = None # Will hold (masked[1,1,H,W], target[1,1,H,W], mask[1,1,H,W], [orig_shape], [audio_path])
  608.  
  609. try:
  610.     if train_dataset and len(train_dataset.audio_files) > 0:
  611.         found_test_item = False; attempts = 0; max_attempts = min(len(train_dataset), 20) # Try up to 20 times
  612.         while not found_test_item and attempts < max_attempts:
  613.             attempts += 1
  614.             test_idx = random.randrange(len(train_dataset))
  615.             print(f"Attempt {attempts}/{max_attempts}: Trying index {test_idx}...")
  616.             # Get all items from dataset __getitem__
  617.             test_masked_spec, test_target_spec, test_mask, test_original_shape, test_audio_path = train_dataset[test_idx]
  618.  
  619.             # Check if all components are valid
  620.             if all(item is not None for item in [test_masked_spec, test_target_spec, test_mask, test_original_shape, test_audio_path]):
  621.                 # Unsqueeze to create a batch of 1 and store in tuple
  622.                 test_input_tuple = (
  623.                     test_masked_spec.unsqueeze(0), #[1, 1, H, W]
  624.                     test_target_spec.unsqueeze(0), #[1, 1, H, W]
  625.                     test_mask.unsqueeze(0),        #[1, 1, H, W]
  626.                     [test_original_shape],         # List containing one shape tuple
  627.                     [test_audio_path]              # List containing one path string
  628.                 )
  629.                 print(f"Fixed test input prepared successfully from: {test_audio_path}")
  630.                 found_test_item = True
  631.             else:
  632.                 print(f"Failed to process file at index {test_idx}. Retrying...")
  633.         if not found_test_item: print(f"Failed to prepare test input after {max_attempts} attempts.")
  634.     else: print("Cannot prepare test input: Dataset empty or not initialized.")
  635. except Exception as e:
  636.     print(f"Error preparing test input: {e}")
  637.     print(f"Traceback: {traceback.format_exc()}")
  638.  
  639.  
  640. # --- 10. Run Training ---
  641. if __name__ == "__main__":
  642.     try:
  643.         if train_dataloader is None or len(train_dataloader) == 0: raise RuntimeError("DataLoader invalid/empty.")
  644.         if test_input_tuple is None: print("WARNING: Test input missing. Sample generation will be skipped.")
  645.         train(train_dataloader, EPOCHS, CHECKPOINT_SAVE_INTERVAL)
  646.         print("\nTraining finished.")
  647.         print(f"Outputs saved in: {TRAINING_OUTPUT_DIR}")
  648.     except Exception as e:
  649.         print(f"\nAn critical error occurred during training: {e}")
  650.         print(f"Traceback: {traceback.format_exc()}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement