Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
- class MyNet(nn.Module):
- def __init__(self):
- super(MyNet, self).__init__()
- self.cnn = nn.Sequential(
- nn.Sequential(
- nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1),
- nn.ReLU(),
- nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
- nn.BatchNorm2d(16),
- ),
- nn.Sequential(
- nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
- nn.ReLU(),
- nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
- nn.BatchNorm2d(32),
- ),
- nn.Sequential(
- nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),
- nn.ReLU(),
- nn.BatchNorm2d(64),
- ),
- nn.Sequential(
- nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),
- nn.ReLU(),
- nn.BatchNorm2d(128),
- ),
- nn.AdaptiveAvgPool2d((1, 1)),
- )
- self.fc = nn.Sequential(
- nn.Linear(128, 10),
- )
- def forward(self, x):
- x = self.cnn(x)
- x = torch.flatten(x, 1)
- x = self.fc(x)
- x = F.softmax(x, dim=1)
- return x
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement