Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import Dataset, DataLoader, Subset
- from sklearn.model_selection import train_test_split
- class AmplitudeDataset(Dataset):
- def __init__(self, input_dir_ampl, target_dir_ampl):
- self.input_dir_ampl = input_dir_ampl
- self.target_dir_ampl = target_dir_ampl
- self.fileids = sorted([f.split('_')[1] for f in os.listdir(input_dir_ampl)])
- def __len__(self):
- return len(self.fileids)
- def __getitem__(self, idx):
- input_ampl = np.load(os.path.join(self.input_dir_ampl, f'tweezer_{self.fileids[idx]}_ampl_input.npy')).astype(np.float32)
- target_ampl = np.load(os.path.join(self.target_dir_ampl, f'tweezer_{self.fileids[idx]}_wgs_ampl.npy')).astype(np.float32)
- return torch.tensor(input_ampl).unsqueeze(0), torch.tensor(target_ampl).unsqueeze(0)
- class SingleOutputFCN(nn.Module):
- def __init__(self):
- super(SingleOutputFCN, self).__init__()
- self.net = nn.Sequential(
- nn.Conv2d(1, 16, kernel_size=3, padding='same'),
- nn.LeakyReLU(0.1),
- nn.Conv2d(16, 16, kernel_size=3, padding='same'),
- nn.LeakyReLU(0.1),
- nn.Conv2d(16, 1, kernel_size=3, padding='same'),
- )
- def forward(self, x):
- return self.net(x)
- def train(model, dataloader, device, epochs=4, lr=1e-3):
- model.to(device)
- model.train()
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
- criterion = nn.MSELoss()
- for epoch in range(epochs):
- total_loss = 0
- for inputs, targets in dataloader:
- inputs = inputs.to(device)
- targets = targets.to(device)
- optimizer.zero_grad()
- pred = model(inputs)
- loss = criterion(pred, targets)
- loss.backward()
- optimizer.step()
- total_loss += loss.item()
- pred_std = pred.std().item()
- print(f"[AMPL] Epoch {epoch+1}/{epochs} | Avg Loss: {total_loss / len(dataloader):.6f}")
- def test(model, dataloader, device):
- model.eval()
- output_dir = './cnn_predictions/'
- os.makedirs(output_dir, exist_ok=True)
- with torch.no_grad():
- for i, (inputs, _) in enumerate(dataloader):
- inputs = inputs.to(device)
- outputs = model(inputs).cpu()
- input_np = inputs.cpu().numpy()
- output_np = outputs.cpu().numpy()
- for j in range(inputs.shape[0]):
- idx = i * dataloader.batch_size + j
- file_id = dataloader.dataset.dataset.fileids[dataloader.dataset.indices[idx]] if isinstance(dataloader.dataset, Subset) else dataloader.dataset.fileids[idx]
- np.save(os.path.join(output_dir, f'tweezer_{file_id}_predicted_ampl.npy'), output_np[j, 0])
- np.save(os.path.join(output_dir, f'tweezer_{file_id}_input_ampl.npy'), input_np[j, 0])
- if __name__ == '__main__':
- input_dir_ampl = './WGS_CNN_Cropped/ampl_input'
- target_dir_ampl = './WGS_CNN_Cropped/wgs_amplitude'
- dataset = AmplitudeDataset(input_dir_ampl, target_dir_ampl)
- train_ids, test_ids = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
- train_dataset = Subset(dataset, train_ids)
- test_dataset = Subset(dataset, test_ids)
- train_loader = DataLoader(train_dataset, batch_size=15, shuffle=True)
- test_loader = DataLoader(test_dataset, batch_size=15, shuffle=False)
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model = SingleOutputFCN()
- print("\n=== Training Amplitude Model ===")
- train(model, train_loader, device)
- print("\n=== Testing Amplitude Model ===")
- test(model, test_loader, device)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement