Advertisement
Guest User

Untitled

a guest
Apr 26th, 2022
108
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.03 KB | None | 0 0
  1. import os
  2. import random
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.optim import Adam
  8. from torchvision.utils import save_image
  9. from tensorboardX import SummaryWriter
  10. from dataloader import get_loader
  11. from models.image_encoder import ImageEncoder
  12. from models.image_decoder import ImageDecoder
  13. from models.modality_fusion import ModalityFusion
  14. from models.vgg_perceptual_loss import VGGPerceptualLoss
  15. from models.svg_decoder import SVGLSTMDecoder, SVGMDNTop
  16. from models.svg_encoder import SVGLSTMEncoder
  17. from models import util_funcs
  18. from options import get_parser_main_model
  19. from data_utils.svg_utils import render
  20. from models.imgsr.modules import TrainOptions, create_model
  21.  
  22. opts = get_parser_main_model().parse_args()
  23. opts.experiment_name = opts.experiment_name + '_' + opts.model_name
  24. os.makedirs("experiments", exist_ok=True)
  25.  
  26. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  27. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  28.  
  29.  
  30. def get_z(temperature=1.0):
  31.     z = torch.randn(1, 128).to(device) * temperature
  32.     return z
  33.  
  34. def get_img_feat(temperature=0.5):
  35.     z = torch.randn(1, 1024).to(device) * temperature
  36.     return z
  37. def get_enc_z(temperature=0.5):
  38.     z = torch.randn(4, 1024).to(device) * temperature
  39.     return z
  40.  
  41.  
  42. svg_dir = "output"
  43.  
  44. exp_dir = os.path.join("experiments", opts.experiment_name)
  45. ckpt_dir = os.path.join(exp_dir, "checkpoints")
  46. res_dir = os.path.join(exp_dir, "results")
  47.  
  48. svg_decoder = SVGLSTMDecoder(char_categories=opts.char_categories,
  49.                              bottleneck_bits=opts.bottleneck_bits, mode=opts.mode, max_sequence_length=opts.max_seq_len,
  50.                              hidden_size=opts.hidden_size,
  51.                              num_hidden_layers=opts.num_hidden_layers,
  52.                              feature_dim=opts.seq_feature_dim, ff_dropout=opts.ff_dropout, rec_dropout=opts.rec_dropout)
  53.  
  54. mdn_top_layer = SVGMDNTop(num_mixture=opts.num_mixture, seq_len=opts.max_seq_len, hidden_size=opts.hidden_size,
  55.                           mode=opts.mode, mix_temperature=opts.mix_temperature,
  56.                           gauss_temperature=opts.gauss_temperature, dont_reduce=opts.dont_reduce_loss)
  57.  
  58. modality_fusion = ModalityFusion(img_feat_dim = 16 * opts.image_size, hidden_size = opts.hidden_size, ref_nshot = opts.ref_nshot, bottleneck_bits = opts.bottleneck_bits, mode=opts.mode)
  59.  
  60.  
  61. # load parameters
  62. epoch = opts.test_epoch
  63.  
  64. svg_decoder_fpath = os.path.join(ckpt_dir, f"{opts.model_name}_{epoch}.seqdec.pth")
  65. svg_decoder.load_state_dict(torch.load(svg_decoder_fpath, map_location=device))
  66. svg_decoder.eval()
  67.  
  68.  
  69. mdn_top_layer_fpath = os.path.join(ckpt_dir, f"{opts.model_name}_{epoch}.mdntl.pth")
  70. mdn_top_layer.load_state_dict(torch.load(mdn_top_layer_fpath, map_location=device))
  71. mdn_top_layer.eval()
  72.  
  73.  
  74. modality_fusion_fpath = os.path.join(ckpt_dir, f"{opts.model_name}_{epoch}.modalfuse.pth")
  75. modality_fusion.load_state_dict(torch.load(modality_fusion_fpath, map_location=device))
  76. modality_fusion.eval()
  77.  
  78. # to device
  79. modality_fusion = modality_fusion.to(device)
  80. svg_decoder = svg_decoder.to(device)
  81. mdn_top_layer = mdn_top_layer.to(device)
  82.  
  83. val_img_l1_loss = 0.0
  84. val_img_pt_loss = 0.0
  85. mean = np.load('./data/mean.npz')
  86. std = np.load('./data/stdev.npz')
  87.  
  88. mean = torch.from_numpy(mean).to(device).to(torch.float32)
  89. std = torch.from_numpy(std).to(device).to(torch.float32)
  90.  
  91. latent_feat = get_z()
  92. latent_feat = latent_feat.repeat(opts.char_categories, 1)
  93.  
  94. trg_char = F.one_hot(torch.arange(52), num_classes=opts.char_categories)
  95. sd_init_state = svg_decoder.init_state_input(latent_feat, trg_char)
  96. hidden, cell = sd_init_state['hidden'], sd_init_state['cell']
  97. hidden_self, cell_self = hidden, cell
  98. # outputs_self = torch.zeros(trg_seq.size(0), trg_seq.size(1), opts.hidden_size).to(device)
  99.  
  100.  
  101. tgt_len = 51
  102. trg_seq_size_1 = 52
  103. sampled_svg = torch.zeros(tgt_len, trg_seq_size_1, opts.seq_feature_dim).to(device)
  104. for t in range(0, tgt_len):
  105.     # self sample results
  106.     if t == 0:
  107.         inpt_self = torch.zeros(trg_seq_size_1, opts.seq_feature_dim).to(device)
  108.     else:
  109.         inpt_self = sampled_svg[t - 1]
  110.     decoder_output_self = svg_decoder(inpt_self, hidden_self, cell_self)
  111.     output_self, hidden_self, cell_self = decoder_output_self['output'], decoder_output_self['hidden'], \
  112.                                           decoder_output_self['cell']
  113.     top_output_self = mdn_top_layer(output_self)
  114.     sampled_step = mdn_top_layer.sample(top_output_self, output_self, opts.mode)
  115.     #print(sampled_step.size())
  116.     sampled_svg[t] = sampled_step
  117.  
  118. svg_dec_out = sampled_svg.clone().detach()
  119. svg_dec_out = svg_dec_out.transpose(0, 1)
  120. svg_dec_out = svg_dec_out * std + mean
  121.  
  122. for i, one_seq in enumerate(svg_dec_out):
  123.     syn_svg_outfile = os.path.join(svg_dir, f"generated{i}.svg")
  124.     syn_svg_f = open(syn_svg_outfile, 'w')
  125.     try:
  126.         svg = render(one_seq.cpu().numpy())
  127.         syn_svg_f.write(svg)
  128.     except:
  129.         continue
  130.     syn_svg_f.close()
  131.  
  132. syn_svg_f.close()
  133.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement