Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class Rocket_E_NN(nn.Module):
- def __init__(self):
- super().__init__()
- self.encoder = nn.Sequential(
- nn.Conv2d(3, 32, 4, 2, 1), # B, 32, 32, 32
- nn.ReLU(True),
- nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 16, 16
- nn.ReLU(True),
- nn.Conv2d(32, 64, 4, 2, 1), # B, 64, 8, 8
- nn.ReLU(True),
- nn.Conv2d(64, 64, 4, 2, 1), # B, 64, 4, 4
- nn.ReLU(True),
- nn.Conv2d(64, 256, 4, 1), # B, 256, 1, 1
- nn.ReLU(True),
- View((-1, 256*1*1)), # B, 256
- nn.Linear(256, 2), # B, 1
- )
- def forward(self, x):
- z = self.encoder(x)
- return z
- class Rocket_D_NN(nn.Module):
- def __init__(self):
- super().__init__()
- self.decoder = nn.Sequential(
- nn.Linear(2, 256), # B, 256
- View((-1, 256, 1, 1)), # B, 256, 1, 1
- nn.ReLU(True),
- nn.ConvTranspose2d(256, 64, 4), # B, 64, 4, 4
- nn.ReLU(True),
- nn.ConvTranspose2d(64, 64, 4, 2, 1), # B, 64, 8, 8
- nn.ReLU(True),
- nn.ConvTranspose2d(64, 32, 4, 2, 1), # B, 32, 16, 16
- nn.ReLU(True),
- nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 32, 32
- nn.ReLU(True),
- nn.ConvTranspose2d(32, 3, 4, 2, 1), # B, 3, 64, 64
- )
- def forward(self, z):
- x = self.decoder(z)
- return x
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement