Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class EdnetDown(nn.Module):
- def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
- super(EdnetDown, self).__init__()
- layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
- if normalize:
- layers.append(nn.InstanceNorm2d(out_size))
- layers.append(nn.LeakyReLU(0.2))
- if dropout:
- layers.append(nn.Dropout(dropout))
- self.model = nn.Sequential(*layers)
- def forward(self, x):
- return self.model(x)
- class EdnetUp(nn.Module):
- def __init__(self, in_size, out_size, dropout=0.0):
- super(EdnetUp, self).__init__()
- layers = [
- nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
- nn.InstanceNorm2d(out_size),
- nn.ReLU(inplace=True),
- ]
- if dropout:
- layers.append(nn.Dropout(dropout))
- self.model = nn.Sequential(*layers)
- weights_init_normal(self.model)
- def forward(self, x, skip_input):
- x = self.model(x)
- x = torch.cat((x, skip_input), 1)
- return x
- class Ednet(nn.Module):
- def __init__(self, n_channels=3, n_filters=32):
- super(Ednet, self).__init__()
- self.down1 = EdnetDown(n_channels, 64, normalize=False)
- self.down2 = EdnetDown(64, 128)
- self.down3 = EdnetDown(128, 256)
- self.down4 = EdnetDown(256, 512, dropout=0.5)
- self.down5 = EdnetDown(512, 512, dropout=0.5)
- self.down6 = EdnetDown(512, 512, dropout=0.5)
- self.down7 = EdnetDown(512, 512, dropout=0.5)
- self.down8 = EdnetDown(512, 512, normalize=False, dropout=0.5)
- self.up1 = EdnetUp(512, 512, dropout=0.5)
- self.up2 = EdnetUp(1024, 512, dropout=0.5)
- self.up3 = EdnetUp(1024, 512, dropout=0.5)
- self.up4 = EdnetUp(1024, 512, dropout=0.5)
- self.up5 = EdnetUp(1024, 256)
- self.up6 = EdnetUp(512, 128)
- self.up7 = EdnetUp(256, 64)
- self.final = nn.Sequential(
- nn.Upsample(scale_factor=2),
- nn.ZeroPad2d((1, 0, 1, 0)),
- nn.Conv2d(128, 3, 4, padding=1),
- nn.Tanh(),
- )
- def forward(self, x):
- # U-Net generator with skip connections from encoder to decoder
- d1 = self.down1(x)
- d2 = self.down2(d1)
- d3 = self.down3(d2)
- d4 = self.down4(d3)
- d5 = self.down5(d4)
- d6 = self.down6(d5)
- d7 = self.down7(d6)
- d8 = self.down8(d7)
- u1 = self.up1(d8, d7)
- u2 = self.up2(u1, d6)
- u3 = self.up3(u2, d5)
- u4 = self.up4(u3, d4)
- u5 = self.up5(u4, d3)
- u6 = self.up6(u5, d2)
- u7 = self.up7(u6, d1)
- return self.final(u7)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement