Guest User

Untitled

a guest
Feb 16th, 2019
110
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.16 KB | None | 0 0
  1. from torch.functional import F
  2. import torch.nn as nn
  3.  
  4.  
  5. class CNN(nn.Module):
  6. def __init__(self, dense, channels, kernels, paddings, drop_out,
  7. vocab_size, embedding_length, weights):
  8. super(CNN, self).__init__()
  9.  
  10. """
  11. Arguments
  12. ---------
  13. output_size : 2 = (pos, neg)
  14. in_channels : Number of input channels. Here it is 1 as the input data has dimension = (batch_size, num_seq, embedding_length)
  15. kernel_heights : A list consisting of 3 different kernel_heights. Convolution will be performed 3 times and finally results from each kernel_height will be concatenated.
  16. drop_out : Probability of retaining an activation node during dropout operation
  17. vocab_size : Size of the vocabulary containing unique words
  18. embedding_length : Embedding dimension of GloVe word embeddings
  19. weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table
  20. --------
  21.  
  22. """
  23. self.dense = dense
  24. self.channels = channels
  25. self.kernels = kernels
  26. self.paddings = paddings
  27. self.vocab_size = vocab_size
  28. self.embedding_length = embedding_length
  29.  
  30. self.word_embeddings = self.__create_embedding(weights)
  31. self.conv1 = self.__create_convolution(0)
  32. self.conv2 = self.__create_convolution(1)
  33. self.conv3 = self.__create_convolution(2)
  34. self.lstm = nn.LSTM(input_size=self.dense[0], hidden_size=self.dense[1], dropout=drop_out)
  35. self.denses = self.__create_dense()
  36.  
  37. def __create_dense(self):
  38. return nn.Sequential(
  39. nn.Linear(self.dense[0], self.dense[1]),
  40. nn.ReLU(),
  41.  
  42. nn.Linear(self.dense[1], self.dense[2]),
  43. nn.ReLU(),
  44.  
  45. nn.Linear(self.dense[2], self.dense[3]),
  46. nn.ReLU(),
  47.  
  48. nn.Linear(self.dense[3], self.dense[4]),
  49. nn.ReLU(),
  50.  
  51. nn.Softmax()
  52. )
  53.  
  54. def __create_embedding(self, weights):
  55. word_embeddings = nn.Embedding(self.vocab_size, self.embedding_length)
  56. word_embeddings.weight = nn.Parameter(weights, requires_grad=False)
  57. return word_embeddings
  58.  
  59. def __create_convolution(self, step):
  60. return nn.Sequential(
  61. nn.Conv2d(
  62. in_channels=self.channels[step],
  63. out_channels=self.channels[step+1],
  64. kernel_size=(self.kernels[step], self.embedding_length if step == 0 else 1),
  65. padding=(self.paddings[step], 0),
  66. ),
  67. nn.ReLU(),
  68. nn.MaxPool2d(kernel_size=(self.kernels[step], 1), padding=(1, 0), stride=(1, 1)),
  69. )
  70.  
  71. def forward(self, x, hidden=None):
  72. """
  73. The idea of the Convolutional Neural Netwok for Text Classification is very simple. We perform convolution operation on the embedding matrix
  74. whose shape for each batch is (num_seq, embedding_length) with kernel of varying height but constant width which is same as the embedding_length.
  75. We will be using ReLU activation after the convolution operation and then for each kernel height, we will use max_pool operation on each tensor
  76. and will filter all the maximum activation for every channel and then we will concatenate the resulting tensors. This output is then fully connected
  77. to the output layers consisting two units which basically gives us the logits for both positive and negative classes.
  78.  
  79. Parameters
  80. ----------
  81. x: input_sentences of shape = (batch_size, num_sequences)
  82. hidden: embedding from prev step
  83.  
  84. Returns
  85. -------
  86. Output of the linear layer containing logits for pos & neg class.
  87. logits.size() = (batch_size, output_size)
  88.  
  89. """
  90.  
  91. x = self.word_embeddings(x)
  92. max_out1 = self.conv1(x)
  93. max_out2 = self.conv2(max_out1)
  94. max_out3 = self.conv3(max_out2)
  95.  
  96. # all_out = torch.cat((max_out1, max_out2, max_out3), 0)
  97. flatten = max_out3.view(max_out3.shape[0], 1, max_out3.shape[1] * max_out3.shape[2] * max_out3.shape[3])
  98. lstm, hidden = self.lstm(flatten, hidden)
  99. output = self.denses(lstm)
  100. return output, hidden, max_out3
Add Comment
Please, Sign In to add comment