daily pastebin goal
2%
SHARE
TWEET

Untitled

a guest Mar 21st, 2019 67 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torch.nn as nn
  3. import math
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. from torch.autograd import Variable
  7.  
  8. def init_bn(bn):
  9.     bn.weight.data.fill_(1.)
  10.    
  11. class Attention(nn.Module):
  12.     def __init__(self, n_in, n_out, att_activation, cla_activation):
  13.         super(Attention, self).__init__()
  14.  
  15.         self.att_activation = att_activation
  16.         self.cla_activation = cla_activation
  17.  
  18.         self.att = nn.Conv2d(
  19.             in_channels=n_in, out_channels=n_out, kernel_size=(
  20.                 1, 1), stride=(
  21.                 1, 1), padding=(
  22.                 0, 0), bias=True)
  23.  
  24.         self.cla = nn.Conv2d(
  25.             in_channels=n_in, out_channels=n_out, kernel_size=(
  26.                 1, 1), stride=(
  27.                 1, 1), padding=(
  28.                 0, 0), bias=True)
  29.  
  30.         self.init_weights()
  31.  
  32.     def init_weights(self):
  33.         init_layer(self.att,)
  34.         init_layer(self.cla)
  35.  
  36.     def activate(self, x, activation):
  37.  
  38.         if activation == 'linear':
  39.             return x
  40.  
  41.         elif activation == 'relu':
  42.             return F.relu(x)
  43.  
  44.         elif activation == 'sigmoid':
  45.             return F.sigmoid(x)
  46.  
  47.         elif activation == 'softmax':
  48.             return F.softmax(x, dim=1)
  49.  
  50.     def forward(self, x):
  51.         """input: (samples_num, freq_bins, time_steps, 1)
  52.         """
  53.  
  54.         att = self.att(x)
  55.         att = self.activate(att, self.att_activation)
  56.  
  57.         cla = self.cla(x)
  58.         cla = self.activate(cla, self.cla_activation)
  59.  
  60.         att = att[:, :, :, 0]   # (samples_num, classes_num, time_steps)
  61.         cla = cla[:, :, :, 0]   # (samples_num, classes_num, time_steps)
  62.  
  63.         epsilon = 1e-7
  64.         att = torch.clamp(att, epsilon, 1. - epsilon)
  65.  
  66.         norm_att = att / torch.sum(att, dim=2)[:, :, None]
  67.         x = torch.sum(norm_att * cla, dim=2)
  68.  
  69.         return x
  70.  
  71. def init_layer(layer):
  72.     if layer.weight.ndimension() == 4:
  73.         (n_out, n_in, height, width) = layer.weight.size()
  74.         n = n_in * height * width
  75.     elif layer.weight.ndimension() == 2:
  76.         (n_out, n) = layer.weight.size()
  77.  
  78.     std = math.sqrt(2. / n)
  79.     scale = std * math.sqrt(3.)
  80.     layer.weight.data.uniform_(-scale, scale)
  81.  
  82.     if layer.bias is not None:
  83.         layer.bias.data.fill_(0.)
  84.        
  85. class EmbeddingLayers(nn.Module):
  86.  
  87.     def __init__(self, freq_bins, hidden_units, drop_rate):
  88.         super(EmbeddingLayers, self).__init__()
  89.  
  90.         self.freq_bins = freq_bins
  91.         self.hidden_units = hidden_units
  92.         self.drop_rate = drop_rate
  93.  
  94.         self.conv1 = nn.Conv2d(
  95.             in_channels=freq_bins, out_channels=hidden_units,
  96.             kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
  97.  
  98.         self.conv2 = nn.Conv2d(
  99.             in_channels=hidden_units, out_channels=hidden_units,
  100.             kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
  101.  
  102.         self.conv3 = nn.Conv2d(
  103.             in_channels=hidden_units, out_channels=hidden_units,
  104.             kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
  105.  
  106.         self.bn0 = nn.BatchNorm2d(freq_bins)
  107.         self.bn1 = nn.BatchNorm2d(hidden_units)
  108.         self.bn2 = nn.BatchNorm2d(hidden_units)
  109.         self.bn3 = nn.BatchNorm2d(hidden_units)
  110.  
  111.         self.init_weights()
  112.  
  113.     def init_weights(self):
  114.  
  115.         init_layer(self.conv1)
  116.         init_layer(self.conv2)
  117.         init_layer(self.conv3)
  118.  
  119.         init_bn(self.bn0)
  120.         init_bn(self.bn1)
  121.         init_bn(self.bn2)
  122.         init_bn(self.bn3)
  123.  
  124.     def forward(self, input, return_layers=False):
  125.         """input: (samples_num, time_steps, freq_bins)
  126.         """
  127.  
  128.         drop_rate = self.drop_rate
  129.  
  130.         # (samples_num, freq_bins, time_steps)
  131.         x = input.transpose(1, 2)
  132.  
  133.         # Add an extra dimension for using Conv2d
  134.         # (samples_num, freq_bins, time_steps, 1)
  135.         x = x[:, :, :, None].contiguous()
  136.  
  137.         a0 = self.bn0(x)
  138.         a1 = F.dropout(F.relu(self.bn1(self.conv1(a0))),
  139.                        p=drop_rate,
  140.                        training=self.training)
  141.  
  142.         a2 = F.dropout(F.relu(self.bn2(self.conv2(a1))),
  143.                        p=drop_rate,
  144.                        training=self.training)
  145.  
  146.         emb = F.dropout(F.relu(self.bn3(self.conv3(a2))),
  147.                         p=drop_rate,
  148.                         training=self.training)
  149.  
  150.         if return_layers is False:
  151.             # (samples_num, hidden_units, time_steps, 1)
  152.             return emb
  153.  
  154.         else:
  155.             return [a0, a1, a2, emb]
  156.  
  157. class FeatureLevelSingleAttention(nn.Module):
  158.  
  159.     def __init__(self, freq_bins, classes_num, hidden_units, drop_rate):
  160.  
  161.         super(FeatureLevelSingleAttention, self).__init__()
  162.  
  163.         self.emb = EmbeddingLayers(freq_bins, hidden_units, drop_rate)
  164.        
  165.         self.attention = Attention(
  166.             hidden_units,
  167.             hidden_units,
  168.             att_activation='sigmoid',
  169.             cla_activation='linear')
  170.            
  171.         self.fc_final = nn.Linear(hidden_units, classes_num)
  172.         self.bn_attention = nn.BatchNorm1d(hidden_units)
  173.  
  174.         self.drop_rate = drop_rate
  175.  
  176.         self.init_weights()
  177.  
  178.     def init_weights(self):
  179.  
  180.         init_layer(self.fc_final)
  181.         init_bn(self.bn_attention)
  182.  
  183.     def forward(self, input):
  184.         """input: (samples_num, freq_bins, time_steps, 1)
  185.         """
  186.         drop_rate = self.drop_rate
  187.  
  188.         # (samples_num, hidden_units, time_steps, 1)
  189.         b1 = self.emb(input)
  190.  
  191.         # (samples_num, hidden_units)
  192.         b2 = self.attention(b1)
  193.         b2 = F.dropout(
  194.             F.relu(
  195.                 self.bn_attention(b2)),
  196.             p=drop_rate,
  197.             training=self.training)
  198.  
  199.         # (samples_num, classes_num)
  200.         output = F.sigmoid(self.fc_final(b2))
  201.  
  202.         return output
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top