Advertisement
Guest User

Untitled

a guest
Mar 21st, 2019
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.61 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import math
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. from torch.autograd import Variable
  7.  
  8. def init_bn(bn):
  9. bn.weight.data.fill_(1.)
  10.  
  11. class Attention(nn.Module):
  12. def __init__(self, n_in, n_out, att_activation, cla_activation):
  13. super(Attention, self).__init__()
  14.  
  15. self.att_activation = att_activation
  16. self.cla_activation = cla_activation
  17.  
  18. self.att = nn.Conv2d(
  19. in_channels=n_in, out_channels=n_out, kernel_size=(
  20. 1, 1), stride=(
  21. 1, 1), padding=(
  22. 0, 0), bias=True)
  23.  
  24. self.cla = nn.Conv2d(
  25. in_channels=n_in, out_channels=n_out, kernel_size=(
  26. 1, 1), stride=(
  27. 1, 1), padding=(
  28. 0, 0), bias=True)
  29.  
  30. self.init_weights()
  31.  
  32. def init_weights(self):
  33. init_layer(self.att,)
  34. init_layer(self.cla)
  35.  
  36. def activate(self, x, activation):
  37.  
  38. if activation == 'linear':
  39. return x
  40.  
  41. elif activation == 'relu':
  42. return F.relu(x)
  43.  
  44. elif activation == 'sigmoid':
  45. return F.sigmoid(x)
  46.  
  47. elif activation == 'softmax':
  48. return F.softmax(x, dim=1)
  49.  
  50. def forward(self, x):
  51. """input: (samples_num, freq_bins, time_steps, 1)
  52. """
  53.  
  54. att = self.att(x)
  55. att = self.activate(att, self.att_activation)
  56.  
  57. cla = self.cla(x)
  58. cla = self.activate(cla, self.cla_activation)
  59.  
  60. att = att[:, :, :, 0] # (samples_num, classes_num, time_steps)
  61. cla = cla[:, :, :, 0] # (samples_num, classes_num, time_steps)
  62.  
  63. epsilon = 1e-7
  64. att = torch.clamp(att, epsilon, 1. - epsilon)
  65.  
  66. norm_att = att / torch.sum(att, dim=2)[:, :, None]
  67. x = torch.sum(norm_att * cla, dim=2)
  68.  
  69. return x
  70.  
  71. def init_layer(layer):
  72. if layer.weight.ndimension() == 4:
  73. (n_out, n_in, height, width) = layer.weight.size()
  74. n = n_in * height * width
  75. elif layer.weight.ndimension() == 2:
  76. (n_out, n) = layer.weight.size()
  77.  
  78. std = math.sqrt(2. / n)
  79. scale = std * math.sqrt(3.)
  80. layer.weight.data.uniform_(-scale, scale)
  81.  
  82. if layer.bias is not None:
  83. layer.bias.data.fill_(0.)
  84.  
  85. class EmbeddingLayers(nn.Module):
  86.  
  87. def __init__(self, freq_bins, hidden_units, drop_rate):
  88. super(EmbeddingLayers, self).__init__()
  89.  
  90. self.freq_bins = freq_bins
  91. self.hidden_units = hidden_units
  92. self.drop_rate = drop_rate
  93.  
  94. self.conv1 = nn.Conv2d(
  95. in_channels=freq_bins, out_channels=hidden_units,
  96. kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
  97.  
  98. self.conv2 = nn.Conv2d(
  99. in_channels=hidden_units, out_channels=hidden_units,
  100. kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
  101.  
  102. self.conv3 = nn.Conv2d(
  103. in_channels=hidden_units, out_channels=hidden_units,
  104. kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
  105.  
  106. self.bn0 = nn.BatchNorm2d(freq_bins)
  107. self.bn1 = nn.BatchNorm2d(hidden_units)
  108. self.bn2 = nn.BatchNorm2d(hidden_units)
  109. self.bn3 = nn.BatchNorm2d(hidden_units)
  110.  
  111. self.init_weights()
  112.  
  113. def init_weights(self):
  114.  
  115. init_layer(self.conv1)
  116. init_layer(self.conv2)
  117. init_layer(self.conv3)
  118.  
  119. init_bn(self.bn0)
  120. init_bn(self.bn1)
  121. init_bn(self.bn2)
  122. init_bn(self.bn3)
  123.  
  124. def forward(self, input, return_layers=False):
  125. """input: (samples_num, time_steps, freq_bins)
  126. """
  127.  
  128. drop_rate = self.drop_rate
  129.  
  130. # (samples_num, freq_bins, time_steps)
  131. x = input.transpose(1, 2)
  132.  
  133. # Add an extra dimension for using Conv2d
  134. # (samples_num, freq_bins, time_steps, 1)
  135. x = x[:, :, :, None].contiguous()
  136.  
  137. a0 = self.bn0(x)
  138. a1 = F.dropout(F.relu(self.bn1(self.conv1(a0))),
  139. p=drop_rate,
  140. training=self.training)
  141.  
  142. a2 = F.dropout(F.relu(self.bn2(self.conv2(a1))),
  143. p=drop_rate,
  144. training=self.training)
  145.  
  146. emb = F.dropout(F.relu(self.bn3(self.conv3(a2))),
  147. p=drop_rate,
  148. training=self.training)
  149.  
  150. if return_layers is False:
  151. # (samples_num, hidden_units, time_steps, 1)
  152. return emb
  153.  
  154. else:
  155. return [a0, a1, a2, emb]
  156.  
  157. class FeatureLevelSingleAttention(nn.Module):
  158.  
  159. def __init__(self, freq_bins, classes_num, hidden_units, drop_rate):
  160.  
  161. super(FeatureLevelSingleAttention, self).__init__()
  162.  
  163. self.emb = EmbeddingLayers(freq_bins, hidden_units, drop_rate)
  164.  
  165. self.attention = Attention(
  166. hidden_units,
  167. hidden_units,
  168. att_activation='sigmoid',
  169. cla_activation='linear')
  170.  
  171. self.fc_final = nn.Linear(hidden_units, classes_num)
  172. self.bn_attention = nn.BatchNorm1d(hidden_units)
  173.  
  174. self.drop_rate = drop_rate
  175.  
  176. self.init_weights()
  177.  
  178. def init_weights(self):
  179.  
  180. init_layer(self.fc_final)
  181. init_bn(self.bn_attention)
  182.  
  183. def forward(self, input):
  184. """input: (samples_num, freq_bins, time_steps, 1)
  185. """
  186. drop_rate = self.drop_rate
  187.  
  188. # (samples_num, hidden_units, time_steps, 1)
  189. b1 = self.emb(input)
  190.  
  191. # (samples_num, hidden_units)
  192. b2 = self.attention(b1)
  193. b2 = F.dropout(
  194. F.relu(
  195. self.bn_attention(b2)),
  196. p=drop_rate,
  197. training=self.training)
  198.  
  199. # (samples_num, classes_num)
  200. output = F.sigmoid(self.fc_final(b2))
  201.  
  202. return output
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement