Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from scene.object import Estimator
- def dot(x: torch.Tensor, y: torch.Tensor):
- return torch.einsum('abi,abi->ab', x, y).unsqueeze(-1)
- def ifelse(p: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
- p_ = p.long()
- p_ = torch.stack([p_ for i in range(3)], -1).unsqueeze(0)
- return torch.gather(torch.stack([y, x]), 0, p_).squeeze()
- def cast_view_rays(size: (int, int), fov: float):
- aspect = size[0]/size[1]
- fov_ = math.sin(fov/2)
- fov_x = fov_*aspect
- fov_y = fov_
- x = torch.linspace(-fov_x, fov_x, size[0])
- y = torch.linspace(-fov_y, fov_y, size[1])
- grid_x, grid_y = torch.meshgrid(x, y)
- grid = torch.stack([grid_x, grid_y, torch.zeros(size)], -1)
- grid_ = grid - torch.FloatTensor([0, 0, -1])
- return F.normalize(grid_, dim=-1)
- class NormEstimate(nn.Module):
- def __init__(self, objects: Estimator, size: torch.Size, device: torch.device, d: float = 0.0001):
- super(NormEstimate, self).__init__()
- self.objects = objects
- self.dx = torch.FloatTensor([d, 0, 0]).to(device)
- self.dy = torch.FloatTensor([0, d, 0]).to(device)
- self.dz = torch.FloatTensor([0, 0, d]).to(device)
- def forward(self, x: torch.Tensor, ):
- gx1 = self.objects(x - self.dx)
- gx2 = self.objects(x + self.dx)
- gy1 = self.objects(x - self.dy)
- gy2 = self.objects(x + self.dy)
- gz1 = self.objects(x - self.dz)
- gz2 = self.objects(x + self.dz)
- grad = torch.stack([gx2 - gx1, gy2 - gy1, gz2 - gz1], -1)
- return F.normalize(grad, dim=-1)
- def trace_ray(x: torch.Tensor,
- direction: torch.Tensor,
- objects: Estimator,
- maxiter: int,
- eps: float):
- w, h, _ = direction.shape
- n = torch.zeros((w, h), device=x.device)
- for i in range(maxiter):
- r = objects(x)
- n = n + torch.gt(r - eps, 0).float()
- x = x + F.relu(r).unsqueeze(-1)*direction
- return r, x
- def phong(light: torch.Tensor, eye: torch.Tensor, x: torch.Tensor, n: torch.Tensor):
- ambient = 1
- diffuse = torch.tensor([0.8275, 0.8275, 0.8275], dtype=torch.float, device=x.device)
- specularExponent = 10
- specularity = 0.5
- l_ = F.normalize(light - x, dim=-1)
- e_ = F.normalize(eye - x, dim=-1)
- mag = dot(n, l_)
- r = 2*mag*n - l_
- return (ambient*torch.tensor([0.9608, 0.9608, 0.9608], dtype=torch.float, device=x.device) +
- diffuse*F.relu(mag) +
- specularity*torch.pow(F.relu(dot(e_, r)), specularExponent))/(ambient + 1 + specularity)
- def simple_shading(eye: torch.Tensor,
- direction: torch.Tensor,
- objects: Estimator,
- light: torch.Tensor,
- maxiter: int,
- eps: float):
- r, x = trace_ray(eye, direction, objects, maxiter, eps)
- is_intersect = torch.lt(r, eps)
- norm_estimate = NormEstimate(objects, x.size, x.device)
- n = norm_estimate(x)
- colour = phong(light, eye, x, n)
- l_ = F.normalize(light - x, dim=-1)
- x = x + 2*eps*n
- r, _ = trace_ray(x, l_, objects, maxiter, eps)
- is_shadow = torch.lt(r, eps)
- colour_ = ifelse(is_shadow, 0.4*colour, colour)
- colour_ = ifelse(is_intersect, colour_, torch.ones_like(colour))
- return colour_
Add Comment
Please, Sign In to add comment