Advertisement
noodleham

3D GS Training Plots Main

May 13th, 2024 (edited)
521
0
169 days
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 34.31 KB | Software | 0 0
  1. #
  2. # Copyright (C) 2023, Inria
  3. # GRAPHDECO research group, https://team.inria.fr/graphdeco
  4. # All rights reserved.
  5. #
  6. # This software is free for non-commercial, research and evaluation use
  7. # under the terms of the LICENSE.md file.
  8. #
  9. # For inquiries contact  george.drettakis@inria.fr
  10. #
  11. import copy
  12. import math
  13. import os
  14. import pdb
  15. from collections import defaultdict
  16. from typing import List, Union, Tuple, Dict, Literal
  17. import pprint
  18.  
  19. import numpy as np
  20. import torch
  21. import random
  22. from random import randint
  23. from utils.loss_utils import l1_loss, ssim
  24. from gaussian_renderer import render, network_gui
  25. import sys
  26. from scene import Scene, GaussianModel
  27. from utils.general_utils import safe_state
  28. import uuid
  29. from tqdm import tqdm
  30. from utils.image_utils import psnr
  31. from argparse import ArgumentParser, Namespace
  32. from arguments import ModelParams, PipelineParams, OptimizationParams
  33.  
  34. import matplotlib.pyplot as plt
  35. import matplotlib
  36.  
  37. try:
  38.     from torch.utils.tensorboard import SummaryWriter
  39.  
  40.     TENSORBOARD_FOUND = True
  41. except ImportError:
  42.     TENSORBOARD_FOUND = False
  43.  
  44.  
  45. def get_full_batch_gradients(gaussians: GaussianModel, viewpoint_stack, background, pipe, opt, grad_keys):
  46.     gaussians.optimizer.zero_grad(set_to_none=True)
  47.     for i in range(len(viewpoint_stack)):
  48.         viewpoint_cam = viewpoint_stack[i]
  49.         bg = background
  50.         render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
  51.         image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], \
  52.             render_pkg["visibility_filter"], render_pkg["radii"]
  53.         # Loss
  54.         gt_image = viewpoint_cam.original_image.cuda()
  55.         Ll1 = l1_loss(image, gt_image)
  56.         loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
  57.         loss.backward()
  58.     ret = {k: getattr(gaussians, k).grad for k in grad_keys}
  59.     gaussians.optimizer.zero_grad(set_to_none=True)
  60.     return ret
  61.  
  62.  
  63. def get_grad_stats(gaussians: GaussianModel, viewpoint_stack, background, pipe, opt, grad_keys,
  64.                    sampling: str = "random",
  65.                    accum_steps: int = 1, determininstic_index: int = None,
  66.                    monitor_params: Tuple[str, List[int]] = None) -> Tuple[
  67.     torch.Tensor, Dict[str, np.array], Dict[str, np.array], Dict[str, np.array], Dict[str, np.array]]:
  68.     assert sampling in ["random", "random_order_whole", "nearby", ""]
  69.     cam_indices = None
  70.  
  71.     if determininstic_index is not None:
  72.         cam_indices = list(range(determininstic_index, determininstic_index + accum_steps))
  73.     else:
  74.         if sampling == "random":
  75.             cam_indices = [randint(0, len(viewpoint_stack) - 1) for _ in range(accum_steps)]
  76.         elif sampling == 'nearby':
  77.             index = randint(0, len(viewpoint_stack) - accum_steps)
  78.             cam_indices = list(range(index, index + accum_steps))
  79.         elif sampling == "random_order_whole":
  80.             assert accum_steps == len(viewpoint_stack)
  81.             cam_indices = np.random.permutation(np.arange(len(viewpoint_stack)))
  82.     grad_running_sum = {k: 0 for k in grad_keys}
  83.     sparsities = {k: [] for k in grad_keys}
  84.     variances = {k: [] for k in grad_keys}
  85.     cosines = {k: [] for k in grad_keys}
  86.     SNRs = {k: [] for k in grad_keys}
  87.     target_grad = get_full_batch_gradients(gaussians, viewpoint_stack, background, pipe, opt, grad_keys)
  88.  
  89.     gaussians.optimizer.zero_grad(set_to_none=True)
  90.     for i, cam_idx in enumerate(cam_indices):
  91.         viewpoint_cam = viewpoint_stack[cam_idx]
  92.         bg = background
  93.         render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
  94.         image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], \
  95.             render_pkg["visibility_filter"], render_pkg["radii"]
  96.         # Loss
  97.         gt_image = viewpoint_cam.original_image.cuda()
  98.         Ll1 = l1_loss(image, gt_image)
  99.         loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
  100.         loss.backward()
  101.  
  102.         if monitor_params is not None:
  103.             raise NotImplementedError()
  104.         else:
  105.             for k in grad_keys:
  106.                 grad_running_sum[k] += getattr(gaussians, k).grad
  107.                 sparsities[k].append(float(get_sparsity(grad_running_sum[k])))
  108.                 variances[k].append(float(get_variance(grad_running_sum[k], mean=0).mean()))
  109.                 cosines[k].append(float(
  110.                     torch.nn.functional.cosine_similarity(grad_running_sum[k].flatten(), target_grad[k].flatten(),
  111.                                                           dim=0)))
  112.                 signal = target_grad[k].flatten() / len(cam_indices)
  113.                 sample = grad_running_sum[k].flatten() / (i + 1)
  114.                 noise = (sample - signal)
  115.                 SNRs[k].append(float(torch.inner(signal, signal) / torch.inner(noise, noise)))
  116.         gaussians.optimizer.zero_grad(set_to_none=True)
  117.     sparsities_np = {k: np.array(v) for k, v in sparsities.items()}
  118.     variances_np = {k: np.array(v) for k, v in variances.items()}
  119.     cosines_np = {k: np.array(v) for k, v in cosines.items()}
  120.     SNRs_np = {k: np.array(v) for k, v in SNRs.items()}
  121.     return grad_running_sum, sparsities_np, variances_np, cosines_np, SNRs_np
  122.  
  123.  
  124. def get_sparsity(grad: torch.Tensor) -> torch.Tensor:
  125.     return (grad == 0).sum() / grad.numel()
  126.  
  127.  
  128. def get_variance(grad: torch.Tensor, mean) -> torch.Tensor:
  129.     return (grad - mean) ** 2
  130.  
  131.  
  132. def restored_gaussians(model_params, dataset, opt, deepcopy=False) -> GaussianModel:
  133.     if deepcopy:
  134.         model_params = copy.deepcopy(model_params)
  135.     gaussians = GaussianModel(dataset.sh_degree)
  136.     gaussians.training_setup(opt)
  137.     gaussians.restore(model_params, opt)
  138.     return gaussians
  139.  
  140.  
  141. def fill_subplot(ax, title, xs, ys, xlabel, ylabel, xscale='linear', legend_labels: Union[str, List[str]] = ""):
  142.     if isinstance(ys[0], list):
  143.         for i in range(len(ys)):
  144.             if isinstance(xs[0], list):
  145.                 ax.plot(xs[i], ys[i], label=legend_labels[i].replace('_', ''), marker='.')
  146.             else:
  147.                 ax.plot(xs, ys[i], label=legend_labels[i].replace('_', ''), marker='.')
  148.     else:
  149.         ax.plot(xs, ys, label=legend_labels.replace('_', ''), marker='.')
  150.     ax.set_title(title)
  151.     ax.set_xlabel(xlabel)
  152.     ax.set_xscale(xscale)
  153.     # if xscale == 'log':
  154.     #     ax.set_xticks(xs)
  155.     #     ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
  156.     ax.set_ylabel(ylabel)
  157.     if legend_labels != '':
  158.         ax.legend()
  159.  
  160.  
  161. def plot_histogram(grads, num_bins=1000):
  162.     counts, bins = torch.histogram(grads, bins=num_bins)
  163.     plt.hist(bins[:-1], bins, weights=counts)
  164.     # plt.yscale('symlog')
  165.     plt.ylim((0, counts.max()))
  166.     plt.show()
  167.     plt.close()
  168.  
  169.  
  170. def plot_covariance(cov, to_plot: List[int], sqrt=True):
  171.     if len(to_plot) >= 4:
  172.         nrows, ncols = 2, math.ceil(len(to_plot) / 2)
  173.     else:
  174.         nrows, ncols = 1, len(to_plot)
  175.     fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 6 * nrows + 2), dpi=200)
  176.     for i, ax in enumerate(fig.axes):
  177.         if i > len(to_plot):
  178.             break
  179.         if sqrt:
  180.             ax.imshow(torch.sqrt(torch.abs(cov[to_plot[i]])), cmap='gray')
  181.             ax.set_title(f'sqrt(covariance) for parameter {to_plot[i]}')
  182.         else:
  183.             ax.imshow(torch.abs(cov[to_plot[i]]), cmap='gray')
  184.             ax.set_title(f'covariance for parameter {to_plot[i]}')
  185.         ax.set_xlabel('view #')
  186.         ax.set_ylabel('view #')
  187.     plt.suptitle('Grad covariance by view')
  188.     plt.show()
  189.     plt.close()
  190.  
  191.  
  192. def get_variance_sparsity_cosine_SNR(gaussians, train_cameras, background, pipe, opt, keys, num_trials, accum_steps):
  193.     sparsities = {k: 0 for k in keys}
  194.     variances = {k: 0 for k in keys}
  195.     cosines = {k: 0 for k in keys}
  196.     SNRs = {k: 0 for k in keys}
  197.     for trial in range(num_trials):
  198.         grads_runing_sum, s, v, c, snrs = get_grad_stats(gaussians, train_cameras, background, pipe, opt, keys,
  199.                                                          sampling="random_order_whole", accum_steps=accum_steps)
  200.         # grads_runing_sum, s, v, c = get_grad_stats(gaussians, train_cameras, background, pipe, opt, keys,
  201.         #                                         sampling="", accum_steps=accum_steps, determininstic_index=0)
  202.         for k in keys:
  203.             sparsities[k] += s[k] / num_trials
  204.             variances[k] += v[k] / num_trials
  205.             cosines[k] += c[k] / num_trials
  206.             SNRs[k] += snrs[k] / num_trials
  207.     return variances, sparsities, cosines, SNRs
  208.  
  209.  
  210. def plot_variance_sparsity_cosine(dataset, opt, train_cameras, background, pipe, checkpoint, keys, num_trials=32,
  211.                                   iters_list=[15000, 30000]):
  212.     accum_steps = len(train_cameras)
  213.     random.seed()
  214.     variances: List[Dict[str: np.array]] = []
  215.     sparsities: List[Dict[str: np.array]] = []
  216.     cosines: List[Dict[str: np.array]] = []
  217.     SNRs: List[Dict[str: np.array]] = []
  218.     for iter in iters_list:
  219.         chpt = checkpoint.rstrip(".pth").split("chkpnt")[1]
  220.         print('loading checkpoint ', checkpoint.replace("chkpnt" + str(chpt), "chkpnt" + str(iter)))
  221.         (model_params, first_iter) = torch.load(checkpoint.replace("chkpnt" + str(chpt), "chkpnt" + str(iter)))
  222.         gaussians = restored_gaussians(model_params, dataset, opt)
  223.         v, s, c, snrs = get_variance_sparsity_cosine_SNR(gaussians, train_cameras, background, pipe, opt, keys,
  224.                                                          num_trials, accum_steps)
  225.         variances.append(v)
  226.         sparsities.append(s)
  227.         cosines.append(c)
  228.         SNRs.append(snrs)
  229.  
  230.     scene_name = os.path.basename(args.source_path)
  231.     for k in keys:
  232.         fig, ax = plt.subplots(1, 5, figsize=(6 * 5, 6))
  233.         for i, iter in enumerate(iters_list):
  234.             fill_subplot(ax[0], 'Batch size vs Grad Sparsity', np.arange(accum_steps),
  235.                          sparsities[i][k],
  236.                          'Batch size', 'Sparsity', xscale='log', legend_labels='iter ' + str(iter))
  237.             fill_subplot(ax[1], 'Batch size vs Grad Variance', np.arange(accum_steps),
  238.                          variances[i][k],
  239.                          'Batch size', 'Avg Parameter Variance', xscale='linear', legend_labels='iter ' + str(iter))
  240.             fill_subplot(ax[2], 'Batch size vs Cosine Sim. with Full-Batch Grad', np.arange(accum_steps),
  241.                          cosines[i][k],
  242.                          'Batch size', 'Cosine Similarity', xscale='linear', legend_labels='iter ' + str(iter))
  243.             # Ignore the last value because full-batch SNR is infinite
  244.             fill_subplot(ax[3], 'Batch size vs grad SNR', np.arange(accum_steps)[:-10],
  245.                          SNRs[i][k][:-10],
  246.                          'Batch size', 'SNR', xscale='linear', legend_labels='iter ' + str(iter))
  247.             # Ignore the last value because full-batch SNR is infinite
  248.             fill_subplot(ax[4], 'Batch size vs grad NSR', np.arange(accum_steps)[:-10],
  249.                          1 / SNRs[i][k][:-10],
  250.                          'Batch size', 'NSR', xscale='linear', legend_labels='iter ' + str(iter))
  251.         fig.suptitle(f'Scene: {scene_name}. Param group: {k.replace("_", "")}')
  252.         fig.tight_layout()
  253.         os.makedirs(os.path.join('plots', scene_name), exist_ok=True)
  254.         fig.savefig(os.path.join('plots', scene_name,
  255.                                  f'scene_{scene_name}_param_{k.replace("_", "")}_random_order_trials_{num_trials}_snr.png'))
  256.         fig.show()
  257.         plt.close(fig)
  258.  
  259.  
  260. def backward_once(gaussians: GaussianModel, viewpoint_cam, opt, pipe, background):
  261.     render_pkg = render(viewpoint_cam, gaussians, pipe, background)
  262.     image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg[
  263.         "viewspace_points"], \
  264.         render_pkg["visibility_filter"], render_pkg["radii"]
  265.     # Loss
  266.     gt_image = viewpoint_cam.original_image
  267.     Ll1 = l1_loss(image, gt_image)
  268.     loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
  269.     loss.backward()
  270.     return loss
  271.  
  272.  
  273. def run_iterations(gaussians: GaussianModel, train_cameras, opt, camera_ids, batch_size, pipe, background,
  274.                    discard_last=True):
  275.     gaussians.optimizer.zero_grad(set_to_none=True)
  276.     for i, camera_id in enumerate(camera_ids):
  277.         backward_once(gaussians, train_cameras[camera_id], opt, pipe, background)
  278.  
  279.         # Update params every batch_size iterations
  280.         if (i + 1) % batch_size == 0:
  281.             gaussians.optimizer.step()
  282.             gaussians.optimizer.zero_grad(set_to_none=True)
  283.         elif not discard_last and i == len(camera_ids) - 1:
  284.             gaussians.optimizer.step()
  285.             gaussians.optimizer.zero_grad(set_to_none=True)
  286.     gaussians.optimizer.zero_grad(set_to_none=True)
  287.     return gaussians
  288.  
  289.  
  290. def clear_adam_state(optimizer, bach_size, rescale_betas=True,
  291.                      lr_scaling: Literal['constant', 'sqrt', 'linear'] = 'sqrt',
  292.                      disable_momentum: bool = False):
  293.     for group in optimizer.param_groups:
  294.         # clear ADAM state
  295.         if rescale_betas:
  296.             group['betas'] = (group['betas'][0] ** bach_size, group['betas'][1] ** bach_size)
  297.             if disable_momentum:
  298.                 group['betas'] = (group['betas'][0] * 0, group['betas'][1])
  299.         for p in group['params']:
  300.             state = optimizer.state[p]
  301.             state['exp_avg'] *= 0
  302.             state['exp_avg_sq'] *= 0
  303.             state['step'] *= 0
  304.         if lr_scaling == 'constant':
  305.             coeff = 1
  306.         elif lr_scaling == 'sqrt':
  307.             coeff = float(bach_size) ** 0.5
  308.         elif lr_scaling == 'linear':
  309.             coeff = float(bach_size)
  310.         else:
  311.             raise ValueError(f'Unknown lr_scaling {lr_scaling}')
  312.         group['lr'] *= coeff
  313.  
  314.  
  315. def print_learning_rate(optimizer):
  316.     for group in optimizer.param_groups:
  317.         print(group['name'], group['lr'])
  318.  
  319.  
  320. def plot_batch_size_vs_weights_delta_similarity(dataset, opt, train_cameras, background, pipe, checkpoint_path, keys,
  321.                                                 num_trials=32,
  322.                                                 checkpoints_list=[15000, 30000],
  323.                                                 batch_sizes=[1, 4, 16, 64], warmup_epochs=1, run_epochs=5,
  324.                                                 rescale_betas=True,
  325.                                                 lr_scaling: Literal['constant', 'sqrt', 'linear'] = 'sqrt',
  326.                                                 disable_momentum=False):
  327.     random.seed()
  328.     cosines_for_checkpoint = []
  329.     norms_for_checkpoint = []
  330.     losses_for_checkpoint = []
  331.     param_index_map = {'_xyz': 1, '_features_dc': 2, '_features_rest': 3, '_scaling': 4, '_rotation': 5, '_opacity': 6}
  332.  
  333.     for checkpoint_itr in checkpoints_list:
  334.         cosines: Dict[str, List[Dict[int, float]]] = {k: [{} for _ in range(len(batch_sizes))] for k in keys}
  335.         norms: Dict[str, List[Dict[int, float]]] = {k: [{} for _ in range(len(batch_sizes))] for k in keys}
  336.         losses: List[Dict[int, float]] = [defaultdict(float) for _ in range(len(batch_sizes))]
  337.         chpt = checkpoint_path.rstrip(".pth").split("chkpnt")[1]
  338.         cur_checkpoint = checkpoint_path.replace("chkpnt" + str(chpt), "chkpnt" + str(checkpoint_itr))
  339.         print('loading checkpoint ', cur_checkpoint)
  340.         (model_params, first_iter) = torch.load(cur_checkpoint)
  341.  
  342.         # original_gaussians = restored_gaussians(model_params, opt)
  343.         original_params = {k: model_params[param_index_map[k]] for k in keys}
  344.         camera_idx = np.concatenate(
  345.             [np.random.permutation(np.arange(len(train_cameras))) for _ in range(warmup_epochs + run_epochs)])
  346.  
  347.         for batch_size in batch_sizes:
  348.             if batch_size == 1:
  349.                 continue
  350.             print('Running for batch size', batch_size)
  351.             temp_batch_sizes = [1, batch_size]
  352.             running_gaussians = [restored_gaussians(model_params, dataset, opt, deepcopy=True) for _ in
  353.                                  temp_batch_sizes]
  354.             # Readjust ADAM parameters for batch size > 1
  355.             for i in range(len(temp_batch_sizes)):
  356.                 clear_adam_state(running_gaussians[i].optimizer, temp_batch_sizes[i], rescale_betas=rescale_betas,
  357.                                  lr_scaling=lr_scaling, disable_momentum=disable_momentum)
  358.                 # warmup new ADAM state
  359.                 run_iterations(running_gaussians[i], train_cameras, opt,
  360.                                camera_idx[:len(train_cameras) * warmup_epochs], temp_batch_sizes[i], pipe, background,
  361.                                discard_last=True)
  362.                 running_gaussians[i].restore_parameters(model_params, opt)
  363.  
  364.             for i, camera_id in enumerate(tqdm(camera_idx[len(train_cameras) * warmup_epochs:])):
  365.                 for j, temp_batch_size in enumerate(temp_batch_sizes):
  366.                     running_gaussian = running_gaussians[j]
  367.                     loss = backward_once(running_gaussian, train_cameras[camera_id], opt, pipe, background)
  368.  
  369.                     # average gradients from views
  370.                     next_accum_step = min(len(camera_idx), (i // temp_batch_size + 1) * temp_batch_size)
  371.                     last_accum_step = (i // temp_batch_size) * temp_batch_size
  372.                     # loss /= min(temp_batch_size, len(train_cameras) - last_accum_step)
  373.                     # loss /= temp_batch_size
  374.  
  375.                     # print('batch size ', temp_batch_size, 'next accum step', next_accum_step, 'divider', min(temp_batch_size, len(train_cameras) - last_accum_step))
  376.                     if temp_batch_size == 1:
  377.                         losses[batch_sizes.index(temp_batch_size)][next_accum_step] = float(
  378.                             loss.item()) / temp_batch_size
  379.                     else:
  380.                         losses[batch_sizes.index(temp_batch_size)][next_accum_step] += float(
  381.                             loss.item()) / temp_batch_size
  382.                     # Update params every batch_size iterations
  383.                     # if (i + 1) % temp_batch_size == 0 or i == len(camera_idx) - 1:
  384.                     if (i + 1) % temp_batch_size == 0:
  385.                         running_gaussian.optimizer.step()
  386.                         running_gaussian.optimizer.zero_grad(set_to_none=True)
  387.                         for k in keys:
  388.                             # compare weight delta from batch-size 1 and that from the current batch-size
  389.                             weight_delta = getattr(running_gaussian, k).detach() - original_params[k]
  390.                             norms[k][batch_sizes.index(temp_batch_size)][i + 1] = float(torch.linalg.norm(weight_delta))
  391.                             if temp_batch_size != 1:
  392.                                 reference_weight_delta = getattr(running_gaussians[0], k).detach() - original_params[k]
  393.                                 cosines[k][batch_sizes.index(temp_batch_size)][i + 1] = float(
  394.                                     torch.nn.functional.cosine_similarity(reference_weight_delta.flatten(),
  395.                                                                           weight_delta.flatten(), dim=0))
  396.             del running_gaussians, running_gaussian, weight_delta, reference_weight_delta, loss
  397.         cosines_for_checkpoint.append(cosines)
  398.         losses_for_checkpoint.append(losses)
  399.         norms_for_checkpoint.append(norms)
  400.         del model_params, first_iter, original_params
  401.     # pprint.pp(losses_for_checkpoint)
  402.     # pprint.pp(norms_for_checkpoint)
  403.     # pprint.pp(cosines_for_checkpoint)
  404.     return cosines_for_checkpoint, losses_for_checkpoint, norms_for_checkpoint
  405.  
  406.  
  407. def plot(cosines, losses, norms, keys, checkpoint_iter, batch_sizes, rescale_betas: bool, lr_scaling: str, warmup_epochs: int, disable_momentum: bool):
  408.     scene_name = os.path.basename(args.source_path)
  409.     for k in keys:
  410.         fig, ax = plt.subplots(1, 3, figsize=(6 * 3, 6), dpi=200)
  411.         fill_subplot(ax[0], 'Batch size vs cosine(weight delta w.r.t bs=1)',
  412.                      [list(cosines[k][j].keys()) for j in range(len(batch_sizes))],
  413.                      [list(cosines[k][j].values()) for j in range(len(batch_sizes))],
  414.                      'Iterations', 'Cosine Sim', xscale='linear', legend_labels=[f'BS {b}' for b in batch_sizes])
  415.         fill_subplot(ax[1], 'Batch size vs Loss',
  416.                      [list(losses[j].keys()) for j in range(len(batch_sizes))],
  417.                      [list(losses[j].values()) for j in range(len(batch_sizes))],
  418.                      'Iterations', 'loss', xscale='linear', legend_labels=[f'BS {b}' for b in batch_sizes])
  419.         fill_subplot(ax[2], 'Batch size vs norm(weight delta)',
  420.                      [list(norms[k][j].keys()) for j in range(len(batch_sizes))],
  421.                      [list(norms[k][j].values()) for j in range(len(batch_sizes))],
  422.                      'Iterations', 'norm', xscale='linear', legend_labels=[f'BS {b}' for b in batch_sizes])
  423.  
  424.         disable_momentum_str = '_disable momentum' if disable_momentum else ''
  425.         fig.suptitle(f'Scene: {scene_name}. Checkpoint {checkpoint_iter}. Rescale betas: {rescale_betas}{disable_momentum_str}. LR scaling: {lr_scaling}. Warmup: {warmup_epochs} epochs. Params: {k}')
  426.         fig.tight_layout()
  427.         os.makedirs(os.path.join('plots_grad_delta', scene_name), exist_ok=True)
  428.         fig.savefig(os.path.join('plots_grad_delta', scene_name,
  429.                                  f'scene_{scene_name}_checkpoint_{checkpoint_iter}_param_{k.replace("_", "")}'
  430.                                  f'_rescale_betas_{rescale_betas}{disable_momentum_str.replace(" ", "_")}_lr_{lr_scaling}_warmup_{warmup_epochs}.png'))
  431.         if k == '_xyz':
  432.             fig.show()
  433.         plt.close(fig)
  434.  
  435.  
  436. def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
  437.     first_iter = 0
  438.     tb_writer = prepare_output_and_logger(dataset)
  439.     scene = Scene(dataset, None)
  440.     # if checkpoint:
  441.     #     (model_params, first_iter) = torch.load(checkpoint)
  442.     #     gaussians.restore(model_params, opt)
  443.  
  444.     bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
  445.     background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
  446.  
  447.     iter_start = torch.cuda.Event(enable_timing=True)
  448.     iter_end = torch.cuda.Event(enable_timing=True)
  449.  
  450.     viewpoint_stack = None
  451.     ema_loss_for_log = 0.0
  452.     progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
  453.     first_iter += 1
  454.  
  455.     random.seed()
  456.     n_epochs = 0
  457.     keys = ['_xyz', '_rotation', '_scaling', '_opacity', '_features_dc']
  458.     num_views = len(scene.getTrainCameras())
  459.     train_cameras = scene.getTrainCameras()
  460.     # plot_variance_sparsity_cosine(dataset, opt, train_cameras, background, pipe, checkpoint, keys, num_trials=4)
  461.     del scene
  462.     checkpoints_list = [15000, 30000]
  463.     run_epochs = 4
  464.     warmup_epochs = 2
  465.     batch_sizes = [1, 4, 8, 16, 32, 64]
  466.     disable_momentum = False
  467.     # for lr_scaling in ['sqrt', 'constant', 'linear']:
  468.     #     for rescale_betas in [True, False]:
  469.     #         cosines_checkpoint, losses_checkpoint, norms_checkpoint = plot_batch_size_vs_weights_delta_similarity(
  470.     #             dataset, opt, train_cameras, background, pipe, checkpoint, keys,
  471.     #             checkpoints_list=checkpoints_list, batch_sizes=batch_sizes, run_epochs=run_epochs, warmup_epochs=warmup_epochs, rescale_betas=rescale_betas,
  472.     #             lr_scaling='sqrt')
  473.     #         for i in range(len(checkpoints_list)):
  474.     #             plot(cosines_checkpoint[i], losses_checkpoint[i], norms_checkpoint[i], keys, checkpoints_list[i],
  475.     #                  batch_sizes, rescale_betas, lr_scaling, warmup_epochs)
  476.     for warmup_epochs in [0, 1, 2]:
  477.         for lr_scaling in ['sqrt', 'constant', 'linear']:
  478.             disable_momentums = [False, True] if warmup_epochs == 2 else [False]
  479.             for disable_momentum in disable_momentums:
  480.                 rescale_betas = True
  481.                 cosines_checkpoint, losses_checkpoint, norms_checkpoint = plot_batch_size_vs_weights_delta_similarity(
  482.                     dataset, opt, train_cameras, background, pipe, checkpoint, keys,
  483.                     checkpoints_list=checkpoints_list, batch_sizes=batch_sizes,
  484.                     run_epochs=run_epochs, warmup_epochs=warmup_epochs,
  485.                     rescale_betas=rescale_betas,
  486.                     disable_momentum=disable_momentum,
  487.                     lr_scaling=lr_scaling)
  488.                 for i in range(len(checkpoints_list)):
  489.                     plot(cosines_checkpoint[i], losses_checkpoint[i], norms_checkpoint[i], keys, checkpoints_list[i],
  490.                          batch_sizes, rescale_betas, lr_scaling, warmup_epochs, disable_momentum)
  491.  
  492.     quit()
  493.  
  494.     for iteration in range(first_iter, opt.iterations + 1):
  495.         if network_gui.conn == None:
  496.             network_gui.try_connect()
  497.         while network_gui.conn != None:
  498.             try:
  499.                 net_image_bytes = None
  500.                 custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
  501.                 if custom_cam != None:
  502.                     net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
  503.                     net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2,
  504.                                                                                                                0).contiguous().cpu().numpy())
  505.                 network_gui.send(net_image_bytes, dataset.source_path)
  506.                 if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
  507.                     break
  508.             except Exception as e:
  509.                 network_gui.conn = None
  510.  
  511.         iter_start.record()
  512.  
  513.         gaussians.update_learning_rate(iteration)
  514.  
  515.         # Every 1000 its we increase the levels of SH up to a maximum degree
  516.         if iteration % 1000 == 0:
  517.             gaussians.oneupSHdegree()
  518.  
  519.         # Pick a random Camera
  520.         if not viewpoint_stack:
  521.             viewpoint_stack = scene.getTrainCameras().copy()
  522.             epoch_change = True
  523.             n_epochs += 1
  524.         else:
  525.             epoch_change = False
  526.         viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
  527.  
  528.         if epoch_change and False:
  529.             print('epoch ', n_epochs)
  530.  
  531.         # Render
  532.         if (iteration - 1) == debug_from:
  533.             pipe.debug = True
  534.  
  535.         bg = torch.rand((3), device="cuda") if opt.random_background else background
  536.  
  537.         render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
  538.         image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], \
  539.             render_pkg["visibility_filter"], render_pkg["radii"]
  540.  
  541.         # Loss
  542.         gt_image = viewpoint_cam.original_image.cuda()
  543.         Ll1 = l1_loss(image, gt_image)
  544.         loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
  545.         loss.backward()
  546.  
  547.         iter_end.record()
  548.  
  549.         with torch.no_grad():
  550.             # Progress bar
  551.             ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
  552.             if iteration % 10 == 0:
  553.                 progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
  554.                 progress_bar.update(10)
  555.             if iteration == opt.iterations:
  556.                 progress_bar.close()
  557.  
  558.             # Log and save
  559.             # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end),
  560.             #                 testing_iterations, scene, render, (pipe, background))
  561.             # if (iteration in saving_iterations):
  562.             #     print("\n[ITER {}] Saving Gaussians".format(iteration))
  563.             #     scene.save(iteration)
  564.  
  565.             # # Densification
  566.             # if iteration < opt.densify_until_iter:
  567.             #     # Keep track of max radii in image-space for pruning
  568.             #     gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter],
  569.             #                                                          radii[visibility_filter])
  570.             #     gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
  571.             #
  572.             #     if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
  573.             #         size_threshold = 20 if iteration > opt.opacity_reset_interval else None
  574.             #         gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
  575.             #
  576.             #     if iteration % opt.opacity_reset_interval == 0 or (
  577.             #             dataset.white_background and iteration == opt.densify_from_iter):
  578.             #         gaussians.reset_opacity()
  579.  
  580.             # Optimizer step
  581.             if iteration < opt.iterations:
  582.                 gaussians.optimizer.step()
  583.                 gaussians.optimizer.zero_grad(set_to_none=True)
  584.  
  585.             # if (iteration in checkpoint_iterations):
  586.             #     print("\n[ITER {}] Saving Checkpoint".format(iteration))
  587.             #     torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
  588.  
  589.  
  590. def prepare_output_and_logger(args):
  591.     if not args.model_path:
  592.         if os.getenv('OAR_JOB_ID'):
  593.             unique_str = os.getenv('OAR_JOB_ID')
  594.         else:
  595.             unique_str = str(uuid.uuid4())
  596.         args.model_path = os.path.join("/tmp/sparsity-output/", unique_str[0:10])
  597.  
  598.     # Set up output folder
  599.     print("Output folder: {}".format(args.model_path))
  600.     os.makedirs(args.model_path, exist_ok=True)
  601.     with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
  602.         cfg_log_f.write(str(Namespace(**vars(args))))
  603.  
  604.     # Create Tensorboard writer
  605.     tb_writer = None
  606.     if TENSORBOARD_FOUND:
  607.         tb_writer = SummaryWriter(args.model_path)
  608.     else:
  609.         print("Tensorboard not available: not logging progress")
  610.     return tb_writer
  611.  
  612.  
  613. def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc,
  614.                     renderArgs):
  615.     if tb_writer:
  616.         tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
  617.         tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
  618.         tb_writer.add_scalar('iter_time', elapsed, iteration)
  619.  
  620.     # Report test and samples of training set
  621.     if iteration in testing_iterations:
  622.         torch.cuda.empty_cache()
  623.         validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()},
  624.                               {'name': 'train',
  625.                                'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in
  626.                                            range(5, 30, 5)]})
  627.  
  628.         for config in validation_configs:
  629.             if config['cameras'] and len(config['cameras']) > 0:
  630.                 l1_test = 0.0
  631.                 psnr_test = 0.0
  632.                 for idx, viewpoint in enumerate(config['cameras']):
  633.                     image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
  634.                     gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
  635.                     if tb_writer and (idx < 5):
  636.                         tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name),
  637.                                              image[None], global_step=iteration)
  638.                         if iteration == testing_iterations[0]:
  639.                             tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name),
  640.                                                  gt_image[None], global_step=iteration)
  641.                     l1_test += l1_loss(image, gt_image).mean().double()
  642.                     psnr_test += psnr(image, gt_image).mean().double()
  643.                 psnr_test /= len(config['cameras'])
  644.                 l1_test /= len(config['cameras'])
  645.                 print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
  646.                 if tb_writer:
  647.                     tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
  648.                     tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
  649.  
  650.         if tb_writer:
  651.             tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
  652.             tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
  653.         torch.cuda.empty_cache()
  654.  
  655.  
  656. if __name__ == "__main__":
  657.     # Set up command line argument parser
  658.     parser = ArgumentParser(description="Training script parameters")
  659.     lp = ModelParams(parser)
  660.     op = OptimizationParams(parser)
  661.     pp = PipelineParams(parser)
  662.     save_iters = [0, 1_000, 4_500, 7_000, 11_000, 15_000, 18_000, 21_000, 24_000, 27_000, 30_000]
  663.     parser.add_argument('--ip', type=str, default="127.0.0.1")
  664.     parser.add_argument('--port', type=int, default=6009)
  665.     parser.add_argument('--debug_from', type=int, default=-1)
  666.     parser.add_argument('--detect_anomaly', action='store_true', default=False)
  667.     parser.add_argument("--test_iterations", nargs="+", type=int, default=save_iters)
  668.     parser.add_argument("--save_iterations", nargs="+", type=int, default=save_iters)
  669.     parser.add_argument("--quiet", action="store_true")
  670.     parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=save_iters)
  671.     parser.add_argument("--start_checkpoint", type=str, default=None)
  672.     args = parser.parse_args(sys.argv[1:])
  673.     args.save_iterations.append(args.iterations)
  674.  
  675.     print("Optimizing " + args.model_path)
  676.  
  677.     # Initialize system state (RNG)
  678.     safe_state(args.quiet)
  679.  
  680.     # Start GUI server, configure and run training
  681.     # network_gui.init(args.ip, args.port)
  682.     torch.autograd.set_detect_anomaly(args.detect_anomaly)
  683.     training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations,
  684.              args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
  685.  
  686.     # All done
  687.     print("\nTraining complete.")
  688.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement