Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import math
- import numpy as np
- import torch.multiprocessing as mp
- import torch.distributed as dist
- def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=0, padding=1, isReLU=True):
- if isReLU:
- return nn.Sequential(
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
- padding=padding, bias=True),
- nn.LeakyReLU(0.1, inplace=True)
- )
- else:
- return nn.Sequential(
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
- dilation=dilation,
- padding=padding, bias=True),
- nn.LeakyReLU(0.1, inplace=True)
- )
- class DefaultModel(nn.Module):
- def __init__(self, cfg):
- super(DefaultModel, self).__init__()
- self.cfg = cfg
- self.conv_1x1 = nn.Sequential(conv(3, 32, kernel_size=3, stride=1, dilation=0, padding=1),
- nn.MaxPool2d(2),
- conv(32, self.cfg.var.ndepth, kernel_size=3, stride=1, dilation=0, padding=1),
- nn.MaxPool2d(2),
- nn.BatchNorm2d(64)
- )
- def num_parameters(self):
- return sum(
- [p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
- def init_weights(self):
- for layer in self.named_modules():
- if isinstance(layer, nn.Conv2d):
- nn.init.kaiming_normal_(layer.weight)
- if layer.bias is not None:
- nn.init.constant_(layer.bias, 0)
- elif isinstance(layer, nn.ConvTranspose2d):
- nn.init.kaiming_normal_(layer.weight)
- if layer.bias is not None:
- nn.init.constant_(layer.bias, 0)
- def forward(self, input):
- images = input["rgb"][:, -1, :, :, :]
- output = self.conv_1x1(images)
- output_refined = F.interpolate(output, None, 4.)
- output_lsm = F.log_softmax(output, dim=1)
- output_refined_lsm = F.log_softmax(output_refined, dim=1)
- return {"output": [output_lsm], "output_refined": [output_refined_lsm], "flow": None, "flow_refined": None}
- # Set Flags
- import os
- import time
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = "8081"
- os.environ["WORLD_SIZE"] = "1"
- os.environ["RANK"] = str(0)
- from path import Path
- from easydict import EasyDict
- import json
- def worker(id):
- with open("default.json") as f:
- cfg = EasyDict(json.load(f))
- dist.init_process_group(backend="nccl", init_method="env://",
- world_size=1, rank=0)
- model = DefaultModel(cfg).cuda()
- # Manual
- model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
- optimizer = torch.optim.Adam(model.parameters(), lr=cfg.train.lr,
- betas=(cfg.train.momentum, cfg.train.beta))
- for epoch in range(cfg.train.epoch_num):
- for iter in range(10):
- input_left = {"rgb": torch.zeros((2, 2, 3, 100, 100)).cuda()}
- input_right = {"rgb": torch.zeros((2, 2, 3, 100, 100)).cuda()}
- truth_left = {"soft_labels": torch.ones((2, 64, 25, 25)).cuda()}
- truth_right = {"soft_labels": torch.ones((2, 64, 25, 25)).cuda()}
- # Model
- output_left = model(input_left)
- output_right = model(input_right)
- #loss = loss_func([output_left, output_right], [truth_left, truth_right])
- loss = torch.sum(output_left["output"][0] - 0) + torch.sum(output_right["output"][0] - 0)
- print("Try")
- # Opt
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- print("wait")
- dist.destroy_process_group()
- if __name__ == '__main__':
- # Spawn Worker
- mp.spawn(worker, nprocs=1, args=())
- #
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement