Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import Dataset, DataLoader
- import torchaudio
- import librosa
- import librosa.display
- import numpy as np
- import os
- import random
- import matplotlib.pyplot as plt
- from pathlib import Path # Use pathlib for better path handling
- import warnings # To suppress warnings
- import io # For stderr capture
- import contextlib # For stderr capture
- import soundfile as sf # For saving audio
- import traceback # For detailed error printing
- # --- 1. Configuration ---
- DATASET_PATH = Path('/kaggle/input/fma-small/fma_small/fma_small') # <--- VERIFY THIS PATH!
- TRAINING_OUTPUT_DIR = Path('./training_output_unet_norm') # <--- Changed output dir name
- TRAINING_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) # Ensure it exists
- SAMPLE_RATE = 22050
- N_FFT = 512
- HOP_LENGTH = 128
- SPECTROGRAM_HEIGHT = 128 # Target height after resizing
- SPECTROGRAM_WIDTH = 128 # Target width after resizing
- MASK_SIZE = (32, 32) # Size of the square mask in the resized spectrogram
- LOAD_DURATION = 5.0 # How many seconds of audio to load initially
- BATCH_SIZE = 4
- EPOCHS = 100
- LEARNING_RATE_GENERATOR = 0.0002
- LEARNING_RATE_DISCRIMINATOR = 0.0001
- BETA_1 = 0.5
- CHECKPOINT_SAVE_INTERVAL = 10
- GRIFFIN_LIM_ITERS = 32
- LAMBDA_L1 = 100
- # --- Normalization Parameters ---
- MIN_DB = -80.0
- MAX_DB = 0.0
- # --- Device Setup ---
- if torch.cuda.is_available():
- if torch.cuda.device_count() >= 2:
- DEVICE = torch.device("cuda:0")
- print(f"Using primary device: {DEVICE}")
- print(f"Utilizing {torch.cuda.device_count()} GPUs with DataParallel.")
- use_dataparallel = True
- else:
- DEVICE = torch.device("cuda:0")
- print(f"Using single GPU: {DEVICE}")
- use_dataparallel = False
- else:
- DEVICE = torch.device("cpu")
- print(f"Using CPU: {DEVICE}")
- use_dataparallel = False
- # --- 2. Dataset Class (Returns Mask) ---
- class AudioInpaintingDataset(Dataset):
- def __init__(self, dataset_dir, sr, n_fft, hop_length, spec_height, spec_width, mask_size, load_duration, min_db=MIN_DB, max_db=MAX_DB):
- self.dataset_dir = Path(dataset_dir)
- self.sr = sr
- self.n_fft = n_fft
- self.hop_length = hop_length
- self.spec_height = spec_height # Target height for network input
- self.spec_width = spec_width # Target width for network input
- self.mask_h, self.mask_w = mask_size # Store mask dimensions
- self.load_duration = load_duration
- self.min_db = min_db
- self.max_db = max_db
- self.audio_files = self._find_and_validate_files()
- if not self.audio_files:
- raise FileNotFoundError(f"No valid MP3 files found >= {self.load_duration}s in {self.dataset_dir}")
- print(f"Dataset initialized with {len(self.audio_files)} valid audio files.")
- def _find_and_validate_files(self):
- # (Validation code remains the same as before)
- print(f"Searching for MP3 files in: {self.dataset_dir}")
- candidate_files = list(self.dataset_dir.rglob('*.mp3'))
- print(f"Found {len(candidate_files)} potential MP3 files. Starting pre-check for >= {self.load_duration}s duration...")
- valid_audio_files = []
- skipped_short_count = 0; skipped_error_count = 0
- max_verbose_skips = 5; short_msgs_printed = 0; error_msgs_printed = 0
- stderr_capture = io.StringIO()
- with contextlib.redirect_stderr(stderr_capture), warnings.catch_warnings():
- warnings.simplefilter("ignore", category=UserWarning); warnings.simplefilter("ignore", category=FutureWarning)
- for i, file_path in enumerate(candidate_files):
- if (i + 1) % 500 == 0: print(f" Pre-checking file {i+1}/{len(candidate_files)}...")
- try:
- actual_duration = librosa.get_duration(path=file_path)
- if actual_duration >= self.load_duration: valid_audio_files.append(file_path)
- else:
- skipped_short_count += 1
- if short_msgs_printed < max_verbose_skips: print(f" Skipping short: {file_path} ({actual_duration:.2f}s)"); short_msgs_printed += 1
- elif short_msgs_printed == max_verbose_skips: print(" (Further short file skip messages suppressed)"); short_msgs_printed += 1
- except Exception as e:
- skipped_error_count += 1
- if error_msgs_printed < max_verbose_skips: print(f" Skipping error: {file_path} ({type(e).__name__})"); error_msgs_printed += 1
- elif error_msgs_printed == max_verbose_skips: print(" (Further error messages suppressed)"); error_msgs_printed += 1
- continue
- print("-" * 30 + f"\nPre-check complete. Found {len(valid_audio_files)} valid files. Skipped {skipped_short_count} (short), {skipped_error_count} (error).\n" + "-" * 30)
- captured_stderr = stderr_capture.getvalue()
- if captured_stderr: print("\n--- Captured Stderr during pre-check ---\n" + captured_stderr.strip() + "\n--------------------------------------\n")
- return valid_audio_files
- def __len__(self):
- return len(self.audio_files)
- def normalize_spectrogram(self, spec_db):
- scaled = (spec_db - self.min_db) / (self.max_db - self.min_db) * 2.0 - 1.0
- return torch.clamp(scaled, -1.0, 1.0)
- def create_mask(self, spectrogram_shape):
- """Creates the mask AND returns its coordinates."""
- mask = torch.ones(spectrogram_shape)
- h, w = spectrogram_shape
- mask_h_actual = min(h, self.mask_h)
- mask_w_actual = min(w, self.mask_w)
- if h <= 0 or w <= 0 or mask_h_actual <= 0 or mask_w_actual <= 0:
- print(f"Warning: Cannot apply mask to shape {spectrogram_shape}. Returning full mask.")
- return mask, (0, 0, 0, 0) # Return mask and zero coordinates
- if mask_h_actual < self.mask_h or mask_w_actual < self.mask_w:
- print(f"Warning: Spectrogram shape {spectrogram_shape} too small for mask {(self.mask_h, self.mask_w)}. Applying smaller mask {(mask_h_actual, mask_w_actual)}.")
- start_row = random.randint(0, h - mask_h_actual)
- start_col = random.randint(0, w - mask_w_actual)
- mask[start_row : start_row + mask_h_actual, start_col : start_col + mask_w_actual] = 0.0
- coords = (start_row, start_col, mask_h_actual, mask_w_actual) # (y, x, height, width)
- return mask, coords
- def __getitem__(self, idx):
- audio_path = self.audio_files[idx]
- stderr_capture = io.StringIO()
- with contextlib.redirect_stderr(stderr_capture), warnings.catch_warnings():
- warnings.simplefilter("ignore", category=UserWarning); warnings.simplefilter("ignore", category=FutureWarning)
- try:
- y, sr_loaded = librosa.load(audio_path, sr=self.sr, duration=self.load_duration)
- if sr_loaded != self.sr: warnings.warn(f"SR mismatch: {sr_loaded} != {self.sr}")
- if y is None or len(y) == 0: raise ValueError("Loaded audio empty")
- target_samples = int(self.load_duration * self.sr)
- if len(y) < target_samples: y = np.pad(y, (0, target_samples - len(y)), mode='constant')
- elif len(y) > target_samples: y = y[:target_samples]
- spectrogram = librosa.stft(y, n_fft=self.n_fft, hop_length=self.hop_length, center=True)
- spectrogram_db = librosa.amplitude_to_db(np.abs(spectrogram), ref=np.max)
- original_shape = spectrogram_db.shape # (F_orig, T_orig)
- spectrogram_tensor = torch.tensor(spectrogram_db).unsqueeze(0) # [1, F_orig, T_orig]
- spectrogram_resized = torch.nn.functional.interpolate(
- spectrogram_tensor.unsqueeze(0), size=(self.spec_height, self.spec_width),
- mode='bilinear', align_corners=False
- ).squeeze(0) # [1, H, W]
- target_spectrogram_normalized = self.normalize_spectrogram(spectrogram_resized) # [1, H, W]
- # Create mask for RESIZED dimensions
- mask_hw, mask_coords = self.create_mask(target_spectrogram_normalized.shape[1:]) # Shape is (H, W)
- mask_tensor = mask_hw.unsqueeze(0) # [1, H, W]
- masked_spectrogram = target_spectrogram_normalized * mask_tensor # Apply mask [1, H, W]
- # Return necessary items including the mask tensor and path for full reconstruction
- return masked_spectrogram, target_spectrogram_normalized, mask_tensor, original_shape, audio_path
- except Exception as e:
- print(f"WARNING: Error in __getitem__ for {audio_path}: {e}. Skipping.")
- # print(f"Traceback: {traceback.format_exc()}") # Uncomment for debugging
- item_stderr = stderr_capture.getvalue()
- if item_stderr: print(f"--- Stderr for {audio_path} ---\n{item_stderr.strip()}\n---")
- return None, None, None, None, None # Match return structure
- # --- Custom Collate Function (Handles Mask Tensor and Path) ---
- def collate_fn_skip_none(batch):
- batch = [item for item in batch if all(i is not None for i in item)] # Check all items in tuple
- if not batch: return None
- masked_specs = [item[0] for item in batch]
- target_specs = [item[1] for item in batch]
- masks = [item[2] for item in batch] # Collect mask tensors
- original_shapes = [item[3] for item in batch]
- audio_paths = [item[4] for item in batch] # Collect paths
- masked_specs_collated = torch.utils.data.dataloader.default_collate(masked_specs)
- target_specs_collated = torch.utils.data.dataloader.default_collate(target_specs)
- masks_collated = torch.utils.data.dataloader.default_collate(masks) # Collate masks
- # Return collated tensors and lists for shapes/paths
- return masked_specs_collated, target_specs_collated, masks_collated, original_shapes, audio_paths
- # --- 3. Data Loader ---
- try:
- train_dataset = AudioInpaintingDataset(
- DATASET_PATH, SAMPLE_RATE, N_FFT, HOP_LENGTH,
- SPECTROGRAM_HEIGHT, SPECTROGRAM_WIDTH, MASK_SIZE, LOAD_DURATION,
- min_db=MIN_DB, max_db=MAX_DB
- )
- if len(train_dataset) == 0: raise RuntimeError("Dataset is empty after validation.")
- train_dataloader = DataLoader(
- train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, # Reduced workers potentially
- pin_memory=True if DEVICE.type == 'cuda' else False, drop_last=True,
- collate_fn=collate_fn_skip_none # Use the modified collate function
- )
- print(f"DataLoader initialized with {len(train_dataloader)} batches.")
- except (FileNotFoundError, RuntimeError) as e: print(f"Error initializing dataset/dataloader: {e}"); exit()
- except Exception as e: print(f"Unexpected error during dataloader init: {e}"); traceback.print_exc(); exit()
- # --- 4. Generator Model (U-Net - Remains the Same) ---
- class UNetDown(nn.Module):
- def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
- super(UNetDown, self).__init__()
- layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
- if normalize: layers.append(nn.InstanceNorm2d(out_channels))
- layers.append(nn.LeakyReLU(0.2, inplace=True))
- if dropout > 0.0: layers.append(nn.Dropout(dropout))
- self.model = nn.Sequential(*layers)
- def forward(self, x): return self.model(x)
- class UNetUp(nn.Module):
- def __init__(self, in_channels, out_channels, dropout=0.0):
- super(UNetUp, self).__init__()
- layers = [
- nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
- nn.InstanceNorm2d(out_channels),
- nn.ReLU(inplace=True)
- ]
- if dropout > 0.0: layers.append(nn.Dropout(dropout))
- self.model = nn.Sequential(*layers)
- def forward(self, x, skip_input):
- x = self.model(x)
- x = torch.cat((x, skip_input), 1)
- return x
- class UNetGenerator(nn.Module):
- def __init__(self, in_channels=1, out_channels=1):
- super(UNetGenerator, self).__init__()
- self.down1=UNetDown(in_channels,64,normalize=False); self.down2=UNetDown(64,128)
- self.down3=UNetDown(128,256); self.down4=UNetDown(256,512,dropout=0.2)
- self.down5=UNetDown(512,512,dropout=0.2); self.down6=UNetDown(512,512,dropout=0.0)
- self.down7=UNetDown(512,512,normalize=False,dropout=0.5)
- self.up1=UNetUp(512,512,dropout=0.5); self.up2=UNetUp(1024,512,dropout=0.2)
- self.up3=UNetUp(1024,512,dropout=0.2); self.up4=UNetUp(1024,256,dropout=0.0)
- self.up5=UNetUp(512,128,dropout=0.0); self.up6=UNetUp(256,64,dropout=0.0)
- self.final_up=nn.Sequential(
- nn.ConvTranspose2d(128,out_channels,kernel_size=4,stride=2,padding=1), nn.Tanh())
- def forward(self, x):
- d1=self.down1(x);d2=self.down2(d1);d3=self.down3(d2);d4=self.down4(d3)
- d5=self.down5(d4);d6=self.down6(d5);d7=self.down7(d6)
- u1=self.up1(d7,d6);u2=self.up2(u1,d5);u3=self.up3(u2,d4)
- u4=self.up4(u3,d3);u5=self.up5(u4,d2);u6=self.up6(u5,d1)
- return self.final_up(u6)
- # --- 5. Discriminator Model (PatchGAN - Remains the Same) ---
- class Discriminator(nn.Module):
- def __init__(self, in_channels=1): # Takes 1 channel input
- super(Discriminator, self).__init__()
- def block(i,o,n=True): layers=[nn.Conv2d(i,o,4,2,1)]; layers.append(nn.InstanceNorm2d(o) if n else nn.Identity()); layers.append(nn.LeakyReLU(0.2,inplace=True)); return layers
- self.model = nn.Sequential( *block(in_channels,64,False), *block(64,128), *block(128,256), *block(256,512), nn.Conv2d(512,1,4,1,1) )
- def forward(self, x): return self.model(x)
- # --- 6. Loss Functions & Optimizers (Remains the Same) ---
- adversarial_loss = nn.BCEWithLogitsLoss().to(DEVICE)
- content_loss_l1 = nn.L1Loss().to(DEVICE)
- lambda_l1 = LAMBDA_L1
- def calculate_discriminator_loss(real_output, fake_output):
- real_loss = adversarial_loss(real_output, torch.ones_like(real_output, device=DEVICE))
- fake_loss = adversarial_loss(fake_output, torch.zeros_like(fake_output, device=DEVICE))
- return (real_loss + fake_loss) / 2
- def calculate_generator_loss(fake_output_disc, generated_spectrogram, target_spectrogram, mask_inv):
- # mask_inv selects the *masked* region (where mask is 0, inv is 1)
- adv_loss = adversarial_loss(fake_output_disc, torch.ones_like(fake_output_disc, device=DEVICE))
- content_loss_val = content_loss_l1(generated_spectrogram * mask_inv, target_spectrogram * mask_inv)
- total_loss = adv_loss + lambda_l1 * content_loss_val
- return total_loss, adv_loss, content_loss_val
- # --- Initialize Models & Optimizers ---
- generator = UNetGenerator(in_channels=1, out_channels=1).to(DEVICE)
- discriminator = Discriminator(in_channels=1).to(DEVICE) # Takes 1 channel
- if use_dataparallel:
- generator = nn.DataParallel(generator); discriminator = nn.DataParallel(discriminator)
- print("Models wrapped in nn.DataParallel.")
- optimizer_generator = optim.Adam(generator.parameters(), lr=LEARNING_RATE_GENERATOR, betas=(BETA_1, 0.999))
- optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE_DISCRIMINATOR, betas=(BETA_1, 0.999))
- # --- 7. Helper Functions for Saving (Adjusted for dB input reconstruction) ---
- def save_spectrogram_plot(spectrogram_data, title, filepath):
- plt.figure(figsize=(6, 6)); plt.tight_layout()
- if isinstance(spectrogram_data, torch.Tensor): spectrogram_data = spectrogram_data.detach().squeeze().cpu().numpy()
- img = librosa.display.specshow(spectrogram_data, sr=SAMPLE_RATE, hop_length=HOP_LENGTH, x_axis='time', y_axis='linear', cmap='magma')
- plt.colorbar(img, format='%+.1f dB' if 'Original' in title else '%+.2f') # Format based on type
- plt.title(title)
- plt.savefig(filepath); plt.close()
- def denormalize_spectrogram(spec_norm, min_db=MIN_DB, max_db=MAX_DB):
- scaled_01 = (spec_norm + 1.0) / 2.0
- spec_db = scaled_01 * (max_db - min_db) + min_db
- return spec_db
- def reconstruct_audio_from_spectrogram(
- spectrogram_norm=None, # Input: Normalized spectrogram [-1, 1]
- input_spec_db=None, # Input: OR dB spectrogram directly
- original_shape=None, # Optional: Target shape (F_orig, T_orig) for resizing norm spec
- n_iter=GRIFFIN_LIM_ITERS):
- """
- Reconstructs audio from either a NORMALIZED spectrogram (optionally resizing it)
- OR directly from a provided dB spectrogram.
- """
- try:
- if input_spec_db is not None:
- # Use provided dB spectrogram directly
- if isinstance(input_spec_db, torch.Tensor):
- spectrogram_db_final = input_spec_db.detach().squeeze().cpu().numpy()
- else:
- spectrogram_db_final = np.squeeze(input_spec_db)
- print(f"Reconstructing directly from provided dB spectrogram, shape: {spectrogram_db_final.shape}")
- elif spectrogram_norm is not None:
- # Process normalized spectrogram
- if isinstance(spectrogram_norm, torch.Tensor):
- spectrogram_norm_np = spectrogram_norm.detach().squeeze().cpu().numpy()
- else:
- spectrogram_norm_np = np.squeeze(spectrogram_norm)
- # --- De-normalize first ---
- spectrogram_db = denormalize_spectrogram(spectrogram_norm_np, MIN_DB, MAX_DB)
- # --- Optional: Resize back to original dimensions ---
- if original_shape is not None:
- current_shape = spectrogram_db.shape
- target_F, target_T = original_shape
- if current_shape != original_shape:
- print(f"Resizing de-normalized spectrogram from {current_shape} to {original_shape}...")
- spec_db_tensor = torch.tensor(spectrogram_db).unsqueeze(0).unsqueeze(0)
- spec_resized_tensor = torch.nn.functional.interpolate(
- spec_db_tensor, size=(target_F, target_T), mode='bilinear', align_corners=False)
- spectrogram_db_final = spec_resized_tensor.squeeze().numpy()
- else: spectrogram_db_final = spectrogram_db # No resize needed
- else: spectrogram_db_final = spectrogram_db # Use as is
- print(f"Reconstructing from de-normalized/resized spec, final shape: {spectrogram_db_final.shape}")
- else:
- raise ValueError("Must provide either spectrogram_norm or input_spec_db")
- # --- Convert dB back to linear amplitude ---
- spectrogram_amp = librosa.db_to_amplitude(spectrogram_db_final, ref=1.0) # ref=1.0 as we scaled dB absolutely
- # --- Estimate phase using Griffin-Lim ---
- num_frames = spectrogram_amp.shape[1]
- expected_length = int((num_frames) * HOP_LENGTH) # More accurate length estimate for center=True STFT
- estimated_audio = librosa.griffinlim(spectrogram_amp,
- n_iter=n_iter,
- hop_length=HOP_LENGTH,
- n_fft=N_FFT,
- length=expected_length) # Give length hint
- print(f"Reconstructed audio length: {len(estimated_audio)} samples ({len(estimated_audio)/SAMPLE_RATE:.2f}s)")
- return estimated_audio
- except Exception as e:
- print(f"Error during audio reconstruction: {e}")
- print(f"Traceback: {traceback.format_exc()}")
- # Print shapes for debugging if available
- if 'spectrogram_norm_np' in locals(): print(f"Input norm spec shape: {spectrogram_norm_np.shape}")
- if 'spectrogram_db' in locals(): print(f"De-norm spec shape: {spectrogram_db.shape}")
- if 'spectrogram_db_final' in locals(): print(f"Final spec shape for GL: {spectrogram_db_final.shape}")
- if original_shape: print(f"Target original shape: {original_shape}")
- return None
- def generate_and_save_samples(epoch, generator_model, test_inputs, output_dir):
- """
- Generates plots and saves three audio versions:
- 1. Short audio from direct generator output (HxW spec).
- 2. Long audio from resizing generator output back to original shape (F_orig x T_orig spec).
- 3. FULL inpainted audio by combining original + generated (F_orig x T_orig spec).
- """
- generator_model.eval()
- # test_inputs: (masked_spec[1,1,H,W], target_spec[1,1,H,W], mask[1,1,H,W], [orig_shape], [audio_path])
- test_masked, test_target_normalized, test_mask_resized, test_original_shapes, test_audio_paths = test_inputs
- # Use the first item in the batch for visualization/saving
- test_original_shape = test_original_shapes[0] if test_original_shapes else None
- test_audio_path = test_audio_paths[0] if test_audio_paths else None
- if not test_original_shape or not test_audio_path:
- print(f"Warning: Missing original shape or audio path for test sample. Cannot generate full inpainted audio.")
- # We can still generate short/resized versions if masked/generated are available
- if test_masked is None: # Check if we even have input
- print("Error: test_masked is None. Cannot generate any samples.")
- generator_model.train()
- return
- test_masked_dev = test_masked.to(DEVICE)
- with torch.no_grad():
- generated_spectrogram_normalized = generator_model(test_masked_dev) # Output is normalized [1, 1, H, W]
- # --- Move results to CPU ---
- test_masked_cpu = test_masked[0, 0].cpu() # [H, W] normalized
- generated_norm_cpu = generated_spectrogram_normalized[0, 0].cpu() # [H, W] normalized
- test_target_norm_cpu = test_target_normalized[0, 0].cpu() # [H, W] normalized
- test_mask_resized_cpu = test_mask_resized[0, 0].cpu() # [H, W] mask (0 or 1)
- print(f"\n--- Generating Samples for Epoch {epoch} ---")
- print(f"Network Output (Generated Norm HxW): {generated_norm_cpu.shape}")
- print(f"Test Mask (Resized HxW): {test_mask_resized_cpu.shape}")
- if test_original_shape: print(f"Original Spec Shape (F_orig x T_orig): {test_original_shape}")
- if test_audio_path: print(f"Original Audio Path: {test_audio_path}")
- # --- Save Spectrogram Plots ---
- save_spectrogram_plot(test_masked_cpu, f"Masked Input (Norm) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_1_masked_input_norm.png")
- save_spectrogram_plot(generated_norm_cpu, f"Generated Output (Norm) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_2_generated_output_norm.png")
- save_spectrogram_plot(test_target_norm_cpu, f"Target (Norm, Resized) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_3_target_resized_norm.png")
- save_spectrogram_plot(test_mask_resized_cpu, f"Test Mask (Resized HxW) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_4_mask_resized.png")
- # --- 1. Reconstruct and Save SHORT Audio (from HxW norm spec) ---
- print("\n1. Reconstructing SHORT audio (from network HxW output)...")
- audio_short = reconstruct_audio_from_spectrogram(
- spectrogram_norm=generated_norm_cpu,
- original_shape=None # Don't resize
- )
- if audio_short is not None:
- try: sf.write(output_dir / f"epoch_{epoch:03d}_audio_short.wav", audio_short, SAMPLE_RATE)
- except Exception as e: print(f"Error saving SHORT audio: {e}")
- else: print("Skipping SHORT audio saving due to reconstruction error.")
- # --- 2. Reconstruct and Save LONG RESIZED Audio (Resize HxW norm spec to F_orig x T_orig) ---
- if test_original_shape:
- print("\n2. Reconstructing LONG RESIZED audio (resizing HxW output to original shape)...")
- audio_long_resized = reconstruct_audio_from_spectrogram(
- spectrogram_norm=generated_norm_cpu,
- original_shape=test_original_shape # Resize back
- )
- if audio_long_resized is not None:
- try: sf.write(output_dir / f"epoch_{epoch:03d}_audio_long_resized.wav", audio_long_resized, SAMPLE_RATE)
- except Exception as e: print(f"Error saving LONG RESIZED audio: {e}")
- else: print("Skipping LONG RESIZED audio saving due to reconstruction error.")
- else: print("Skipping LONG RESIZED audio reconstruction: original shape missing.")
- # --- 3. Reconstruct and Save FULL INPAINTED Audio ---
- if test_original_shape and test_audio_path:
- print("\n3. Reconstructing FULL INPAINTED audio (combining original and generated)...")
- try:
- # --- a) Reload original audio and compute its full dB spectrogram ---
- print(f" Reloading original audio from: {test_audio_path}")
- y_orig, _ = librosa.load(test_audio_path, sr=SAMPLE_RATE, duration=LOAD_DURATION)
- target_samples = int(LOAD_DURATION * SAMPLE_RATE)
- if len(y_orig) < target_samples: y_orig = np.pad(y_orig, (0, target_samples - len(y_orig)), mode='constant')
- elif len(y_orig) > target_samples: y_orig = y_orig[:target_samples]
- spec_orig = librosa.stft(y_orig, n_fft=N_FFT, hop_length=HOP_LENGTH, center=True)
- spec_orig_db = librosa.amplitude_to_db(np.abs(spec_orig), ref=np.max)
- spec_orig_db_tensor = torch.tensor(spec_orig_db) # Keep as tensor [F_orig, T_orig]
- save_spectrogram_plot(spec_orig_db_tensor, f"Original Full Spec (dB) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_5_original_full_spec_db.png")
- # --- b) De-normalize and resize generated HxW spectrogram to original shape ---
- print(" De-normalizing and resizing generated HxW spectrogram...")
- generated_db_cpu = denormalize_spectrogram(generated_norm_cpu) # [H, W] dB
- generated_db_tensor = generated_db_cpu.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
- generated_db_resized_tensor = torch.nn.functional.interpolate(
- generated_db_tensor, size=test_original_shape, mode='bilinear', align_corners=False
- ).squeeze() # [F_orig, T_orig] dB
- save_spectrogram_plot(generated_db_resized_tensor, f"Generated Resized Spec (dB) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_6_generated_resized_spec_db.png")
- # --- c) Resize the HxW mask to original shape (use nearest neighbor) ---
- print(" Resizing HxW mask to original shape...")
- mask_resized_tensor = test_mask_resized_cpu.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
- mask_orig_shape_tensor = torch.nn.functional.interpolate(
- mask_resized_tensor, size=test_original_shape, mode='nearest' # Use nearest for mask
- ).squeeze() # [F_orig, T_orig] (0s and 1s)
- save_spectrogram_plot(mask_orig_shape_tensor, f"Mask Resized to Orig Shape - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_7_mask_orig_shape.png")
- # --- d) Combine original and generated using the resized mask ---
- # output = original * mask + generated * (1 - mask)
- # We want generated where mask_orig is 0, original where mask_orig is 1
- print(" Combining original and generated spectrograms...")
- combined_spec_db_tensor = (spec_orig_db_tensor * mask_orig_shape_tensor) + \
- (generated_db_resized_tensor * (1.0 - mask_orig_shape_tensor))
- save_spectrogram_plot(combined_spec_db_tensor, f"Combined Full Spec (dB) - Epoch {epoch}", output_dir / f"epoch_{epoch:03d}_8_combined_spec_db.png")
- # --- e) Reconstruct audio from the combined dB spectrogram ---
- print(" Reconstructing audio from combined spectrogram...")
- audio_full_inpainted = reconstruct_audio_from_spectrogram(
- input_spec_db=combined_spec_db_tensor # Pass dB spec directly
- )
- if audio_full_inpainted is not None:
- try: sf.write(output_dir / f"epoch_{epoch:03d}_audio_full_inpainted.wav", audio_full_inpainted, SAMPLE_RATE)
- except Exception as e: print(f"Error saving FULL INPAINTED audio: {e}")
- else: print("Skipping FULL INPAINTED audio saving due to reconstruction error.")
- except Exception as e:
- print(f"Error during FULL INPAINTED audio creation: {e}")
- print(f"Traceback: {traceback.format_exc()}")
- else: print("Skipping FULL INPAINTED audio reconstruction: original shape or path missing.")
- print("-" * 30)
- generator_model.train() # Set back to training mode
- # --- 8. Training Loop ---
- def train(dataloader, epochs, save_interval):
- history = {'gen_loss_adv': [], 'gen_loss_l1': [], 'disc_loss': []}
- print("Starting Training Loop...")
- for epoch in range(epochs):
- gen_loss_adv_epoch = 0.0; gen_loss_l1_epoch = 0.0; disc_loss_epoch = 0.0
- items_processed = 0
- generator.train(); discriminator.train()
- for i, batch_data in enumerate(dataloader):
- if batch_data is None: continue
- # Unpack batch data (masked, target, mask_tensor, shapes, paths)
- masked_batch, target_batch, mask_batch, _, _ = batch_data # Use mask_batch now
- masked_batch = masked_batch.to(DEVICE)
- target_batch = target_batch.to(DEVICE)
- mask_batch = mask_batch.to(DEVICE) # [B, 1, H, W], values are 0 or 1
- if masked_batch.nelement() == 0: continue # Skip empty batch
- batch_size = masked_batch.size(0)
- items_processed += batch_size
- # Inverse mask for L1 loss calculation (where mask is 0 -> inv is 1)
- mask_batch_inv = 1.0 - mask_batch
- # --- Train Discriminator ---
- optimizer_discriminator.zero_grad(set_to_none=True)
- with torch.no_grad(): generated_specs_d = generator(masked_batch)
- real_output = discriminator(target_batch) # D sees normalized target
- fake_output = discriminator(generated_specs_d) # D sees normalized generated
- disc_loss = calculate_discriminator_loss(real_output, fake_output)
- disc_loss.backward()
- optimizer_discriminator.step()
- # --- Train Generator ---
- optimizer_generator.zero_grad(set_to_none=True)
- generated_specs_g = generator(masked_batch) # [B, 1, H, W] normalized
- fake_output_gen = discriminator(generated_specs_g) # D sees normalized generated
- # Calculate loss (Adv + L1 in the masked region)
- gen_loss, adv_loss, l1_loss = calculate_generator_loss(
- fake_output_gen, generated_specs_g, target_batch, mask_batch_inv
- )
- gen_loss.backward()
- optimizer_generator.step()
- # Accumulate losses
- gen_loss_adv_epoch += adv_loss.item() * batch_size
- gen_loss_l1_epoch += l1_loss.item() * batch_size
- disc_loss_epoch += disc_loss.item() * batch_size
- if (i + 1) % 50 == 0: print(f"[{epoch+1}/{epochs}][{i+1}/{len(dataloader)}] D L: {disc_loss.item():.4f} | G L: {gen_loss.item():.4f} (Adv: {adv_loss.item():.4f}, L1: {l1_loss.item():.4f})")
- # --- End of Epoch ---
- if items_processed > 0:
- avg_d_loss = disc_loss_epoch / items_processed
- avg_g_adv = gen_loss_adv_epoch / items_processed
- avg_g_l1 = gen_loss_l1_epoch / items_processed
- history['disc_loss'].append(avg_d_loss); history['gen_loss_adv'].append(avg_g_adv); history['gen_loss_l1'].append(avg_g_l1)
- print(f"Epoch {epoch+1}/{epochs} Avg Losses -> D: {avg_d_loss:.4f}, G_Adv: {avg_g_adv:.4f}, G_L1: {avg_g_l1:.4f}")
- else: print(f"Epoch {epoch+1}/{epochs} finished - 0 items processed.")
- # --- Save Checkpoints and Samples ---
- if (epoch + 1) % save_interval == 0 or (epoch + 1) == epochs:
- epoch_num_padded = f"{epoch+1:03d}"
- print(f"Saving outputs for epoch {epoch+1} to {TRAINING_OUTPUT_DIR}")
- g_path = TRAINING_OUTPUT_DIR / f'generator_epoch_{epoch_num_padded}.pth'
- d_path = TRAINING_OUTPUT_DIR / f'discriminator_epoch_{epoch_num_padded}.pth'
- model_to_save_g = generator.module if use_dataparallel else generator
- model_to_save_d = discriminator.module if use_dataparallel else discriminator
- torch.save(model_to_save_g.state_dict(), g_path)
- torch.save(model_to_save_d.state_dict(), d_path)
- if test_input_tuple is not None:
- generate_and_save_samples(epoch + 1, generator, test_input_tuple, TRAINING_OUTPUT_DIR)
- else: print("Skipping sample generation: fixed test input not available.")
- # --- End of Training Plot ---
- plt.figure(figsize=(12, 6))
- plt.plot(history['disc_loss'], label='Discriminator Loss'); plt.plot(history['gen_loss_adv'], label='Gen Adversarial Loss')
- plt.plot(history['gen_loss_l1'], label=f'Gen L1 Loss (Unscaled)'); plt.xlabel('Epoch'); plt.ylabel('Loss')
- plt.legend(); plt.title('Training Losses'); plt.grid(True)
- plt.savefig(TRAINING_OUTPUT_DIR / 'training_losses.png'); plt.show()
- # --- 9. Prepare Fixed Test Input (Once) ---
- print("Preparing fixed test input for visualization...")
- test_input_tuple = None # Will hold (masked[1,1,H,W], target[1,1,H,W], mask[1,1,H,W], [orig_shape], [audio_path])
- try:
- if train_dataset and len(train_dataset.audio_files) > 0:
- found_test_item = False; attempts = 0; max_attempts = min(len(train_dataset), 20) # Try up to 20 times
- while not found_test_item and attempts < max_attempts:
- attempts += 1
- test_idx = random.randrange(len(train_dataset))
- print(f"Attempt {attempts}/{max_attempts}: Trying index {test_idx}...")
- # Get all items from dataset __getitem__
- test_masked_spec, test_target_spec, test_mask, test_original_shape, test_audio_path = train_dataset[test_idx]
- # Check if all components are valid
- if all(item is not None for item in [test_masked_spec, test_target_spec, test_mask, test_original_shape, test_audio_path]):
- # Unsqueeze to create a batch of 1 and store in tuple
- test_input_tuple = (
- test_masked_spec.unsqueeze(0), #[1, 1, H, W]
- test_target_spec.unsqueeze(0), #[1, 1, H, W]
- test_mask.unsqueeze(0), #[1, 1, H, W]
- [test_original_shape], # List containing one shape tuple
- [test_audio_path] # List containing one path string
- )
- print(f"Fixed test input prepared successfully from: {test_audio_path}")
- found_test_item = True
- else:
- print(f"Failed to process file at index {test_idx}. Retrying...")
- if not found_test_item: print(f"Failed to prepare test input after {max_attempts} attempts.")
- else: print("Cannot prepare test input: Dataset empty or not initialized.")
- except Exception as e:
- print(f"Error preparing test input: {e}")
- print(f"Traceback: {traceback.format_exc()}")
- # --- 10. Run Training ---
- if __name__ == "__main__":
- try:
- if train_dataloader is None or len(train_dataloader) == 0: raise RuntimeError("DataLoader invalid/empty.")
- if test_input_tuple is None: print("WARNING: Test input missing. Sample generation will be skipped.")
- train(train_dataloader, EPOCHS, CHECKPOINT_SAVE_INTERVAL)
- print("\nTraining finished.")
- print(f"Outputs saved in: {TRAINING_OUTPUT_DIR}")
- except Exception as e:
- print(f"\nAn critical error occurred during training: {e}")
- print(f"Traceback: {traceback.format_exc()}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement