Advertisement
ridicul0us

Theano Lecun LCN

Jan 20th, 2015
771
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.51 KB | None | 0 0
  1. import theano
  2. import theano.tensor as T
  3. import numpy as np
  4. from theano.tensor.nnet import conv
  5. import matplotlib.pyplot as plt
  6. import pylab
  7.  
  8. def gaussian_filter(kernel_shape):
  9.     x = np.zeros((kernel_shape, kernel_shape), dtype='float32')
  10.  
  11.     def gauss(x, y, sigma=2.0):
  12.         Z = 2 * np.pi * sigma ** 2
  13.         return  1. / Z * np.exp(-(x ** 2 + y ** 2) / (2. * sigma ** 2))
  14.  
  15.     mid = np.floor(kernel_shape / 2.)
  16.     for i in xrange(0, kernel_shape):
  17.         for j in xrange(0, kernel_shape):
  18.             x[i, j] = gauss(i - mid, j - mid)
  19.  
  20.     return x / np.sum(x)
  21.  
  22.  
  23. def lecun_lcn(input, img_shape, kernel_shape, threshold=1e-4):
  24.     input = input.reshape(input.shape[0], 1, img_shape[0], img_shape[1])
  25.     X = T.matrix(dtype=theano.config.floatX)
  26.     X = X.reshape(input.shape)
  27.  
  28.     filter_shape = (1, 1, kernel_shape, kernel_shape)
  29.     filters = gaussian_filter(kernel_shape).reshape(filter_shape)
  30.  
  31.     convout = conv.conv2d(input=X,
  32.                           filters=filters,
  33.                           image_shape=(input.shape[0], 1, img_shape[0], img_shape[1]),
  34.                           filter_shape=filter_shape,
  35.                           border_mode='full')
  36.  
  37.     # For each pixel, remove mean of 9x9 neighborhood
  38.     mid = int(np.floor(kernel_shape / 2.))
  39.     centered_X = X - convout[:, :, mid:-mid, mid:-mid]
  40.     centered_X = X - convout[:, :, mid:-mid, mid:-mid]
  41.  
  42.     # Scale down norm of 9x9 patch if norm is bigger than 1
  43.     sum_sqr_XX = conv.conv2d(input=centered_X ** 2,
  44.                              filters=filters,
  45.                              image_shape=(input.shape[0], 1, img_shape[0], img_shape[1]),
  46.                              filter_shape=filter_shape,
  47.                              border_mode='full')
  48.  
  49.     denom = T.sqrt(sum_sqr_XX[:, :, mid:-mid, mid:-mid])
  50.     per_img_mean = denom.mean(axis=[1, 2])
  51.     divisor = T.largest(per_img_mean.dimshuffle(0, 'x', 'x', 1), denom)
  52.     divisor = T.maximum(divisor, threshold)
  53.  
  54.     new_X = centered_X / divisor
  55.     new_X = new_X.dimshuffle(0, 2, 3, 1)
  56.     new_X = new_X.flatten(ndim=3)
  57.  
  58.     f = theano.function([X], new_X)
  59.     return f(input)
  60.  
  61. if __name__=='__main__':
  62.     x_img = plt.imread("..//data//Lenna.png") #change as needed
  63.  
  64.     x_img = x_img.reshape(1, x_img.shape[0], x_img.shape[1], x_img.shape[2])
  65.     for d in range(3):
  66.             x_img[:, :, :, d] = lecun_lcn(x_img[:, :, :, d], (x_img.shape[1], x_img.shape[2]), 9)
  67.     x_img = x_img[0]
  68.  
  69.     pylab.gray()
  70.     pylab.axis('off'); pylab.imshow(x_img)
  71.     pylab.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement