Advertisement
nhoxhaizxc123456

encoders.py

Apr 16th, 2024
586
0
11 hours
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.17 KB | None | 0 0
  1. #encoders.py
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import functools
  6. from .layers import SNConv2d, Attention
  7. from torch.nn import init
  8. from networks.components import get_layer
  9.  
  10.  
  11. class ClassConditionNorm(nn.Module):
  12.     def __init__(self,
  13.             output_size,
  14.             input_size,
  15.             which_linear=functools.partial(nn.Linear, bias=False),
  16.             eps=1e-5,
  17.             norm_style='bn'):
  18.         super().__init__()
  19.         self.output_size, self.input_size = output_size, input_size
  20.         # Prepare gain and bias layers
  21.         self.gain = which_linear(input_size, output_size)
  22.         self.bias = which_linear(input_size, output_size)
  23.         # epsilon to avoid dividing by 0
  24.         self.eps = eps
  25.         self.norm_style = norm_style
  26.  
  27.         self.register_buffer('stored_mean', torch.zeros(output_size))
  28.         self.register_buffer('stored_var', torch.ones(output_size))
  29.  
  30.     def forward(self, x, y):
  31.         # Calculate class-conditional gains and biases
  32.         gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
  33.         bias = self.bias(y).view(y.size(0), -1, 1, 1)
  34.  
  35.         if self.norm_style == 'bn':
  36.             out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
  37.                               self.training, 0.1, self.eps)
  38.         elif self.norm_style == 'in':
  39.             out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
  40.                               self.training, 0.1, self.eps)
  41.  
  42.         return out * gain + bias
  43.  
  44. # For Encoder T
  45. class Netv2(nn.Module):
  46.     def __init__(self, args):
  47.         super(Netv2, self).__init__()
  48.  
  49.  
  50.         print('Init Net')
  51.         self.depth = args.g_depth
  52.         norms = [args.g_norm, args.g_norm]
  53.  
  54.         ## Declare params
  55.         _tmp = [2**(k)*32 for k in range(self.depth-1)]
  56.         _tmp += [_tmp[-1]]
  57.         pdims = [args.g_in_channels[0]] + _tmp
  58.         # sdims = [args.g_in_channels[1]] + _tmp
  59.         pddims = [2048, 1024, args.g_out_channels[0]]
  60.         # sddims = [k*2 for k in _tmp[::-1]]
  61.         # sddims += [sddims[-1]]
  62.         enc_kernels = [[5,5]] + [[3,3] for k in range(self.depth-1)]
  63.  
  64.         print("P_Dims: {}\nPD_Dims: {}\nEnc_Kernels: {}".format(pdims, pddims, enc_kernels))
  65.  
  66.         ## Encoder
  67.        
  68.        
  69.         ### Encoder T
  70.         for i in range(self.depth):
  71.             setattr(self, 'pconv_{}'.format(i+1), get_layer('basic')(pdims[i],pdims[i+1],kernels=enc_kernels[i], subsampling=args.g_downsampler if i > 0 else 'none', norms=norms))
  72.  
  73.         ## Linear Decoder
  74.         img_size = args.crop_size[0]//2**(self.depth-1)
  75.         self.linear_in = img_size*img_size*pdims[-1]
  76.  
  77.         self.pdfc1 = nn.Sequential(
  78.             nn.Linear(self.linear_in, pddims[0]),
  79.             nn.ReLU(),
  80.             nn.Linear(pddims[0], pddims[1]),
  81.             nn.ReLU()
  82.         )
  83.  
  84.         self.pdfc2 = nn.Sequential(
  85.             nn.Linear(pddims[1], pddims[2]),
  86.             nn.Tanh()
  87.         )
  88.  
  89.     def forward(self, X, R):
  90.         sources = [None]
  91.         pout = torch.cat([X, R], 1)
  92.         # sout = X
  93.  
  94.         ## Encoder
  95.         for i in range(self.depth):
  96.             pout = getattr(self, 'pconv_{:d}'.format(i + 1))(pout)
  97.             # sout = getattr(self, 'sconv_{:d}'.format(i + 1))(sout)
  98.             # sources.append(torch.cat([pout, sout], 1))
  99.             sources.append(pout)
  100.         ## Linear Decoder
  101.         p_emb = self.pdfc1(pout.view(-1, self.linear_in))
  102.  
  103.         p_vec = self.pdfc2(p_emb)
  104.  
  105.         ## Decoder
  106.         # sdout = sources.pop()
  107.  
  108.         # for i in range(self.depth):
  109.         #     # print(X.size())
  110.  
  111.         #     sdout = getattr(self, 'sdconv_{:d}'.format(i + 1))(sdout, _skip_feat=sources[-i-1] if i < self.depth-1 else None)
  112.  
  113.         # return self.final_conv(sdout), p_vec, p_emb
  114.         return sources, p_vec, p_emb
  115.    
  116.     def estimate_preset(self, _input):
  117.         pout = _input
  118.  
  119.         ## Encoder
  120.         for i in range(self.depth):
  121.             pout = getattr(self, 'pconv_{:d}'.format(i + 1))(pout)
  122.  
  123.         ## Linear Decoder
  124.         p_emb = self.pdfc1(pout.view(-1, self.linear_in))
  125.         p_vec = self.pdfc2(p_emb)
  126.         return None, p_vec, p_emb
  127.  
  128. class ResConvBlock(nn.Module):
  129.     def __init__(self,
  130.             ch_in,
  131.             ch_out,
  132.             ch_c=128,
  133.             is_down=False,
  134.             dropout=0.2,
  135.             activation='relu',
  136.             pool='avg',
  137.             norm='batch',
  138.             use_res=True,
  139.             **kwargs):
  140.         super().__init__()
  141.  
  142.         self.is_down = is_down
  143.         self.has_condition = False
  144.         self.use_res = use_res
  145.  
  146.         # Convolution
  147.         if self.use_res:
  148.             self.conv = nn.Conv2d(ch_in, ch_out,
  149.                     kernel_size=1,
  150.                     stride=1,
  151.                     padding=0)
  152.  
  153.         self.conv_1 = nn.Conv2d(ch_in, ch_out,
  154.                 kernel_size=3,
  155.                 stride=1,
  156.                 padding=1)
  157.         self.conv_2 = nn.Conv2d(ch_out, ch_out,
  158.                 kernel_size=3,
  159.                 stride=1,
  160.                 padding=1)
  161.  
  162.         # Normalization
  163.         if norm == 'batch':
  164.             self.normalize_1 = nn.BatchNorm2d(ch_in)
  165.             self.normalize_2 = nn.BatchNorm2d(ch_out)
  166.         elif norm == 'id':
  167.             self.normalize_1 = nn.Identity()
  168.             self.normalize_2 = nn.Identity()
  169.         elif norm == 'instance':
  170.             self.normalize_1 = nn.InstanceNorm2d(ch_in)
  171.             self.normalize_2 = nn.InstanceNorm2d(ch_out)
  172.         elif norm == 'layer':
  173.             self.normalize_1 = nn.LayerNorm(kwargs['l_norm_shape_1'])
  174.             self.normalize_2 = nn.LayerNorm(kwargs['l_norm_shape_2'])
  175.         elif norm == 'adain':
  176.             self.has_condition = True
  177.             self.normalize_1 = ClassConditionNorm(ch_in, ch_c, norm_style='in')
  178.             self.normalize_2 = ClassConditionNorm(ch_out, ch_c, norm_style='in')
  179.         elif norm == 'adabatch':
  180.             self.has_condition = True
  181.             self.normalize_1 = ClassConditionNorm(ch_in, ch_c, norm_style='bn')
  182.             self.normalize_2 = ClassConditionNorm(ch_out, ch_c, norm_style='bn')
  183.         else:
  184.             raise Exception('Invalid Normalization')
  185.  
  186.         # Nonlinearity
  187.         self.activation = None
  188.         if activation == 'relu':
  189.             self.activation = lambda x: F.relu(x, True)
  190.         elif activation == 'sigmoid':
  191.             self.activation = F.sigmoid
  192.         elif activation == 'lrelu':
  193.             slope = kwargs['l_slope']
  194.             self.activation = lambda x: F.leaky_relu(x, slope, True)
  195.         else:
  196.             raise Exception('Invalid Nonlinearity')
  197.  
  198.         # Pooling
  199.         self.pool = None
  200.         if pool == 'avg':
  201.             self.pool = lambda x: F.avg_pool2d(x, kernel_size=2)
  202.         elif pool == 'max':
  203.             self.pool = lambda x: F.max_pool2d(x, kernel_size=2)
  204.         elif pool == 'min':
  205.             self.pool = lambda x: F.min_pool2d(x, kernel_size=2)
  206.         else:
  207.             raise Exception('Invalid Pooling')
  208.  
  209.         # Dropout
  210.         if dropout is not None:
  211.             self.dropout = nn.Dropout(dropout)
  212.         else:
  213.             self.dropout=None
  214.  
  215.     def forward(self, x, c=None):
  216.  
  217.         # Residual Path
  218.         x_ = x
  219.  
  220.         if self.has_condition:
  221.             x_ = self.normalize_1(x_, c)
  222.         else:
  223.             x_ = self.normalize_1(x_)
  224.         x_ = self.activation(x_)
  225.  
  226.         if self.is_down:
  227.             x_ = self.pool(x_)
  228.  
  229.         x_ = self.conv_1(x_)
  230.  
  231.         if self.has_condition:
  232.             x_ = self.normalize_2(x_, c)
  233.         else:
  234.             x_ = self.normalize_2(x_)
  235.         x_ = self.activation(x_)
  236.         x_ = self.conv_2(x_)
  237.  
  238.         # Main Path
  239.         if self.use_res:
  240.             if self.is_down:
  241.                 x = self.pool(x)
  242.             x = self.conv(x)
  243.         else:
  244.             x = 0
  245.  
  246.         # Merge
  247.         x = x + x_
  248.  
  249.         if self.dropout is not None:
  250.             x = self.dropout(x)
  251.  
  252.         return x
  253.  
  254. class EncoderF_Res(nn.Module):
  255.  
  256.     def __init__(self,
  257.                  ch_in=1,
  258.                  ch_out=768,
  259.                  ch_unit=96,
  260.                  norm='batch',
  261.                  activation='relu',
  262.                  init='ortho',
  263.                  use_att=False,
  264.                  use_res=True):
  265.         super().__init__()
  266.  
  267.         self.init = init
  268.         self.use_att = use_att
  269.  
  270.         kwargs = {}
  271.         if activation == 'lrelu':
  272.             kwargs['l_slope'] = 0.2
  273.  
  274.         if use_att:
  275.             print('Adding attention layer in E at resolution %d' % (64))
  276.             conv4att = functools.partial(
  277.                 SNConv2d,
  278.                 kernel_size=3,
  279.                 padding=1,
  280.                 num_svs=1,
  281.                 num_itrs=1,
  282.                 eps=1e-06)
  283.             self.att = Attention(384, conv4att)
  284.  
  285.         # output is 96 x 256 x 256
  286.         self.res1 = ResConvBlock(ch_in, ch_unit * 1,
  287.                                  is_down=False,
  288.                                  activation=activation,
  289.                                  norm=norm,
  290.                                  use_res=use_res,
  291.                                  **kwargs)
  292.         # output is 192 x 128 x 128
  293.         self.res2 = ResConvBlock(ch_unit * 1, ch_unit * 2,
  294.                                  is_down=True,
  295.                                  activation=activation,
  296.                                  norm=norm,
  297.                                  use_res=use_res,
  298.                                  **kwargs)
  299.         # output is  384 x 64 x 64
  300.         self.res3 = ResConvBlock(ch_unit * 2, ch_unit * 4,
  301.                                  is_down=True,
  302.                                  activation=activation,
  303.                                  norm=norm,
  304.                                  use_res=use_res,
  305.                                  **kwargs)
  306.         # output is  768 x 32 x 32
  307.         self.res4 = ResConvBlock(ch_unit * 4, ch_unit * 8,
  308.                                  is_down=True,
  309.                                  activation=activation,
  310.                                  norm=norm,
  311.                                  use_res=use_res,
  312.                                  **kwargs)
  313.         # output is  768 x 16 x 16
  314.         self.res5 = ResConvBlock(ch_unit * 8, ch_unit * 8,
  315.                                  is_down=True,
  316.                                  activation=activation,
  317.                                  norm=norm,
  318.                                  use_res=use_res,
  319.                                  dropout=None,
  320.                                  **kwargs)
  321.  
  322.         self.init_weights()
  323.  
  324.     def forward(self, x, c=None):
  325.         x = self.res1(x, c)
  326.         x = self.res2(x, c)
  327.         x = self.res3(x, c)
  328.         if self.use_att:
  329.             x = self.att(x)
  330.         x = self.res4(x, c)
  331.         x = self.res5(x, c)
  332.         return x
  333.  
  334.  
  335.     def forward_with_cp(self, x, cp):
  336.         x = self.res1(x, cp[0])
  337.         x = self.res2(x, cp[1])
  338.         x = self.res3(x, cp[2])
  339.         if self.use_att:
  340.             x = self.att(x)
  341.         x = self.res4(x, cp[3])
  342.         x = self.res5(x, cp[4])
  343.         return x
  344.  
  345.  
  346.     def init_weights(self):
  347.         for module in self.modules():
  348.             if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)
  349.                     or isinstance(module, nn.Embedding)):
  350.                 if self.init == 'ortho':
  351.                     init.orthogonal_(module.weight)
  352.                 elif self.init == 'N02':
  353.                     init.normal_(module.weight, 0, 0.02)
  354.                 elif self.init in ['glorot', 'xavier']:
  355.                     init.xavier_uniform_(module.weight)
  356.                 else:
  357.                     pass
  358.                     # print('Init style not recognized...')
  359.  
  360.  
  361. # z: ([batch, 17])
  362. # h: ([batch, 24576])
  363. # index 0 : ([batch, 1536, 4, 4])
  364. # index 1 : ([batch, 1536, 8, 8])
  365. # index 2 : ([batch, 768, 16, 16])
  366. # index 3 : ([batch, 768, 32, 32])
  367. # index 4 : ([batch, 384, 64, 64])
  368. # index 5 : ([batch, 192, 128, 128])
  369. # index 6: ([batch, 96, 256, 256])
  370. # result: ([batch, 3 256, 256])
  371. if __name__ == '__main__':
  372.     model = EncoderF_Res(use_att=True)
  373.     model.float()
  374.     y = model(torch.randn(4,1,256,256))
  375.     print(y.shape)
  376.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement