Advertisement
Guest User

Untitled

a guest
Nov 18th, 2017
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.14 KB | None | 0 0
  1. import numpy as np
  2.  
  3.  
  4. def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
  5. # First figure out what the size of the output should be
  6. N, C, H, W = x_shape
  7. assert (H + 2 * padding - field_height) % stride == 0
  8. assert (W + 2 * padding - field_height) % stride == 0
  9. out_height = int((H + 2 * padding - field_height) / stride + 1)
  10. out_width = int((W + 2 * padding - field_width) / stride + 1)
  11.  
  12. i0 = np.repeat(np.arange(field_height), field_width)
  13. i0 = np.tile(i0, C)
  14. i1 = stride * np.repeat(np.arange(out_height), out_width)
  15. j0 = np.tile(np.arange(field_width), field_height * C)
  16. j1 = stride * np.tile(np.arange(out_width), out_height)
  17. i = i0.reshape(-1, 1) + i1.reshape(1, -1)
  18. j = j0.reshape(-1, 1) + j1.reshape(1, -1)
  19.  
  20. k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
  21.  
  22. return (k.astype(int), i.astype(int), j.astype(int))
  23.  
  24.  
  25. def im2col_indices(x, field_height, field_width, padding=1, stride=1):
  26. """ An implementation of im2col based on some fancy indexing """
  27. # Zero-pad the input
  28. p = padding
  29. x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
  30.  
  31. k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, stride)
  32.  
  33. cols = x_padded[:, k, i, j]
  34. C = x.shape[1]
  35. cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
  36. return cols
  37.  
  38.  
  39. def col2im_indices(cols, x_shape, field_height=3, field_width=3, padding=1,
  40. stride=1):
  41. """ An implementation of col2im based on fancy indexing and np.add.at """
  42. N, C, H, W = x_shape
  43. H_padded, W_padded = H + 2 * padding, W + 2 * padding
  44. x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
  45. k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding, stride)
  46. cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
  47. cols_reshaped = cols_reshaped.transpose(2, 0, 1)
  48. np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped)
  49. if padding == 0:
  50. return x_padded
  51. return x_padded[:, :, padding:-padding, padding:-padding]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement