Pastebin
API
tools
faq
paste
Login
Sign up
Please fix the following errors:
New Paste
Syntax Highlighting
{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard" }, "cells": [ { "cell_type": "markdown", "source": [ "First, install the CUDA extensions." ], "metadata": { "id": "ZUbS-f1DPXvV" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1tVeqisMQ5wr" }, "outputs": [], "source": [ "#!apt-get -y update\n", "#!apt-get -y install python3.10-dev\n", "#!python -m pip install --upgrade pip\n", "#!git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git\n", "!git clone https://github.com/MasterTaffer/GPTQ-for-LLaMa.git\n", "%cd 'GPTQ-for-LLaMa'\n", "!python setup_cuda.py install\n", "#!python test_kernel.py" ] }, { "cell_type": "markdown", "source": [ "Next, restart the runtime (but don't delete it). We'll need to do that in order for colab to be able to use the quant_cuda CPP extensions.\n", "\n", "Afterward, return to this this cell and execute it to clone the repo, install libraries and download your 4 bit LLaMA model of choice." ], "metadata": { "id": "R5-qdqtyPu1g" } }, { "cell_type": "code", "source": [ "import sys\n", "import torch\n", "import quant_cuda\n", "\n", "!pip install transformers\n", "!pip install sentencepiece\n", "weights_url = 'https://huggingface.co/wcde/llama-13b-4bit-gr128/resolve/main/llama-13b-4bit-gr128.pt' #@param {type:\"string\"}\n", "num_params = \"13b\" #@param [\"7b\", \"13b\", \"30b\", \"65b\"]\n", "group_size = 128 #@param {type:\"number\"}\n", "wbits = 4 #@param [\"2\", \"3\", \"4\", \"8\"]\n", "!wget {weights_url}\n", "!pip install git+https://github.com/zphang/transformers@llama_push\n", "sys.path.insert(0, 'GPTQ-for-LLaMa/')#sys.path.insert(0, '/content/GPTQ-for-LLaMa/')\n", "#!CUDA_VISIBLE_DEVICES=0 python llama_inference.py decapoda-research/llama-13b-hf --wbits 4 --load llama-13b-4bit.pt --text \"It was the best of times, it was the worst of times\"" ], "metadata": { "id": "a4Q7JnyOZHB-" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now execute this cell in order to load in the model. Additionally, you can specify your context size (if you're free tier and running 13B, you'll have to keep this pretty low or you may either run out of memory or have ridiculously slow generation times) and a flag denoting whether to load and split the model checkpoint in GPU VRAM before loading (also needed for free tier 13B)." ], "metadata": { "id": "3hBn0BoIQoNZ" } }, { "cell_type": "code", "source": [ "import time\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "from gptq import *\n", "from modelutils import *\n", "from quant import *\n", "\n", "from transformers import LlamaTokenizer\n", "\n", "DEV = torch.device('cuda:0')\n", "#context_size = 1024 #@param {type:\"number\"}\n", "split_checkpoint = True #@param {type:\"boolean\"}\n", "\n", "def load_quant(model, checkpoint, wbits, group_size):\n", " from transformers import LlamaConfig, LlamaForCausalLM \n", " config = LlamaConfig.from_pretrained(model)\n", " def noop(*args, **kwargs):\n", " pass\n", " torch.nn.init.kaiming_uniform_ = noop \n", " torch.nn.init.uniform_ = noop\n", " torch.nn.init.normal_ = noop \n", "\n", " if split_checkpoint:\n", " print('Splitting checkpoint ...')\n", " ckpt = torch.load(checkpoint, map_location='cuda')\n", "\n", " d1 = dict(list(ckpt.items())[:len(ckpt)//2])\n", " torch.save(d1, checkpoint + '0')\n", " del(d1)\n", "\n", " d2 = dict(list(ckpt.items())[len(ckpt)//2:])\n", " torch.save(d2, checkpoint + '1')\n", " del(d2)\n", "\n", " del(ckpt)\n", "\n", " torch.set_default_dtype(torch.half)\n", " transformers.modeling_utils._init_weights = False\n", " torch.set_default_dtype(torch.half)\n", " model = LlamaForCausalLM(config)\n", " torch.set_default_dtype(torch.float)\n", " model = model.eval()\n", " layers = find_layers(model)\n", " for name in ['lm_head']:\n", " if name in layers:\n", " del layers[name]\n", " make_quant(model, layers, wbits, group_size)\n", "\n", " if split_checkpoint:\n", " print('Loading model ...')\n", " for i in range(2):\n", " ckpt = torch.load(checkpoint + str(i))\n", " model.load_state_dict(ckpt, strict=False)\n", " del(ckpt)\n", " print('Done.')\n", "\n", " else:\n", " ckpt = torch.load(checkpoint)\n", " print('Loading model ...')\n", " model.load_state_dict(torch.load(checkpoint))\n", " print('Done.')\n", "\n", " #model.seqlen = context_size\n", " return model\n", "\n", "model = load_quant('decapoda-research/llama-{}-hf'.format(num_params), 'llama-{}-{}bit-gr{}.pt'.format(num_params, wbits, group_size), wbits, group_size).cuda()\n", "model.to(DEV)\n", "tokenizer = LlamaTokenizer.from_pretrained('decapoda-research/llama-{}-hf'.format(num_params))" ], "metadata": { "id": "KleSQ3ziiQ3n" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Define our token generation functions (both normal and generator)." ], "metadata": { "id": "35GvK2M5BASW" } }, { "cell_type": "code", "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "\n", "def gen_next_tokens(model, tokenizer, tokenized, context_len, max_gen_len,\n", " mask_id, temperature=0.8, top_p=0.95, tfs=1.0, typical=1.0,\n", " penalty_range=1024, penalty_slope=0.7, penalty=1.1,\n", " past_key_values=None):\n", " #tokenized = tokenizer.encode(inp, return_tensors='pt').to(DEV)\n", " total_len = min(context_len, max_gen_len + tokenized.shape[1])\n", "\n", " tokens = torch.full((1, total_len), mask_id).to(DEV)\n", " tokens[0, :tokenized.shape[1]] = tokenized[0]\n", "\n", " if past_key_values:\n", " output_past_key_values = past_key_values\n", "\n", " for cur_id in range(tokenized.shape[1], total_len):\n", " #print(cur_id - tokenized.shape[1])\n", " if past_key_values:\n", " output = model(tokens[:, cur_id-1:cur_id], past_key_values=past_key_values, use_cache=True)\n", " else:\n", " output = model(tokens[:, :cur_id], use_cache=True)\n", "\n", " if not past_key_values:\n", " logits = output.logits[:, cur_id-1, :]\n", " output_past_key_values = output.past_key_values\n", " else:\n", " logits = output.logits[:, 0, :]\n", " \n", " past_key_values = output.past_key_values\n", " input_ids = tokens[:, cur_id-1:cur_id]\n", "\n", " # Apply samplers - do greedy sampling if temperature is 0.\n", " if temperature > 0:\n", " next_token_scores = sample_top_p_actual(input_ids, logits,\n", " top_p)\n", " next_token_scores = sample_tail_free(input_ids,\n", " next_token_scores, tfs)\n", " next_token_scores = sample_typical(input_ids, next_token_scores,\n", " typical)\n", " next_token_scores = sample_temperature(input_ids,\n", " next_token_scores,\n", " temperature)\n", " next_token_scores = sample_advanced_repetition_penalty(input_ids,\n", " next_token_scores,\n", " penalty_range,\n", " penalty_slope,\n", " penalty)\n", "\n", " next_token_scores = torch.nn.functional.softmax(next_token_scores,\n", " dim=-1)\n", "\n", " next_token = torch.multinomial(next_token_scores,\n", " num_samples=1).squeeze(1)\n", " else:\n", " next_token = torch.argmax(logits, axis=-1)[0]\n", "\n", " tokens[0, cur_id] = next_token\n", " if next_token.item() == tokenizer.eos_token_id:\n", " return tokens[:, :cur_id], output_past_key_values\n", " \n", " return tokens, output_past_key_values\n", "\n", "def stm_next_tokens(model, tokenizer, tokenized, context_len, max_gen_len,\n", " mask_id, temperature=0.8, top_p=0.95, tfs=1.0, typical=1.0,\n", " penalty_range=1024, penalty_slope=0.7, penalty=1.1,\n", " past_key_values=None):\n", " #tokenized = tokenizer.encode(inp, return_tensors='pt').to(DEV)\n", " total_len = min(context_len, max_gen_len + tokenized.shape[1])\n", "\n", " tokens = torch.full((1, total_len), mask_id).to(DEV)\n", " tokens[0, :tokenized.shape[1]] = tokenized[0]\n", "\n", " if past_key_values:\n", " output_past_key_values = past_key_values\n", "\n", " for cur_id in range(tokenized.shape[1], total_len):\n", " #print(cur_id - tokenized.shape[1])\n", " if past_key_values:\n", " output = model(tokens[:, cur_id-1:cur_id], past_key_values=past_key_values, use_cache=True)\n", " else:\n", " output = model(tokens[:, :cur_id], use_cache=True)\n", "\n", " if not past_key_values:\n", " logits = output.logits[:, cur_id-1, :]\n", " output_past_key_values = output.past_key_values\n", " else:\n", " logits = output.logits[:, 0, :]\n", " \n", " past_key_values = output.past_key_values\n", " input_ids = tokens[:, cur_id-1:cur_id]\n", "\n", " # Apply samplers - do greedy sampling if temperature is 0.\n", " if temperature > 0:\n", " next_token_scores = sample_top_p_actual(input_ids, logits,\n", " top_p)\n", " next_token_scores = sample_tail_free(input_ids,\n", " next_token_scores, tfs)\n", " next_token_scores = sample_typical(input_ids, next_token_scores,\n", " typical)\n", " next_token_scores = sample_temperature(input_ids,\n", " next_token_scores,\n", " temperature)\n", " next_token_scores = sample_advanced_repetition_penalty(input_ids,\n", " next_token_scores,\n", " penalty_range,\n", " penalty_slope,\n", " penalty)\n", "\n", " next_token_scores = torch.nn.functional.softmax(next_token_scores,\n", " dim=-1)\n", "\n", " next_token = torch.multinomial(next_token_scores,\n", " num_samples=1).squeeze(1)\n", " else:\n", " next_token = torch.argmax(logits, axis=-1)[0]\n", "\n", " tokens[0, cur_id] = next_token\n", " yield next_token, None\n", "\n", " if next_token.item() == tokenizer.eos_token_id:\n", " yield None, output_past_key_values\n", " return\n", " \n", " yield None, output_past_key_values\n", " return\n", "\n", "# taken from Kobold and transformers so this stuff is AGPL I guess\n", "def sample_temperature(input_ids, scores, tempt):\n", " scores = scores / tempt\n", " return scores\n", "\n", "def sample_typical(input_ids, scores, typical, filter_value = -float(\"Inf\"),\n", " min_tokens_to_keep = 1):\n", " if filter_value >= 1.0:\n", " return scores\n", "\n", " probs = scores.softmax(dim=-1)\n", " log_probs = probs.log()\n", "\n", " neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True)\n", "\n", " entropy_deviation = (neg_entropy - log_probs).abs()\n", "\n", " _, sorted_indices = torch.sort(entropy_deviation)\n", " sorted_logits = probs.gather(-1, sorted_indices)\n", " sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= typical\n", " sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)\n", "\n", " min_tokens_to_keep = max(min_tokens_to_keep, 1)\n", " # Keep at least min_tokens_to_keep\n", " sorted_indices_to_remove[..., : min_tokens_to_keep] = 0\n", "\n", " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)\n", " scores = scores.masked_fill(indices_to_remove, filter_value)\n", " return scores \n", "\n", "def sample_top_p_actual(input_ids, scores, top_p, filter_value = -float(\"Inf\"),\n", " min_tokens_to_keep = 1):\n", " sorted_logits, sorted_indices = torch.sort(scores, descending=False)\n", " cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)\n", "\n", " # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)\n", " sorted_indices_to_remove = cumulative_probs <= (1 - top_p)\n", " if min_tokens_to_keep > 1:\n", " # Keep at least min_tokens_to_keep\n", " sorted_indices_to_remove[..., -min_tokens_to_keep :] = 0\n", "\n", " # scatter sorted tensors to original indexing\n", " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,\n", " sorted_indices_to_remove)\n", " scores = scores.masked_fill(indices_to_remove, filter_value)\n", " return scores\n", "\n", "def sample_advanced_repetition_penalty(input_ids, scores, penalty_range,\n", " penalty_slope, penalty):\n", " penalty_range = int(penalty_range)\n", " clipped_penalty_range = min(input_ids.shape[-1], penalty_range)\n", "\n", " if penalty != 1.0:\n", " if penalty_range > 0:\n", " if clipped_penalty_range < input_ids.shape[1]:\n", " input_ids = input_ids[..., -clipped_penalty_range:]\n", "\n", " if penalty_slope != 0:\n", " _penalty = (torch.arange(penalty_range, dtype=scores.dtype,\n", " device=scores.device)/(penalty_range - 1)) * 2. - 1\n", " _penalty = (penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (penalty_slope - 1))\n", " _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)\n", " penalty = _penalty[..., -clipped_penalty_range:]\n", "\n", " score = torch.gather(scores, 1, input_ids)\n", " score = torch.where(score <= 0, score * penalty, score / penalty)\n", " scores.scatter_(1, input_ids, score)\n", "\n", " return scores \n", "\n", "def sample_top_a(input_ids, scores, top_a, filter_value = -float(\"Inf\"),\n", " min_tokens_to_keep = 1):\n", " if filter_value >= 1.0:\n", " return scores\n", "\n", " sorted_logits, sorted_indices = torch.sort(scores, descending=True)\n", " probs = sorted_logits.softmax(dim=-1)\n", "\n", " # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)\n", " probs_max = probs[..., 0, None]\n", " sorted_indices_to_remove = probs < probs_max * probs_max * top_a\n", "\n", " if min_tokens_to_keep > 1:\n", " # Keep at least min_tokens_to_keep\n", " sorted_indices_to_remove[..., : min_tokens_to_keep] = 0\n", "\n", " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,\n", " sorted_indices_to_remove)\n", " scores = scores.masked_fill(indices_to_remove, filter_value)\n", " return scores \n", "\n", "def sample_tail_free(input_ids, scores, tfs, filter_value = -float(\"Inf\"),\n", " min_tokens_to_keep = 1):\n", " if filter_value >= 1.0:\n", " return scores\n", " sorted_logits, sorted_indices = torch.sort(scores, descending=True)\n", " probs = sorted_logits.softmax(dim=-1)\n", "\n", " # Compute second derivative normalized CDF\n", " d2 = probs.diff().diff().abs()\n", " normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)\n", " normalized_d2_cdf = normalized_d2.cumsum(dim=-1)\n", "\n", " # Remove tokens with CDF value above the threshold (token with 0 are kept)\n", " sorted_indices_to_remove = normalized_d2_cdf > tfs\n", "\n", " # Centre the distribution around the cutoff as in the original implementation of the algorithm\n", " sorted_indices_to_remove = torch.cat(\n", " (\n", " torch.zeros(scores.shape[0], 1, dtype=torch.bool,\n", " device=scores.device),\n", " sorted_indices_to_remove,\n", " torch.ones(scores.shape[0], 1, dtype=torch.bool,\n", " device=scores.device),\n", " ),\n", " dim=-1,\n", " )\n", "\n", " if min_tokens_to_keep > 1:\n", " # Keep at least min_tokens_to_keep\n", " sorted_indices_to_remove[..., : min_tokens_to_keep] = 0\n", "\n", " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,\n", " sorted_indices_to_remove)\n", " scores = scores.masked_fill(indices_to_remove, filter_value)\n", " return scores" ], "metadata": { "id": "9ZWJd4lzLjkP" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Main GUI." ], "metadata": { "id": "nZ44wSJGNoDY" } }, { "cell_type": "code", "source": [ "import ipywidgets as widgets\n", "from IPython.display import display\n", "import time\n", "from enum import Enum\n", "\n", "context_size = 1024 #@param {type:\"number\"}\n", "max_gen_len = 160 #@param {type:\"number\"}\n", "temperature = 1.0 #@param {type:\"number\"}\n", "top_p = 0.95 #@param {type:\"number\"}\n", "tfs = 1.0 #@param {type:\"number\"}\n", "typical = 1.0 #@param {type:\"number\"}\n", "penalty_range = 1024 #@param {type:\"number\"}\n", "penalty_slope = 0.7 #@param {type:\"number\"}\n", "penalty = 1.1 #@param {type:\"number\"}\n", "output_streaming = True #@param {type:\"boolean\"}\n", "kv_cache_size = 1 #@param {type:\"number\"}\n", "\n", "input_text_area = widgets.Textarea(placeholder='Enter a prompt...',\n", " layout=widgets.Layout(width='900px',\n", " height='600px'))\n", "model.seqlen = context_size\n", "send_button = widgets.Button(description='Send')\n", "undo_button = widgets.Button(description='Undo')\n", "redo_button = widgets.Button(description='Redo')\n", "retry_button = widgets.Button(description='Retry')\n", "prev_retry_button = widgets.Button(description='Previous Retry')\n", "memory_button = widgets.ToggleButton(description='Memory')\n", "context_button = widgets.ToggleButton(description='Context')\n", "\n", "hbox = widgets.HBox([input_text_area,\n", " widgets.VBox([send_button, undo_button, redo_button,\n", " retry_button, prev_retry_button, memory_button,\n", " context_button])])\n", "output = widgets.Output()\n", "\n", "Mode = Enum('Mode', ['INPUT', 'MEMORY', 'GENERATING', 'CONTEXT'])\n", "\n", "class State:\n", " def __init__(self, pos, mode):\n", " self.pos = pos\n", " self.mode = mode\n", " self.mem = ''\n", " self.saved_input = ''\n", " self.kv_dict = {}\n", " self.kv_queue = []\n", " \n", " def get_kv(self, id):\n", " if id in self.kv_dict:\n", " return self.kv_dict[id]\n", " else:\n", " return None\n", " \n", " def delete_kv(self, id):\n", " if id in self.kv_dict:\n", " self.kv_queue.remove(id)\n", " del self.kv_dict[id]\n", " \n", " '''\n", " def add_kv(self, id, kv):\n", " if id in self.kv_dict:\n", " self.kv_queue.remove(id)\n", " self.kv_queue.insert(0, id)\n", " else:\n", " if len(self.kv_queue) == kv_cache_size:\n", " old_id = self.kv_queue.pop()\n", " del self.kv_dict[old_id]\n", " self.kv_queue.insert(0, id)\n", " self.kv_dict[id] = kv\n", " '''\n", "\n", " def add_kv(self, id, kv):\n", " if id in self.kv_dict:\n", " self.kv_queue.remove(id)\n", " self.kv_queue.insert(0, id)\n", " elif self.get_len_kv() < kv_cache_size:\n", " self.kv_queue.insert(0, id)\n", " self.kv_dict[id] = kv\n", " \n", " def remove_last_kv(self):\n", " if self.kv_queue:\n", " old_id = self.kv_queue.pop()\n", " del self.kv_dict[old_id]\n", " \n", " def get_len_kv(self):\n", " return len(self.kv_queue)\n", " \n", " def clear_kv(self):\n", " for id in self.kv_queue:\n", " del self.kv_dict[id]\n", " self.kv_queue = []\n", "\n", "class Position:\n", " cur_id = 0\n", " def __init__(self):\n", " self.id = Position.cur_id\n", " self.pred = None\n", " self.succs = []\n", " self.succ_idx = -1\n", " self.text = ''\n", " Position.cur_id += 1\n", "\n", "init_pos = Position()\n", "cur_state = State(init_pos, Mode.INPUT)\n", "\n", "def build_context():\n", " # When creating the context, first, place the full memory followed by a\n", " # newline.\n", " #\n", " # Next, taking the last (max_seq_len-1-max_gen_len-len(mem)) tokens,\n", " # place these tokens in the context.\n", "\n", " if cur_state.mem:\n", " mem_tokenized = tokenizer.encode(cur_state.mem + '\\n', return_tensors='pt')[0].tolist()\n", " else:\n", " mem_tokenized = []\n", " \n", " inp_tokenized = tokenizer.encode(input_text_area.value, return_tensors='pt')[0].tolist()\n", " num_inp_tokens = max(model.seqlen-1-max_gen_len-len(mem_tokenized), 0)\n", "\n", " if num_inp_tokens > 0:\n", " tokenized = mem_tokenized + inp_tokenized[-num_inp_tokens:]\n", " elif len(mem_tokenized) > 0:\n", " num_mem_tokens = model.seqlen-1-max_gen_len\n", " tokenized = mem_tokenized[-num_mem_tokens:]\n", " else:\n", " tokenized = []\n", "\n", " detokenized = tokenizer.decode(tokenized)\n", " return detokenized\n", "\n", "def generate():\n", " # Create the context and send it to the model, update the text area.\n", " \n", " gen_context = build_context()\n", " retokenized = tokenizer.encode(gen_context, return_tensors='pt').to(DEV)\n", " prev_num_tokens = len(retokenized[0])\n", "\n", " output = ''\n", " past_key_values = None\n", " num_characters = 0\n", "\n", " if output_streaming:\n", " with torch.no_grad():\n", " out_tokens = retokenized[0].tolist()\n", " gen = stm_next_tokens(model, tokenizer, retokenized, model.seqlen,\n", " max_gen_len, 1, temperature=temperature, top_p=top_p, tfs=tfs,\n", " typical=typical, penalty_range=penalty_range,\n", " penalty_slope=penalty_slope, penalty=penalty,\n", " past_key_values=cur_state.get_kv(cur_state.pos.id))\n", " for tkn, pkv in gen:\n", " if pkv is not None:\n", " past_key_values = pkv\n", " else:\n", " out_tokens.append(tkn.item())\n", " output = tokenizer.decode(out_tokens)\n", " num_characters = len(output) - len(gen_context) - 1\n", " input_text_area.value = cur_state.pos.text + output[-num_characters:]\n", " else:\n", " with torch.no_grad():\n", " output_tokenized, past_key_values = gen_next_tokens(model, tokenizer,\n", " retokenized, model.seqlen, max_gen_len, 1, temperature=temperature,\n", " top_p=top_p, tfs=tfs, typical=typical, penalty_range=penalty_range,\n", " penalty_slope=penalty_slope, penalty=penalty,\n", " past_key_values=cur_state.get_kv(cur_state.pos.id))\n", " output = tokenizer.decode(output_tokenized[0].tolist())\n", " num_characters = len(output) - len(gen_context) - 1\n", " input_text_area.value = cur_state.pos.text + output[-num_characters:]\n", " return output[-num_characters:], past_key_values\n", "\n", "def on_update_input_text_area(change):\n", " # Input mode: Destroy all successors in the node list.\n", " #\n", " # Memory mode: n/a.\n", " #\n", " # Action allowed criterion: state.mode == 'input' or state.mode == 'memory'.\n", "\n", " if cur_state.mode == Mode.INPUT and (cur_state.pos.succs or cur_state.get_kv(cur_state.pos.id)) and cur_state.pos.text != input_text_area.value:\n", " if cur_state.pos.succs:\n", " del cur_state.pos.succs\n", " cur_state.pos.succs = []\n", " cur_state.pos.succ_idx = -1\n", " update_buttons_visible()\n", " if cur_state.get_kv(cur_state.pos.id):\n", " cur_state.delete_kv(cur_state.pos.id)\n", "\n", "def send():\n", " cur_state.pos.text = input_text_area.value\n", " cur_state.mode = Mode.GENERATING\n", " update_buttons_visible()\n", "\n", " if cur_state.get_len_kv() == kv_cache_size and not cur_state.get_kv(cur_state.pos.id):\n", " cur_state.remove_last_kv()\n", "\n", " generation, past_key_values = generate()\n", "\n", " new_succ = Position()\n", " new_succ.pred = cur_state.pos\n", " #new_succ.text = input_text_area.value + generation\n", " new_succ.text = input_text_area.value\n", " cur_state.pos.succs.append(new_succ)\n", " cur_state.pos.succ_idx = len(cur_state.pos.succs) - 1\n", " cur_state.add_kv(cur_state.pos.id, past_key_values)\n", " \n", " jump_to(new_succ)\n", "\n", " cur_state.mode = Mode.INPUT\n", " update_buttons_visible()\n", "\n", "def send_button_clicked(b):\n", " # Set text in current node to whatever is in the input area, generate text\n", " # (setting mode to 'generating' in the meantime), create a new successor at\n", " # head of list with text, set successor position to it, jump to it.\n", " #\n", " # Action allowed criterion: state.mode == 'input'.\n", "\n", " send()\n", " \n", "def undo_button_clicked(b):\n", " # Jump to predecessor.\n", " #\n", " # Action allowed criterion: state.mode == 'input', state.predecessor != nil.\n", "\n", " jump_to(cur_state.pos.pred)\n", "\n", "def redo_button_clicked(b):\n", " # Jump to current successor.\n", " #\n", " # Action allowed criterion: state.mode == 'input', state.successor_list !=\n", " # nil.\n", "\n", " jump_to(cur_state.pos.succs[cur_state.pos.succ_idx])\n", "\n", "def retry_button_clicked(b):\n", " # Jump to predecessor, then set successor position to next in the list if\n", " # it exists and jump to it, otherwise send_button_clicked().\n", " #\n", " # Action allowed criterion: state.mode == 'input', state.predecessor != nil.\n", "\n", " jump_to(cur_state.pos.pred)\n", "\n", " if cur_state.pos.succ_idx < len(cur_state.pos.succs) - 1:\n", " cur_state.pos.succ_idx += 1\n", " jump_to(cur_state.pos.succs[cur_state.pos.succ_idx])\n", " else:\n", " send()\n", "\n", "def prev_retry_button_clicked(b):\n", " # Jump to predecessor, then set successor position to prev in the list and\n", " # jump to it.\n", " #\n", " # Action allowed criterion: state.mode == 'input', state.predecessor != nil,\n", " # state.predecessor.succ_idx > 0.\n", "\n", " jump_to(cur_state.pos.pred)\n", " cur_state.pos.succ_idx -= 1\n", " jump_to(cur_state.pos.succs[cur_state.pos.succ_idx])\n", "\n", "def memory_button_clicked(b):\n", " # Input mode: switch modes to 'memory', save current state.\n", " #\n", " # Memory mode: switch modes to 'input', save memory, restore current state.\n", " #\n", " # Action allowed criterion: state.mode == 'input' or state.mode == 'memory'.\n", "\n", " if cur_state.mode == Mode.INPUT:\n", " cur_state.mode = Mode.MEMORY\n", " cur_state.saved_input = input_text_area.value\n", " input_text_area.value = cur_state.mem\n", " update_buttons_visible()\n", " elif cur_state.mode == Mode.MEMORY:\n", " if cur_state.mem != input_text_area.value:\n", " cur_state.clear_kv()\n", " cur_state.mode = Mode.INPUT\n", " cur_state.mem = input_text_area.value\n", " input_text_area.value = cur_state.saved_input\n", " update_buttons_visible()\n", "\n", "def context_button_clicked(b):\n", " # Input mode: switch modes to 'context', save current state.\n", " #\n", " # Context mode: switch mode to 'input', restore current state.\n", " #\n", " # Action allowed criterion: state.mode == 'input' or state.mode == 'context'.\n", "\n", " if cur_state.mode == Mode.INPUT:\n", " cur_state.mode = Mode.CONTEXT\n", " cur_state.saved_input = input_text_area.value\n", " input_text_area.value = build_context()\n", " update_buttons_visible()\n", " elif cur_state.mode == Mode.CONTEXT:\n", " cur_state.mode = Mode.INPUT\n", " input_text_area.value = cur_state.saved_input\n", " update_buttons_visible()\n", "\n", "def jump_to(pos):\n", " cur_state.pos = pos\n", " input_text_area.value = pos.text\n", " update_buttons_visible()\n", "\n", "def update_buttons_visible():\n", " send_button.disabled = cur_state.mode != Mode.INPUT\n", " undo_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.pred\n", " redo_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.succs\n", " retry_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.pred\n", " prev_retry_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.pred or not cur_state.pos.pred.succ_idx > 0\n", " memory_button.disabled = cur_state.mode != Mode.INPUT and cur_state.mode != Mode.MEMORY\n", " context_button.disabled = cur_state.mode != Mode.INPUT and cur_state.mode != Mode.CONTEXT\n", " input_text_area.disabled = cur_state.mode == Mode.GENERATING or cur_state.mode == Mode.CONTEXT\n", "\n", "send_button.on_click(send_button_clicked)\n", "undo_button.on_click(undo_button_clicked)\n", "redo_button.on_click(redo_button_clicked)\n", "retry_button.on_click(retry_button_clicked)\n", "prev_retry_button.on_click(prev_retry_button_clicked)\n", "memory_button.observe(memory_button_clicked, names='value')\n", "context_button.observe(context_button_clicked, names='value')\n", "input_text_area.observe(on_update_input_text_area, names='value')\n", "update_buttons_visible()\n", "\n", "display(hbox, output)" ], "metadata": { "id": "uYX2FVP4BlVC" }, "execution_count": null, "outputs": [] } ] }
Optional Paste Settings
Category:
None
Cryptocurrency
Cybersecurity
Fixit
Food
Gaming
Haiku
Help
History
Housing
Jokes
Legal
Money
Movies
Music
Pets
Photo
Science
Software
Source Code
Spirit
Sports
Travel
TV
Writing
Tags:
Syntax Highlighting:
None
Bash
C
C#
C++
CSS
HTML
JSON
Java
JavaScript
Lua
Markdown (PRO members only)
Objective C
PHP
Perl
Python
Ruby
Swift
4CS
6502 ACME Cross Assembler
6502 Kick Assembler
6502 TASM/64TASS
ABAP
AIMMS
ALGOL 68
APT Sources
ARM
ASM (NASM)
ASP
ActionScript
ActionScript 3
Ada
Apache Log
AppleScript
Arduino
Asymptote
AutoIt
Autohotkey
Avisynth
Awk
BASCOM AVR
BNF
BOO
Bash
Basic4GL
Batch
BibTeX
Blitz Basic
Blitz3D
BlitzMax
BrainFuck
C
C (WinAPI)
C Intermediate Language
C for Macs
C#
C++
C++ (WinAPI)
C++ (with Qt extensions)
C: Loadrunner
CAD DCL
CAD Lisp
CFDG
CMake
COBOL
CSS
Ceylon
ChaiScript
Chapel
Clojure
Clone C
Clone C++
CoffeeScript
ColdFusion
Cuesheet
D
DCL
DCPU-16
DCS
DIV
DOT
Dart
Delphi
Delphi Prism (Oxygene)
Diff
E
ECMAScript
EPC
Easytrieve
Eiffel
Email
Erlang
Euphoria
F#
FO Language
Falcon
Filemaker
Formula One
Fortran
FreeBasic
FreeSWITCH
GAMBAS
GDB
GDScript
Game Maker
Genero
Genie
GetText
Go
Godot GLSL
Groovy
GwBasic
HQ9 Plus
HTML
HTML 5
Haskell
Haxe
HicEst
IDL
INI file
INTERCAL
IO
ISPF Panel Definition
Icon
Inno Script
J
JCL
JSON
Java
Java 5
JavaScript
Julia
KSP (Kontakt Script)
KiXtart
Kotlin
LDIF
LLVM
LOL Code
LScript
Latex
Liberty BASIC
Linden Scripting
Lisp
Loco Basic
Logtalk
Lotus Formulas
Lotus Script
Lua
M68000 Assembler
MIX Assembler
MK-61/52
MPASM
MXML
MagikSF
Make
MapBasic
Markdown (PRO members only)
MatLab
Mercury
MetaPost
Modula 2
Modula 3
Motorola 68000 HiSoft Dev
MySQL
Nagios
NetRexx
Nginx
Nim
NullSoft Installer
OCaml
OCaml Brief
Oberon 2
Objeck Programming Langua
Objective C
Octave
Open Object Rexx
OpenBSD PACKET FILTER
OpenGL Shading
Openoffice BASIC
Oracle 11
Oracle 8
Oz
PARI/GP
PCRE
PHP
PHP Brief
PL/I
PL/SQL
POV-Ray
ParaSail
Pascal
Pawn
Per
Perl
Perl 6
Phix
Pic 16
Pike
Pixel Bender
PostScript
PostgreSQL
PowerBuilder
PowerShell
ProFTPd
Progress
Prolog
Properties
ProvideX
Puppet
PureBasic
PyCon
Python
Python for S60
QBasic
QML
R
RBScript
REBOL
REG
RPM Spec
Racket
Rails
Rexx
Robots
Roff Manpage
Ruby
Ruby Gnuplot
Rust
SAS
SCL
SPARK
SPARQL
SQF
SQL
SSH Config
Scala
Scheme
Scilab
SdlBasic
Smalltalk
Smarty
StandardML
StoneScript
SuperCollider
Swift
SystemVerilog
T-SQL
TCL
TeXgraph
Tera Term
TypeScript
TypoScript
UPC
Unicon
UnrealScript
Urbi
VB.NET
VBScript
VHDL
VIM
Vala
Vedit
VeriLog
Visual Pro Log
VisualBasic
VisualFoxPro
WHOIS
WhiteSpace
Winbatch
XBasic
XML
XPP
Xojo
Xorg Config
YAML
YARA
Z80 Assembler
ZXBasic
autoconf
jQuery
mIRC
newLISP
q/kdb+
thinBasic
Paste Expiration:
Never
Burn after read
10 Minutes
1 Hour
1 Day
1 Week
2 Weeks
1 Month
6 Months
1 Year
Paste Exposure:
Public
Unlisted
Private
Folder:
(members only)
Password
NEW
Enabled
Disabled
Burn after read
NEW
Paste Name / Title:
Create New Paste
Hello
Guest
Sign Up
or
Login
Sign in with Facebook
Sign in with Twitter
Sign in with Google
You are currently not logged in, this means you can not edit or delete anything you paste.
Sign Up
or
Login
Public Pastes
memtest 86 logg
1 hour ago | 229.83 KB
TLOZ Windwaker - Windfall Island - Virtual Pi...
4 hours ago | 1.57 KB
squar
10 hours ago | 0.10 KB
my-pus
11 hours ago | 0.09 KB
OoT rando seed 6/18
18 hours ago | 69.75 KB
Peter Thiel Dialog Society
20 hours ago | 23.72 KB
other seps
CSS | 21 hours ago | 0.15 KB
Check socradar.io for your FortiGate
PowerShell | 23 hours ago | 2.33 KB
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the
Cookies Policy
.
OK, I Understand
Not a member of Pastebin yet?
Sign Up
, it unlocks many cool features!