SHARE
TWEET

Untitled

a guest Jan 20th, 2019 79 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #!/usr/local/anaconda3/envs/experiments/bin/python3
  2.  
  3. import torch
  4. import torch.nn as nn
  5. import sys
  6. import math
  7. from termcolor import colored
  8. import datetime
  9. import atexit
  10. import tensorboardX
  11. import os
  12.  
  13. writer = tensorboardX.SummaryWriter()
  14.  
  15. #enable CUDA
  16. if torch.cuda.is_available():
  17.     print("CUDA")
  18. else:
  19.     print("CPU")
  20.  
  21. sys.stdout.flush()
  22.  
  23. use_greedy = False
  24.  
  25. #print("\nParameters: ")
  26.  
  27. #for arg in sys.argv:
  28. #    print(arg, " ", end="")
  29. #print("\n")
  30.  
  31. class GRU_TEST(nn.Module):
  32.  
  33.     def __init__(self, size, prev, batch_size, hidden, clip):
  34.  
  35.         super(GRU_TEST, self).__init__()
  36.  
  37.         self.size = size
  38.         self.hidden_size = hidden
  39.  
  40.         self.clip = clip
  41.  
  42.         self.r = torch.nn.Linear(size*prev*batch_size+hidden, hidden)
  43.         self.z = torch.nn.Linear(size*prev*batch_size+hidden, hidden)
  44.         self.h = torch.nn.Linear(size*prev*batch_size+hidden, hidden)
  45.  
  46.         self.context = torch.zeros(hidden).cuda()
  47.  
  48.         self.tanh = torch.nn.Tanh()
  49.         self.sigmoid = torch.nn.Sigmoid()
  50.  
  51.         self.hidden = torch.zeros(hidden).cuda()
  52.  
  53.     def reset(self):
  54.  
  55.         del self.hidden
  56.         self.hidden = torch.zeros(hidden).cuda()
  57.  
  58.  
  59.     def forward(self, inp):
  60.  
  61.         h = torch.cat((inp, self.hidden))
  62.  
  63.         #process layers
  64.         rt = self.sigmoid(self.r(h))
  65.         zt = self.sigmoid(self.z(h))
  66.  
  67.         ht = torch.cat(((self.hidden * rt), inp))
  68.         ht = self.tanh(self.h(ht))
  69.  
  70.         self.hidden = (self.hidden * (1 - zt))
  71.         self.hidden = self.hidden + zt * ht
  72.  
  73.         torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip, 1)
  74.  
  75.         return self.hidden
  76.  
  77. class Model(nn.Module):
  78.  
  79.     def __init__(self, size, prev, batch_size, momentum, dropout, hidden, clip):
  80.         super(Model, self).__init__()
  81.  
  82.         self.gru1 = GRU_TEST(size, prev, batch_size, hidden, clip).cuda()
  83.         self.gru1.reset()
  84.  
  85.         self.softmax = torch.nn.modules.activation.LogSoftmax(dim=0)
  86.         self.dropout = torch.nn.Dropout(dropout)
  87.  
  88.         self.output_decoder = torch.nn.Linear(hidden, alphabet_size*batch_size).cuda()
  89.  
  90.         self.epochs = 1
  91.         self.batches = 1
  92.         self.counter = 0
  93.         self.runs = 0
  94.  
  95.         self.loss_function = torch.nn.CrossEntropyLoss()
  96.         self.optimizer = torch.optim.Adam([
  97.             {'params': self.parameters(),
  98.              'weight_decay': 0.25}
  99.             ], lr=rate)
  100.         #self.optimizer = torch.optim.SGD(params=self.parameters(), lr=rate, momentum=momentum)
  101.  
  102.     def clear_internal_states(self):
  103.         self.gru1.reset()
  104.  
  105.         std = 1.0/math.sqrt(self.gru1.hidden_size)
  106.  
  107.         for p in self.parameters():
  108.             p.data.uniform_(-std, std)
  109.  
  110.     def forward(self, inp):
  111.  
  112.         x = torch.autograd.Variable((inp.detach()).view(-1))
  113.         x = self.gru1(x)
  114.         x = self.dropout(x)
  115.         x = self.output_decoder(x)
  116.         x = x.view(nbatches, -1)
  117.         return x
  118.  
  119. def splash(a):
  120.     if a:
  121.         print("GRU Text Generator\nUsage:\n\n-f --filename: filename of input text - required\n-h --hidden: number of hidden layers, default 1\n-r --rate: learning rate\n-p --prev: number of previous states to observe, default 0.05")
  122.         print("\nExample usage: <command> -f input.txt -h 5 -r 0.025")
  123.     else:
  124.         print("\nGRU Text Generator\n")
  125.         print("Alphabet size: {}".format(alphabet_size))
  126.  
  127.         print("Hyperparameters:")
  128.         params = sys.argv
  129.         params.pop(0)
  130.        
  131.         for a in list(params):
  132.             print(a, " ",end="")
  133.  
  134.         os.system("clear")
  135.         print("\n")
  136.         print(datetime.datetime.now())
  137.         print("\n")
  138.  
  139. def getIndexFromLetter(letter, list):
  140.     return list.index(letter)
  141.  
  142. def getLetterFromIndex(i, list):
  143.     return list[i]
  144.  
  145. def parse(args, arg):
  146.  
  147.     for i in range(len(args)):
  148.         if args[i] in arg:
  149.             if len(args) < i+1:
  150.                 return ""
  151.             if args[i+1].startswith("-"):
  152.                 splash(True)
  153.             else:
  154.                 return args[i+1]
  155.  
  156.     return False
  157.  
  158. def savemodel():
  159.  
  160.     print("Save model parameters? [y/n]➡")
  161.     filename_input = input()
  162.  
  163.     if filename_input == 'y' or filename_input == 'Y' or filename_input.lower() == 'yes':
  164.         filename = "Model-" + str(datetime.datetime.now()).replace(" ", "_")
  165.         print("Save as filename [default: {}]➡".format(filename))
  166.  
  167.         filename_input = input()
  168.         if not filename_input == "":
  169.             filename = "Model-" + str(filename_input).replace(" ", "_")
  170.  
  171.         print("Saving model as {}...".format(filename))
  172.         modelname = "./models/{}".format(filename)
  173.  
  174.         torch.save({
  175.             'model_state_dict': model.state_dict(),
  176.             'optimizer_state_dict': model.optimizer.state_dict()
  177.         }, modelname)
  178.  
  179.     quit()
  180.  
  181. def loadmodel():
  182.     #load model parameters if checkpoint specified
  183.     if not model_filename == False:
  184.         try:
  185.             checkpoint = torch.load(model_filename)
  186.             model.load_state_dict(checkpoint['model_state_dict'])
  187.             model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  188.         except FileNotFoundError:
  189.             print("Model not found.")
  190.             quit()
  191.     else:
  192.         print("New model")
  193.  
  194. atexit.register(savemodel)
  195. model_filename = None
  196.  
  197. try:
  198.     model_filename = parse(sys.argv, ["--load", "-l"])
  199.     filename = parse(sys.argv, ["--filename", "-f"])
  200.     if not filename or filename == "":
  201.         splash()
  202.     rate = float(parse(sys.argv, ["--rate", "-r"]))
  203.     if not rate or rate == "":
  204.         rate = 0.05
  205.     hidden = int(parse(sys.argv, ["--hidden", "-h"]))
  206.     if not hidden or hidden == "":
  207.         hidden = 1
  208.     clip = float(parse(sys.argv, ["--clip", "-c"]))
  209.     if not clip:
  210.         clip = 1
  211.     epochs = int(parse(sys.argv, ["--epoch", "-e"]))
  212.     if not epochs:
  213.         epochs = 1
  214.     nbatches = int(parse(sys.argv, ["--batch", "-b"]))
  215.     if not nbatches:
  216.         nbatches = 1
  217.     momentum = float(parse(sys.argv, ["--momentum", "-m"]))
  218.     if not momentum:
  219.         momentum = 0.4
  220.     n_prev = int(parse(sys.argv, ["--previous", "-p"]))
  221.     if not n_prev:
  222.         n_prev = 5
  223.     dropout = float(parse(sys.argv, ["--dropout", "-d"]))
  224.     if not dropout:
  225.         dropout = 0.0
  226.     sample_size = int(parse(sys.argv, ["--sample", "-s"]))
  227.     if not sample_size:
  228.         sample_size = 1000
  229.     temperature = float(parse(sys.argv, ["--temperature", "-t"]))
  230.     if not temperature:
  231.         temperature = 1.0
  232.  
  233. except:
  234.     splash(True)
  235.     quit()
  236.  
  237. #open file
  238. master_words = []
  239. alphabet = []
  240. #writer = SummaryWriter()
  241. text = []
  242. e = 0
  243. c = 0
  244. randomness = 0
  245.  
  246. with open(filename, "r") as f:
  247.     # reads all lines and removes non alphabet words
  248.     intext = f.read()
  249.  
  250. for l in list(intext):
  251.     if l == "\n": l = "¶"
  252.     if l == "\x1b": print("XXX")
  253.     text.append(l)
  254.  
  255. for l in text:
  256.     sys.stdout.flush()
  257.  
  258.     if l not in alphabet:
  259.         alphabet.append(l)
  260.         print("\r{}% - {}/{}".format(int(c/len(text)*100), c, len(text)), end="")
  261.     c+=1
  262.  
  263. epochs = 1
  264. alphabet_size = len(alphabet)
  265.  
  266. splash(False)
  267. model = Model(alphabet_size, n_prev, nbatches, momentum, dropout, hidden, clip).cuda()
  268.  
  269. nchars = len(text)
  270.  
  271. graph_time = 0
  272.  
  273. mem_max = torch.cuda.max_memory_allocated()
  274. field = [" " for _ in range(10)]
  275.  
  276. def one_hot(char):
  277.     output = torch.zeros(alphabet_size).cuda()
  278.     output[alphabet.index(char)] = 1
  279.  
  280.     return output
  281. def get_input_vector(chars):
  282.     out = []
  283.  
  284.     for b in range(nbatches):
  285.         batch = []
  286.  
  287.         for c in chars:
  288.             batch.append(one_hot(c))
  289.  
  290.         o = torch.stack(batch)
  291.         out.append(o)
  292.  
  293.     out = torch.stack(out).cuda()
  294.     return out
  295. def get_output(inp):
  296.  
  297.     inp = torch.nn.Softmax(dim=1)(inp / temperature)
  298.     sample = torch.multinomial(inp, 1)[:]
  299.  
  300.     return alphabet[sample]
  301.     model.counter = n_prev
  302. while True:
  303.  
  304.     t = 0
  305.  
  306.     variation = []
  307.     done = False
  308.     total_loss = 0
  309.     total_time = 0
  310.  
  311.     mem = 0
  312.     loss = 0
  313.  
  314.     model.optimizer.zero_grad()
  315.     model.clear_internal_states()
  316.  
  317.     steps = 2500
  318.     model.runs += 1
  319.  
  320.     field = [text[model.counter+x] for x in range(n_prev)]
  321.     start = datetime.datetime.timestamp(datetime.datetime.now())
  322.  
  323.     while t < steps:
  324.         #get target char
  325.         target = alphabet.index(text[(model.counter+1)%len(text)])
  326.  
  327.         #make target vector
  328.         target = [target for _ in range(nbatches)]
  329.         target = torch.tensor(target).cuda()
  330.  
  331.         #forward pass
  332.         inp = get_input_vector(field)
  333.         out = model.forward(inp)
  334.  
  335.         #get outputs
  336.         char = []
  337.         for o in out.split(1):
  338.             a = get_output(o)
  339.             char.append(a)
  340.  
  341.         f = ''.join(str(e) for e in field)
  342.         c = ''.join(str(e) for e in char)
  343.  
  344.         progress = int(100*(t / steps))
  345.  
  346.         txt = colored("\r ▲ {} | Forward prop... | Progress: {}% | Epoch: {} | Batch: {} | {} | {} | ...".format(model.runs, progress, model.epochs, model.batches, f, c),
  347.                       attrs=['reverse'])
  348.         print(txt,end="")
  349.  
  350.         model.counter += 1
  351.         l = model.loss_function(out, target)
  352.         writer.add_scalar('loss', l, model.counter)
  353.         loss += l
  354.  
  355.         t += 1
  356.  
  357.         if model.counter > len(text):
  358.             model.epochs += 1
  359.             model.batches = 0
  360.             model.counter = 0
  361.             print("\nxz")
  362.  
  363.         field.append(alphabet[target[0]])
  364.         field.pop(0)
  365.  
  366.     end = datetime.datetime.timestamp(datetime.datetime.now())
  367.     total_time = int(end - start)
  368.  
  369.     writer.add_scalar('time', torch.tensor(total_time), model.runs)
  370.  
  371.     print("\nBackward prop...")
  372.     sys.stdout.flush()
  373.  
  374.     model.batches += 1
  375.     writer.add_scalar('total_loss', int(loss), model.counter)
  376.     loss.backward()
  377.     model.optimizer.step()
  378.  
  379.     field = [text[x] for x in range(n_prev)]
  380.  
  381.     if model.runs % 1 == 0:
  382.         variety = []
  383.  
  384.         print("\nGenerating...\n")
  385.         for i in range(1000):
  386.  
  387.             inp = get_input_vector(field)
  388.             out = model.forward(inp)
  389.  
  390.             char = []
  391.             for o in out.split(1):
  392.                 #a = get_output(o)
  393.                 a = alphabet[torch.argmax(o)]
  394.                 char.append(a)
  395.  
  396.             print(char[0], end="")
  397.             sys.stdout.flush()
  398.  
  399.             if char[0] not in variety:
  400.                 variety.append(char[0])
  401.  
  402.             field.append(char[0])
  403.             field.pop(0)
  404.  
  405.         variety = int(100*(len(variety)/alphabet_size))
  406.         print("\nVariety: {}\n".format(variety))
  407.         writer.add_scalar('variety', variety, model.runs)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top