SHARE
TWEET

Untitled

a guest Apr 20th, 2019 67 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #!/usr/bin/env python
  2.  
  3. import chainer
  4. import chainer.functions as F
  5. import numpy as np
  6. import trimesh
  7. import trimesh.transformations as tf
  8.  
  9. import imgviz
  10.  
  11. import objslampp
  12.  
  13.  
  14. np.random.seed(0)
  15.  
  16. pcd_file = objslampp.datasets.YCBVideoModels()\
  17.     .get_model(class_id=20)['points_xyz']
  18. points = np.loadtxt(pcd_file, dtype=np.float32)
  19. indices = np.random.permutation(len(points))[:1000]
  20. points = points[indices]
  21.  
  22. dim = 16
  23. pitch = max(points.max(axis=0) - points.min(axis=0)) * 1.1 / dim
  24. origin = (- pitch * dim / 2,) * 3
  25.  
  26. points_reference = points
  27.  
  28. print('voxelization start')
  29. print(f'pitch: {pitch}')
  30. print(f'dim: {dim}')
  31. grid_reference = objslampp.functions.occupancy_grid_3d(
  32.     points,
  33.     pitch=pitch,
  34.     origin=origin,
  35.     dimension=(dim, dim, dim),
  36.     connectivity=2,
  37. ).array
  38. if 0:
  39.     trimesh.Scene(trimesh.voxel.Voxel(grid_reference, pitch, origin).as_boxes()).show()
  40.  
  41. index = 0
  42. quaternion = None
  43. while True:
  44.     index += 1
  45.  
  46.     print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
  47.     if quaternion is None:
  48.         quaternion = tf.random_quaternion().astype(np.float32)
  49.         quaternion = chainer.Variable(quaternion)
  50.     else:
  51.         quaternion = chainer.Variable(quaternion.array)
  52.  
  53.     print(type(quaternion), quaternion, quaternion.grad)
  54.  
  55.     transform = objslampp.functions.quaternion_matrix(quaternion[None])
  56.     points_observed = objslampp.functions.transform_points(points_reference, transform)[0]
  57.  
  58.     add = objslampp.metrics.average_distance(
  59.         [points_reference], [np.eye(4)], [transform.array[0]]
  60.     )[0]
  61.     print('add:', add)
  62.  
  63.     if 1:
  64.         geom1 = trimesh.PointCloud(vertices=points_reference, colors=[(1., 0, 0)]*len(points_reference))
  65.         geom2 = trimesh.PointCloud(vertices=points_observed.array, colors=[(0, 1., 0)]*len(points_reference))
  66.         scene = trimesh.Scene([geom1, geom2])
  67.         scene.set_camera(angles=(np.deg2rad(0), np.deg2rad(0), 0), distance=0.3)
  68.         image = objslampp.extra.trimesh.save_image(scene)
  69.         imgviz.io.imsave(f'points_{index:08d}.jpg', image[:, :, :3])
  70.  
  71.     grid_observed = objslampp.functions.occupancy_grid_3d(
  72.         points_observed,
  73.         pitch=pitch,
  74.         origin=origin,
  75.         dimension=(dim, dim, dim),
  76.         connectivity=2,
  77.     )
  78.     if 1:
  79.         grid = grid_observed.array
  80.         colors = imgviz.depth2rgb(grid.reshape(1, -1), min_value=0, max_value=1)
  81.         colors = colors.reshape(dim, dim, dim, 3)
  82.         colors = np.concatenate((colors, np.full((dim, dim, dim, 1), 127)), axis=3)
  83.         geom = trimesh.voxel.Voxel(grid, pitch, origin).as_boxes()
  84.         geom.apply_translation((pitch / 2, pitch / 2, pitch / 2))
  85.         I, J, K = zip(*np.argwhere(grid))
  86.         geom.visual.face_colors = colors[I, J, K].repeat(12, axis=0)
  87.         geom1 = geom
  88.  
  89.         grid = grid_reference
  90.         colors = imgviz.depth2rgb(grid.reshape(1, -1), min_value=0, max_value=1)
  91.         colors = colors.reshape(dim, dim, dim, 3)
  92.         colors = np.concatenate((colors, np.full((dim, dim, dim, 1), 127)), axis=3)
  93.         geom = trimesh.voxel.Voxel(grid, pitch, origin).as_boxes()
  94.         geom.apply_translation((pitch / 2, pitch / 2, pitch / 2))
  95.         I, J, K = zip(*np.argwhere(grid))
  96.         geom.visual.face_colors = colors[I, J, K].repeat(12, axis=0)
  97.         geom2 = geom
  98.  
  99.         scene = trimesh.Scene([geom1, geom2])
  100.         scene.set_camera(angles=(np.deg2rad(0), np.deg2rad(0), 0), distance=0.3)
  101.         image = objslampp.extra.trimesh.save_image(scene)
  102.         imgviz.io.imsave(f'occupancy_{index:08d}.jpg', image[:, :, :3])
  103.  
  104.     loss = F.mean_squared_error(grid_reference, grid_observed)
  105.     loss.backward()
  106.     print('loss:', float(loss.array))
  107.     del loss, grid_observed
  108.  
  109.     print(quaternion.grad)
  110.     lr = 1
  111.     quaternion = quaternion - lr * quaternion.grad
  112.     quaternion.cleargrad()
  113.     print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top