Guest User

Untitled

a guest
May 23rd, 2021
82
28 days
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import argparse
  2. import time
  3. import numpy as np
  4. import networkx as nx
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import dgl
  9. import sys
  10. import random
  11. from pathlib import Path
  12. from dgl.data import register_data_args
  13.  
  14. import gcn
  15. from gcn import GCN
  16. from MISDataset import MISDataset
  17.  
  18. def main():
  19.     self_loop = True
  20.     cuda = True
  21.     prob_maps = 2
  22.  
  23.     # load datasets
  24.     training_graphs = []
  25.     validation_graphs = []
  26.  
  27.     graph_path = Path(__file__).parent / "many_graphs"
  28.     pathlist = graph_path.rglob('*.gpickle')
  29.  
  30.     training_size = 38000
  31.     validation_size = 2000
  32.  
  33.     pathlist = list(pathlist)
  34.     random.shuffle(pathlist)
  35.     pathlist = pathlist[:training_size+validation_size]
  36.  
  37.     for idx, graph in enumerate(pathlist):
  38.         ds = MISDataset(graph.resolve())
  39.         g = ds[0]
  40.         if cuda:
  41.             g = g.int().to(0)
  42.         if self_loop:
  43.             g = dgl.remove_self_loop(g)
  44.             g = dgl.add_self_loop(g)
  45.         else:
  46.             g = dgl.remove_self_loop(g)
  47.  
  48.         if idx >= training_size:
  49.             validation_graphs.append(g)
  50.         else:
  51.             training_graphs.append(g)
  52.  
  53.     print(f"Loaded {len(training_graphs)} graphs for training and {len(validation_graphs)} graphs for validation.")
  54.  
  55.     model = GCN(1, # 1 input feature - the weight
  56.                 32, #32 dimensions in hidden layers
  57.                 prob_maps,  #probability maps
  58.                 20, #20 hidden layers
  59.                 F.relu,
  60.                 0)
  61.  
  62.     if cuda:
  63.         model.cuda()
  64.  
  65.     loss_fcn = gcn.hindsight_loss
  66.  
  67.     # use optimizer
  68.     optimizer = torch.optim.Adam(model.parameters(),
  69.                                  lr=0.001,
  70.                                  weight_decay=5e-4)
  71.  
  72.     dur = []
  73.     num_epochs = 5
  74.     status_update_every = 100
  75.     for epoch in range(num_epochs + 1):
  76.         epoch_losses = list()
  77.         for gidx, graph in enumerate(training_graphs):
  78.             features = graph.ndata['weight']
  79.             labels = graph.ndata['label']
  80.             model.train()
  81.  
  82.             # forward
  83.             output = model(graph, features)
  84.             loss = loss_fcn(output, labels)
  85.             epoch_losses.append(float(loss))
  86.  
  87.             optimizer.zero_grad()
  88.             loss.backward()
  89.             optimizer.step()
  90.  
  91.         torch.save(model.state_dict(), f"model{prob_maps}_{epoch}.torch")
  92.  
  93.     torch.save(model.state_dict(), "final_model.torch")
  94.  
  95. if __name__ == '__main__':
  96.     main()
RAW Paste Data