Advertisement
Guest User

Untitled

a guest
Apr 20th, 2019
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.79 KB | None | 0 0
  1. #!/usr/bin/env python
  2.  
  3. import chainer
  4. import chainer.functions as F
  5. import numpy as np
  6. import trimesh
  7. import trimesh.transformations as tf
  8.  
  9. import imgviz
  10.  
  11. import objslampp
  12.  
  13.  
  14. np.random.seed(0)
  15.  
  16. pcd_file = objslampp.datasets.YCBVideoModels()\
  17. .get_model(class_id=20)['points_xyz']
  18. points = np.loadtxt(pcd_file, dtype=np.float32)
  19. indices = np.random.permutation(len(points))[:1000]
  20. points = points[indices]
  21.  
  22. dim = 16
  23. pitch = max(points.max(axis=0) - points.min(axis=0)) * 1.1 / dim
  24. origin = (- pitch * dim / 2,) * 3
  25.  
  26. points_reference = points
  27.  
  28. print('voxelization start')
  29. print(f'pitch: {pitch}')
  30. print(f'dim: {dim}')
  31. grid_reference = objslampp.functions.occupancy_grid_3d(
  32. points,
  33. pitch=pitch,
  34. origin=origin,
  35. dimension=(dim, dim, dim),
  36. connectivity=2,
  37. ).array
  38. if 0:
  39. trimesh.Scene(trimesh.voxel.Voxel(grid_reference, pitch, origin).as_boxes()).show()
  40.  
  41. index = 0
  42. quaternion = None
  43. while True:
  44. index += 1
  45.  
  46. print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
  47. if quaternion is None:
  48. quaternion = tf.random_quaternion().astype(np.float32)
  49. quaternion = chainer.Variable(quaternion)
  50. else:
  51. quaternion = chainer.Variable(quaternion.array)
  52.  
  53. print(type(quaternion), quaternion, quaternion.grad)
  54.  
  55. transform = objslampp.functions.quaternion_matrix(quaternion[None])
  56. points_observed = objslampp.functions.transform_points(points_reference, transform)[0]
  57.  
  58. add = objslampp.metrics.average_distance(
  59. [points_reference], [np.eye(4)], [transform.array[0]]
  60. )[0]
  61. print('add:', add)
  62.  
  63. if 1:
  64. geom1 = trimesh.PointCloud(vertices=points_reference, colors=[(1., 0, 0)]*len(points_reference))
  65. geom2 = trimesh.PointCloud(vertices=points_observed.array, colors=[(0, 1., 0)]*len(points_reference))
  66. scene = trimesh.Scene([geom1, geom2])
  67. scene.set_camera(angles=(np.deg2rad(0), np.deg2rad(0), 0), distance=0.3)
  68. image = objslampp.extra.trimesh.save_image(scene)
  69. imgviz.io.imsave(f'points_{index:08d}.jpg', image[:, :, :3])
  70.  
  71. grid_observed = objslampp.functions.occupancy_grid_3d(
  72. points_observed,
  73. pitch=pitch,
  74. origin=origin,
  75. dimension=(dim, dim, dim),
  76. connectivity=2,
  77. )
  78. if 1:
  79. grid = grid_observed.array
  80. colors = imgviz.depth2rgb(grid.reshape(1, -1), min_value=0, max_value=1)
  81. colors = colors.reshape(dim, dim, dim, 3)
  82. colors = np.concatenate((colors, np.full((dim, dim, dim, 1), 127)), axis=3)
  83. geom = trimesh.voxel.Voxel(grid, pitch, origin).as_boxes()
  84. geom.apply_translation((pitch / 2, pitch / 2, pitch / 2))
  85. I, J, K = zip(*np.argwhere(grid))
  86. geom.visual.face_colors = colors[I, J, K].repeat(12, axis=0)
  87. geom1 = geom
  88.  
  89. grid = grid_reference
  90. colors = imgviz.depth2rgb(grid.reshape(1, -1), min_value=0, max_value=1)
  91. colors = colors.reshape(dim, dim, dim, 3)
  92. colors = np.concatenate((colors, np.full((dim, dim, dim, 1), 127)), axis=3)
  93. geom = trimesh.voxel.Voxel(grid, pitch, origin).as_boxes()
  94. geom.apply_translation((pitch / 2, pitch / 2, pitch / 2))
  95. I, J, K = zip(*np.argwhere(grid))
  96. geom.visual.face_colors = colors[I, J, K].repeat(12, axis=0)
  97. geom2 = geom
  98.  
  99. scene = trimesh.Scene([geom1, geom2])
  100. scene.set_camera(angles=(np.deg2rad(0), np.deg2rad(0), 0), distance=0.3)
  101. image = objslampp.extra.trimesh.save_image(scene)
  102. imgviz.io.imsave(f'occupancy_{index:08d}.jpg', image[:, :, :3])
  103.  
  104. loss = F.mean_squared_error(grid_reference, grid_observed)
  105. loss.backward()
  106. print('loss:', float(loss.array))
  107. del loss, grid_observed
  108.  
  109. print(quaternion.grad)
  110. lr = 1
  111. quaternion = quaternion - lr * quaternion.grad
  112. quaternion.cleargrad()
  113. print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement