Advertisement
Guest User

Untitled

a guest
Aug 24th, 2019
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.62 KB | None | 0 0
  1. from gym.envs.box2d.bipedal_walker import (
  2. BipedalWalker, VIEWPORT_H, VIEWPORT_W, SCALE, TERRAIN_HEIGHT, TERRAIN_STEP
  3. )
  4. from Box2D.b2 import circleShape
  5.  
  6. import cv2
  7. import numpy as np
  8. import copy
  9. import uuid
  10.  
  11.  
  12. class Surface(object):
  13. __slots__ = [
  14. 'center', 'shape', 'mask_key', '_img', 'height', 'width', 'name'
  15. ]
  16.  
  17. def __init__(
  18. self,
  19. center_x=0,
  20. center_y=0,
  21. height=1,
  22. width=1,
  23. mask_key=23,
  24. name=None
  25. ):
  26. self.center = int(center_x), int(center_y)
  27. self.shape = self.height, self.width = (int(height), int(width))
  28. self.mask_key = int(mask_key & 255)
  29. self.name = name or str(uuid.uuid4())
  30. self._img = np.zeros((height, width, 4), dtype='uint8')
  31.  
  32. def zero(self):
  33. self.pixels_ref().fill(0)
  34. return self
  35.  
  36. def fill(self, color):
  37. color_to_uint = lambda color: np.uint32(
  38. (color[0] << 16) | (color[1] << 8) | (color[2])
  39. )
  40. assert isinstance(color, tuple) and len(color) == 3
  41. self.pixels_ref().fill(color_to_uint(color) | (self.mask_key << 24))
  42. return self
  43.  
  44. def blit(self, surface):
  45. center_x, center_y = int(surface.center[0]), int(surface.center[1])
  46. half_height, half_width = int(surface.height / 2
  47. ), int(surface.width / 2)
  48. left = min(half_width, center_y)
  49. right = min(half_width, self.width - center_y)
  50. top = min(half_height, center_x)
  51. down = min(half_height, self.height - center_x)
  52. mask = surface.get_mask()
  53. negative_mask = 1 - mask
  54. img_src = surface.pixels_ref()
  55. screen_img = self.pixels_ref()
  56. screen_img[(center_x - top):(center_x + down), (center_y - left):(center_y + right)] *= \
  57. negative_mask[(half_height - top):(half_height + down), (half_width - left):(half_width + right)]
  58. screen_img[(center_x - top):(center_x + down), (center_y - left):(center_y + right)] += \
  59. img_src[(half_height - top):(half_height + down), (half_width - left):(half_width + right)] \
  60. * mask[(half_height - top):(half_height + down), (half_width - left):(half_width + right)]
  61. return self
  62.  
  63. def get_mask(self):
  64. mask = self._img[:, :, 3] == self.mask_key
  65. return mask.astype(np.uint32)
  66.  
  67. def raw_data(self):
  68. return self._img
  69.  
  70. def pixels_ref(self):
  71. shape = (self.height, self.width)
  72. pxls = self._img.view('uint32').reshape(shape)
  73. return pxls
  74.  
  75. def pixels_cpy(self):
  76. return self.pixels_ref().copy()
  77.  
  78. def copy(self):
  79. return copy.deepcopy(self)
  80.  
  81. def polygon(self, points, color):
  82. color = (color[2], color[1], color[0], self.mask_key)
  83. points = np.array(points).astype(np.int).reshape((len(points), 2))
  84. points = np.flip(points, 1)
  85. cv2.fillConvexPoly(self._img, points, color)
  86.  
  87. def circle(self, center, radius, color, thickness=1):
  88. color = (color[2], color[1], color[0], self.mask_key)
  89. center = (int(center[1]), int(center[0]))
  90. self._img = cv2.circle(
  91. self._img, center, radius, color, thickness=thickness
  92. )
  93.  
  94. def line(self, p1, p2, color, thickness=1):
  95. color = (color[2], color[1], color[0], self.mask_key)
  96. p1 = (int(p1[1]), int(p1[0]))
  97. p2 = (int(p2[1]), int(p2[0]))
  98. self._img = cv2.line(self._img, p1, p2, color, thickness=thickness)
  99.  
  100. def put_text(self, pos, text):
  101. font = cv2.FONT_HERSHEY_SIMPLEX
  102. pos = (int(pos[1]), int(pos[0]))
  103. font_size = 0.4
  104. color = (255, 255, 255)
  105. thickness = 1
  106. cv2.putText(self._img, text, pos, font, font_size, color, thickness)
  107.  
  108. def display(self, time=0):
  109. cv2.imshow(self.name, self._img)
  110. key = cv2.waitKey(time) & 0xFF
  111. ESC = 27
  112. if key == ESC:
  113. cv2.destroyAllWindows()
  114. exit()
  115. return key
  116.  
  117.  
  118. def scale_color(color_in_1):
  119. return tuple(int(c * 255) for c in color_in_1)
  120.  
  121.  
  122. class OpencvViewer(object):
  123. def __init__(self, width, height):
  124. self.width = width
  125. self.height = height
  126. self.surface = Surface(height=height, width=width)
  127. self.translation = 0, 0
  128. self.scale = 1, 1
  129. self.frame = np.empty((height, width, 4), dtype=np.uint8)
  130. self.frame.fill(255)
  131.  
  132. def set_bounds(self, left, right, bottom, top):
  133. assert right > left and top > bottom
  134. scalex = self.width / (right - left)
  135. scaley = self.height / (top - bottom)
  136. self.translation = -left, -bottom
  137. self.scale = scalex, scaley
  138.  
  139. def draw_circle(self, radius=10, res=30, filled=True, **attrs):
  140. raise NotImplementedError
  141.  
  142. def translate(self, point):
  143. point1 = point[0] + self.translation[0], point[1] + \
  144. self.translation[1]
  145. point2 = point1[0] * self.scale[0], point1[1] * \
  146. self.scale[1]
  147. return self.height - point2[1], point2[0]
  148.  
  149. def draw_polygon(self, v, filled=True, **attrs):
  150. v = [self.translate(p) for p in v]
  151. color = scale_color(attrs["color"])
  152. self.surface.polygon(v, color)
  153.  
  154. def draw_polyline(self, v, **attrs):
  155. color = scale_color(attrs["color"])
  156. thickness = attrs['thickness'] if 'thickness' in attrs \
  157. else attrs['linewidth']
  158. for point1, point2 in zip(v[:-1], v[1:]):
  159. point1 = self.translate(tuple(point1))
  160. point2 = self.translate(tuple(point2))
  161. self.surface.line(point1, point2, color, thickness)
  162.  
  163. def draw_line(self, start, end, **attrs):
  164. start = self.translate(start)
  165. end = self.translate(end)
  166. self.surface.line(start, end, **attrs)
  167.  
  168. def render(self, return_rgb_array):
  169. self.frame.fill(255)
  170. if not return_rgb_array:
  171. self.surface.display(1)
  172. frame = self.surface.raw_data()
  173. return frame[:, :, 2::-1]
  174.  
  175. def close(self):
  176. del self.surface
  177.  
  178.  
  179. class BipedalWalkerWrapper(BipedalWalker):
  180. def render(self, mode='human'):
  181. # This function is almost identical to the original one but the
  182. # importing of pyglet is avoided.
  183. if self.viewer is None:
  184. self.viewer = OpencvViewer(VIEWPORT_W, VIEWPORT_H)
  185. self.viewer.set_bounds(
  186. self.scroll, VIEWPORT_W / SCALE + self.scroll, 0,
  187. VIEWPORT_H / SCALE
  188. )
  189.  
  190. self.viewer.draw_polygon(
  191. [
  192. (self.scroll, 0),
  193. (self.scroll + VIEWPORT_W / SCALE, 0),
  194. (self.scroll + VIEWPORT_W / SCALE, VIEWPORT_H / SCALE),
  195. (self.scroll, VIEWPORT_H / SCALE),
  196. ],
  197. color=(0.9, 0.9, 1.0)
  198. )
  199. for poly, x1, x2 in self.cloud_poly:
  200. if x2 < self.scroll / 2: continue
  201. if x1 > self.scroll / 2 + VIEWPORT_W / SCALE: continue
  202. self.viewer.draw_polygon(
  203. [(p[0] + self.scroll / 2, p[1]) for p in poly],
  204. color=(1, 1, 1)
  205. )
  206. for poly, color in self.terrain_poly:
  207. if poly[1][0] < self.scroll: continue
  208. if poly[0][0] > self.scroll + VIEWPORT_W / SCALE: continue
  209. self.viewer.draw_polygon(poly, color=color)
  210.  
  211. self.lidar_render = (self.lidar_render + 1) % 100
  212. i = self.lidar_render
  213. if i < 2 * len(self.lidar):
  214. l = self.lidar[i] if i < len(self.lidar
  215. ) else self.lidar[len(self.lidar) -
  216. i - 1]
  217. self.viewer.draw_polyline(
  218. [l.p1, l.p2], color=(1, 0, 0), linewidth=1
  219. )
  220.  
  221. for obj in self.drawlist:
  222. for f in obj.fixtures:
  223. trans = f.body.transform
  224. if type(f.shape) is circleShape:
  225. raise NotImplementedError
  226. # t = rendering.Transform(translation=trans*f.shape.pos)
  227. # self.viewer.draw_circle(f.shape.radius, 30, color=obj.color1).add_attr(t)
  228. # self.viewer.draw_circle(f.shape.radius, 30, color=obj.color2, filled=False, linewidth=2).add_attr(t)
  229. else:
  230. path = [trans * v for v in f.shape.vertices]
  231. self.viewer.draw_polygon(path, color=obj.color1)
  232. path.append(path[0])
  233. self.viewer.draw_polyline(
  234. path, color=obj.color2, linewidth=2
  235. )
  236.  
  237. flagy1 = TERRAIN_HEIGHT
  238. flagy2 = flagy1 + 50 / SCALE
  239. x = TERRAIN_STEP * 3
  240. self.viewer.draw_polyline(
  241. [(x, flagy1), (x, flagy2)], color=(0, 0, 0), linewidth=2
  242. )
  243. f = [
  244. (x, flagy2), (x, flagy2 - 10 / SCALE),
  245. (x + 25 / SCALE, flagy2 - 5 / SCALE)
  246. ]
  247. self.viewer.draw_polygon(f, color=(0.9, 0.2, 0))
  248. self.viewer.draw_polyline(f + [f[0]], color=(0, 0, 0), linewidth=2)
  249.  
  250. return self.viewer.render(return_rgb_array=mode == 'rgb_array')
  251.  
  252.  
  253. if __name__ == '__main__':
  254. """
  255. Usage:
  256. 1. Ask administrator to install xvfb
  257. 2. run:
  258. xvfb-run -s "-screen 0 600x400x24" python test_render.py
  259.  
  260. Use mode="human" to see the pop-up OpenCV window.
  261. Use mode="rgb_array" to get the (X, X, 4) ndarray.
  262. """
  263. env = BipedalWalkerWrapper()
  264. env.reset()
  265. cnt = 0
  266. while True:
  267. cnt += 1
  268. frame = env.render(mode='rgb_array')
  269. print('Current Time Step: {}, frame Shape: {}'.format(cnt, frame.shape))
  270. action = env.action_space.sample()
  271. observation, reward, done, info = env.step(action)
  272. if done:
  273. print("Done!")
  274. break
  275. env.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement