Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- from torch.nn.init import kaiming_normal_, constant_
- import matplotlib.pyplot as plt
- def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
- if batchNorm:
- return nn.Sequential(
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
- nn.BatchNorm2d(out_planes),
- #nn.ReLU()
- nn.ReLU()
- )
- else:
- return nn.Sequential(
- nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
- nn.ReLU()
- )
- def predict_flow(in_planes):
- return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=False)
- def deconv(in_planes, out_planes):
- return nn.Sequential(
- nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False),
- nn.LeakyReLU(0.1,inplace=True)
- )
- def crop_like(input, target):
- if input.size()[2:] == target.size()[2:]:
- return input
- else:
- return input[:, :, :target.size(2), :target.size(3)]
- class GridNet(nn.Module):
- expansion = 1
- def __init__(self,batchNorm=True):
- super(GridNet,self).__init__()
- self.batchNorm = True
- self.conv1 = conv(self.batchNorm, 3, 16, kernel_size=7, stride=2)
- self.conv2 = conv(self.batchNorm, 16, 32, kernel_size=5, stride=2)
- self.conv3 = conv(self.batchNorm, 32, 64, kernel_size=3, stride=2)
- self.conv3_1 = conv(self.batchNorm, 64, 128, stride=2)
- self.conv4 = conv(self.batchNorm, 128, 256, stride=2)
- self.conv5 = conv(self.batchNorm, 256, 256, stride=2)
- self.conv6 = conv(self.batchNorm, 256, 512, stride=2)
- self.drift = nn.Sequential(nn.Dropout(0.5),nn.Linear(2048,1024), nn.ReLU())
- self.odom = nn.Sequential(nn.Dropout(0.5),nn.Linear(2048,1024), nn.ReLU())
- self.trans_drift = nn.Sequential(nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU())
- self.rot_drift = nn.Sequential(nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU())
- self.trans = nn.Sequential(nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU())
- self.rot1 = nn.Sequential(nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU())
- self.rot2 = nn.Sequential(nn.Dropout(0.5), nn.Linear(512, 112))
- self.t = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, 20), nn.BatchNorm1d(20))
- self.drift_rot = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, 60))
- self.drift_x = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, 40))
- self.drift_y = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, 40), nn.BatchNorm1d(40))
- def freeze_drift(self):
- print("freezing drift")
- for param in self.drift_x.parameters():
- param.requires_grad = False
- for param in self.drift_y.parameters():
- param.requires_grad = False
- for param in self.drift_rot.parameters():
- param.requires_grad = False
- def freeze_odom(self):
- print("freezing odom")
- for param in self.x.parameters():
- param.requires_grad = False
- for param in self.y.parameters():
- param.requires_grad = False
- for param in self.rot.parameters():
- param.requires_grad = False
- for param in self.rot1.parameters():
- param.requires_grad = False
- for param in self.xy.parameters():
- param.requires_grad = False
- def forward(self, input):
- out_conv1 = (self.conv1(input))
- out_conv2 = self.conv2(out_conv1)
- out_conv3 = self.conv3(out_conv2)
- out_conv3_1 = self.conv3_1(out_conv3)
- out_conv4 = self.conv4(out_conv3_1)
- out_conv5 = self.conv5(out_conv4)
- out_conv6= self.conv6(out_conv5)
- out = out_conv6.view(out_conv6.size(0), -1)
- out_drift = self.drift(out)
- out_odom = self.odom(out)
- out_trans = self.trans(out_odom)
- out_rot = self.rot1(out_odom)
- out_drift_trans = self.trans_drift(out_drift)
- out_drift_rot = self.rot_drift(out_drift)
- t = self.t(out_trans)
- rot = self.rot2(out_rot)
- drift_x = self.drift_x(out_drift_trans)
- drift_y = self.drift_y(out_drift_trans)
- drift_rot = self.drift_rot(out_drift_rot)
- return t,rot, drift_x, drift_y, drift_rot
- def weight_parameters(self):
- return [param for name, param in self.named_parameters() if 'weight' in name]
- def bias_parameters(self):
- return [param for name, param in self.named_parameters() if 'bias' in name]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement