Advertisement
Guest User

Untitled

a guest
May 26th, 2019
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.20 KB | None | 0 0
  1. import os
  2. import json
  3. import time
  4. import torch
  5. import argparse
  6. import numpy as np
  7. from multiprocessing import cpu_count
  8. from tensorboardX import SummaryWriter
  9. from torch.utils.data import DataLoader
  10. from collections import OrderedDict, defaultdict
  11.  
  12. from ptb import PTB
  13. from utils import to_var, idx2word, expierment_name
  14. from model import SentenceVAE
  15.  
  16. def main(args):
  17.  
  18. ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())
  19.  
  20. splits = ['train', 'valid'] + (['test'] if args.test else [])
  21.  
  22. datasets = OrderedDict()
  23. for split in splits:
  24. datasets[split] = PTB(
  25. data_dir=args.data_dir,
  26. split=split,
  27. create_data=args.create_data,
  28. max_sequence_length=args.max_sequence_length,
  29. min_occ=args.min_occ
  30. )
  31.  
  32. model = SentenceVAE(
  33. vocab_size=datasets['train'].vocab_size,
  34. sos_idx=datasets['train'].sos_idx,
  35. eos_idx=datasets['train'].eos_idx,
  36. pad_idx=datasets['train'].pad_idx,
  37. unk_idx=datasets['train'].unk_idx,
  38. max_sequence_length=args.max_sequence_length,
  39. embedding_size=args.embedding_size,
  40. rnn_type=args.rnn_type,
  41. hidden_size=args.hidden_size,
  42. word_dropout=args.word_dropout,
  43. embedding_dropout=args.embedding_dropout,
  44. latent_size=args.latent_size,
  45. num_layers=args.num_layers,
  46. bidirectional=args.bidirectional
  47. )
  48.  
  49. if torch.cuda.is_available():
  50. model = model.cuda()
  51.  
  52. print(model)
  53.  
  54. if args.tensorboard_logging:
  55. writer = SummaryWriter(os.path.join(args.logdir, expierment_name(args,ts)))
  56. writer.add_text("model", str(model))
  57. writer.add_text("args", str(args))
  58. writer.add_text("ts", ts)
  59.  
  60. save_model_path = os.path.join(args.save_model_path, ts)
  61. os.makedirs(save_model_path)
  62.  
  63. def kl_anneal_function(anneal_function, step, k, x0):
  64. if anneal_function == 'logistic':
  65. return float(1/(1+np.exp(-k*(step-x0))))
  66. elif anneal_function == 'linear':
  67. return min(1, step/x0)
  68.  
  69. NLL = torch.nn.NLLLoss(size_average=False, ignore_index=datasets['train'].pad_idx)
  70. def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0):
  71.  
  72. # cut-off unnecessary padding from target, and flatten
  73. target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
  74. logp = logp.view(-1, logp.size(2))
  75.  
  76. # Negative Log Likelihood
  77. NLL_loss = NLL(logp, target)
  78.  
  79. # KL Divergence
  80. KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
  81. KL_weight = kl_anneal_function(anneal_function, step, k, x0)
  82.  
  83. return NLL_loss, KL_loss, KL_weight
  84.  
  85. # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
  86.  
  87. encoder_optimizer = torch.optim.Adam(model.encoder_rnn.parameters(), lr=args.learning_rate)
  88. decoder_optimizer = torch.optim.Adam(model.decoder_rnn.parameters(), lr=args.learning_rate)
  89. sub_batch = 100
  90. flag = True
  91.  
  92. tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
  93. step = 0
  94. for epoch in range(args.epochs):
  95.  
  96. for split in splits:
  97.  
  98. data_loader = DataLoader(
  99. dataset=datasets[split],
  100. batch_size=args.batch_size,
  101. shuffle=split=='train',
  102. num_workers=cpu_count(),
  103. pin_memory=torch.cuda.is_available()
  104. )
  105.  
  106. tracker = defaultdict(tensor)
  107.  
  108. # Enable/Disable Dropout
  109. if split == 'train':
  110. model.train()
  111. else:
  112. model.eval()
  113.  
  114. for iteration, batch in enumerate(data_loader):
  115.  
  116. batch_size = batch['input'].size(0)
  117.  
  118. for k, v in batch.items():
  119. if torch.is_tensor(v):
  120. batch[k] = to_var(v)
  121.  
  122. # Forward pass
  123. logp, mean, logv, z = model(batch['input'], batch['length'])
  124.  
  125. # loss calculation
  126. NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
  127. batch['length'], mean, logv, args.anneal_function, step, args.k, args.x0)
  128.  
  129. if split != 'train':
  130. KL_weight = 1
  131.  
  132. loss = (NLL_loss + KL_weight * KL_loss)/batch_size
  133.  
  134. # backward + optimization
  135. if split == 'train':
  136. if flag:
  137. encoder_optimizer.zero_grad()
  138. else:
  139. decoder_optimizer.zero_grad()
  140. loss.backward()
  141. if flag:
  142. encoder_optimizer.step()
  143. else:
  144. decoder_optimizer.step()
  145. step += 1
  146. if step % sub_batch == 0:
  147. flag = not flag
  148.  
  149. # optimizer.zero_grad()
  150. # loss.backward()
  151. # optimizer.step()
  152. # step += 1
  153.  
  154.  
  155. # bookkeepeing
  156. tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data))
  157.  
  158. if args.tensorboard_logging:
  159. writer.add_scalar("%s/ELBO"%split.upper(), loss.data[0], epoch*len(data_loader) + iteration)
  160. writer.add_scalar("%s/NLL Loss"%split.upper(), NLL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
  161. writer.add_scalar("%s/KL Loss"%split.upper(), KL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
  162. writer.add_scalar("%s/KL Weight"%split.upper(), KL_weight, epoch*len(data_loader) + iteration)
  163.  
  164. if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
  165. print("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
  166. %(split.upper(), iteration, len(data_loader)-1, loss.data[0], NLL_loss.data[0]/batch_size, KL_loss.data[0]/batch_size, KL_weight))
  167.  
  168. if split == 'valid':
  169. if 'target_sents' not in tracker:
  170. tracker['target_sents'] = list()
  171. tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(), pad_idx=datasets['train'].pad_idx)
  172. tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)
  173.  
  174. print("%s Epoch %02d/%i, Mean ELBO %9.4f"%(split.upper(), epoch, args.epochs, torch.mean(tracker['ELBO'])))
  175.  
  176. if args.tensorboard_logging:
  177. writer.add_scalar("%s-Epoch/ELBO"%split.upper(), torch.mean(tracker['ELBO']), epoch)
  178.  
  179. # save a dump of all sentences and the encoded latent space
  180. if split == 'valid':
  181. dump = {'target_sents':tracker['target_sents'], 'z':tracker['z'].tolist()}
  182. if not os.path.exists(os.path.join('dumps', ts)):
  183. os.makedirs('dumps/'+ts)
  184. with open(os.path.join('dumps/'+ts+'/valid_E%i.json'%epoch), 'w') as dump_file:
  185. json.dump(dump,dump_file)
  186.  
  187. # save checkpoint
  188. if split == 'train':
  189. checkpoint_path = os.path.join(save_model_path, "E%i.pytorch"%(epoch))
  190. torch.save(model.state_dict(), checkpoint_path)
  191. print("Model saved at %s"%checkpoint_path)
  192.  
  193.  
  194. if __name__ == '__main__':
  195.  
  196. parser = argparse.ArgumentParser()
  197.  
  198. parser.add_argument('--data_dir', type=str, default='data')
  199. parser.add_argument('--create_data', action='store_true')
  200. parser.add_argument('--max_sequence_length', type=int, default=60)
  201. parser.add_argument('--min_occ', type=int, default=1)
  202. parser.add_argument('--test', action='store_true')
  203.  
  204. parser.add_argument('-ep', '--epochs', type=int, default=10)
  205. parser.add_argument('-bs', '--batch_size', type=int, default=32)
  206. parser.add_argument('-lr', '--learning_rate', type=float, default=0.001)
  207.  
  208. parser.add_argument('-eb', '--embedding_size', type=int, default=300)
  209. parser.add_argument('-rnn', '--rnn_type', type=str, default='gru')
  210. parser.add_argument('-hs', '--hidden_size', type=int, default=256)
  211. parser.add_argument('-nl', '--num_layers', type=int, default=1)
  212. parser.add_argument('-bi', '--bidirectional', action='store_true')
  213. parser.add_argument('-ls', '--latent_size', type=int, default=16)
  214. parser.add_argument('-wd', '--word_dropout', type=float, default=0)
  215. parser.add_argument('-ed', '--embedding_dropout', type=float, default=0.5)
  216.  
  217. parser.add_argument('-af', '--anneal_function', type=str, default='logistic')
  218. parser.add_argument('-k', '--k', type=float, default=0.0025)
  219. parser.add_argument('-x0', '--x0', type=int, default=2500)
  220.  
  221. parser.add_argument('-v','--print_every', type=int, default=50)
  222. parser.add_argument('-tb','--tensorboard_logging', action='store_true')
  223. parser.add_argument('-log','--logdir', type=str, default='logs')
  224. parser.add_argument('-bin','--save_model_path', type=str, default='bin')
  225.  
  226. args = parser.parse_args()
  227.  
  228. args.rnn_type = args.rnn_type.lower()
  229. args.anneal_function = args.anneal_function.lower()
  230.  
  231. assert args.rnn_type in ['rnn', 'lstm', 'gru']
  232. assert args.anneal_function in ['logistic', 'linear']
  233. assert 0 <= args.word_dropout <= 1
  234.  
  235. main(args)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement