Advertisement
Guest User

Llama_Colab

a guest
Mar 4th, 2023
3,921
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 29.79 KB | None | 0 0
  1. {
  2. "nbformat": 4,
  3. "nbformat_minor": 0,
  4. "metadata": {
  5. "colab": {
  6. "provenance": []
  7. },
  8. "kernelspec": {
  9. "name": "python3",
  10. "display_name": "Python 3"
  11. },
  12. "language_info": {
  13. "name": "python"
  14. },
  15. "gpuClass": "standard",
  16. "accelerator": "GPU"
  17. },
  18. "cells": [
  19. {
  20. "cell_type": "markdown",
  21. "source": [
  22. "Initialize repository, copy weights from Google drive."
  23. ],
  24. "metadata": {
  25. "id": "uh_RON9_FqBV"
  26. }
  27. },
  28. {
  29. "cell_type": "code",
  30. "execution_count": null,
  31. "metadata": {
  32. "id": "aChaziGm-OrN"
  33. },
  34. "outputs": [],
  35. "source": [
  36. "import os\n",
  37. "import sys\n",
  38. "from google.colab import drive\n",
  39. "\n",
  40. "# Mount google drive.\n",
  41. "drive.mount('/gdrive')\n",
  42. "\n",
  43. "#@markdown Location of tokenizer.\n",
  44. "tokenizer_loc = '/gdrive/My Drive/tokenizer.model' #@param {type:\"string\"}\n",
  45. "\n",
  46. "# @markdown Location of directory containing model weights / parameters.\n",
  47. "weight_loc = '/gdrive/My Drive/7B/' #@param {type:\"string\"}\n",
  48. "\n",
  49. "!pip install fairscale\n",
  50. "!pip install sentencepiece\n",
  51. "!git clone https://github.com/facebookresearch/llama.git\n",
  52. "\n",
  53. "sys.path.insert(0, '/content/llama/')\n",
  54. "\n",
  55. "!nvidia-smi"
  56. ]
  57. },
  58. {
  59. "cell_type": "markdown",
  60. "source": [
  61. "The 7B checkpoint is too large to fit into RAM. Run this cell if you need to split the 7B checkpoint. Will save the results to your 7B directory so you should only ever need to run this cell once. You may need to restart the runtime afterward."
  62. ],
  63. "metadata": {
  64. "id": "YUViS0koD_aj"
  65. }
  66. },
  67. {
  68. "cell_type": "code",
  69. "source": [
  70. "import torch\n",
  71. "\n",
  72. "checkpoint = torch.load(os.path.join(weight_loc, 'consolidated.00.pth'),\n",
  73. " map_location=\"cuda\")\n",
  74. "\n",
  75. "d1 = dict(list(checkpoint.items())[:len(checkpoint)//2])\n",
  76. "torch.save(d1, os.path.join(weight_loc, 'consolidated.00.00.pth'))\n",
  77. "del(d1)\n",
  78. "\n",
  79. "d2 = dict(list(checkpoint.items())[len(checkpoint)//2:])\n",
  80. "torch.save(d2, os.path.join(weight_loc, 'consolidated.00.01.pth'))\n",
  81. "del(d2)\n",
  82. "\n",
  83. "del(checkpoint)"
  84. ],
  85. "metadata": {
  86. "id": "MxwFpC1fCAdz"
  87. },
  88. "execution_count": null,
  89. "outputs": []
  90. },
  91. {
  92. "cell_type": "markdown",
  93. "source": [
  94. "Include that one anon's additional sampling methods so we have Kobold parameters like repetition penalty, tfs, etc."
  95. ],
  96. "metadata": {
  97. "id": "BMbLXEqjcmdi"
  98. }
  99. },
  100. {
  101. "cell_type": "code",
  102. "source": [
  103. "# Copyright (c) Meta Platforms, Inc. and affiliates.\n",
  104. "# This software may be used and distributed according to the terms of the GNU\n",
  105. "# General Public License version 3.\n",
  106. "\n",
  107. "from typing import List\n",
  108. "\n",
  109. "import torch\n",
  110. "\n",
  111. "from llama.tokenizer import Tokenizer\n",
  112. "from llama.model import Transformer\n",
  113. "\n",
  114. "class LLaMA:\n",
  115. " def __init__(self, model: Transformer, tokenizer: Tokenizer):\n",
  116. " self.model = model\n",
  117. " self.tokenizer = tokenizer\n",
  118. "\n",
  119. " def generate(\n",
  120. " self,\n",
  121. " prompts: List[str],\n",
  122. " max_gen_len: int,\n",
  123. " temperature: float = 0.8,\n",
  124. " top_p: float = 0.95,\n",
  125. " tfs: float = 1.0,\n",
  126. " typical: float = 1.0,\n",
  127. " penalty_range: float = 1024,\n",
  128. " penalty_slope: float = 0.7,\n",
  129. " penalty: float = 1.1\n",
  130. " ) -> List[str]:\n",
  131. " bsz = len(prompts)\n",
  132. " params = self.model.params\n",
  133. " assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)\n",
  134. "\n",
  135. " prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False)\n",
  136. " for x in prompts]\n",
  137. "\n",
  138. " min_prompt_size = min([len(t) for t in prompt_tokens])\n",
  139. " max_prompt_size = max([len(t) for t in prompt_tokens])\n",
  140. "\n",
  141. " total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)\n",
  142. "\n",
  143. " tokens = torch.full((bsz, total_len),\n",
  144. " self.tokenizer.pad_id).cuda().long()\n",
  145. " for k, t in enumerate(prompt_tokens):\n",
  146. " tokens[k, : len(t)] = torch.tensor(t).long()\n",
  147. " input_text_mask = tokens != self.tokenizer.pad_id\n",
  148. " start_pos = min_prompt_size\n",
  149. " prev_pos = 0\n",
  150. " for cur_pos in range(start_pos, total_len):\n",
  151. " input_ids = tokens[:, prev_pos:cur_pos]\n",
  152. " logits = self.model.forward(input_ids, prev_pos)\n",
  153. " if temperature > 0:\n",
  154. "\n",
  155. " next_token_scores = sample_top_p_actual(input_ids, logits,\n",
  156. " top_p)\n",
  157. " next_token_scores = sample_tail_free(input_ids,\n",
  158. " next_token_scores, tfs)\n",
  159. " next_token_scores = sample_typical(input_ids, next_token_scores,\n",
  160. " typical)\n",
  161. " next_token_scores = sample_temperature(input_ids,\n",
  162. " next_token_scores,\n",
  163. " temperature)\n",
  164. " next_token_scores = sample_advanced_repetition_penalty(input_ids,\n",
  165. " next_token_scores,\n",
  166. " penalty_range,\n",
  167. " penalty_slope,\n",
  168. " penalty)\n",
  169. "\n",
  170. " next_token_scores = torch.nn.functional.softmax(next_token_scores,\n",
  171. " dim=-1)\n",
  172. " next_token = torch.multinomial(next_token_scores,\n",
  173. " num_samples=1).squeeze(1)\n",
  174. " else:\n",
  175. " next_token = torch.argmax(logits, dim=-1)\n",
  176. " next_token = next_token.reshape(-1)\n",
  177. " # only replace token if prompt has already been generated\n",
  178. " next_token = torch.where(\n",
  179. " input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token\n",
  180. " )\n",
  181. " tokens[:, cur_pos] = next_token\n",
  182. " prev_pos = cur_pos\n",
  183. "\n",
  184. " decoded = []\n",
  185. " for i, t in enumerate(tokens.tolist()):\n",
  186. " # cut to max gen len\n",
  187. " t = t[: len(prompt_tokens[i]) + max_gen_len]\n",
  188. " # cut to eos tok if any\n",
  189. " try:\n",
  190. " t = t[: t.index(self.tokenizer.eos_id)]\n",
  191. " except ValueError:\n",
  192. " pass\n",
  193. " decoded.append(self.tokenizer.decode(t))\n",
  194. " return decoded\n",
  195. "\n",
  196. "# taken from Kobold and transformers so this stuff is AGPL I guess\n",
  197. "def sample_temperature(input_ids, scores, tempt):\n",
  198. " scores = scores / tempt\n",
  199. " return scores\n",
  200. "\n",
  201. "def sample_typical(input_ids, scores, typical, filter_value = -float(\"Inf\"),\n",
  202. " min_tokens_to_keep = 1):\n",
  203. " if filter_value >= 1.0:\n",
  204. " return scores\n",
  205. "\n",
  206. " probs = scores.softmax(dim=-1)\n",
  207. " log_probs = probs.log()\n",
  208. "\n",
  209. " neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True)\n",
  210. "\n",
  211. " entropy_deviation = (neg_entropy - log_probs).abs()\n",
  212. "\n",
  213. " _, sorted_indices = torch.sort(entropy_deviation)\n",
  214. " sorted_logits = probs.gather(-1, sorted_indices)\n",
  215. " sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= typical\n",
  216. " sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)\n",
  217. "\n",
  218. " min_tokens_to_keep = max(min_tokens_to_keep, 1)\n",
  219. " # Keep at least min_tokens_to_keep\n",
  220. " sorted_indices_to_remove[..., : min_tokens_to_keep] = 0\n",
  221. "\n",
  222. " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)\n",
  223. " scores = scores.masked_fill(indices_to_remove, filter_value)\n",
  224. " return scores \n",
  225. "\n",
  226. "def sample_top_p_actual(input_ids, scores, top_p, filter_value = -float(\"Inf\"),\n",
  227. " min_tokens_to_keep = 1):\n",
  228. " sorted_logits, sorted_indices = torch.sort(scores, descending=False)\n",
  229. " cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)\n",
  230. "\n",
  231. " # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)\n",
  232. " sorted_indices_to_remove = cumulative_probs <= (1 - top_p)\n",
  233. " if min_tokens_to_keep > 1:\n",
  234. " # Keep at least min_tokens_to_keep\n",
  235. " sorted_indices_to_remove[..., -min_tokens_to_keep :] = 0\n",
  236. "\n",
  237. " # scatter sorted tensors to original indexing\n",
  238. " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,\n",
  239. " sorted_indices_to_remove)\n",
  240. " scores = scores.masked_fill(indices_to_remove, filter_value)\n",
  241. " return scores\n",
  242. "\n",
  243. "def sample_advanced_repetition_penalty(input_ids, scores, penalty_range,\n",
  244. " penalty_slope, penalty):\n",
  245. " penalty_range = int(penalty_range)\n",
  246. " clipped_penalty_range = min(input_ids.shape[-1], penalty_range)\n",
  247. "\n",
  248. " if penalty != 1.0:\n",
  249. " if penalty_range > 0:\n",
  250. " if clipped_penalty_range < input_ids.shape[1]:\n",
  251. " input_ids = input_ids[..., -clipped_penalty_range:]\n",
  252. "\n",
  253. " if penalty_slope != 0:\n",
  254. " _penalty = (torch.arange(penalty_range, dtype=scores.dtype,\n",
  255. " device=scores.device)/(penalty_range - 1)) * 2. - 1\n",
  256. " _penalty = (penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (penalty_slope - 1))\n",
  257. " _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)\n",
  258. " penalty = _penalty[..., -clipped_penalty_range:]\n",
  259. "\n",
  260. " score = torch.gather(scores, 1, input_ids)\n",
  261. " score = torch.where(score <= 0, score * penalty, score / penalty)\n",
  262. " scores.scatter_(1, input_ids, score)\n",
  263. "\n",
  264. " return scores \n",
  265. "\n",
  266. "def sample_top_a(input_ids, scores, top_a, filter_value = -float(\"Inf\"),\n",
  267. " min_tokens_to_keep = 1):\n",
  268. " if filter_value >= 1.0:\n",
  269. " return scores\n",
  270. "\n",
  271. " sorted_logits, sorted_indices = torch.sort(scores, descending=True)\n",
  272. " probs = sorted_logits.softmax(dim=-1)\n",
  273. "\n",
  274. " # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)\n",
  275. " probs_max = probs[..., 0, None]\n",
  276. " sorted_indices_to_remove = probs < probs_max * probs_max * top_a\n",
  277. "\n",
  278. " if min_tokens_to_keep > 1:\n",
  279. " # Keep at least min_tokens_to_keep\n",
  280. " sorted_indices_to_remove[..., : min_tokens_to_keep] = 0\n",
  281. "\n",
  282. " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,\n",
  283. " sorted_indices_to_remove)\n",
  284. " scores = scores.masked_fill(indices_to_remove, filter_value)\n",
  285. " return scores \n",
  286. "\n",
  287. "def sample_tail_free(input_ids, scores, tfs, filter_value = -float(\"Inf\"),\n",
  288. " min_tokens_to_keep = 1):\n",
  289. " if filter_value >= 1.0:\n",
  290. " return scores\n",
  291. " sorted_logits, sorted_indices = torch.sort(scores, descending=True)\n",
  292. " probs = sorted_logits.softmax(dim=-1)\n",
  293. "\n",
  294. " # Compute second derivative normalized CDF\n",
  295. " d2 = probs.diff().diff().abs()\n",
  296. " normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)\n",
  297. " normalized_d2_cdf = normalized_d2.cumsum(dim=-1)\n",
  298. "\n",
  299. " # Remove tokens with CDF value above the threshold (token with 0 are kept)\n",
  300. " sorted_indices_to_remove = normalized_d2_cdf > tfs\n",
  301. "\n",
  302. " # Centre the distribution around the cutoff as in the original implementation of the algorithm\n",
  303. " sorted_indices_to_remove = torch.cat(\n",
  304. " (\n",
  305. " torch.zeros(scores.shape[0], 1, dtype=torch.bool,\n",
  306. " device=scores.device),\n",
  307. " sorted_indices_to_remove,\n",
  308. " torch.ones(scores.shape[0], 1, dtype=torch.bool,\n",
  309. " device=scores.device),\n",
  310. " ),\n",
  311. " dim=-1,\n",
  312. " )\n",
  313. "\n",
  314. " if min_tokens_to_keep > 1:\n",
  315. " # Keep at least min_tokens_to_keep\n",
  316. " sorted_indices_to_remove[..., : min_tokens_to_keep] = 0\n",
  317. "\n",
  318. " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,\n",
  319. " sorted_indices_to_remove)\n",
  320. " scores = scores.masked_fill(indices_to_remove, filter_value)\n",
  321. " return scores"
  322. ],
  323. "metadata": {
  324. "id": "fYMxCH_Zcajj"
  325. },
  326. "execution_count": null,
  327. "outputs": []
  328. },
  329. {
  330. "cell_type": "markdown",
  331. "source": [
  332. "Load model."
  333. ],
  334. "metadata": {
  335. "id": "47yXVODMO6l0"
  336. }
  337. },
  338. {
  339. "cell_type": "code",
  340. "source": [
  341. "from typing import Tuple\n",
  342. "import os\n",
  343. "import sys\n",
  344. "import torch\n",
  345. "import time\n",
  346. "import json\n",
  347. "\n",
  348. "from pathlib import Path\n",
  349. "\n",
  350. "from fairscale.nn.model_parallel.initialize import initialize_model_parallel\n",
  351. "\n",
  352. "from llama.model import ModelArgs, Transformer\n",
  353. "from llama.tokenizer import Tokenizer\n",
  354. "\n",
  355. "\n",
  356. "os.environ['RANK'] = '0'\n",
  357. "os.environ['WORLD_SIZE'] = '1'\n",
  358. "os.environ['MP'] = '1'\n",
  359. "os.environ['MASTER_ADDR'] = '127.0.0.1'\n",
  360. "os.environ['MASTER_PORT'] = '2223'\n",
  361. "\n",
  362. "\n",
  363. "def setup_model_parallel() -> Tuple[int, int]:\n",
  364. " local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n",
  365. " world_size = int(os.environ.get(\"WORLD_SIZE\", -1))\n",
  366. "\n",
  367. " torch.distributed.init_process_group(\"gloo\")\n",
  368. " initialize_model_parallel(world_size)\n",
  369. " torch.cuda.set_device(local_rank)\n",
  370. "\n",
  371. " # seed must be the same in all processes\n",
  372. " torch.manual_seed(1)\n",
  373. " return local_rank, world_size\n",
  374. "\n",
  375. "\n",
  376. "'''\n",
  377. "def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int,\n",
  378. " max_seq_len: int, max_batch_size: int) -> LLaMA:\n",
  379. " start_time = time.time()\n",
  380. " checkpoints = sorted(Path(ckpt_dir).glob(\"*.pth\"))\n",
  381. " assert (\n",
  382. " world_size == len(checkpoints)\n",
  383. " ), f\"Loading a checkpoint for MP={len(checkpoints)} but world size is\n",
  384. " {world_size}\"\n",
  385. " ckpt_path = checkpoints[local_rank]\n",
  386. " print(\"Loading\")\n",
  387. " \n",
  388. " checkpoint = torch.load(ckpt_path, map_location=\"cpu\")\n",
  389. " with open(Path(ckpt_dir) / \"params.json\", \"r\") as f:\n",
  390. " params = json.loads(f.read())\n",
  391. "\n",
  392. " model_args: ModelArgs = ModelArgs(max_seq_len=max_seq_len,\n",
  393. " max_batch_size=max_batch_size,\n",
  394. " **params)\n",
  395. " tokenizer = Tokenizer(model_path=tokenizer_path)\n",
  396. " model_args.vocab_size = tokenizer.n_words\n",
  397. " torch.set_default_tensor_type(torch.cuda.HalfTensor)\n",
  398. " model = Transformer(model_args).cuda().half()\n",
  399. " torch.set_default_tensor_type(torch.FloatTensor)\n",
  400. " model.load_state_dict(checkpoint, strict=False)\n",
  401. "\n",
  402. " generator = LLaMA(model, tokenizer)\n",
  403. " print(f\"Loaded in {time.time() - start_time:.2f} seconds\")\n",
  404. " return generator\n",
  405. "'''\n",
  406. "\n",
  407. "\n",
  408. "def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int,\n",
  409. " max_seq_len: int, max_batch_size: int) -> LLaMA:\n",
  410. " start_time = time.time()\n",
  411. " \n",
  412. " print(\"Loading\")\n",
  413. " with open(Path(ckpt_dir) / \"params.json\", \"r\") as f:\n",
  414. " params = json.loads(f.read())\n",
  415. "\n",
  416. " model_args: ModelArgs = ModelArgs(max_seq_len=max_seq_len,\n",
  417. " max_batch_size=max_batch_size,\n",
  418. " **params)\n",
  419. " tokenizer = Tokenizer(model_path=tokenizer_path)\n",
  420. " model_args.vocab_size = tokenizer.n_words\n",
  421. " torch.set_default_tensor_type(torch.cuda.HalfTensor)\n",
  422. " model = Transformer(model_args).cuda().half()\n",
  423. " torch.set_default_tensor_type(torch.FloatTensor)\n",
  424. "\n",
  425. " checkpoint_paths = [os.path.join(ckpt_dir, 'consolidated.00.00.pth'),\n",
  426. " os.path.join(ckpt_dir, 'consolidated.00.01.pth')]\n",
  427. " \n",
  428. " for checkpoint_path in checkpoint_paths:\n",
  429. " checkpoint = torch.load(checkpoint_path, map_location='cpu')\n",
  430. " model.load_state_dict(checkpoint, strict=False)\n",
  431. " del checkpoint\n",
  432. "\n",
  433. " generator = LLaMA(model, tokenizer)\n",
  434. " print(f\"Loaded in {time.time() - start_time:.2f} seconds\")\n",
  435. " return generator\n",
  436. "\n",
  437. "# @markdown Context size. Can be up to 2048, but Colab GPU doesn't always play well with high values.\n",
  438. "max_seq_len = 1024 # @param {type:\"number\"}\n",
  439. "max_batch_size = 1\n",
  440. "\n",
  441. "local_rank, world_size = setup_model_parallel()\n",
  442. "if local_rank > 0:\n",
  443. " sys.stdout = open(os.devnull, 'w')\n",
  444. "\n",
  445. "generator = load(weight_loc, tokenizer_loc, local_rank, world_size,\n",
  446. " max_seq_len, max_batch_size)\n",
  447. "tokenizer = generator.tokenizer"
  448. ],
  449. "metadata": {
  450. "id": "U5eiyiLJMNpz"
  451. },
  452. "execution_count": null,
  453. "outputs": []
  454. },
  455. {
  456. "cell_type": "markdown",
  457. "source": [
  458. "Main GUI. If you change the presets, you'll have to reload the cell for the changes to take effect."
  459. ],
  460. "metadata": {
  461. "id": "yfgItXzVPGGS"
  462. }
  463. },
  464. {
  465. "cell_type": "code",
  466. "source": [
  467. "import ipywidgets as widgets\n",
  468. "from IPython.display import display\n",
  469. "import time\n",
  470. "\n",
  471. "max_gen_len = 64 #@param {type:\"number\"}\n",
  472. "temperature = 0.8 #@param {type:\"number\"}\n",
  473. "top_p = 0.95 #@param {type:\"number\"}\n",
  474. "tfs = 1.0 #@param {type:\"number\"}\n",
  475. "typical = 1.0 #@param {type:\"number\"}\n",
  476. "penalty_range = 1024 #@param {type:\"number\"}\n",
  477. "penalty_slope = 0.7 #@param {type:\"number\"}\n",
  478. "penalty = 1.1 #@param {type:\"number\"}\n",
  479. "\n",
  480. "input_text_area = widgets.Textarea(placeholder='Enter a prompt...',\n",
  481. " layout=widgets.Layout(width='1200px',\n",
  482. " height='600px'))\n",
  483. "send_button = widgets.Button(description='Send')\n",
  484. "undo_button = widgets.Button(description='Undo')\n",
  485. "redo_button = widgets.Button(description='Redo')\n",
  486. "retry_button = widgets.Button(description='Retry')\n",
  487. "memory_button = widgets.ToggleButton(description='Memory')\n",
  488. "\n",
  489. "hbox = widgets.HBox([input_text_area,\n",
  490. " widgets.VBox([send_button, undo_button, redo_button,\n",
  491. " retry_button, memory_button])])\n",
  492. "output = widgets.Output()\n",
  493. "\n",
  494. "undo_button.disabled = True\n",
  495. "redo_button.disabled = True\n",
  496. "retry_button.disabled = True\n",
  497. "\n",
  498. "listen_for_updates = False\n",
  499. "cur_outputs = []\n",
  500. "cur_outputs_idx = -1\n",
  501. "memory_text = ''\n",
  502. "input_text = ''\n",
  503. "\n",
  504. "def generate():\n",
  505. " # When creating the context, first, place the full memory followed by a\n",
  506. " # newline.\n",
  507. " #\n",
  508. " # Next, taking the last (max_seq_len-1-max_gen_len-len(mem)) tokens,\n",
  509. " # place these tokens in the context.\n",
  510. " \n",
  511. " if memory_text:\n",
  512. " mem_tokenized = tokenizer.encode(memory_text + '\\n', bos=False, eos=False)\n",
  513. " else:\n",
  514. " mem_tokenized = []\n",
  515. " \n",
  516. " inp_tokenized = tokenizer.encode(input_text_area.value, bos=False, eos=False)\n",
  517. " num_inp_tokens = max(max_seq_len-1-max_gen_len-len(mem_tokenized), 0)\n",
  518. "\n",
  519. " if num_inp_tokens > 0:\n",
  520. " tokenized = mem_tokenized + inp_tokenized[-num_inp_tokens:]\n",
  521. " elif len(mem_tokenized) > 0:\n",
  522. " num_mem_tokens = max_seq_len-1-max_gen_len\n",
  523. " tokenized = mem_tokenized[-num_mem_tokens:]\n",
  524. " else:\n",
  525. " tokenized = []\n",
  526. " \n",
  527. " detokenized = tokenizer.decode(tokenized)\n",
  528. " output = generator.generate([detokenized],\n",
  529. " max_gen_len=max_gen_len,\n",
  530. " temperature=temperature,\n",
  531. " top_p=top_p,\n",
  532. " tfs=tfs,\n",
  533. " typical=typical,\n",
  534. " penalty_range=penalty_range,\n",
  535. " penalty_slope=penalty_slope,\n",
  536. " penalty=penalty)\n",
  537. "\n",
  538. " num_characters = len(output) - len(detokenized) - 1\n",
  539. " return output[0][-num_characters:]\n",
  540. "\n",
  541. " '''\n",
  542. " tokenized = tokenizer.encode(input_text_area.value, bos=True, eos=False)\n",
  543. " detokenized = tokenizer.decode(tokenized[-(max_seq_len-1-max_gen_len):])\n",
  544. " output = generator.generate([detokenized],\n",
  545. " max_gen_len=max_gen_len,\n",
  546. " temperature=temperature,\n",
  547. " top_p=top_p,\n",
  548. " tfs=tfs,\n",
  549. " typical=typical,\n",
  550. " penalty_range=penalty_range,\n",
  551. " penalty_slope=penalty_slope,\n",
  552. " penalty=penalty)\n",
  553. " num_characters = len(output) - len(detokenized) - 1\n",
  554. " return output[0][-num_characters:]\n",
  555. " '''\n",
  556. "\n",
  557. "def on_update_input_text_area(change):\n",
  558. " global listen_for_updates, cur_outputs, cur_outputs_idx\n",
  559. "\n",
  560. " if listen_for_updates:\n",
  561. " cur_outputs = []\n",
  562. " cur_outputs_idx = -1\n",
  563. " undo_button.disabled = True\n",
  564. " redo_button.disabled = True\n",
  565. " retry_button.disabled = True\n",
  566. "\n",
  567. "def send():\n",
  568. " global listen_for_updates, cur_outputs, cur_outputs_idx\n",
  569. "\n",
  570. " input_text_area.disabled = True\n",
  571. " memory_button.disabled = True\n",
  572. " listen_for_updates = False\n",
  573. "\n",
  574. " generation = generate()\n",
  575. " input_text_area.value += generation\n",
  576. " cur_outputs_idx += 1\n",
  577. " cur_outputs = cur_outputs[:cur_outputs_idx]\n",
  578. " cur_outputs.append(generation)\n",
  579. "\n",
  580. " undo_button.disabled = False\n",
  581. " redo_button.disabled = True\n",
  582. " retry_button.disabled = False\n",
  583. " listen_for_updates = True\n",
  584. " memory_button.disabled = False\n",
  585. " input_text_area.disabled = False\n",
  586. "\n",
  587. "def undo():\n",
  588. " global listen_for_updates, cur_outputs, cur_outputs_idx\n",
  589. "\n",
  590. " listen_for_updates = False\n",
  591. " num_chars = len(cur_outputs[cur_outputs_idx])\n",
  592. " input_text_area.value = input_text_area.value[:-num_chars]\n",
  593. " cur_outputs_idx -= 1\n",
  594. "\n",
  595. " if cur_outputs_idx == -1:\n",
  596. " undo_button.disabled = True\n",
  597. " retry_button.disabled = True\n",
  598. " if len(cur_outputs) > 0:\n",
  599. " redo_button.disabled = False\n",
  600. "\n",
  601. " listen_for_updates = True\n",
  602. "\n",
  603. "def redo():\n",
  604. " global listen_for_updates, cur_outputs, cur_outputs_idx\n",
  605. "\n",
  606. " listen_for_updates = False\n",
  607. " input_text_area.value += cur_outputs[cur_outputs_idx+1]\n",
  608. " cur_outputs_idx += 1\n",
  609. "\n",
  610. " if cur_outputs_idx == len(cur_outputs) - 1:\n",
  611. " redo_button.disabled = True\n",
  612. " if len(cur_outputs) > 0:\n",
  613. " undo_button.disabled = False\n",
  614. " retry_button.disabled = False\n",
  615. "\n",
  616. " listen_for_updates = True\n",
  617. "\n",
  618. "def send_button_clicked(b):\n",
  619. " send()\n",
  620. "\n",
  621. "def undo_button_clicked(b):\n",
  622. " undo()\n",
  623. "\n",
  624. "def redo_button_clicked(b):\n",
  625. " redo()\n",
  626. "\n",
  627. "def retry_button_clicked(b):\n",
  628. " undo()\n",
  629. " send()\n",
  630. "\n",
  631. "def memory_button_clicked(b):\n",
  632. " global listen_for_updates, cur_outputs, cur_outputs_idx, memory_text, \\\n",
  633. " input_text\n",
  634. " if memory_button.value:\n",
  635. " listen_for_updates = False\n",
  636. " send_button.disabled = True\n",
  637. " undo_button.disabled = True\n",
  638. " redo_button.disabled = True\n",
  639. " retry_button.disabled = True\n",
  640. " input_text = input_text_area.value\n",
  641. " input_text_area.value = memory_text\n",
  642. " else:\n",
  643. " memory_text = input_text_area.value\n",
  644. " input_text_area.value = input_text\n",
  645. " input_text = ''\n",
  646. " send_button.disabled = False\n",
  647. " undo_button.disabled = cur_outputs_idx < 0\n",
  648. " redo_button.disabled = cur_outputs_idx >= len(cur_outputs) - 1\n",
  649. " retry_button.disabled = undo_button.disabled\n",
  650. " listen_for_updates = True\n",
  651. "\n",
  652. "send_button.on_click(send_button_clicked)\n",
  653. "undo_button.on_click(undo_button_clicked)\n",
  654. "redo_button.on_click(redo_button_clicked)\n",
  655. "retry_button.on_click(retry_button_clicked)\n",
  656. "memory_button.observe(memory_button_clicked, names='value')\n",
  657. "input_text_area.observe(on_update_input_text_area, names='value')\n",
  658. "\n",
  659. "display(hbox, output)"
  660. ],
  661. "metadata": {
  662. "id": "RRpoZt05O4vx"
  663. },
  664. "execution_count": null,
  665. "outputs": []
  666. }
  667. ]
  668. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement