Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class AutoEncoder(nn.Module):
- def __init__(self):
- super().__init__()
- # Encoder
- self.encoder_layer1 = nn.Sequential(
- nn.Conv2d(1, 8, kernel_size=5, padding=4),
- nn.ReLU(),
- nn.MaxPool2d(2, 2, return_indices=True)
- )
- self.encoder_layer2 = nn.Sequential(
- nn.Conv2d(8, 16, kernel_size=4, padding=3),
- nn.ReLU(),
- nn.MaxPool2d(2, 2, return_indices=True)
- )
- self.encoder_layer3 = nn.Sequential(
- nn.Conv2d(16, 32, kernel_size=4, padding=3),
- nn.ReLU(),
- nn.MaxPool2d(2, 2, return_indices=True)
- )
- self.encoder_layer4 = nn.Sequential(
- nn.Conv2d(32, 64, kernel_size=3, padding=2),
- nn.ReLU(),
- nn.MaxPool2d(2, 2, return_indices=True)
- )
- self.encoder_layer5 = nn.Sequential(
- nn.Conv2d(64, 128, kernel_size=3, padding=2),
- nn.ReLU(),
- )
- # Decoder
- self.decoder_layer1 = nn.Sequential(
- nn.ConvTranspose2d(128, 64, kernel_size=3, padding=2),
- nn.ReLU(),
- )
- self.decoder_layer2 = nn.Sequential(
- nn.ConvTranspose2d(64, 32, kernel_size=3, padding=2),
- nn.ReLU(),
- )
- self.decoder_layer3 = nn.Sequential(
- nn.ConvTranspose2d(32, 16, kernel_size=4, padding=3),
- nn.ReLU(),
- )
- self.decoder_layer4 = nn.Sequential(
- nn.ConvTranspose2d(16, 8, kernel_size=4, padding=3),
- nn.ReLU(),
- )
- self.decoder_layer5 = nn.Sequential(
- nn.ConvTranspose2d(8, 1, kernel_size=5, padding=4),
- nn.ReLU()
- )
- # UnPool
- self.unpool = nn.MaxUnpool2d(2, 2)
- self.all_indices = []
- self.all_outputs = []
- def encoder(self, X):
- self.all_outputs.append(X.shape)
- X, indices = self.encoder_layer1(X)
- self.all_indices.append(indices)
- self.all_outputs.append(X.shape)
- X, indices = self.encoder_layer2(X)
- self.all_indices.append(indices)
- self.all_outputs.append(X.shape)
- X, indices = self.encoder_layer3(X)
- self.all_indices.append(indices)
- self.all_outputs.append(X.shape)
- X, indices = self.encoder_layer4(X)
- self.all_indices.append(indices)
- X = self.encoder_layer5(X)
- return X
- def decoder(self, X):
- X = self.decoder_layer1(X)
- X = self.decoder_layer2(X)
- X = self.unpool(X, self.all_indices[-1], output_size=self.all_outputs[-1])
- X = self.decoder_layer3(X)
- X = self.unpool(X, self.all_indices[-2], output_size=self.all_outputs[-2])
- X = self.decoder_layer4(X)
- X = self.unpool(X, self.all_indices[-3], output_size=self.all_outputs[-3])
- X = self.decoder_layer5(X)
- X = self.unpool(X, self.all_indices[-4], output_size=self.all_outputs[-4])
- return X
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement