Guest User

Untitled

a guest
Oct 23rd, 2017
68
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.13 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. import torch.nn.functional as F
  5. from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
  6. from util import *
  7.  
  8. class RES(nn.Module):
  9. def __init__(self, in_ch, out_ch, kernel_size, n_layers):
  10. super(RES, self).__init__()
  11. self.in_ch = in_ch
  12. self.out_ch = out_ch
  13. self.kernel_size = kernel_size
  14. self.n_layers = n_layers
  15.  
  16. self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1)
  17. self.bn1 = nn.BatchNorm2d(out_ch)
  18. self.relu = nn.ReLU()
  19.  
  20. convs = []
  21. for i in range(n_layers):
  22. convs.append(nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size, stride=1, padding=kernel_size//2))
  23. convs.append(nn.BatchNorm2d(out_ch))
  24. convs.append(nn.ReLU())
  25.  
  26. self.convs = nn.Sequential(*convs)
  27.  
  28. self.proj = None
  29. if in_ch != out_ch:
  30. self.proj = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1)
  31.  
  32. def forward(self, input):
  33. resi = input
  34. input = self.conv1(input)
  35. input = self.bn1(input)
  36. input = self.relu(input)
  37. input = self.convs(input)
  38.  
  39. if self.proj is not None:
  40. resi = self.proj(resi)
  41.  
  42. output = resi + input
  43. return output
  44.  
  45.  
  46. class RESR(nn.Module):
  47. def __init__(self,
  48. input_size,
  49. output_size,
  50. batch_size,
  51. dropout
  52. ):
  53. super(RESR, self).__init__()
  54. self.input_size = input_size
  55. self.hidden_size = 128
  56. self.n_layers = 4
  57. self.output_size = output_size
  58. self.batch_size = batch_size
  59. self.dropout = dropout
  60.  
  61. self.lstm = nn.LSTM(input_size,
  62. 128,
  63. 4,
  64. batch_first=True,
  65. dropout=self.dropout)
  66.  
  67. self.res1 = RES(1, 16, 3, 6)
  68. self.res2 = RES(16, 8, 3, 2)
  69. self.res3 = RES(8, 4, 3, 2)
  70. self.res4 = RES(4, 2, 3, 2)
  71.  
  72. self.W = nn.Linear(256, self.output_size)
  73.  
  74. def forward(self, input, hc, lens):
  75. # input: (batch x maxlen x feat)
  76. input_p = pack_padded_sequence(input, lens, batch_first=True)
  77. output_p, hc = self.lstm(input_p, hc)
  78. output, _ = pad_packed_sequence(output_p, batch_first=True)
  79. # output: (batch x maxlen x 128)
  80.  
  81. output = output.unsqueeze(1)
  82. # output: (batch x 1 x maxlen x 128)
  83. output = self.res1(output)
  84. # output: (batch x 16 x maxlen x 128)
  85. output = self.res2(output)
  86. # output: (batch x 8 x maxlen x 128)
  87. output = self.res3(output)
  88. # output: (batch x 4 x maxlen x 128)
  89. output = self.res4(output)
  90. # output: (batch x 2 x maxlen x 128)
  91.  
  92. output = torch.cat((output[:,0,:,:], output[:,1,:,:]), dim=2)
  93.  
  94. output = self.W(output)
  95.  
  96. return output, hc
  97.  
  98. def init_hidden(self):
  99. h0 = Variable(torch.zeros(4, self.batch_size, 128))
  100. c0 = Variable(torch.zeros(4, self.batch_size, 128))
  101. if USE_CUDA:
  102. h0, c0 = h0.cuda(), c0.cuda()
  103. return (h0, c0)
Add Comment
Please, Sign In to add comment