Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch_geometric.loader import DataLoader
- from torch_geometric.nn import MessagePassing
- import numpy as np
- import torch.nn.functional as F
- from sklearn.metrics import f1_score
- from copy import deepcopy
- # You can copy/paste helper functions from the task description here.
- # Do not change function signature
- def init_model():
- # Your code
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- model = MazeGNN(input_dim=2, hidden_dim=64, output_dim=2, dropout=0.3).to(device)
- return model
- def train_model(model, train_generator):
- dataset = train_generator(n_samples=10000)
- # Your code
- criterion = torch.nn.NLLLoss()
- optimizer = optim.Adam(model.parameters())
- val_split = 0.2
- train_size = int(val_split*len(dataset))
- train_loader = DataLoader(dataset[:train_size], batch_size=1, shuffle=True)
- val_set = DataLoader(dataset[train_size:], batch_size = 1)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- model.train()
- worst_loss = -1
- best_model = None
- epochs = 4
- for epoch in range(epochs):
- running_loss = 0.0
- for i, data in enumerate(train_loader):
- optimizer.zero_grad()
- data = data.to(device)
- # could change additional parameters here
- pred = model(data, data.num_nodes)
- loss = criterion(pred, data.y.to(torch.long))
- loss.backward()
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
- optimizer.step()
- running_loss += loss.item()
- ss = eval_model(model, val_set)
- graph_val = float((ss.split(" ")[-1]))
- print(f'Epoch: {epoch + 1} loss: {running_loss / len(train_loader.dataset):.5f} \t {ss}')
- comp = (-graph_val, running_loss)
- if worst_loss == -1 or comp < worst_loss:
- worst_loss = comp
- best_model = deepcopy(model)
- print("store new best model", comp)
- running_loss = 0.0
- return best_model
- # --- Helper Functions (Copied from description) ---
- def eval_model(model, dataset, mode=None):
- model.eval()
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- acc = 0
- tot_nodes = 0
- tot_graphs = 0
- perf = 0
- gpred = []
- gsol = []
- for step, batch in enumerate(dataset):
- if batch is None: continue
- n = batch.num_nodes
- with torch.no_grad():
- batch = batch.to(device)
- try:
- pred = model(batch, n)
- except TypeError:
- pred = model(batch)
- if mode == "small":
- if n > 4*4:
- continue
- elif mode == "medium":
- if n <= 4*4 or n > 8*8:
- continue
- elif mode == "large":
- if n <= 8*8 or n > 16*16:
- continue
- elif mode == "xlarge":
- if n <= 16*16 or n > 32*32:
- continue
- y_pred = torch.argmax(pred,dim=1)
- tot_nodes += n
- tot_graphs += batch.num_graphs
- if hasattr(batch, 'y') and batch.y is not None and len(batch.y) == n:
- graph_acc = torch.sum(y_pred == batch.y).item()
- acc += graph_acc
- for p in y_pred:
- gpred.append(int(p.item()))
- for p in batch.y:
- gsol.append(int(p.item()))
- if graph_acc == n:
- perf += 1
- else:
- print(f"Warning: Missing or malformed ground truth for graph {step}. Skipping accuracy calculation for this graph.")
- if tot_nodes == 0 or tot_graphs == 0:
- return "node accuracy: N/A | node f1 score: N/A | graph accuracy: N/A (No valid graphs processed)"
- gpred_tensor = torch.tensor(gpred, device='cpu')
- gsol_tensor = torch.tensor(gsol, device='cpu')
- f1score = f1_score(gsol_tensor.numpy(), gpred_tensor.numpy(), average='binary', zero_division=0)
- return f"node accuracy: {acc/tot_nodes:.3f} | node f1 score: {f1score:.3f} | graph accuracy: {perf/tot_graphs:.3f}"
- class MazeConv(MessagePassing):
- def __init__(self, hidden_dim, dropout=0.2):
- super(MazeConv, self).__init__(aggr='add')
- self.dropout = dropout
- self.mlp_message = nn.Sequential(
- nn.Linear(2 * hidden_dim, hidden_dim),
- nn.ReLU(),
- nn.Dropout(self.dropout),
- nn.Linear(hidden_dim, hidden_dim),
- )
- self.mlp_update = nn.Sequential(
- nn.Linear(hidden_dim, hidden_dim),
- nn.ReLU(),
- nn.Dropout(self.dropout)
- )
- self.norm = nn.LayerNorm(hidden_dim)
- def forward(self, x, edge_index):
- aggregated_messages = self.propagate(edge_index, x=x)
- out = x + aggregated_messages
- out = self.mlp_update(out)
- out = self.norm(out)
- return out
- def message(self, x_j, x_i):
- edge_features = torch.cat([x_i, x_j], dim=-1)
- msg = self.mlp_message(edge_features)
- return msg
- class MazeGNN(torch.nn.Module):
- def __init__(self, input_dim=2, hidden_dim=64, output_dim=2, dropout=0.3):
- super().__init__()
- self.dropout = dropout
- self.hidden_dim = hidden_dim
- self.input_dim = input_dim
- self.encoder = self.get_mlp(input_dim, hidden_dim * 2, hidden_dim)
- self.decoder = self.get_mlp(hidden_dim, hidden_dim * 2, output_dim, last_relu=False)
- self.conv = MazeConv(hidden_dim, dropout=self.dropout)
- self.pre_conv_mlp = self.get_mlp(hidden_dim + input_dim, hidden_dim * 2, hidden_dim)
- def get_mlp(self, input_dim, hidden_layer_dim, output_dim, last_relu=True):
- modules = [
- torch.nn.Linear(input_dim, int(hidden_layer_dim)),
- torch.nn.ReLU(),
- torch.nn.Dropout(self.dropout),
- torch.nn.Linear(int(hidden_layer_dim), output_dim)
- ]
- if last_relu:
- modules.append(torch.nn.ReLU())
- return torch.nn.Sequential(*modules)
- def forward(self, data, num_nodes):
- x, edge_index = data.x, data.edge_index
- original_input = x
- x = self.encoder(x)
- num_layers = 32
- for i in range(num_layers):
- combined_features = torch.cat([x, original_input], dim=-1)
- processed_features = self.pre_conv_mlp(combined_features)
- x = self.conv(processed_features, edge_index)
- x = self.decoder(x)
- return F.log_softmax(x, dim=1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement