daily pastebin goal
38%
SHARE
TWEET

Untitled

a guest Jan 20th, 2019 61 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #!/usr/local/anaconda3/envs/experiments/bin/python3
  2.  
  3. import torch
  4. import torch.nn as nn
  5. import sys
  6. import math
  7. from termcolor import colored
  8. import datetime
  9. import atexit
  10. import tensorboardX
  11. import os
  12.  
  13. writer = tensorboardX.SummaryWriter()
  14.  
  15. #enable CUDA
  16. if torch.cuda.is_available():
  17.     print("CUDA")
  18. else:
  19.     print("CPU")
  20.  
  21. sys.stdout.flush()
  22.  
  23. use_greedy = False
  24.  
  25. #print("\nParameters: ")
  26.  
  27. #for arg in sys.argv:
  28. #    print(arg, " ", end="")
  29. #print("\n")
  30.  
  31. class GRU_TEST(nn.Module):
  32.  
  33.     def __init__(self, size, prev, batch_size, hidden, clip):
  34.  
  35.         super(GRU_TEST, self).__init__()
  36.  
  37.         self.size = size
  38.         self.hidden_size = hidden
  39.  
  40.         self.clip = clip
  41.  
  42.         self.r = torch.nn.Linear(size*prev*batch_size+hidden, hidden)
  43.         self.z = torch.nn.Linear(size*prev*batch_size+hidden, hidden)
  44.         self.h = torch.nn.Linear(size*prev*batch_size+hidden, hidden)
  45.  
  46.         self.context = torch.zeros(hidden).cuda()
  47.  
  48.         self.tanh = torch.nn.Tanh()
  49.         self.sigmoid = torch.nn.Sigmoid()
  50.  
  51.         self.hidden = torch.zeros(hidden).cuda()
  52.  
  53.     def reset(self):
  54.  
  55.         del self.hidden
  56.         self.hidden = torch.zeros(hidden).cuda()
  57.  
  58.  
  59.     def forward(self, inp):
  60.  
  61.         h = torch.cat((inp, self.hidden))
  62.  
  63.         #process layers
  64.         rt = self.sigmoid(self.r(h))
  65.         zt = self.sigmoid(self.z(h))
  66.  
  67.         ht = torch.cat(((self.hidden * rt), inp))
  68.         ht = self.tanh(self.h(ht))
  69.  
  70.         self.hidden = (self.hidden * (1 - zt))
  71.         self.hidden = self.hidden + zt * ht
  72.  
  73.         torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip, 1)
  74.  
  75.         return self.hidden
  76.  
  77. class Model(nn.Module):
  78.  
  79.     def __init__(self, size, prev, batch_size, momentum, dropout, hidden, clip):
  80.         super(Model, self).__init__()
  81.  
  82.         self.gru1 = GRU_TEST(size, prev, batch_size, hidden, clip).cuda()
  83.         self.gru1.reset()
  84.  
  85.         self.softmax = torch.nn.modules.activation.LogSoftmax(dim=0)
  86.         self.dropout = torch.nn.Dropout(dropout)
  87.  
  88.         self.output_decoder = torch.nn.Linear(hidden, alphabet_size*batch_size).cuda()
  89.  
  90.         self.epochs = 1
  91.         self.batches = 1
  92.         self.counter = 0
  93.         self.runs = 0
  94.  
  95.         self.loss_function = torch.nn.CrossEntropyLoss()
  96.         self.optimizer = torch.optim.Adam([
  97.             {'params': self.parameters(),
  98.              'weight_decay': 0.25}
  99.             ], lr=rate)
  100.         #self.optimizer = torch.optim.SGD(params=self.parameters(), lr=rate, momentum=momentum)
  101.  
  102.     def clear_internal_states(self):
  103.         self.gru1.reset()
  104.  
  105.         std = 1.0/math.sqrt(self.gru1.hidden_size)
  106.  
  107.         for p in self.parameters():
  108.             p.data.uniform_(-std, std)
  109.  
  110.     def forward(self, inp):
  111.  
  112.         x = torch.autograd.Variable((inp.detach()).view(-1))
  113.         x = self.gru1(x)
  114.         x = self.dropout(x)
  115.         x = self.output_decoder(x)
  116.         x = x.view(nbatches, -1)
  117.         return x
  118.  
  119. def splash(a):
  120.     if a:
  121.         print("GRU Text Generator\nUsage:\n\n-f --filename: filename of input text - required\n-h --hidden: number of hidden layers, default 1\n-r --rate: learning rate\n-p --prev: number of previous states to observe, default 0.05")
  122.         print("\nExample usage: <command> -f input.txt -h 5 -r 0.025")
  123.     else:
  124.         print("\nGRU Text Generator\n")
  125.         print("Alphabet size: {}".format(alphabet_size))
  126.  
  127.         print("Hyperparameters:")
  128.         for a in list(sys.argv).pop(0):
  129.             print(a, " ",end="")
  130.  
  131.         os.system("clear")
  132.         print("\n")
  133.         print(datetime.datetime.now())
  134.         print("\n")
  135.  
  136. def getIndexFromLetter(letter, list):
  137.     return list.index(letter)
  138.  
  139. def getLetterFromIndex(i, list):
  140.     return list[i]
  141.  
  142. def parse(args, arg):
  143.  
  144.     for i in range(len(args)):
  145.         if args[i] in arg:
  146.             if len(args) < i+1:
  147.                 return ""
  148.             if args[i+1].startswith("-"):
  149.                 splash(True)
  150.             else:
  151.                 return args[i+1]
  152.  
  153.     return False
  154.  
  155. def savemodel():
  156.  
  157.     print("Save model parameters? [y/n]➡")
  158.     filename_input = input()
  159.  
  160.     if filename_input == 'y' or filename_input == 'Y' or filename_input.lower() == 'yes':
  161.         filename = "Model-" + str(datetime.datetime.now()).replace(" ", "_")
  162.         print("Save as filename [default: {}]➡".format(filename))
  163.  
  164.         filename_input = input()
  165.         if not filename_input == "":
  166.             filename = "Model-" + str(filename_input).replace(" ", "_")
  167.  
  168.         print("Saving model as {}...".format(filename))
  169.         modelname = "./models/{}".format(filename)
  170.  
  171.         torch.save({
  172.             'model_state_dict': model.state_dict(),
  173.             'optimizer_state_dict': model.optimizer.state_dict()
  174.         }, modelname)
  175.  
  176.     quit()
  177.  
  178. def loadmodel():
  179.     #load model parameters if checkpoint specified
  180.     if not model_filename == False:
  181.         try:
  182.             checkpoint = torch.load(model_filename)
  183.             model.load_state_dict(checkpoint['model_state_dict'])
  184.             model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  185.         except FileNotFoundError:
  186.             print("Model not found.")
  187.             quit()
  188.     else:
  189.         print("New model")
  190.  
  191. atexit.register(savemodel)
  192. model_filename = None
  193.  
  194. try:
  195.     model_filename = parse(sys.argv, ["--load", "-l"])
  196.     filename = parse(sys.argv, ["--filename", "-f"])
  197.     if not filename or filename == "":
  198.         splash()
  199.     rate = float(parse(sys.argv, ["--rate", "-r"]))
  200.     if not rate or rate == "":
  201.         rate = 0.05
  202.     hidden = int(parse(sys.argv, ["--hidden", "-h"]))
  203.     if not hidden or hidden == "":
  204.         hidden = 1
  205.     clip = float(parse(sys.argv, ["--clip", "-c"]))
  206.     if not clip:
  207.         clip = 1
  208.     epochs = int(parse(sys.argv, ["--epoch", "-e"]))
  209.     if not epochs:
  210.         epochs = 1
  211.     nbatches = int(parse(sys.argv, ["--batch", "-b"]))
  212.     if not nbatches:
  213.         nbatches = 1
  214.     momentum = float(parse(sys.argv, ["--momentum", "-m"]))
  215.     if not momentum:
  216.         momentum = 0.4
  217.     n_prev = int(parse(sys.argv, ["--previous", "-p"]))
  218.     if not n_prev:
  219.         n_prev = 5
  220.     dropout = float(parse(sys.argv, ["--dropout", "-d"]))
  221.     if not dropout:
  222.         dropout = 0.0
  223.     sample_size = int(parse(sys.argv, ["--sample", "-s"]))
  224.     if not sample_size:
  225.         sample_size = 1000
  226.     temperature = float(parse(sys.argv, ["--temperature", "-t"]))
  227.     if not temperature:
  228.         temperature = 1.0
  229.  
  230. except:
  231.     splash(True)
  232.     quit()
  233.  
  234. #open file
  235. master_words = []
  236. alphabet = []
  237. #writer = SummaryWriter()
  238. text = []
  239. e = 0
  240. c = 0
  241. randomness = 0
  242.  
  243. with open(filename, "r") as f:
  244.     # reads all lines and removes non alphabet words
  245.     intext = f.read()
  246.  
  247. for l in list(intext):
  248.     if l == "\n": l = "¶"
  249.     if l == "\x1b": print("XXX")
  250.     text.append(l)
  251.  
  252. for l in text:
  253.     sys.stdout.flush()
  254.  
  255.     if l not in alphabet:
  256.         alphabet.append(l)
  257.         print("\r{}% - {}/{}".format(int(c/len(text)*100), c, len(text)), end="")
  258.     c+=1
  259.  
  260. epochs = 1
  261. alphabet_size = len(alphabet)
  262.  
  263. splash(False)
  264. model = Model(alphabet_size, n_prev, nbatches, momentum, dropout, hidden, clip).cuda()
  265.  
  266. nchars = len(text)
  267.  
  268. graph_time = 0
  269.  
  270. mem_max = torch.cuda.max_memory_allocated()
  271. field = [" " for _ in range(10)]
  272.  
  273. def one_hot(char):
  274.     output = torch.zeros(alphabet_size).cuda()
  275.     output[alphabet.index(char)] = 1
  276.  
  277.     return output
  278. def get_input_vector(chars):
  279.     out = []
  280.  
  281.     for b in range(nbatches):
  282.         batch = []
  283.  
  284.         for c in chars:
  285.             batch.append(one_hot(c))
  286.  
  287.         o = torch.stack(batch)
  288.         out.append(o)
  289.  
  290.     out = torch.stack(out).cuda()
  291.     return out
  292. def get_output(inp):
  293.  
  294.     inp = torch.nn.Softmax(dim=1)(inp / temperature)
  295.     sample = torch.multinomial(inp, 1)[:]
  296.  
  297.     return alphabet[sample]
  298.     model.counter = n_prev
  299. while True:
  300.  
  301.     t = 0
  302.  
  303.     variation = []
  304.     done = False
  305.     total_loss = 0
  306.     total_time = 0
  307.  
  308.     mem = 0
  309.     loss = 0
  310.  
  311.     model.optimizer.zero_grad()
  312.     model.clear_internal_states()
  313.  
  314.     steps = 2500
  315.     model.runs += 1
  316.  
  317.     field = [text[model.counter+x] for x in range(n_prev)]
  318.     start = datetime.datetime.timestamp(datetime.datetime.now())
  319.  
  320.     while t < steps:
  321.         #get target char
  322.         target = alphabet.index(text[(model.counter+1)%len(text)])
  323.  
  324.         #make target vector
  325.         target = [target for _ in range(nbatches)]
  326.         target = torch.tensor(target).cuda()
  327.  
  328.         #forward pass
  329.         inp = get_input_vector(field)
  330.         out = model.forward(inp)
  331.  
  332.         #get outputs
  333.         char = []
  334.         for o in out.split(1):
  335.             a = get_output(o)
  336.             char.append(a)
  337.  
  338.         f = ''.join(str(e) for e in field)
  339.         c = ''.join(str(e) for e in char)
  340.  
  341.         progress = int(100*(t / steps))
  342.  
  343.         txt = colored("\r ▲ {} | Forward prop... | Progress: {}% | Epoch: {} | Batch: {} | {} | {} | ...".format(model.runs, progress, model.epochs, model.batches, f, c),
  344.                       attrs=['reverse'])
  345.         print(txt,end="")
  346.  
  347.         model.counter += 1
  348.         l = model.loss_function(out, target)
  349.         writer.add_scalar('loss', l, model.counter)
  350.         loss += l
  351.  
  352.         t += 1
  353.  
  354.         if model.counter > len(text):
  355.             model.epochs += 1
  356.             model.batches = 0
  357.             model.counter = 0
  358.             print("\nxz")
  359.  
  360.         field.append(alphabet[target[0]])
  361.         field.pop(0)
  362.  
  363.     end = datetime.datetime.timestamp(datetime.datetime.now())
  364.     total_time = int(end - start)
  365.  
  366.     writer.add_scalar('time', torch.tensor(total_time), model.runs)
  367.  
  368.     print("\nBackward prop...")
  369.     sys.stdout.flush()
  370.  
  371.     model.batches += 1
  372.     writer.add_scalar('total_loss', int(loss), model.counter)
  373.     loss.backward()
  374.     model.optimizer.step()
  375.  
  376.     field = [text[x] for x in range(n_prev)]
  377.  
  378.     if model.runs % 1 == 0:
  379.         variety = []
  380.  
  381.         print("\nGenerating...\n")
  382.         for i in range(1000):
  383.  
  384.             inp = get_input_vector(field)
  385.             out = model.forward(inp)
  386.  
  387.             char = []
  388.             for o in out.split(1):
  389.                 #a = get_output(o)
  390.                 a = alphabet[torch.argmax(o)]
  391.                 char.append(a)
  392.  
  393.             print(char[0], end="")
  394.             sys.stdout.flush()
  395.  
  396.             if char[0] not in variety:
  397.                 variety.append(char[0])
  398.  
  399.             field.append(char[0])
  400.             field.pop(0)
  401.  
  402.         variety = int(100*(len(variety)/alphabet_size))
  403.         print("\nVariety: {}\n".format(variety))
  404.         writer.add_scalar('variety', variety, model.runs)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top