Advertisement
Guest User

Untitled

a guest
Aug 23rd, 2019
92
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.34 KB | None | 0 0
  1. import random
  2. from typing import Union
  3.  
  4. import cv2
  5. import numpy as np
  6. from albumentations import DualTransform, to_tuple
  7.  
  8.  
  9. def _parse_border_mode(border: Union[str, int]):
  10. if isinstance(border, str):
  11. mode = {
  12. 'reflect': cv2.BORDER_REFLECT101,
  13. }[border]
  14. value = None
  15. else:
  16. mode = cv2.BORDER_CONSTANT
  17. value = border
  18. return dict(borderMode=mode, borderValue=value)
  19.  
  20.  
  21. def _parse_interpolation(interpolation: str):
  22. return {
  23. 'lanczos': cv2.INTER_LANCZOS4,
  24. }[interpolation]
  25.  
  26.  
  27. class RandomPerspective(DualTransform):
  28. def __init__(
  29. self,
  30. corner_shift_limit=.2,
  31. clip_coords=True,
  32. border: Union[str, int] = 0,
  33. interpolation: str = 'lanczos',
  34. always_apply=False, p=0.5
  35. ):
  36. super().__init__(always_apply, p)
  37. self.clip_coords = clip_coords
  38. self.corner_shift_limit = to_tuple(corner_shift_limit)
  39. self.border_mode = _parse_border_mode(border)
  40. self.interpolation = _parse_interpolation(interpolation)
  41.  
  42. def get_params(self):
  43. """
  44. Sample 4 corner relative shifts (x, y).
  45. Returns: np.array of shape (4, 2)
  46. """
  47. shifts = np.array([
  48. random.uniform(
  49. *self.corner_shift_limit,
  50. ) for _ in range(8)
  51. ]).reshape(4, 2).astype(np.float32)
  52.  
  53. coords: np.ndarray = shifts + [
  54. [0, 0],
  55. [1, 0],
  56. [1, 1],
  57. [0, 1]
  58. ]
  59.  
  60. if self.clip_coords:
  61. coords = coords.clip(0, 1)
  62.  
  63. return dict(coords=coords)
  64.  
  65. def update_params(self, params, **kwargs):
  66. params = super().update_params(params, **kwargs)
  67. h, w = params["rows"], params["cols"]
  68. target_coords = params.pop("coords")
  69.  
  70. target_coords = (target_coords * [[w, h]]).astype(np.float32)
  71.  
  72. source_coords = np.array([
  73. [0, 0],
  74. [w, 0],
  75. [w, h],
  76. [0, h]
  77. ], dtype=np.float32)
  78.  
  79. params["matrix"] = \
  80. cv2.getPerspectiveTransform(source_coords, target_coords)
  81.  
  82. return params
  83.  
  84. # noinspection PyMethodOverriding
  85. def apply(self, image, matrix, interpolation, rows, cols, **kwargs):
  86. return cv2.warpPerspective(
  87. image, matrix, (cols, rows),
  88. flags=interpolation, **self.border_mode
  89. )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement