Guest User

Untitled

a guest
Oct 18th, 2018
89
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.29 KB | None | 0 0
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5.  
  6. from scene.object import Estimator
  7.  
  8.  
  9. def dot(x: torch.Tensor, y: torch.Tensor):
  10. return torch.einsum('abi,abi->ab', x, y).unsqueeze(-1)
  11.  
  12.  
  13. def ifelse(p: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
  14. p_ = p.long()
  15. p_ = torch.stack([p_ for i in range(3)], -1).unsqueeze(0)
  16.  
  17. return torch.gather(torch.stack([y, x]), 0, p_).squeeze()
  18.  
  19.  
  20. def cast_view_rays(size: (int, int), fov: float):
  21. aspect = size[0]/size[1]
  22. fov_ = math.sin(fov/2)
  23. fov_x = fov_*aspect
  24. fov_y = fov_
  25.  
  26. x = torch.linspace(-fov_x, fov_x, size[0])
  27. y = torch.linspace(-fov_y, fov_y, size[1])
  28. grid_x, grid_y = torch.meshgrid(x, y)
  29. grid = torch.stack([grid_x, grid_y, torch.zeros(size)], -1)
  30. grid_ = grid - torch.FloatTensor([0, 0, -1])
  31.  
  32. return F.normalize(grid_, dim=-1)
  33.  
  34.  
  35. class NormEstimate(nn.Module):
  36. def __init__(self, objects: Estimator, size: torch.Size, device: torch.device, d: float = 0.0001):
  37. super(NormEstimate, self).__init__()
  38. self.objects = objects
  39. self.dx = torch.FloatTensor([d, 0, 0]).to(device)
  40. self.dy = torch.FloatTensor([0, d, 0]).to(device)
  41. self.dz = torch.FloatTensor([0, 0, d]).to(device)
  42.  
  43. def forward(self, x: torch.Tensor, ):
  44. gx1 = self.objects(x - self.dx)
  45. gx2 = self.objects(x + self.dx)
  46. gy1 = self.objects(x - self.dy)
  47. gy2 = self.objects(x + self.dy)
  48. gz1 = self.objects(x - self.dz)
  49. gz2 = self.objects(x + self.dz)
  50.  
  51. grad = torch.stack([gx2 - gx1, gy2 - gy1, gz2 - gz1], -1)
  52.  
  53. return F.normalize(grad, dim=-1)
  54.  
  55.  
  56. def trace_ray(x: torch.Tensor,
  57. direction: torch.Tensor,
  58. objects: Estimator,
  59. maxiter: int,
  60. eps: float):
  61. w, h, _ = direction.shape
  62. n = torch.zeros((w, h), device=x.device)
  63.  
  64. for i in range(maxiter):
  65. r = objects(x)
  66. n = n + torch.gt(r - eps, 0).float()
  67. x = x + F.relu(r).unsqueeze(-1)*direction
  68.  
  69. return r, x
  70.  
  71.  
  72. def phong(light: torch.Tensor, eye: torch.Tensor, x: torch.Tensor, n: torch.Tensor):
  73. ambient = 1
  74. diffuse = torch.tensor([0.8275, 0.8275, 0.8275], dtype=torch.float, device=x.device)
  75. specularExponent = 10
  76. specularity = 0.5
  77.  
  78. l_ = F.normalize(light - x, dim=-1)
  79. e_ = F.normalize(eye - x, dim=-1)
  80. mag = dot(n, l_)
  81. r = 2*mag*n - l_
  82.  
  83. return (ambient*torch.tensor([0.9608, 0.9608, 0.9608], dtype=torch.float, device=x.device) +
  84. diffuse*F.relu(mag) +
  85. specularity*torch.pow(F.relu(dot(e_, r)), specularExponent))/(ambient + 1 + specularity)
  86.  
  87.  
  88. def simple_shading(eye: torch.Tensor,
  89. direction: torch.Tensor,
  90. objects: Estimator,
  91. light: torch.Tensor,
  92. maxiter: int,
  93. eps: float):
  94. r, x = trace_ray(eye, direction, objects, maxiter, eps)
  95. is_intersect = torch.lt(r, eps)
  96.  
  97. norm_estimate = NormEstimate(objects, x.size, x.device)
  98. n = norm_estimate(x)
  99. colour = phong(light, eye, x, n)
  100.  
  101. l_ = F.normalize(light - x, dim=-1)
  102. x = x + 2*eps*n
  103. r, _ = trace_ray(x, l_, objects, maxiter, eps)
  104. is_shadow = torch.lt(r, eps)
  105.  
  106. colour_ = ifelse(is_shadow, 0.4*colour, colour)
  107. colour_ = ifelse(is_intersect, colour_, torch.ones_like(colour))
  108.  
  109. return colour_
Add Comment
Please, Sign In to add comment