Advertisement
Guest User

Untitled

a guest
May 29th, 2016
54
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Lua 1.41 KB | None | 0 0
  1. function mycnn(char_dim, num_widths, basic_width, dw, seq_length, c2w_seq_length, c2w_dropout, batch_size)
  2.     local inputs = {}
  3.     local outputs = {}
  4.     table.insert(inputs, nn.Identity()())
  5.     local conv1, _conv1, mp1, _mp1, proj, _proj = {}, {}, {}, {}, {}, {}
  6.  
  7.     for i = 1, num_widths do
  8.         local nkernels = math.min(200, basic_width * i)
  9.         -- input : seq_length*batch_size X c2w_seq_length X char_dim
  10.         --conv1[i] = nn.TemporalConvolutionFB(char_dim, nkernels, i, dw)(inputs[1])
  11.         local input = nn.View(1, -1, char_dim):setNumInputDims(2)(inputs[1])
  12.         conv1[i] = cudnn.SpatialConvolution(1, nkernels, char_dim, i)(input)
  13.         -- output : seq_length*batch_size X [(c2w_seq_length - i + 1) / dw + 1] X nkernels
  14.  
  15.         _conv1[i] = nn.Tanh()(conv1[i])
  16.  
  17.         --mp1[i] = nn.TemporalMaxPooling( outputFrame(c2w_seq_length, i, dw), dw )(_conv1[i])
  18.         mp1[i] = cudnn.SpatialMaxPooling(1, outputFrame(c2w_seq_length, i, dw))(_conv1[i])
  19.         -- output : seq_length*batch_size X 1 X nkernels
  20.  
  21.         --_mp1[i] = nn.Reshape(seq_length*batch_size, nkernels)(mp1[i])
  22.         _mp1[i] = nn.Squeeze()(mp1[i])
  23.  
  24.         --_proj[i] = nn.Dropout(c2w_dropout)(_mp1[i]) -- TODO : add dropout?
  25.  
  26.         table.insert(outputs, _mp1[i])
  27.     end
  28.  
  29.     local real_output = nn.JoinTable(2)(outputs)
  30.  
  31.     local module = nn.gModule(inputs, {real_output})
  32.     return transfer_data(module)
  33. end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement