SHARE
TWEET

Untitled

a guest May 24th, 2019 111 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import numpy as np
  3. from torch.autograd import Variable
  4. from torch.optim.lr_scheduler import StepLR
  5. # from pinhole_camera import perspective_matrix, viewport_matrix
  6. from data_def import PCAModel, Mesh
  7. from landmarks import detect_landmark
  8. import dlib
  9. import h5py
  10. import math
  11. import trimesh
  12. import os
  13. import pyrender
  14. import cv2
  15.  
  16. IMG = dlib.load_rgb_image("faces/elon-musk-frontal.png")
  17.  
  18. FOVY = 0.5
  19. ASPECT_RATIO = 1
  20.  
  21. FAR, NEAR = 2000, 300
  22.  
  23. TOP = math.tan(FOVY/2) * NEAR
  24. RIGHT = TOP * ASPECT_RATIO
  25. BOTTOM = -TOP
  26. LEFT = -TOP * ASPECT_RATIO
  27.  
  28.  
  29. REL_DATA_PATH = './data/'
  30. MODEL_NAME = 'model2017-1_face12_nomouth.h5'
  31. FACE_LANDMARKS = 'Landmarks68_model2017-1_face12_nomouth.anl'
  32. torch.set_default_tensor_type(torch.FloatTensor)
  33.  
  34. N_ID = 30
  35. N_EXPR = 20
  36.  
  37.  
  38. def get_face_data(n_id=N_ID, n_expr=N_EXPR):
  39.     bfm = h5py.File(REL_DATA_PATH + MODEL_NAME, 'r')
  40.  
  41.     shape_mean = np.asarray(bfm['shape/model/mean'], dtype=np.float32).reshape((-1, 3))
  42.     shape_pca_basis = np.asarray(bfm['shape/model/pcaBasis'], dtype=np.float32).reshape((-1, 3, 199))
  43.     shape_pca_var = np.asarray(bfm['shape/model/pcaVariance'], dtype=np.float32).reshape((199))
  44.  
  45.     expression_mean = np.asarray(bfm['expression/model/mean'], dtype=np.float32).reshape((-1, 3))
  46.     expression_pca_basis = np.asarray(bfm['expression/model/pcaBasis'], dtype=np.float32).reshape((-1, 3, 100))
  47.     expression_pca_var = np.asarray(bfm['expression/model/pcaVariance'], dtype=np.float32).reshape((100))
  48.  
  49.     triangles = np.asarray(bfm['shape/representer/cells'], dtype=np.int32).T
  50.     color_mean = np.asarray(bfm['color/model/mean'], dtype=np.float32).reshape((-1, 3))
  51.  
  52.     mu_id = torch.from_numpy(shape_mean)
  53.     pca_basis_id = torch.from_numpy(shape_pca_basis[:, :, :n_id])
  54.     std_id = torch.from_numpy(np.sqrt(shape_pca_var[:n_id]))
  55.  
  56.     mu_expr = torch.from_numpy(expression_mean)
  57.     pca_basis_expr = torch.from_numpy(expression_pca_basis[:, :, :n_expr])
  58.     std_expr = torch.from_numpy(np.sqrt(expression_pca_var[:n_expr]))
  59.  
  60.     id_model = PCAModel(mu_id, pca_basis_id, std_id)
  61.     expr_model = PCAModel(mu_expr, pca_basis_expr, std_expr)
  62.  
  63.     return id_model, expr_model, triangles, color_mean
  64.  
  65.  
  66. def get_face_landmarks(n_id=N_ID, n_expr=N_EXPR):
  67.     bfm = h5py.File(REL_DATA_PATH + MODEL_NAME, 'r')
  68.  
  69.     shape_mean = np.asarray(bfm['shape/model/mean'], dtype=np.float32).reshape((-1, 3))
  70.     shape_pca_basis = np.asarray(bfm['shape/model/pcaBasis'], dtype=np.float32).reshape((-1, 3, 199))
  71.     shape_pca_var = np.asarray(bfm['shape/model/pcaVariance'], dtype=np.float32).reshape((199))
  72.  
  73.     expression_mean = np.asarray(bfm['expression/model/mean'], dtype=np.float32).reshape((-1, 3))
  74.     expression_pca_basis = np.asarray(bfm['expression/model/pcaBasis'], dtype=np.float32).reshape((-1, 3, 100))
  75.     expression_pca_var = np.asarray(bfm['expression/model/pcaVariance'], dtype=np.float32).reshape((100))
  76.  
  77.     landmarks = np.ndarray.astype(np.loadtxt(FACE_LANDMARKS), int)
  78.  
  79.     mu_id = torch.from_numpy(shape_mean[landmarks])
  80.     pca_basis_id = torch.from_numpy(shape_pca_basis[landmarks, :, :n_id])
  81.     std_id = torch.from_numpy(np.sqrt(shape_pca_var[:n_id]))
  82.  
  83.     mu_expr = torch.from_numpy(expression_mean[landmarks])
  84.     pca_basis_expr = torch.from_numpy(expression_pca_basis[landmarks, :, :n_expr])
  85.     std_expr = torch.from_numpy(np.sqrt(expression_pca_var[:n_expr]))
  86.  
  87.     id_model = PCAModel(mu_id, pca_basis_id, std_id)
  88.     expr_model = PCAModel(mu_expr, pca_basis_expr, std_expr)
  89.  
  90.     return id_model, expr_model
  91.  
  92.  
  93. def view_mesh_render(mesh):
  94.     mesh = trimesh.base.Trimesh(
  95.         vertices=mesh.vertices,
  96.         faces=mesh.triangles,
  97.         vertex_colors=mesh.colors)
  98.  
  99.     pmesh = pyrender.Mesh.from_trimesh(mesh)
  100.     scene = pyrender.Scene()
  101.     scene.add(pmesh)
  102.     pyrender.Viewer(scene, use_raymond_lighting=True)
  103.  
  104.  
  105. def mesh_to_png(file_name, mesh):
  106.     mesh = trimesh.base.Trimesh(
  107.         vertices=mesh.vertices,
  108.         faces=mesh.triangles,
  109.         vertex_colors=mesh.colors)
  110.  
  111.     png = mesh.scene().save_image()
  112.     with open(file_name, 'wb') as f:
  113.         f.write(png)
  114.  
  115.  
  116. def viewport_matrix(top=TOP, right=RIGHT, bottom=BOTTOM, left=LEFT):
  117.     ''' I don't know how this works, just followed the tutorial at:
  118.    http://glasnost.itcarlow.ie/~powerk/GeneralGraphicsNotes/projection/viewport_transformation.html
  119.  
  120.    Given by the TA's:
  121.    cx = W / 2, cy = H / 2
  122.    viewport = [[cx, 0,  0,   cx],
  123.                [0, -cy, 0,   cy],
  124.                [0,  0,  0.5, 0.5],
  125.                [0,  0,  0,   1]]
  126.    This should be the same as our implementation
  127.    '''
  128.     scaling = torch.Tensor([(right - left) / 2, (top - bottom) / 2, 1 / 2, 1])
  129.     translation = torch.Tensor([(right + left) / 2, (top + bottom) / 2, 1 / 2])
  130.  
  131.     T = torch.eye(4)
  132.     T[:3, 3] = translation
  133.     # TODO: did i change this correctly?
  134.     S = torch.diag(scaling)
  135.  
  136.     return T @ S
  137.  
  138.  
  139. def perspective_matrix(top=TOP, right=RIGHT, bottom=BOTTOM, left=LEFT, far=FAR, near=NEAR):
  140.     '''Given by the TA's, don't know what to do with this parameter
  141.    fovy = 0.5
  142.    '''
  143.  
  144.     # Shorten for readability of matrix
  145.     t, r, b, l, f, n = top, right, bottom, left, far, near
  146.  
  147.     P = torch.Tensor([[(2 * n) / (r - l), 0, (r + l) / (r - l), 0],
  148.                       [0, (2 * n) / (t - b), (t + b) / (t - b), 0],
  149.                         [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],
  150.                         [0, 0, -1, 0]])
  151.  
  152.     P = torch.Tensor([[(2 * n) / (t), 0, 0, 0],
  153.                     [0, (2 * n) / (t - b), 0, 0],
  154.                   [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],
  155.                   [0, 0, -1, 0]])
  156.  
  157.     P = torch.Tensor([[math.tan(FOVY/2*math.pi), 0, 0, 0],
  158.                       [0, math.tan(FOVY/2*math.pi), 0, 0],
  159.                       [0, 0, -(f) / (f - n), -(f * n) / (f - n)],
  160.                       [0, 0, -1, 0]])
  161.  
  162.     return P
  163.  
  164.  
  165. def rotation_tensor(theta, phi, psi):
  166.     rot_x = torch.Tensor([(1, 0, 0),
  167.                         (0, theta.cos(), -theta.sin()),
  168.                         (0, theta.sin(), theta.cos())])
  169.     rot_y = torch.Tensor([(phi.cos(), 0, phi.sin()),
  170.                         (0, 1, 0),
  171.                         (-phi.sin(), 0, phi.cos())])
  172.     rot_z = torch.Tensor([(psi.cos(), -psi.sin(), 0),
  173.                         (psi.sin(), psi.cos(), 0),
  174.                         (0, 0, 1)])
  175.     # TODO: correct?
  176.     return rot_x @ rot_y @ rot_z
  177.  
  178.  
  179. def rigid_transform(rot_vec, translation):
  180.     T = torch.eye(4)
  181.     T[:3, :3] = rotation_tensor(rot_vec[0], rot_vec[1], rot_vec[2])
  182.     T[3, :3] = translation
  183.     # TODO: is this correct?
  184.     return T
  185.  
  186.  
  187. def get_G(alpha, delta, id_model, expr_model):
  188.     E_id = id_model.pc @ (alpha * id_model.std).t()
  189.     E_expr = expr_model.pc @ (delta * expr_model.std).t()
  190.     G = id_model.mean + E_id + expr_model.mean + E_expr
  191.     return G
  192.  
  193.  
  194. def transform_points(G, omega, t, transform_view=True):
  195.     R = rotation_tensor(omega[0], omega[1], omega[2])
  196.     p = (R @ G.t()).t() + t.t().repeat(G.shape[0], 1)
  197.  
  198.     if not transform_view:
  199.         return p
  200.  
  201.     p_4d = torch.cat((p, torch.ones((p.shape[0], 1))), 1)
  202.     P = perspective_matrix()
  203.     V = viewport_matrix()
  204.     result_4d = (V @ P @ p_4d.t())
  205.     d = result_4d[3, :]
  206.     result = result_4d / d
  207.     result = result.t()
  208.     return result
  209.  
  210.  
  211. def annotate_landmarks(image, landmarks, gt):
  212.     """
  213.    Given image and a set of landmark points, annotates the points for viewing
  214.    :param image: Input image
  215.    :type image: np.array
  216.    :param landmarks: set of facial landmark points
  217.    :type landmarks: [(float, float)]
  218.    :return: Resulting annotated image
  219.    :rtype: np.array
  220.    """
  221.     image = image.copy()
  222.     # landmarks += torch.tensor((image.shape[0]/2, image.shape[1]/2))
  223.     for idx, point in enumerate(landmarks):
  224.         pos = (point[0], point[1])
  225.         cv2.circle(image, pos, 3, color=(255, 255, 0))
  226.     for idx, point in enumerate(gt):
  227.         pos = (point[0], point[1])
  228.         cv2.circle(image, pos, 3, color=(0, 255, 0))
  229.     return image
  230.  
  231.  
  232. def overlay_mesh_on_img(img, gt, id_model_lm, expr_model_lm, alpha, delta, omega, t):
  233.     G = get_G(alpha, delta, id_model_lm, expr_model_lm)
  234.     points = transform_points(G, omega, t, transform_view=True)
  235.     win = dlib.image_window()
  236.     win.clear_overlay()
  237.     angle = torch.tensor((np.pi)).type(torch.FloatTensor)
  238.     # points = result @ rotation_tensor(angle*0, angle*0, 0*angle)
  239.     img_lm = annotate_landmarks(img, points[:,:2], gt)
  240.     win.set_image(img_lm)
  241.     # win.add_overlay(result)
  242.     dlib.hit_enter_to_continue()
  243.  
  244.  
  245. def train(max_steps=5000):
  246.     n_id, n_expr = 30, 20
  247.     id_model_lm, expr_model_lm = get_face_landmarks(n_id, n_expr)
  248.     ground_truth = torch.from_numpy(detect_landmark(IMG)).type(torch.FloatTensor)
  249.  
  250.     lr = 0.1
  251.     # alpha = Variable(torch.rand(n_id,), requires_grad=True)
  252.     # delta = Variable(torch.rand(n_expr,), requires_grad=True)
  253.     # omega = Variable(torch.rand(3,), requires_grad=True)
  254.     # t = Variable(torch.rand(3,), requires_grad=True)
  255.     alpha = Variable(torch.zeros(n_id, ), requires_grad=True)
  256.     delta = Variable(torch.zeros(n_expr, ), requires_grad=True)
  257.     omega = Variable(torch.zeros(3, ), requires_grad=True)
  258.     t = Variable(torch.zeros(3, ), requires_grad=True)
  259.     print('Initial alpha = ', alpha)
  260.     print('Initial delta = ', delta)
  261.     print('Initial omega = ', omega)
  262.     print('Initial t = ', t)
  263.     opt = torch.optim.Adam([omega, t], lr=lr)
  264.     scheduler = StepLR(opt, step_size=max_steps, gamma=0.9995)
  265.     lambda_alpha = 1000.0
  266.     lambda_delta = 1000.0
  267.  
  268.     loss_list = []
  269.     for i in range(max_steps):
  270.         # G = torch.ones((id_model_lm.pc.shape[0], 4))
  271.         # G[:, :3] = get_G(alpha, delta, id_model_lm, expr_model_lm)
  272.         G = get_G(alpha, delta, id_model_lm, expr_model_lm)
  273.         result = transform_points(G, omega, t)
  274.  
  275.         opt.zero_grad()
  276.         # result = torch.cat((result[:,1:2], result[:,0:1]), dim=1)
  277.         loss_lan = torch.nn.functional.mse_loss(result[:, :2], ground_truth)
  278.         loss_reg = lambda_alpha * (alpha**2).sum() + lambda_delta * (delta**2).sum()
  279.         loss = loss_lan + loss_reg
  280.         loss.backward()
  281.         torch.nn.utils.clip_grad_norm_([alpha, delta, omega, t], max_norm=5.0)
  282.         opt.step()
  283.         scheduler.step()
  284.         loss_list.append(loss.item())
  285.         # print("-------------------")
  286.         # print(loss_lan)
  287.         # print(loss_reg)
  288.         print("Loss:", loss.item())
  289.         # if loss.item() < 500 and loss_list[-1] > loss_list[-2]:
  290.         #     break
  291.  
  292.     print("Alpha:", alpha.mean().item(), "Delta:", delta.mean().item())
  293.     overlay_mesh_on_img(IMG, ground_truth, id_model_lm, expr_model_lm, alpha, delta, omega, t)
  294.  
  295.     id_model, expr_model, triangles, color_mean = get_face_data(n_id, n_expr)
  296.     # G = torch.ones((id_model.pc.shape[0], 4))
  297.     # G[:, :3] = get_G(alpha, delta, id_model, expr_model)
  298.     G = get_G(alpha, delta, id_model, expr_model)
  299.     result = transform_points(G, omega, t, transform_view=False)
  300.     r = torch.cat((0*torch.ones(result.shape[0], 1), 0*torch.ones(result.shape[0],1), 0*torch.ones(result.shape[0],1)), 1)
  301.     result += r
  302.     mesh = Mesh(result.detach(), color_mean, triangles)
  303.     view_mesh_render(mesh)
  304.     os.makedirs("ex4_images/", exist_ok=True)
  305.     mesh_to_png("ex4_images/image", mesh)
  306.  
  307.  
  308. if __name__ == '__main__':
  309.     train()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top