Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import numpy as np
- from torch.autograd import Variable
- from torch.optim.lr_scheduler import StepLR
- # from pinhole_camera import perspective_matrix, viewport_matrix
- from data_def import PCAModel, Mesh
- from landmarks import detect_landmark
- import dlib
- import h5py
- import math
- import trimesh
- import os
- import pyrender
- import cv2
- IMG = dlib.load_rgb_image("faces/elon-musk-frontal.png")
- FOVY = 0.5
- ASPECT_RATIO = 1
- FAR, NEAR = 2000, 300
- TOP = math.tan(FOVY/2) * NEAR
- RIGHT = TOP * ASPECT_RATIO
- BOTTOM = -TOP
- LEFT = -TOP * ASPECT_RATIO
- REL_DATA_PATH = './data/'
- MODEL_NAME = 'model2017-1_face12_nomouth.h5'
- FACE_LANDMARKS = 'Landmarks68_model2017-1_face12_nomouth.anl'
- torch.set_default_tensor_type(torch.FloatTensor)
- N_ID = 30
- N_EXPR = 20
- def get_face_data(n_id=N_ID, n_expr=N_EXPR):
- bfm = h5py.File(REL_DATA_PATH + MODEL_NAME, 'r')
- shape_mean = np.asarray(bfm['shape/model/mean'], dtype=np.float32).reshape((-1, 3))
- shape_pca_basis = np.asarray(bfm['shape/model/pcaBasis'], dtype=np.float32).reshape((-1, 3, 199))
- shape_pca_var = np.asarray(bfm['shape/model/pcaVariance'], dtype=np.float32).reshape((199))
- expression_mean = np.asarray(bfm['expression/model/mean'], dtype=np.float32).reshape((-1, 3))
- expression_pca_basis = np.asarray(bfm['expression/model/pcaBasis'], dtype=np.float32).reshape((-1, 3, 100))
- expression_pca_var = np.asarray(bfm['expression/model/pcaVariance'], dtype=np.float32).reshape((100))
- triangles = np.asarray(bfm['shape/representer/cells'], dtype=np.int32).T
- color_mean = np.asarray(bfm['color/model/mean'], dtype=np.float32).reshape((-1, 3))
- mu_id = torch.from_numpy(shape_mean)
- pca_basis_id = torch.from_numpy(shape_pca_basis[:, :, :n_id])
- std_id = torch.from_numpy(np.sqrt(shape_pca_var[:n_id]))
- mu_expr = torch.from_numpy(expression_mean)
- pca_basis_expr = torch.from_numpy(expression_pca_basis[:, :, :n_expr])
- std_expr = torch.from_numpy(np.sqrt(expression_pca_var[:n_expr]))
- id_model = PCAModel(mu_id, pca_basis_id, std_id)
- expr_model = PCAModel(mu_expr, pca_basis_expr, std_expr)
- return id_model, expr_model, triangles, color_mean
- def get_face_landmarks(n_id=N_ID, n_expr=N_EXPR):
- bfm = h5py.File(REL_DATA_PATH + MODEL_NAME, 'r')
- shape_mean = np.asarray(bfm['shape/model/mean'], dtype=np.float32).reshape((-1, 3))
- shape_pca_basis = np.asarray(bfm['shape/model/pcaBasis'], dtype=np.float32).reshape((-1, 3, 199))
- shape_pca_var = np.asarray(bfm['shape/model/pcaVariance'], dtype=np.float32).reshape((199))
- expression_mean = np.asarray(bfm['expression/model/mean'], dtype=np.float32).reshape((-1, 3))
- expression_pca_basis = np.asarray(bfm['expression/model/pcaBasis'], dtype=np.float32).reshape((-1, 3, 100))
- expression_pca_var = np.asarray(bfm['expression/model/pcaVariance'], dtype=np.float32).reshape((100))
- landmarks = np.ndarray.astype(np.loadtxt(FACE_LANDMARKS), int)
- mu_id = torch.from_numpy(shape_mean[landmarks])
- pca_basis_id = torch.from_numpy(shape_pca_basis[landmarks, :, :n_id])
- std_id = torch.from_numpy(np.sqrt(shape_pca_var[:n_id]))
- mu_expr = torch.from_numpy(expression_mean[landmarks])
- pca_basis_expr = torch.from_numpy(expression_pca_basis[landmarks, :, :n_expr])
- std_expr = torch.from_numpy(np.sqrt(expression_pca_var[:n_expr]))
- id_model = PCAModel(mu_id, pca_basis_id, std_id)
- expr_model = PCAModel(mu_expr, pca_basis_expr, std_expr)
- return id_model, expr_model
- def view_mesh_render(mesh):
- mesh = trimesh.base.Trimesh(
- vertices=mesh.vertices,
- faces=mesh.triangles,
- vertex_colors=mesh.colors)
- pmesh = pyrender.Mesh.from_trimesh(mesh)
- scene = pyrender.Scene()
- scene.add(pmesh)
- pyrender.Viewer(scene, use_raymond_lighting=True)
- def mesh_to_png(file_name, mesh):
- mesh = trimesh.base.Trimesh(
- vertices=mesh.vertices,
- faces=mesh.triangles,
- vertex_colors=mesh.colors)
- png = mesh.scene().save_image()
- with open(file_name, 'wb') as f:
- f.write(png)
- def viewport_matrix(top=TOP, right=RIGHT, bottom=BOTTOM, left=LEFT):
- ''' I don't know how this works, just followed the tutorial at:
- http://glasnost.itcarlow.ie/~powerk/GeneralGraphicsNotes/projection/viewport_transformation.html
- Given by the TA's:
- cx = W / 2, cy = H / 2
- viewport = [[cx, 0, 0, cx],
- [0, -cy, 0, cy],
- [0, 0, 0.5, 0.5],
- [0, 0, 0, 1]]
- This should be the same as our implementation
- '''
- scaling = torch.Tensor([(right - left) / 2, (top - bottom) / 2, 1 / 2, 1])
- translation = torch.Tensor([(right + left) / 2, (top + bottom) / 2, 1 / 2])
- T = torch.eye(4)
- T[:3, 3] = translation
- # TODO: did i change this correctly?
- S = torch.diag(scaling)
- return T @ S
- def perspective_matrix(top=TOP, right=RIGHT, bottom=BOTTOM, left=LEFT, far=FAR, near=NEAR):
- '''Given by the TA's, don't know what to do with this parameter
- fovy = 0.5
- '''
- # Shorten for readability of matrix
- t, r, b, l, f, n = top, right, bottom, left, far, near
- P = torch.Tensor([[(2 * n) / (r - l), 0, (r + l) / (r - l), 0],
- [0, (2 * n) / (t - b), (t + b) / (t - b), 0],
- [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],
- [0, 0, -1, 0]])
- P = torch.Tensor([[(2 * n) / (t), 0, 0, 0],
- [0, (2 * n) / (t - b), 0, 0],
- [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],
- [0, 0, -1, 0]])
- P = torch.Tensor([[math.tan(FOVY/2*math.pi), 0, 0, 0],
- [0, math.tan(FOVY/2*math.pi), 0, 0],
- [0, 0, -(f) / (f - n), -(f * n) / (f - n)],
- [0, 0, -1, 0]])
- return P
- def rotation_tensor(theta, phi, psi):
- rot_x = torch.Tensor([(1, 0, 0),
- (0, theta.cos(), -theta.sin()),
- (0, theta.sin(), theta.cos())])
- rot_y = torch.Tensor([(phi.cos(), 0, phi.sin()),
- (0, 1, 0),
- (-phi.sin(), 0, phi.cos())])
- rot_z = torch.Tensor([(psi.cos(), -psi.sin(), 0),
- (psi.sin(), psi.cos(), 0),
- (0, 0, 1)])
- # TODO: correct?
- return rot_x @ rot_y @ rot_z
- def rigid_transform(rot_vec, translation):
- T = torch.eye(4)
- T[:3, :3] = rotation_tensor(rot_vec[0], rot_vec[1], rot_vec[2])
- T[3, :3] = translation
- # TODO: is this correct?
- return T
- def get_G(alpha, delta, id_model, expr_model):
- E_id = id_model.pc @ (alpha * id_model.std).t()
- E_expr = expr_model.pc @ (delta * expr_model.std).t()
- G = id_model.mean + E_id + expr_model.mean + E_expr
- return G
- def transform_points(G, omega, t, transform_view=True):
- R = rotation_tensor(omega[0], omega[1], omega[2])
- p = (R @ G.t()).t() + t.t().repeat(G.shape[0], 1)
- if not transform_view:
- return p
- p_4d = torch.cat((p, torch.ones((p.shape[0], 1))), 1)
- P = perspective_matrix()
- V = viewport_matrix()
- result_4d = (V @ P @ p_4d.t())
- d = result_4d[3, :]
- result = result_4d / d
- result = result.t()
- return result
- def annotate_landmarks(image, landmarks, gt):
- """
- Given image and a set of landmark points, annotates the points for viewing
- :param image: Input image
- :type image: np.array
- :param landmarks: set of facial landmark points
- :type landmarks: [(float, float)]
- :return: Resulting annotated image
- :rtype: np.array
- """
- image = image.copy()
- # landmarks += torch.tensor((image.shape[0]/2, image.shape[1]/2))
- for idx, point in enumerate(landmarks):
- pos = (point[0], point[1])
- cv2.circle(image, pos, 3, color=(255, 255, 0))
- for idx, point in enumerate(gt):
- pos = (point[0], point[1])
- cv2.circle(image, pos, 3, color=(0, 255, 0))
- return image
- def overlay_mesh_on_img(img, gt, id_model_lm, expr_model_lm, alpha, delta, omega, t):
- G = get_G(alpha, delta, id_model_lm, expr_model_lm)
- points = transform_points(G, omega, t, transform_view=True)
- win = dlib.image_window()
- win.clear_overlay()
- angle = torch.tensor((np.pi)).type(torch.FloatTensor)
- # points = result @ rotation_tensor(angle*0, angle*0, 0*angle)
- img_lm = annotate_landmarks(img, points[:,:2], gt)
- win.set_image(img_lm)
- # win.add_overlay(result)
- dlib.hit_enter_to_continue()
- def train(max_steps=5000):
- n_id, n_expr = 30, 20
- id_model_lm, expr_model_lm = get_face_landmarks(n_id, n_expr)
- ground_truth = torch.from_numpy(detect_landmark(IMG)).type(torch.FloatTensor)
- lr = 0.1
- # alpha = Variable(torch.rand(n_id,), requires_grad=True)
- # delta = Variable(torch.rand(n_expr,), requires_grad=True)
- # omega = Variable(torch.rand(3,), requires_grad=True)
- # t = Variable(torch.rand(3,), requires_grad=True)
- alpha = Variable(torch.zeros(n_id, ), requires_grad=True)
- delta = Variable(torch.zeros(n_expr, ), requires_grad=True)
- omega = Variable(torch.zeros(3, ), requires_grad=True)
- t = Variable(torch.zeros(3, ), requires_grad=True)
- print('Initial alpha = ', alpha)
- print('Initial delta = ', delta)
- print('Initial omega = ', omega)
- print('Initial t = ', t)
- opt = torch.optim.Adam([omega, t], lr=lr)
- scheduler = StepLR(opt, step_size=max_steps, gamma=0.9995)
- lambda_alpha = 1000.0
- lambda_delta = 1000.0
- loss_list = []
- for i in range(max_steps):
- # G = torch.ones((id_model_lm.pc.shape[0], 4))
- # G[:, :3] = get_G(alpha, delta, id_model_lm, expr_model_lm)
- G = get_G(alpha, delta, id_model_lm, expr_model_lm)
- result = transform_points(G, omega, t)
- opt.zero_grad()
- # result = torch.cat((result[:,1:2], result[:,0:1]), dim=1)
- loss_lan = torch.nn.functional.mse_loss(result[:, :2], ground_truth)
- loss_reg = lambda_alpha * (alpha**2).sum() + lambda_delta * (delta**2).sum()
- loss = loss_lan + loss_reg
- loss.backward()
- torch.nn.utils.clip_grad_norm_([alpha, delta, omega, t], max_norm=5.0)
- opt.step()
- scheduler.step()
- loss_list.append(loss.item())
- # print("-------------------")
- # print(loss_lan)
- # print(loss_reg)
- print("Loss:", loss.item())
- # if loss.item() < 500 and loss_list[-1] > loss_list[-2]:
- # break
- print("Alpha:", alpha.mean().item(), "Delta:", delta.mean().item())
- overlay_mesh_on_img(IMG, ground_truth, id_model_lm, expr_model_lm, alpha, delta, omega, t)
- id_model, expr_model, triangles, color_mean = get_face_data(n_id, n_expr)
- # G = torch.ones((id_model.pc.shape[0], 4))
- # G[:, :3] = get_G(alpha, delta, id_model, expr_model)
- G = get_G(alpha, delta, id_model, expr_model)
- result = transform_points(G, omega, t, transform_view=False)
- r = torch.cat((0*torch.ones(result.shape[0], 1), 0*torch.ones(result.shape[0],1), 0*torch.ones(result.shape[0],1)), 1)
- result += r
- mesh = Mesh(result.detach(), color_mean, triangles)
- view_mesh_render(mesh)
- os.makedirs("ex4_images/", exist_ok=True)
- mesh_to_png("ex4_images/image", mesh)
- if __name__ == '__main__':
- train()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement