lamiastella

renderer

Sep 24th, 2020
1,319
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. (base) mona@mona:~/research/3danimals/SMALViewer$ cat smal/smal3d_renderer.py
  2. import sys, os
  3. sys.path.append(os.path.dirname(sys.path[0]))
  4.  
  5. import torch
  6. import torch.nn as nn
  7. import neural_renderer as nr
  8. import torch.nn.functional as F
  9. from SMPL.smal_torch_batch import SMALModel
  10. from smal.joint_catalog import SMALJointInfo
  11.  
  12. import pickle as pkl
  13.  
  14. import numpy as np
  15. import cv2
  16. import matplotlib.pyplot as plt
  17.  
  18. class SMAL3DRenderer(nn.Module):
  19.     def __init__(self, image_size, z_distance = 2.5, elevation = 89.9, azimuth = 0.0):
  20.         super(SMAL3DRenderer, self).__init__()
  21.        
  22.         self.smal_model = SMALModel()
  23.         self.image_size = image_size
  24.         self.smal_info = SMALJointInfo()
  25.  
  26.         self.renderer = nr.Renderer(camera_mode='look_at')
  27.         self.renderer.eye = nr.get_points_from_angles(z_distance, elevation, azimuth)
  28.  
  29.         self.renderer.image_size = image_size
  30.         self.renderer.light_intensity_ambient = 1.0
  31.  
  32.         with open("smal/dog_texture.pkl", 'rb') as f:
  33.             self.textures = pkl.load(f).cuda()
  34.  
  35.     def forward(self, batch_params):
  36.         batch_size = batch_params['betas'].shape[0]
  37.        
  38.         verts, joints_3d = self.smal_model(
  39.             batch_params['betas'],
  40.             torch.cat((batch_params['global_rotation'], batch_params['joint_rotations']), dim = 1),
  41.             batch_params['trans'])
  42.    
  43.         faces = self.smal_model.faces.unsqueeze(0).expand(batch_size, -1, -1)
  44.         textures = self.textures.unsqueeze(0).expand(batch_size, -1, -1, -1, -1, -1)
  45.  
  46.         rendered_joints = self.renderer.render_points(joints_3d[:, self.smal_info.include_classes])
  47.         rendered_silhouettes = self.renderer.render_silhouettes(verts, faces)
  48.         rendered_silhouettes = rendered_silhouettes.unsqueeze(1)
  49.  
  50.         rendered_images = self.renderer.render(verts, faces, textures)    
  51.         rendered_images = torch.clamp(rendered_images[0], 0.0, 1.0)
  52.  
  53.         return rendered_images, rendered_silhouettes, rendered_joints, verts, joints_3d(base)
RAW Paste Data