Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # ResNet style network
- class MyModel(nn.Module):
- def __init__(self, outputNeurons=10):
- super(MyModel, self).__init__()
- self.conv1 = nn.Conv2d(1, 16, kernel_size=8, padding=2)
- # Input : 1 Channel, Output 16 Channel, Filter Size : 8x8
- self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=1)
- # Input : 1 Channel, Output 16 Channel, Filter Size : 5x5
- self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
- # Input : 1 Channel, Output 64 Channel, Filter Size : 3x3
- self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
- # Input : 1 Channel, Output 64 Channel, Filter Size : 3x3
- self.fc = nn.Linear(128, outputNeurons)
- self.maxpool= nn.MaxPool2d(kernel_size=2, stride=2)
- self.relu = nn.ReLU()
- def forward(self, x):
- out = self.conv1(x)
- out = self.relu(out)
- out = self.maxpool(out)
- out = self.conv2(out)
- out = self.relu(out)
- out = self.maxpool(out)
- out = self.conv3(out)
- out = self.relu(out)
- out = self.maxpool(out)
- out = self.conv4(out)
- out = self.relu(out)
- out = self.maxpool(out)
- out = out.reshape(out.size(0), -1)
- out = self.fc(out)
- return out
- def __len__(self):
- return sum([x.numel() for x in self.parameters()])
- testInstance = MyModel()
- if len(testInstance) > 150000:
- print(f"Invalid Network, {len(testInstance)} Parameters!")
- else:
- print(f"Valid Network, {len(testInstance)} Parameters!")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement