Advertisement
Guest User

Untitled

a guest
May 26th, 2020
94
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import math
  5. import numpy as np
  6. import torch.multiprocessing as mp
  7. import torch.distributed as dist
  8.  
  9. def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=0, padding=1, isReLU=True):
  10.     if isReLU:
  11.         return nn.Sequential(
  12.             nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
  13.                       padding=padding, bias=True),
  14.             nn.LeakyReLU(0.1, inplace=True)
  15.         )
  16.     else:
  17.         return nn.Sequential(
  18.             nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
  19.                       dilation=dilation,
  20.                       padding=padding, bias=True),
  21.             nn.LeakyReLU(0.1, inplace=True)
  22.         )
  23.  
  24. class DefaultModel(nn.Module):
  25.     def __init__(self, cfg):
  26.         super(DefaultModel, self).__init__()
  27.         self.cfg = cfg
  28.         self.conv_1x1 = nn.Sequential(conv(3, 32, kernel_size=3, stride=1, dilation=0, padding=1),
  29.                                       nn.MaxPool2d(2),
  30.                                       conv(32, self.cfg.var.ndepth, kernel_size=3, stride=1, dilation=0, padding=1),
  31.                                       nn.MaxPool2d(2),
  32.                                       nn.BatchNorm2d(64)
  33.                                       )
  34.  
  35.     def num_parameters(self):
  36.         return sum(
  37.             [p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
  38.  
  39.     def init_weights(self):
  40.         for layer in self.named_modules():
  41.             if isinstance(layer, nn.Conv2d):
  42.                 nn.init.kaiming_normal_(layer.weight)
  43.                 if layer.bias is not None:
  44.                     nn.init.constant_(layer.bias, 0)
  45.  
  46.             elif isinstance(layer, nn.ConvTranspose2d):
  47.                 nn.init.kaiming_normal_(layer.weight)
  48.                 if layer.bias is not None:
  49.                     nn.init.constant_(layer.bias, 0)
  50.  
  51.     def forward(self, input):
  52.         images = input["rgb"][:, -1, :, :, :]
  53.         output = self.conv_1x1(images)
  54.         output_refined = F.interpolate(output, None, 4.)
  55.         output_lsm = F.log_softmax(output, dim=1)
  56.         output_refined_lsm = F.log_softmax(output_refined, dim=1)
  57.         return {"output": [output_lsm], "output_refined": [output_refined_lsm], "flow": None, "flow_refined": None}
  58.  
  59.  
  60. # Set Flags
  61. import os
  62. import time
  63. os.environ["MASTER_ADDR"] = "localhost"
  64. os.environ["MASTER_PORT"] = "8081"
  65. os.environ["WORLD_SIZE"] = "1"
  66. os.environ["RANK"] = str(0)
  67. from path import Path
  68. from easydict import EasyDict
  69. import json
  70.  
  71. def worker(id):
  72.  
  73.     with open("default.json") as f:
  74.         cfg = EasyDict(json.load(f))
  75.  
  76.     dist.init_process_group(backend="nccl", init_method="env://",
  77.                             world_size=1, rank=0)
  78.  
  79.     model = DefaultModel(cfg).cuda()
  80.  
  81.     # Manual
  82.     model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
  83.     optimizer = torch.optim.Adam(model.parameters(), lr=cfg.train.lr,
  84.                                  betas=(cfg.train.momentum, cfg.train.beta))
  85.  
  86.     for epoch in range(cfg.train.epoch_num):
  87.  
  88.         for iter in range(10):
  89.             input_left = {"rgb": torch.zeros((2, 2, 3, 100, 100)).cuda()}
  90.             input_right = {"rgb": torch.zeros((2, 2, 3, 100, 100)).cuda()}
  91.             truth_left = {"soft_labels": torch.ones((2, 64, 25, 25)).cuda()}
  92.             truth_right = {"soft_labels": torch.ones((2, 64, 25, 25)).cuda()}
  93.             # Model
  94.             output_left = model(input_left)
  95.             output_right = model(input_right)
  96.             #loss = loss_func([output_left, output_right], [truth_left, truth_right])
  97.  
  98.             loss = torch.sum(output_left["output"][0] - 0) + torch.sum(output_right["output"][0] - 0)
  99.  
  100.             print("Try")
  101.  
  102.             # Opt
  103.             optimizer.zero_grad()
  104.             loss.backward()
  105.             optimizer.step()
  106.  
  107.     print("wait")
  108.  
  109.     dist.destroy_process_group()
  110.  
  111.  
  112. if __name__ == '__main__':
  113.     # Spawn Worker
  114.     mp.spawn(worker, nprocs=1, args=())
  115.  
  116.  
  117.  
  118. #
Advertisement
RAW Paste Data Copied
Advertisement