Advertisement
Guest User

Untitled

a guest
May 27th, 2025
39
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.77 KB | None | 0 0
  1. import os
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.utils.data import Dataset, DataLoader, Subset
  7. from sklearn.model_selection import train_test_split
  8.  
  9. class AmplitudeDataset(Dataset):
  10.     def __init__(self, input_dir_ampl, target_dir_ampl):
  11.         self.input_dir_ampl = input_dir_ampl
  12.         self.target_dir_ampl = target_dir_ampl
  13.         self.fileids = sorted([f.split('_')[1] for f in os.listdir(input_dir_ampl)])
  14.  
  15.     def __len__(self):
  16.         return len(self.fileids)
  17.  
  18.     def __getitem__(self, idx):
  19.         input_ampl = np.load(os.path.join(self.input_dir_ampl, f'tweezer_{self.fileids[idx]}_ampl_input.npy')).astype(np.float32)
  20.         target_ampl = np.load(os.path.join(self.target_dir_ampl, f'tweezer_{self.fileids[idx]}_wgs_ampl.npy')).astype(np.float32)
  21.  
  22.         return torch.tensor(input_ampl).unsqueeze(0), torch.tensor(target_ampl).unsqueeze(0)
  23.  
  24. class SingleOutputFCN(nn.Module):
  25.     def __init__(self):
  26.         super(SingleOutputFCN, self).__init__()
  27.         self.net = nn.Sequential(
  28.             nn.Conv2d(1, 16, kernel_size=3, padding='same'),
  29.             nn.LeakyReLU(0.1),
  30.             nn.Conv2d(16, 16, kernel_size=3, padding='same'),
  31.             nn.LeakyReLU(0.1),
  32.             nn.Conv2d(16, 1, kernel_size=3, padding='same'),
  33.         )
  34.  
  35.     def forward(self, x):
  36.         return self.net(x)
  37.  
  38. def train(model, dataloader, device, epochs=4, lr=1e-3):
  39.     model.to(device)
  40.     model.train()
  41.     optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  42.     criterion = nn.MSELoss()
  43.  
  44.     for epoch in range(epochs):
  45.         total_loss = 0
  46.         for inputs, targets in dataloader:
  47.             inputs = inputs.to(device)
  48.             targets = targets.to(device)
  49.  
  50.             optimizer.zero_grad()
  51.             pred = model(inputs)
  52.             loss = criterion(pred, targets)
  53.             loss.backward()
  54.             optimizer.step()
  55.  
  56.             total_loss += loss.item()
  57.             pred_std = pred.std().item()
  58.  
  59.         print(f"[AMPL] Epoch {epoch+1}/{epochs} | Avg Loss: {total_loss / len(dataloader):.6f}")
  60.  
  61. def test(model, dataloader, device):
  62.     model.eval()
  63.     output_dir = './cnn_predictions/'
  64.     os.makedirs(output_dir, exist_ok=True)
  65.  
  66.     with torch.no_grad():
  67.         for i, (inputs, _) in enumerate(dataloader):
  68.             inputs = inputs.to(device)
  69.             outputs = model(inputs).cpu()
  70.             input_np = inputs.cpu().numpy()
  71.             output_np = outputs.cpu().numpy()
  72.  
  73.  
  74.             for j in range(inputs.shape[0]):
  75.                 idx = i * dataloader.batch_size + j
  76.                 file_id = dataloader.dataset.dataset.fileids[dataloader.dataset.indices[idx]] if isinstance(dataloader.dataset, Subset) else dataloader.dataset.fileids[idx]
  77.  
  78.                 np.save(os.path.join(output_dir, f'tweezer_{file_id}_predicted_ampl.npy'), output_np[j, 0])
  79.                 np.save(os.path.join(output_dir, f'tweezer_{file_id}_input_ampl.npy'), input_np[j, 0])
  80.  
  81. if __name__ == '__main__':
  82.     input_dir_ampl = './WGS_CNN_Cropped/ampl_input'
  83.     target_dir_ampl = './WGS_CNN_Cropped/wgs_amplitude'
  84.  
  85.     dataset = AmplitudeDataset(input_dir_ampl, target_dir_ampl)
  86.     train_ids, test_ids = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
  87.     train_dataset = Subset(dataset, train_ids)
  88.     test_dataset = Subset(dataset, test_ids)
  89.  
  90.     train_loader = DataLoader(train_dataset, batch_size=15, shuffle=True)
  91.     test_loader = DataLoader(test_dataset, batch_size=15, shuffle=False)
  92.  
  93.     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  94.  
  95.     model = SingleOutputFCN()
  96.     print("\n=== Training Amplitude Model ===")
  97.     train(model, train_loader, device)
  98.  
  99.     print("\n=== Testing Amplitude Model ===")
  100.     test(model, test_loader, device)
  101.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement