Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import json
- import time
- import torch
- import argparse
- import numpy as np
- from multiprocessing import cpu_count
- from tensorboardX import SummaryWriter
- from torch.utils.data import DataLoader
- from collections import OrderedDict, defaultdict
- from ptb import PTB
- from utils import to_var, idx2word, expierment_name
- from model import SentenceVAE
- def main(args):
- ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())
- splits = ['train', 'valid'] + (['test'] if args.test else [])
- datasets = OrderedDict()
- for split in splits:
- datasets[split] = PTB(
- data_dir=args.data_dir,
- split=split,
- create_data=args.create_data,
- max_sequence_length=args.max_sequence_length,
- min_occ=args.min_occ
- )
- model = SentenceVAE(
- vocab_size=datasets['train'].vocab_size,
- sos_idx=datasets['train'].sos_idx,
- eos_idx=datasets['train'].eos_idx,
- pad_idx=datasets['train'].pad_idx,
- unk_idx=datasets['train'].unk_idx,
- max_sequence_length=args.max_sequence_length,
- embedding_size=args.embedding_size,
- rnn_type=args.rnn_type,
- hidden_size=args.hidden_size,
- word_dropout=args.word_dropout,
- embedding_dropout=args.embedding_dropout,
- latent_size=args.latent_size,
- num_layers=args.num_layers,
- bidirectional=args.bidirectional
- )
- if torch.cuda.is_available():
- model = model.cuda()
- print(model)
- if args.tensorboard_logging:
- writer = SummaryWriter(os.path.join(args.logdir, expierment_name(args,ts)))
- writer.add_text("model", str(model))
- writer.add_text("args", str(args))
- writer.add_text("ts", ts)
- save_model_path = os.path.join(args.save_model_path, ts)
- os.makedirs(save_model_path)
- def kl_anneal_function(anneal_function, step, k, x0):
- if anneal_function == 'logistic':
- return float(1/(1+np.exp(-k*(step-x0))))
- elif anneal_function == 'linear':
- return min(1, step/x0)
- NLL = torch.nn.NLLLoss(size_average=False, ignore_index=datasets['train'].pad_idx)
- def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0):
- # cut-off unnecessary padding from target, and flatten
- target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
- logp = logp.view(-1, logp.size(2))
- # Negative Log Likelihood
- NLL_loss = NLL(logp, target)
- # KL Divergence
- KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
- KL_weight = kl_anneal_function(anneal_function, step, k, x0)
- return NLL_loss, KL_loss, KL_weight
- # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
- encoder_optimizer = torch.optim.Adam(model.encoder_rnn.parameters(), lr=args.learning_rate)
- decoder_optimizer = torch.optim.Adam(model.decoder_rnn.parameters(), lr=args.learning_rate)
- sub_batch = 100
- flag = True
- tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
- step = 0
- for epoch in range(args.epochs):
- for split in splits:
- data_loader = DataLoader(
- dataset=datasets[split],
- batch_size=args.batch_size,
- shuffle=split=='train',
- num_workers=cpu_count(),
- pin_memory=torch.cuda.is_available()
- )
- tracker = defaultdict(tensor)
- # Enable/Disable Dropout
- if split == 'train':
- model.train()
- else:
- model.eval()
- for iteration, batch in enumerate(data_loader):
- batch_size = batch['input'].size(0)
- for k, v in batch.items():
- if torch.is_tensor(v):
- batch[k] = to_var(v)
- # Forward pass
- logp, mean, logv, z = model(batch['input'], batch['length'])
- # loss calculation
- NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
- batch['length'], mean, logv, args.anneal_function, step, args.k, args.x0)
- if split != 'train':
- KL_weight = 1
- loss = (NLL_loss + KL_weight * KL_loss)/batch_size
- # backward + optimization
- if split == 'train':
- if flag:
- encoder_optimizer.zero_grad()
- else:
- decoder_optimizer.zero_grad()
- loss.backward()
- if flag:
- encoder_optimizer.step()
- else:
- decoder_optimizer.step()
- step += 1
- if step % sub_batch == 0:
- flag = not flag
- # optimizer.zero_grad()
- # loss.backward()
- # optimizer.step()
- # step += 1
- # bookkeepeing
- tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data))
- if args.tensorboard_logging:
- writer.add_scalar("%s/ELBO"%split.upper(), loss.data[0], epoch*len(data_loader) + iteration)
- writer.add_scalar("%s/NLL Loss"%split.upper(), NLL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
- writer.add_scalar("%s/KL Loss"%split.upper(), KL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
- writer.add_scalar("%s/KL Weight"%split.upper(), KL_weight, epoch*len(data_loader) + iteration)
- if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
- print("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
- %(split.upper(), iteration, len(data_loader)-1, loss.data[0], NLL_loss.data[0]/batch_size, KL_loss.data[0]/batch_size, KL_weight))
- if split == 'valid':
- if 'target_sents' not in tracker:
- tracker['target_sents'] = list()
- tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(), pad_idx=datasets['train'].pad_idx)
- tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)
- print("%s Epoch %02d/%i, Mean ELBO %9.4f"%(split.upper(), epoch, args.epochs, torch.mean(tracker['ELBO'])))
- if args.tensorboard_logging:
- writer.add_scalar("%s-Epoch/ELBO"%split.upper(), torch.mean(tracker['ELBO']), epoch)
- # save a dump of all sentences and the encoded latent space
- if split == 'valid':
- dump = {'target_sents':tracker['target_sents'], 'z':tracker['z'].tolist()}
- if not os.path.exists(os.path.join('dumps', ts)):
- os.makedirs('dumps/'+ts)
- with open(os.path.join('dumps/'+ts+'/valid_E%i.json'%epoch), 'w') as dump_file:
- json.dump(dump,dump_file)
- # save checkpoint
- if split == 'train':
- checkpoint_path = os.path.join(save_model_path, "E%i.pytorch"%(epoch))
- torch.save(model.state_dict(), checkpoint_path)
- print("Model saved at %s"%checkpoint_path)
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_dir', type=str, default='data')
- parser.add_argument('--create_data', action='store_true')
- parser.add_argument('--max_sequence_length', type=int, default=60)
- parser.add_argument('--min_occ', type=int, default=1)
- parser.add_argument('--test', action='store_true')
- parser.add_argument('-ep', '--epochs', type=int, default=10)
- parser.add_argument('-bs', '--batch_size', type=int, default=32)
- parser.add_argument('-lr', '--learning_rate', type=float, default=0.001)
- parser.add_argument('-eb', '--embedding_size', type=int, default=300)
- parser.add_argument('-rnn', '--rnn_type', type=str, default='gru')
- parser.add_argument('-hs', '--hidden_size', type=int, default=256)
- parser.add_argument('-nl', '--num_layers', type=int, default=1)
- parser.add_argument('-bi', '--bidirectional', action='store_true')
- parser.add_argument('-ls', '--latent_size', type=int, default=16)
- parser.add_argument('-wd', '--word_dropout', type=float, default=0)
- parser.add_argument('-ed', '--embedding_dropout', type=float, default=0.5)
- parser.add_argument('-af', '--anneal_function', type=str, default='logistic')
- parser.add_argument('-k', '--k', type=float, default=0.0025)
- parser.add_argument('-x0', '--x0', type=int, default=2500)
- parser.add_argument('-v','--print_every', type=int, default=50)
- parser.add_argument('-tb','--tensorboard_logging', action='store_true')
- parser.add_argument('-log','--logdir', type=str, default='logs')
- parser.add_argument('-bin','--save_model_path', type=str, default='bin')
- args = parser.parse_args()
- args.rnn_type = args.rnn_type.lower()
- args.anneal_function = args.anneal_function.lower()
- assert args.rnn_type in ['rnn', 'lstm', 'gru']
- assert args.anneal_function in ['logistic', 'linear']
- assert 0 <= args.word_dropout <= 1
- main(args)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement