Advertisement
Guest User

Untitled

a guest
Aug 24th, 2019
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.55 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.init import kaiming_normal_, constant_
  4. import matplotlib.pyplot as plt
  5.  
  6. def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
  7. if batchNorm:
  8. return nn.Sequential(
  9. nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
  10. nn.BatchNorm2d(out_planes),
  11. #nn.ReLU()
  12. nn.ReLU()
  13. )
  14. else:
  15. return nn.Sequential(
  16. nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
  17. nn.ReLU()
  18. )
  19.  
  20.  
  21. def predict_flow(in_planes):
  22. return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=False)
  23.  
  24.  
  25. def deconv(in_planes, out_planes):
  26. return nn.Sequential(
  27. nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False),
  28. nn.LeakyReLU(0.1,inplace=True)
  29. )
  30.  
  31.  
  32. def crop_like(input, target):
  33. if input.size()[2:] == target.size()[2:]:
  34. return input
  35. else:
  36. return input[:, :, :target.size(2), :target.size(3)]
  37.  
  38. class GridNet(nn.Module):
  39. expansion = 1
  40.  
  41. def __init__(self,batchNorm=True):
  42. super(GridNet,self).__init__()
  43.  
  44. self.batchNorm = True
  45.  
  46. self.conv1 = conv(self.batchNorm, 3, 16, kernel_size=7, stride=2)
  47. self.conv2 = conv(self.batchNorm, 16, 32, kernel_size=5, stride=2)
  48. self.conv3 = conv(self.batchNorm, 32, 64, kernel_size=3, stride=2)
  49. self.conv3_1 = conv(self.batchNorm, 64, 128, stride=2)
  50. self.conv4 = conv(self.batchNorm, 128, 256, stride=2)
  51. self.conv5 = conv(self.batchNorm, 256, 256, stride=2)
  52. self.conv6 = conv(self.batchNorm, 256, 512, stride=2)
  53.  
  54.  
  55. self.drift = nn.Sequential(nn.Dropout(0.5),nn.Linear(2048,1024), nn.ReLU())
  56. self.odom = nn.Sequential(nn.Dropout(0.5),nn.Linear(2048,1024), nn.ReLU())
  57.  
  58. self.trans_drift = nn.Sequential(nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU())
  59. self.rot_drift = nn.Sequential(nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU())
  60.  
  61. self.trans = nn.Sequential(nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU())
  62. self.rot1 = nn.Sequential(nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU())
  63.  
  64. self.rot2 = nn.Sequential(nn.Dropout(0.5), nn.Linear(512, 112))
  65. self.t = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, 20), nn.BatchNorm1d(20))
  66.  
  67. self.drift_rot = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, 60))
  68. self.drift_x = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, 40))
  69. self.drift_y = nn.Sequential(nn.Dropout(0.5),nn.Linear(512, 40), nn.BatchNorm1d(40))
  70.  
  71.  
  72. def freeze_drift(self):
  73. print("freezing drift")
  74. for param in self.drift_x.parameters():
  75. param.requires_grad = False
  76.  
  77. for param in self.drift_y.parameters():
  78. param.requires_grad = False
  79.  
  80. for param in self.drift_rot.parameters():
  81. param.requires_grad = False
  82. def freeze_odom(self):
  83. print("freezing odom")
  84. for param in self.x.parameters():
  85. param.requires_grad = False
  86.  
  87. for param in self.y.parameters():
  88. param.requires_grad = False
  89.  
  90. for param in self.rot.parameters():
  91. param.requires_grad = False
  92.  
  93. for param in self.rot1.parameters():
  94. param.requires_grad = False
  95.  
  96. for param in self.xy.parameters():
  97. param.requires_grad = False
  98.  
  99. def forward(self, input):
  100.  
  101. out_conv1 = (self.conv1(input))
  102. out_conv2 = self.conv2(out_conv1)
  103.  
  104. out_conv3 = self.conv3(out_conv2)
  105.  
  106. out_conv3_1 = self.conv3_1(out_conv3)
  107. out_conv4 = self.conv4(out_conv3_1)
  108.  
  109. out_conv5 = self.conv5(out_conv4)
  110.  
  111. out_conv6= self.conv6(out_conv5)
  112.  
  113. out = out_conv6.view(out_conv6.size(0), -1)
  114.  
  115. out_drift = self.drift(out)
  116. out_odom = self.odom(out)
  117.  
  118. out_trans = self.trans(out_odom)
  119. out_rot = self.rot1(out_odom)
  120.  
  121. out_drift_trans = self.trans_drift(out_drift)
  122. out_drift_rot = self.rot_drift(out_drift)
  123.  
  124. t = self.t(out_trans)
  125. rot = self.rot2(out_rot)
  126. drift_x = self.drift_x(out_drift_trans)
  127. drift_y = self.drift_y(out_drift_trans)
  128. drift_rot = self.drift_rot(out_drift_rot)
  129.  
  130. return t,rot, drift_x, drift_y, drift_rot
  131.  
  132. def weight_parameters(self):
  133. return [param for name, param in self.named_parameters() if 'weight' in name]
  134.  
  135. def bias_parameters(self):
  136. return [param for name, param in self.named_parameters() if 'bias' in name]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement