Advertisement
nordlaender

condensed forward backward cython

May 14th, 2016
274
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.95 KB | None | 0 0
  1.  
  2. cimport cython
  3. cimport numpy as np
  4.  
  5. import numpy as np
  6. import cython
  7.  
  8. DTYPE=np.float
  9. ctypedef np.float_t DTYPE_t
  10.  
  11.  
  12. @cython.boundscheck(False)
  13. @cython.nonecheck(False)
  14. def forward(np.ndarray x, np.ndarray y,
  15.             int filter_size,int padding, int stride,
  16.             np.ndarray W, np.ndarray b,
  17.             np.ndarray in_cols,
  18.             int batch_size,int in_colors,int out_colors,int in_width,int in_height,int out_width,int out_height):
  19.     """
  20.    the forward method from a ConvNode.
  21.    moved outside to compile it separately as a cython module
  22.    :param x: input activations
  23.    :param y: output activations
  24.    :param filter_size: size of the receptive field
  25.    :param padding: zero padding for the input
  26.    :param stride: distance between receptive fields
  27.    :param W: weight matrix for the linear node within
  28.    :param b: bias for the linear node within
  29.    :param batch_size: batch_size
  30.    :param in_colors: number of input colors, first dimension of W
  31.    :param out_colors: number of output colors, second dimension of W
  32.    :param in_width: input activation image width
  33.    :param in_height: input activation image height
  34.    :param out_width: output activation image width
  35.    :param out_height: output activation image height
  36.    :return: nothing
  37.    """
  38.  
  39.     # im2col: x -> in_cols
  40.     # padding
  41.     cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((batch_size, in_colors, in_width + padding*2, in_height + padding*2))
  42.     if padding>0:
  43.         x_padded[:, :, padding:in_width+padding, padding:in_height+padding] = x
  44.     else:
  45.         x_padded[:]=x
  46.  
  47.     # allocating new field
  48.     cdef np.ndarray[DTYPE_t, ndim=4] rec_fields = np.empty((filter_size**2* in_colors, batch_size, out_width, out_height))
  49.  
  50.     # copying receptive fields
  51.     cdef int w,h
  52.     for w, h in np.ndindex((out_width, out_height)):
  53.         rec_fields[:, :, w, h] = x_padded[:, :, w*stride:w*stride + filter_size, h*stride:h*stride + filter_size] \
  54.             .reshape((batch_size, filter_size**2* in_colors)) \
  55.             .T
  56.  
  57.     in_cols[:,:] = rec_fields.reshape((filter_size**2 * in_colors, batch_size * out_width * out_height))
  58.  
  59.     # linear node: in_cols -> out_cols
  60.     #cdef np.ndarray[DTYPE_t, ndim=2] out_cols=np.dot(W,in_cols)+b
  61.  
  62.     # col2im: out_cols -> out_image -> y
  63.     y[:,:,:,:]=(np.dot(W,in_cols)+b).reshape((out_colors, batch_size, out_width, out_height)).transpose(1,0,2,3)
  64.  
  65.  
  66. @cython.boundscheck(False)
  67. @cython.nonecheck(False)
  68. def backward(np.ndarray d_x,np.ndarray d_y,
  69.             int filter_size,int padding, int stride,
  70.             np.ndarray W, np.ndarray b,
  71.             np.ndarray d_W, np.ndarray d_b,
  72.             np.ndarray in_cols,
  73.             int batch_size,int in_colors, int out_colors,int in_width,int in_height,int out_width, int out_height):
  74.  
  75.     # col2im: d_y -> d_out_cols
  76.     cdef np.ndarray[DTYPE_t, ndim=2] d_out_cols =  d_y.transpose(1, 0, 2, 3).reshape((out_colors, batch_size * out_width * out_height))
  77.  
  78.     # linear node: d_out_cols -> d_in_cols
  79.     d_W[:] = np.dot(d_out_cols, in_cols.T)
  80.     d_b[:] = np.sum(d_out_cols)
  81.  
  82.     # im2col: d_in_cols -> d_x
  83.     cdef np.ndarray[DTYPE_t, ndim=4] d_rec_fields = np.dot(W.T, d_out_cols).reshape((filter_size**2 * in_colors, batch_size, out_width, out_height))
  84.     cdef np.ndarray[DTYPE_t, ndim=4] d_x_padded = np.zeros((batch_size, in_colors, in_width + 2*padding, in_height + 2*padding))
  85.  
  86.     cdef int w,h
  87.     for w, h in np.ndindex((out_width, out_height)):
  88.         d_x_padded[:, :, w*stride:w*stride + filter_size, h*stride:h*stride + filter_size] += d_rec_fields[:, :, w, h].T.reshape(
  89.             (batch_size, in_colors, filter_size, filter_size))
  90.  
  91.     if padding>0:
  92.         # slicing with non-negative indices
  93.         # goes from [padding:-padding] = [padding: length-padding]
  94.         # where length=in_width+2*padding
  95.         d_x[:] = d_x_padded[:, :, padding:in_width+padding, padding:in_height+padding]
  96.     else:
  97.         d_x[:]=d_x_padded[:,:,:,:]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement