Advertisement
Guest User

Untitled

a guest
Sep 21st, 2019
133
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.37 KB | None | 0 0
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4.  
  5. class ResidualBlock(nn.Module):
  6. def __init__(self, in_channels, out_channels, stride=1, downsample=None):
  7. super(ResidualBlock, self).__init__()
  8. self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3,
  9. stride=stride, padding=1, bias=False)
  10. self.bn1 = nn.BatchNorm1d(out_channels)
  11. self.relu = nn.ReLU(inplace=True)
  12. self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3,
  13. stride=1, padding=1, bias=False)
  14. self.bn2 = nn.BatchNorm1d(out_channels)
  15. self.downsample = downsample
  16.  
  17. def forward(self, x):
  18. residual = x
  19. out = self.conv1(x)
  20. out = self.bn1(out)
  21. out = self.relu(out)
  22. out = self.conv2(out)
  23. out = self.bn2(out)
  24. if self.downsample:
  25. residual = self.downsample(x)
  26. out += residual
  27. out = self.relu(out)
  28. return out
  29.  
  30.  
  31. class ResNet1d(nn.Module):
  32. def __init__(self, channels = [6, 64, 64, 128, 256, 512, 2], dropout=0.85):
  33. super(ResNet1d, self).__init__()
  34.  
  35. self.conv1 = nn.Conv1d(channels[0], channels[1], kernel_size=7,
  36. stride=2, padding=3, bias=True)
  37. self.bn = nn.BatchNorm1d(channels[1])
  38. self.relu = nn.ReLU(inplace=True)
  39. self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  40.  
  41. self.layer1 = nn.Sequential(
  42. ResidualBlock(channels[1], channels[2]),
  43. ResidualBlock(channels[2], channels[2]))
  44.  
  45. self.downsample1 = nn.Sequential(
  46. nn.Conv1d(channels[2], channels[3], kernel_size=1, stride=2, bias=False),
  47. nn.BatchNorm1d(channels[3]))
  48.  
  49. self.layer2 = nn.Sequential(
  50. ResidualBlock(channels[2], channels[3], stride=2, downsample = self.downsample1),
  51. ResidualBlock(channels[3], channels[3]))
  52.  
  53. self.downsample2 = nn.Sequential(
  54. nn.Conv1d(channels[3], channels[4], kernel_size=1, stride=2, bias=False),
  55. nn.BatchNorm1d(channels[4]))
  56.  
  57. self.layer3 = nn.Sequential(
  58. ResidualBlock(channels[3], channels[4], stride=2, downsample = self.downsample2),
  59. ResidualBlock(channels[4], channels[4]))
  60.  
  61. self.downsample3 = nn.Sequential(
  62. nn.Conv1d(channels[4], channels[5], kernel_size=1, stride=2, bias=False),
  63. nn.BatchNorm1d(channels[5]))
  64.  
  65. self.layer4 = nn.Sequential(
  66. ResidualBlock(channels[4], channels[5], stride=2, downsample = self.downsample3),
  67. ResidualBlock(channels[5], channels[5]))
  68.  
  69. self.avgpool = nn.AdaptiveAvgPool1d(output_size=1)
  70. self.dropout = nn.Dropout(dropout)
  71. self.fc = nn.Linear(channels[5], channels[6])
  72.  
  73.  
  74. def forward(self, x):
  75. out = self.conv1(x)
  76. out = self.bn(out)
  77. out = self.relu(out)
  78. out = self.maxpool(out)
  79.  
  80. out = self.layer1(out)
  81. out = self.layer2(out)
  82. out = self.layer3(out)
  83. out = self.layer4(out)
  84.  
  85. out = self.avgpool(out)
  86. out = out.view(out.size(0), -1)
  87. out = self.dropout(out)
  88. out = self.fc(out)
  89. return out
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement