Advertisement
Guest User

Untitled

a guest
Jan 18th, 2019
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 11.25 KB | None | 0 0
  1. #!/usr/local/anaconda3/envs/experiments/bin/python3
  2. #!/usr/bin/python3
  3.  
  4. import torch
  5. import torch.nn as nn
  6. import time
  7. import random
  8. import sys
  9. import numpy
  10. from termcolor import colored
  11. import datetime
  12. import math
  13. import atexit
  14.  
  15. #enable CUDA
  16. if torch.cuda.is_available():
  17. print("CUDA")
  18. else:
  19. print("CPU")
  20.  
  21. sys.stdout.flush()
  22.  
  23. use_greedy = True
  24.  
  25. #print("\nParameters: ")
  26.  
  27. #for arg in sys.argv:
  28. # print(arg, " ", end="")
  29. #print("\n")
  30.  
  31. class GRU_TEST(nn.Module):
  32.  
  33. def __init__(self, size, p, hidden, clip):
  34.  
  35. super(GRU_TEST, self).__init__()
  36.  
  37. self.size = size
  38. self.hidden_size = hidden
  39.  
  40. self.clip = clip
  41.  
  42. p -= 0
  43.  
  44. self.r = torch.nn.Linear(hidden+(size*p), hidden)
  45. self.z = torch.nn.Linear(hidden+(size*p), hidden)
  46. self.h = torch.nn.Linear(hidden+(size*p), hidden)
  47.  
  48. self.context = torch.zeros(hidden).cuda()
  49.  
  50. self.tanh = torch.nn.Tanh()
  51. self.sigmoid = torch.nn.Sigmoid()
  52.  
  53. self.context = torch.zeros(hidden).cuda()
  54. self.hidden = torch.zeros(hidden).cuda()
  55.  
  56. def clear_internal_states(self):
  57. self.hidden = torch.zeros(hidden).cuda()
  58. self.outs = []
  59. self.targets = []
  60.  
  61. def forward(self, inp):
  62.  
  63. h = torch.cat((inp, self.hidden.detach()))
  64.  
  65. #process layers
  66. rt = self.sigmoid(self.r(h))
  67. zt = self.sigmoid(self.z(h))
  68.  
  69. ht = torch.cat(((self.hidden * rt), inp))
  70. ht = self.tanh(self.h(ht))
  71.  
  72. self.hidden = (self.hidden * (1 - zt))
  73. self.hidden = self.hidden + zt * ht
  74.  
  75. torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip, 1)
  76.  
  77. return self.hidden
  78. class Model(nn.Module):
  79.  
  80. def __init__(self, size, p, dropout, hidden, clip):
  81. super(Model, self).__init__()
  82.  
  83. self.gru1 = GRU_TEST(alphabet_size, n_prev, hidden, clip).float().cuda()
  84. self.gru2 = GRU_TEST(hidden, 1, hidden, clip).float().cuda()
  85.  
  86. self.softmax = torch.nn.modules.activation.LogSoftmax(dim=0)
  87. self.dropout = torch.nn.Dropout(dropout)
  88.  
  89. self.output_decoder = torch.nn.Linear(hidden, alphabet_size).cuda()
  90.  
  91. self.epochs = 1
  92. self.batches = 1
  93. self.count = 0
  94. self.counter = 0
  95.  
  96. self.outs = []
  97. self.targets = []
  98.  
  99. self.output_counter = 0
  100.  
  101. self.loss_function = torch.nn.BCEWithLogitsLoss().float()
  102. self.optimizer = torch.optim.Adam([
  103. {'params': self.parameters(),
  104. 'weight_decay': 0.25}
  105. ], lr=rate)
  106.  
  107. def clear_internal_states(self):
  108. self.gru1.clear_internal_states()
  109. self.gru2.clear_internal_states()
  110.  
  111. def forward(self, inp):
  112.  
  113. x = torch.stack(inp).view(-1).detach()
  114. x = self.gru1(x)
  115. #x = self.gru2(x)
  116. x = self.dropout(x)
  117. x = self.output_decoder(x)
  118. x = self.softmax(x)
  119.  
  120. return x
  121.  
  122. def splash(a):
  123. if a:
  124. print("GRU Text Generator\nUsage:\n\n-f --filename: filename of input text - required\n-h --hidden: number of hidden layers, default 1\n-r --rate: learning rate\n-p --prev: number of previous states to observe, default 0.05")
  125. print("\nExample usage: <command> -f input.txt -h 5 -r 0.025")
  126. else:
  127. print("\nGRU Text Generator\n")
  128. print("2019\n")
  129.  
  130. def getIndexFromLetter(letter, list):
  131. return list.index(letter)
  132.  
  133. def getLetterFromIndex(i, list):
  134. return list[i]
  135.  
  136. def parse(args, arg):
  137.  
  138. for i in range(len(args)):
  139. if args[i] in arg:
  140. if len(args) < i+1:
  141. return ""
  142. if args[i+1].startswith("-"):
  143. splash(True)
  144. else:
  145. return args[i+1]
  146.  
  147. return False
  148.  
  149. def get_output(inp, greed):
  150.  
  151. #if model.output_counter % 10 == 0:
  152. # greed = True
  153. #else:
  154. # greed = False
  155. #
  156. # model.output_counter += 1
  157.  
  158. if not greed:
  159. inp = torch.nn.Softmax(dim=0)(inp)
  160. outchar = numpy.random.choice(alphabet, p=inp.cpu().detach().numpy())
  161. i = alphabet.index(outchar)
  162. else:
  163. i = torch.argmax(inp)
  164. outchar = alphabet[i]
  165.  
  166. out = torch.zeros(len(inp)).cuda()
  167. out[i] = 1
  168. out = out * inp
  169. out[torch.argmax(out)] = 1
  170.  
  171. return outchar, out
  172.  
  173. def savemodel():
  174.  
  175. print("Save model parameters? [y/n]➡")
  176. filename_input = input()
  177.  
  178. if filename_input == 'y' or filename_input == 'Y' or filename_input.lower() == 'yes':
  179. filename = "Model-" + str(datetime.datetime.now()).replace(" ", "_")
  180. print("Save as filename [default: {}]➡".format(filename))
  181.  
  182. filename_input = input()
  183. if not filename_input == "":
  184. filename = "Model-" + str(filename_input).replace(" ", "_")
  185.  
  186. print("Saving model as {}...".format(filename))
  187. modelname = "./models/{}".format(filename)
  188.  
  189. torch.save({
  190. 'model_state_dict': model.state_dict(),
  191. 'optimizer_state_dict': model.optimizer.state_dict()
  192. }, modelname)
  193.  
  194. quit()
  195.  
  196. def loadmodel():
  197. #load model parameters if checkpoint specified
  198. if not model_filename == False:
  199. try:
  200. checkpoint = torch.load(model_filename)
  201. model.load_state_dict(checkpoint['model_state_dict'])
  202. model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  203. except FileNotFoundError:
  204. print("Model not found.")
  205. quit()
  206. else:
  207. print("New model")
  208.  
  209. atexit.register(savemodel)
  210. model_filename = None
  211.  
  212. try:
  213. model_filename = parse(sys.argv, ["--load", "-l"])
  214. filename = parse(sys.argv, ["--filename", "-f"])
  215. if not filename or filename == "":
  216. splash()
  217. rate = float(parse(sys.argv, ["--rate", "-r"]))
  218. if not rate or rate == "":
  219. rate = 0.05
  220. hidden = int(parse(sys.argv, ["--hidden", "-h"]))
  221. if not hidden or hidden == "":
  222. hidden = 1
  223. clip = float(parse(sys.argv, ["--clip", "-c"]))
  224. if not clip:
  225. clip = 1
  226. epochs = int(parse(sys.argv, ["--epoch", "-e"]))
  227. if not epochs:
  228. epochs = 1
  229. nbatches = int(parse(sys.argv, ["--batch", "-b"]))
  230. if not nbatches:
  231. nbatches = 1
  232. momentum = float(parse(sys.argv, ["--momentum", "-m"]))
  233. if not momentum:
  234. momentum = 0.4
  235. n_prev = int(parse(sys.argv, ["--previous", "-p"]))
  236. if not n_prev:
  237. n_prev = 5
  238. dropout = float(parse(sys.argv, ["--dropout", "-d"]))
  239. if not dropout:
  240. dropout = 0.0
  241. sample_size = int(parse(sys.argv, ["--sample", "-s"]))
  242. if not sample_size:
  243. sample_size = 1000
  244.  
  245. except:
  246. splash(True)
  247. quit()
  248.  
  249. #open file
  250. master_words = []
  251. alphabet = []
  252. #writer = SummaryWriter()
  253. text = []
  254. e = 0
  255. c = 0
  256. randomness = 0
  257.  
  258. with open(filename, "r") as f:
  259. # reads all lines and removes non alphabet words
  260. intext = f.read()
  261.  
  262. for l in list(intext):
  263. if l == "\n": l = "¶"
  264. if l == "\x1b": print("XXX")
  265. text.append(l)
  266.  
  267. for l in text:
  268. sys.stdout.flush()
  269.  
  270. if l not in alphabet:
  271. alphabet.append(l)
  272. print("\r{}% - {}/{}".format(int(c/len(text)*100), c, len(text)), end="")
  273. c+=1
  274.  
  275. epochs = 1
  276. alphabet_size = len(alphabet)
  277.  
  278. splash(False)
  279. model = Model(alphabet_size, n_prev, dropout, hidden, clip).float().cuda()
  280.  
  281. nchars = len(text)
  282.  
  283. replay_memory = []
  284. graph_time = 0
  285.  
  286. mem_max = torch.cuda.max_memory_allocated()
  287. print(mem_max)
  288.  
  289. def select(out):
  290.  
  291. r = random.randint(0,randomness)
  292.  
  293. for i in range(r):
  294. out[torch.argmax(out)] = 0
  295.  
  296. return out
  297.  
  298. def get_outchar(out):
  299. return alphabet[torch.argmax(out)]
  300.  
  301. def get_next():
  302. output = []
  303. try:
  304.  
  305. for n in range(n_prev):
  306. letter = text[(model.count+n)]
  307. letter = letter.replace("\x1b", " ")
  308.  
  309. if letter == "\n": letter = "¶"
  310.  
  311. o = torch.autograd.Variable(torch.zeros(alphabet_size), requires_grad=True).cuda()
  312. if letter == "[": print(letter)
  313. o[alphabet.index(letter)] = 1
  314. output.append(o)
  315.  
  316. model.count += 1
  317. target = torch.autograd.Variable(torch.zeros(alphabet_size), requires_grad=True).cuda()
  318. target[alphabet.index(letter)] = 1
  319. #target = alphabet.index(text[(model.count + n)])
  320.  
  321. return output, target, False
  322.  
  323. except IndexError:
  324. print("ZZZ")
  325. model.epochs += 1
  326. return None, None, True
  327.  
  328. model.train(True)
  329.  
  330. steps = 5000
  331. t = 0
  332.  
  333. inp, target, done = get_next()
  334. generate_init = 0
  335.  
  336. while True:
  337. #txt = colored("\nForward prop... Epoch: {} | Batch: {} | Working...".format(model.epochs, model.batches), attrs=['reverse'])
  338. #print(txt)
  339. torch.cuda.empty_cache()
  340.  
  341. t = 0
  342. outs = []
  343. targets = []
  344. variation = []
  345. print('\n')
  346. done = False
  347.  
  348. start = datetime.datetime.now()
  349.  
  350. while t < steps:
  351. inp, target, done = get_next()
  352.  
  353. if done:
  354. model.count = 0
  355. break
  356.  
  357. t += 1
  358.  
  359. out = model.forward(inp)
  360. char, out = get_output(out, use_greedy)
  361.  
  362. #print(char, end="")
  363. #sys.stdout.flush()
  364.  
  365. outs.append(out)
  366. targets.append(target)
  367.  
  368. #inp.append(out)
  369. #inp.pop(0)
  370.  
  371. field = ""
  372.  
  373. for o in inp:
  374. field += alphabet[torch.argmax(o)]
  375. field += "| {} | {} |".format(char, alphabet[torch.argmax(target)])
  376.  
  377. progress = int(100*(t/steps))
  378.  
  379. mem_used = torch.cuda.memory_allocated()
  380. mem_avail = torch.cuda.memory_cached()
  381.  
  382. percentage = 100 * (mem_used / mem_avail)
  383.  
  384. if char not in variation:
  385. variation.append(char)
  386.  
  387. txt = "\rForward prop... | Epoch: {} | Batch: {} | {}/{} | {} | Progress: {}% | Memory: {}/{} ({}%)".format(model.epochs, model.batches, t, steps, field, progress, mem_used, mem_avail, percentage)
  388. print(txt, end=" | ")
  389.  
  390. #inp.append(out)
  391. #inp.pop(0)1000f
  392.  
  393.  
  394. end = datetime.datetime.now()
  395. print("\nTime: {} @ {}\n".format((end-start), end))
  396.  
  397. i = 0
  398.  
  399. txt = colored("\rBackward prop... | Epoch: {} | Batch: {} | Working...".format(model.epochs, model.batches),
  400. attrs=['reverse'])
  401. print(txt,end="")
  402.  
  403. out = torch.stack(outs).cuda()
  404. target = torch.stack(targets).cuda()
  405.  
  406. model.optimizer.zero_grad()
  407. loss = model.loss_function(out, target)
  408. loss.backward(retain_graph=True)
  409. model.optimizer.step()
  410.  
  411. del loss, out, target
  412.  
  413. i += 1
  414.  
  415. targets.clear()
  416. outs.clear()
  417.  
  418. time = datetime.datetime.now()
  419.  
  420. variation_int = int(100 * (len(variation) / alphabet_size))
  421.  
  422. print("\nVariation: {} | {}/{} | \nAt {} - Memory: {}/{} ({}%)".format(variation_int, len(variation), alphabet_size, time, mem_used, mem_avail, percentage))
  423. variation = []
  424.  
  425. model.optimizer.zero_grad()
  426. i = 0
  427.  
  428. model.batches += 1
  429. model.counter += 1
  430. model.count = 0
  431.  
  432. inp, target, done = get_next()
  433.  
  434. print("\n Generating text... \n")
  435. for i in range(1000):
  436.  
  437. out = model.forward(inp)
  438.  
  439. #sample argmax letter
  440. char, out = get_output(out, use_greedy)
  441.  
  442. print(char, end="")
  443. sys.stdout.flush()
  444.  
  445. inp.append(out)
  446. inp.pop(0)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement