Abhisek92

Convolve_Trick.py

Jun 12th, 2019
283
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.06 KB | None | 0 0
  1. import numpy as np
  2. from numpy.lib.stride_tricks import as_strided
  3.  
  4.  
  5. def convoly(array_in, kernel):
  6.     if len(array_in.shape) == len(kernel.shape):
  7.         parent_shape = tuple((np.array(array_in.shape) - (np.array(kernel.shape) - 1)).tolist())
  8.         shaper = parent_shape + kernel.shape
  9.         strider = 2 * array_in.strides
  10.         expanded_input = as_strided(
  11.             array_in,
  12.             shape=shaper,
  13.             strides=strider,
  14.             writeable=False,
  15.         )
  16.         return expanded_input
  17.  
  18. def image_filter(img_array, kernel):
  19.     expanded_input = convoly(img_array, kernel)
  20.     filtered_img = np.einsum(
  21.         'xyij,ij->xy',
  22.         expanded_input,
  23.         kernel,
  24.     )
  25.     return filtered_img
  26.  
  27. def convol_flat(array_in, kernel):
  28.     expanded_array = convoly(array_in, kernel)
  29.     out_shape = list(expanded_array.shape[:-1])
  30.     out_shape[-1] = expanded_array.shape[-1] * expanded_array.shape[-2]
  31.     out_shape = tuple(out_shape)
  32.     #print(out_shape)
  33.     out_array = expanded_array.reshape(out_shape)
  34.     return out_array
Add Comment
Please, Sign In to add comment