• API
• FAQ
• Tools
• Archive
SHARE
TWEET

# Untitled

a guest Mar 21st, 2019 71 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import torch
2. import torch.nn as nn
3. import math
4. import torch.nn.functional as F
5. import torch.optim as optim
7.
8. def init_bn(bn):
9.     bn.weight.data.fill_(1.)
10.
11. class Attention(nn.Module):
12.     def __init__(self, n_in, n_out, att_activation, cla_activation):
13.         super(Attention, self).__init__()
14.
15.         self.att_activation = att_activation
16.         self.cla_activation = cla_activation
17.
18.         self.att = nn.Conv2d(
19.             in_channels=n_in, out_channels=n_out, kernel_size=(
20.                 1, 1), stride=(
22.                 0, 0), bias=True)
23.
24.         self.cla = nn.Conv2d(
25.             in_channels=n_in, out_channels=n_out, kernel_size=(
26.                 1, 1), stride=(
28.                 0, 0), bias=True)
29.
30.         self.init_weights()
31.
32.     def init_weights(self):
33.         init_layer(self.att,)
34.         init_layer(self.cla)
35.
36.     def activate(self, x, activation):
37.
38.         if activation == 'linear':
39.             return x
40.
41.         elif activation == 'relu':
42.             return F.relu(x)
43.
44.         elif activation == 'sigmoid':
45.             return F.sigmoid(x)
46.
47.         elif activation == 'softmax':
48.             return F.softmax(x, dim=1)
49.
50.     def forward(self, x):
51.         """input: (samples_num, freq_bins, time_steps, 1)
52.         """
53.
54.         att = self.att(x)
55.         att = self.activate(att, self.att_activation)
56.
57.         cla = self.cla(x)
58.         cla = self.activate(cla, self.cla_activation)
59.
60.         att = att[:, :, :, 0]   # (samples_num, classes_num, time_steps)
61.         cla = cla[:, :, :, 0]   # (samples_num, classes_num, time_steps)
62.
63.         epsilon = 1e-7
64.         att = torch.clamp(att, epsilon, 1. - epsilon)
65.
66.         norm_att = att / torch.sum(att, dim=2)[:, :, None]
67.         x = torch.sum(norm_att * cla, dim=2)
68.
69.         return x
70.
71. def init_layer(layer):
72.     if layer.weight.ndimension() == 4:
73.         (n_out, n_in, height, width) = layer.weight.size()
74.         n = n_in * height * width
75.     elif layer.weight.ndimension() == 2:
76.         (n_out, n) = layer.weight.size()
77.
78.     std = math.sqrt(2. / n)
79.     scale = std * math.sqrt(3.)
80.     layer.weight.data.uniform_(-scale, scale)
81.
82.     if layer.bias is not None:
83.         layer.bias.data.fill_(0.)
84.
85. class EmbeddingLayers(nn.Module):
86.
87.     def __init__(self, freq_bins, hidden_units, drop_rate):
88.         super(EmbeddingLayers, self).__init__()
89.
90.         self.freq_bins = freq_bins
91.         self.hidden_units = hidden_units
92.         self.drop_rate = drop_rate
93.
94.         self.conv1 = nn.Conv2d(
95.             in_channels=freq_bins, out_channels=hidden_units,
96.             kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
97.
98.         self.conv2 = nn.Conv2d(
99.             in_channels=hidden_units, out_channels=hidden_units,
100.             kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
101.
102.         self.conv3 = nn.Conv2d(
103.             in_channels=hidden_units, out_channels=hidden_units,
104.             kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False)
105.
106.         self.bn0 = nn.BatchNorm2d(freq_bins)
107.         self.bn1 = nn.BatchNorm2d(hidden_units)
108.         self.bn2 = nn.BatchNorm2d(hidden_units)
109.         self.bn3 = nn.BatchNorm2d(hidden_units)
110.
111.         self.init_weights()
112.
113.     def init_weights(self):
114.
115.         init_layer(self.conv1)
116.         init_layer(self.conv2)
117.         init_layer(self.conv3)
118.
119.         init_bn(self.bn0)
120.         init_bn(self.bn1)
121.         init_bn(self.bn2)
122.         init_bn(self.bn3)
123.
124.     def forward(self, input, return_layers=False):
125.         """input: (samples_num, time_steps, freq_bins)
126.         """
127.
128.         drop_rate = self.drop_rate
129.
130.         # (samples_num, freq_bins, time_steps)
131.         x = input.transpose(1, 2)
132.
133.         # Add an extra dimension for using Conv2d
134.         # (samples_num, freq_bins, time_steps, 1)
135.         x = x[:, :, :, None].contiguous()
136.
137.         a0 = self.bn0(x)
138.         a1 = F.dropout(F.relu(self.bn1(self.conv1(a0))),
139.                        p=drop_rate,
140.                        training=self.training)
141.
142.         a2 = F.dropout(F.relu(self.bn2(self.conv2(a1))),
143.                        p=drop_rate,
144.                        training=self.training)
145.
146.         emb = F.dropout(F.relu(self.bn3(self.conv3(a2))),
147.                         p=drop_rate,
148.                         training=self.training)
149.
150.         if return_layers is False:
151.             # (samples_num, hidden_units, time_steps, 1)
152.             return emb
153.
154.         else:
155.             return [a0, a1, a2, emb]
156.
157. class FeatureLevelSingleAttention(nn.Module):
158.
159.     def __init__(self, freq_bins, classes_num, hidden_units, drop_rate):
160.
161.         super(FeatureLevelSingleAttention, self).__init__()
162.
163.         self.emb = EmbeddingLayers(freq_bins, hidden_units, drop_rate)
164.
165.         self.attention = Attention(
166.             hidden_units,
167.             hidden_units,
168.             att_activation='sigmoid',
169.             cla_activation='linear')
170.
171.         self.fc_final = nn.Linear(hidden_units, classes_num)
172.         self.bn_attention = nn.BatchNorm1d(hidden_units)
173.
174.         self.drop_rate = drop_rate
175.
176.         self.init_weights()
177.
178.     def init_weights(self):
179.
180.         init_layer(self.fc_final)
181.         init_bn(self.bn_attention)
182.
183.     def forward(self, input):
184.         """input: (samples_num, freq_bins, time_steps, 1)
185.         """
186.         drop_rate = self.drop_rate
187.
188.         # (samples_num, hidden_units, time_steps, 1)
189.         b1 = self.emb(input)
190.
191.         # (samples_num, hidden_units)
192.         b2 = self.attention(b1)
193.         b2 = F.dropout(
194.             F.relu(
195.                 self.bn_attention(b2)),
196.             p=drop_rate,
197.             training=self.training)
198.
199.         # (samples_num, classes_num)
200.         output = F.sigmoid(self.fc_final(b2))
201.
202.         return output
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.
Not a member of Pastebin yet?