Advertisement
sicanus

Untitled

Mar 29th, 2025
346
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.06 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torch_geometric.loader import DataLoader
  5. from torch_geometric.nn import MessagePassing
  6. import numpy as np
  7. import torch.nn.functional as F
  8. from sklearn.metrics import f1_score
  9. from copy import deepcopy
  10.  
  11. # You can copy/paste helper functions from the task description here.
  12.  
  13. # Do not change function signature
  14. def init_model():
  15.   # Your code
  16.   device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  17.   model = MazeGNN(input_dim=2, hidden_dim=64, output_dim=2, dropout=0.3).to(device)
  18.   return model
  19.  
  20. def train_model(model, train_generator):
  21.   dataset = train_generator(n_samples=10000)
  22.   # Your code
  23.   criterion = torch.nn.NLLLoss()
  24.   optimizer = optim.Adam(model.parameters())
  25.  
  26.   val_split = 0.2
  27.   train_size = int(val_split*len(dataset))
  28.   train_loader = DataLoader(dataset[:train_size], batch_size=1, shuffle=True)
  29.   val_set = DataLoader(dataset[train_size:], batch_size = 1)
  30.   device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  31.  
  32.   model.train()
  33.  
  34.   worst_loss = -1
  35.   best_model = None
  36.   epochs = 4
  37.  
  38.   for epoch in range(epochs):
  39.       running_loss = 0.0
  40.       for i, data in enumerate(train_loader):
  41.           optimizer.zero_grad()
  42.           data = data.to(device)
  43.  
  44.           # could change additional parameters here
  45.           pred = model(data, data.num_nodes)
  46.  
  47.           loss = criterion(pred, data.y.to(torch.long))
  48.  
  49.           loss.backward()
  50.           torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  51.           optimizer.step()
  52.  
  53.           running_loss += loss.item()
  54.       ss = eval_model(model, val_set)
  55.  
  56.       graph_val = float((ss.split(" ")[-1]))
  57.       print(f'Epoch: {epoch + 1} loss: {running_loss / len(train_loader.dataset):.5f} \t {ss}')
  58.       comp = (-graph_val, running_loss)
  59.       if worst_loss == -1 or comp < worst_loss:
  60.           worst_loss = comp
  61.           best_model = deepcopy(model)
  62.           print("store new best model", comp)
  63.  
  64.       running_loss = 0.0
  65.   return best_model
  66.  
  67.  
  68.  
  69. # --- Helper Functions (Copied from description) ---
  70.  
  71. def eval_model(model, dataset, mode=None):
  72.   model.eval()
  73.   device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  74.   acc = 0
  75.   tot_nodes = 0
  76.   tot_graphs = 0
  77.   perf = 0
  78.   gpred = []
  79.   gsol = []
  80.  
  81.   for step, batch in enumerate(dataset):
  82.     if batch is None: continue
  83.     n = batch.num_nodes
  84.     with torch.no_grad():
  85.       batch = batch.to(device)
  86.       try:
  87.        pred = model(batch, n)
  88.       except TypeError:
  89.         pred = model(batch)
  90.  
  91.     if mode == "small":
  92.       if n > 4*4:
  93.         continue
  94.     elif mode == "medium":
  95.       if n <= 4*4 or n > 8*8:
  96.         continue
  97.     elif mode == "large":
  98.       if n <= 8*8 or n > 16*16:
  99.         continue
  100.     elif mode == "xlarge":
  101.       if n <= 16*16 or n > 32*32:
  102.         continue
  103.  
  104.  
  105.     y_pred = torch.argmax(pred,dim=1)
  106.     tot_nodes += n
  107.     tot_graphs += batch.num_graphs
  108.  
  109.     if hasattr(batch, 'y') and batch.y is not None and len(batch.y) == n:
  110.       graph_acc = torch.sum(y_pred == batch.y).item()
  111.       acc += graph_acc
  112.       for p in y_pred:
  113.         gpred.append(int(p.item()))
  114.       for p in batch.y:
  115.         gsol.append(int(p.item()))
  116.       if graph_acc == n:
  117.         perf += 1
  118.     else:
  119.       print(f"Warning: Missing or malformed ground truth for graph {step}. Skipping accuracy calculation for this graph.")
  120.  
  121.  
  122.   if tot_nodes == 0 or tot_graphs == 0:
  123.     return "node accuracy: N/A | node f1 score: N/A | graph accuracy: N/A (No valid graphs processed)"
  124.  
  125.   gpred_tensor = torch.tensor(gpred, device='cpu')
  126.   gsol_tensor = torch.tensor(gsol, device='cpu')
  127.  
  128.   f1score = f1_score(gsol_tensor.numpy(), gpred_tensor.numpy(), average='binary', zero_division=0)
  129.  
  130.   return f"node accuracy: {acc/tot_nodes:.3f} | node f1 score: {f1score:.3f} | graph accuracy: {perf/tot_graphs:.3f}"
  131.  
  132. class MazeConv(MessagePassing):
  133.   def __init__(self, hidden_dim, dropout=0.2):
  134.     super(MazeConv, self).__init__(aggr='add')
  135.     self.dropout = dropout
  136.    
  137.     self.mlp_message = nn.Sequential(
  138.       nn.Linear(2 * hidden_dim, hidden_dim),
  139.       nn.ReLU(),
  140.       nn.Dropout(self.dropout),
  141.       nn.Linear(hidden_dim, hidden_dim),
  142.     )
  143.    
  144.     self.mlp_update = nn.Sequential(
  145.       nn.Linear(hidden_dim, hidden_dim),
  146.       nn.ReLU(),
  147.       nn.Dropout(self.dropout)
  148.     )
  149.    
  150.     self.norm = nn.LayerNorm(hidden_dim)
  151.  
  152.   def forward(self, x, edge_index):
  153.     aggregated_messages = self.propagate(edge_index, x=x)
  154.  
  155.     out = x + aggregated_messages
  156.     out = self.mlp_update(out)
  157.     out = self.norm(out)
  158.  
  159.     return out
  160.  
  161.   def message(self, x_j, x_i):
  162.     edge_features = torch.cat([x_i, x_j], dim=-1)
  163.  
  164.     msg = self.mlp_message(edge_features)
  165.     return msg
  166.  
  167. class MazeGNN(torch.nn.Module):
  168.   def __init__(self, input_dim=2, hidden_dim=64, output_dim=2, dropout=0.3):
  169.     super().__init__()
  170.     self.dropout = dropout
  171.     self.hidden_dim = hidden_dim
  172.     self.input_dim = input_dim
  173.     self.encoder = self.get_mlp(input_dim, hidden_dim * 2, hidden_dim)
  174.     self.decoder = self.get_mlp(hidden_dim, hidden_dim * 2, output_dim, last_relu=False)
  175.     self.conv = MazeConv(hidden_dim, dropout=self.dropout)
  176.     self.pre_conv_mlp = self.get_mlp(hidden_dim + input_dim, hidden_dim * 2, hidden_dim)
  177.  
  178.   def get_mlp(self, input_dim, hidden_layer_dim, output_dim, last_relu=True):
  179.     modules = [
  180.       torch.nn.Linear(input_dim, int(hidden_layer_dim)),
  181.       torch.nn.ReLU(),
  182.       torch.nn.Dropout(self.dropout),
  183.       torch.nn.Linear(int(hidden_layer_dim), output_dim)
  184.     ]
  185.     if last_relu:
  186.       modules.append(torch.nn.ReLU())
  187.     return torch.nn.Sequential(*modules)
  188.  
  189.   def forward(self, data, num_nodes):
  190.     x, edge_index = data.x, data.edge_index
  191.     original_input = x
  192.     x = self.encoder(x)
  193.     num_layers = 32
  194.  
  195.     for i in range(num_layers):
  196.       combined_features = torch.cat([x, original_input], dim=-1)
  197.       processed_features = self.pre_conv_mlp(combined_features)
  198.       x = self.conv(processed_features, edge_index)
  199.  
  200.     x = self.decoder(x)
  201.     return F.log_softmax(x, dim=1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement