Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": []
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "gpuClass": "standard",
- "accelerator": "GPU"
- },
- "cells": [
- {
- "cell_type": "markdown",
- "source": [
- "Specify model and group size, as well as whether to save the model to drive (by default, `/drive/My Drive/neko_models/`) when finished and load the model from drive on subsequent runs. If you don't have enough drive space to download the model or you're running this somewhere without drive compatibility, leave it off.\n",
- "\n",
- "A https://huggingface.co/ URL to a quantized model must be provided (path to the actual checkpoint (so the URL to the safetensors, pt, or bin in the repository must be provided. It also must be compatible with GPTQ-for-Llama Triton branch, i.e., not CUDA, otherwise it won't work)."
- ],
- "metadata": {
- "id": "hxMbk3dZiy3x"
- }
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "blPQFwhrRsFE"
- },
- "outputs": [],
- "source": [
- "use_drive = True #@param {type:\"boolean\"}\n",
- "\n",
- "num_params = \"13b\" #@param [\"7b\", \"13b\", \"30b\", \"65b\"]\n",
- "\n",
- "group_size = \"128\" #@param [\"32\", \"128\"]\n",
- "group_size = int(group_size)\n",
- "\n",
- "huggingface_url = '' #@param {type:\"string\"}\n",
- "\n",
- "group_and_model = '/'.join(huggingface_url[len('https://huggingface.co/'):].split('/')[:2])"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "Download repository, install necessary Python packages, retrieve model checkpoint."
- ],
- "metadata": {
- "id": "0w7eC5-mkoN6"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#if use_runpod:\n",
- "# !apt-get -y update\n",
- "# !apt-get -y install python3.10-dev\n",
- "# !python -m pip install --upgrade pip\n",
- "\n",
- "!git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git\n",
- "%cd 'GPTQ-for-LLaMa'\n",
- "!git checkout 29c0e4a4552c26a8a18abff305504d2cc86be539\n",
- "!pip install -r requirements.txt\n",
- "%cd ..\n",
- "\n",
- "import os\n",
- "import sys\n",
- "\n",
- "weights_name = os.path.basename(huggingface_url)\n",
- "drive_weights_path = os.path.join('drive/My Drive/neko_models/{}'.format(weights_name))\n",
- "\n",
- "if use_drive:\n",
- " # Connect to drive.\n",
- " from google.colab import drive\n",
- " drive.mount('/content/drive')\n",
- " weights_path = drive_weights_path\n",
- "else:\n",
- " weights_path = weights_name\n",
- "\n",
- "if not use_drive or not os.path.exists(drive_weights_path):\n",
- " # Download model, save to drive if flag is set.\n",
- " !wget {huggingface_url}\n",
- " if use_drive:\n",
- " if not os.path.exists('drive/My Drive/neko_models/'):\n",
- " os.mkdir('drive/My Drive/neko_models/')\n",
- " !cp {weights_name} \"{drive_weights_path}\"\n",
- "\n",
- "sys.path.insert(0, 'GPTQ-for-LLaMa/')"
- ],
- "metadata": {
- "id": "MYKNE67ciiqc"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "Load model."
- ],
- "metadata": {
- "id": "h-yldfa3eFTR"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "import time\n",
- "\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "\n",
- "from gptq import *\n",
- "from modelutils import *\n",
- "from quant import *\n",
- "\n",
- "from transformers import AutoTokenizer\n",
- "\n",
- "DEV = torch.device('cuda')\n",
- "\n",
- "def load_quant(model, checkpoint, wbits, groupsize, device):\n",
- " from transformers import LlamaConfig, LlamaForCausalLM \n",
- " config = LlamaConfig.from_pretrained(model)\n",
- " def noop(*args, **kwargs):\n",
- " pass\n",
- " torch.nn.init.kaiming_uniform_ = noop \n",
- " torch.nn.init.uniform_ = noop \n",
- " torch.nn.init.normal_ = noop \n",
- "\n",
- " torch.set_default_dtype(torch.half)\n",
- " transformers.modeling_utils._init_weights = False\n",
- " torch.set_default_dtype(torch.half)\n",
- " model = LlamaForCausalLM(config)\n",
- " torch.set_default_dtype(torch.float)\n",
- " model = model.eval()\n",
- " layers = find_layers(model)\n",
- " for name in ['lm_head']:\n",
- " if name in layers:\n",
- " del layers[name]\n",
- " make_quant(model, layers, wbits, groupsize)\n",
- "\n",
- " print('Loading model ...')\n",
- " if checkpoint.endswith('.safetensors'):\n",
- " from safetensors import safe_open\n",
- " with safe_open(checkpoint, framework=\"pt\", device=0) as f:\n",
- " for k in f.keys():\n",
- " layer = {}\n",
- " layer[k] = f.get_tensor(k)\n",
- " model.load_state_dict(layer, strict=False)\n",
- " del(layer)\n",
- " else:\n",
- " model.load_state_dict(torch.load(checkpoint))\n",
- " print('Done.')\n",
- "\n",
- " return model\n",
- "\n",
- "model = load_quant(group_and_model, weights_path, 4, group_size, DEV)\n",
- "model.to(DEV)\n",
- "tokenizer = AutoTokenizer.from_pretrained(group_and_model, use_fast=False)"
- ],
- "metadata": {
- "id": "v7NFb4Q5fR-J"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "Define generation methods."
- ],
- "metadata": {
- "id": "Q_G_lmDseJat"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "import torch\n",
- "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
- "\n",
- "def stm_next_tokens(model, tokenizer, tokenized, context_len, max_gen_len,\n",
- " mask_id, temperature=0.8, top_p=0.95, tfs=1.0, typical=1.0,\n",
- " penalty_range=1024, penalty_slope=0.7, penalty=1.1,\n",
- " past_key_values=None, return_past_key_values=True):\n",
- " #tokenized = tokenizer.encode(inp, return_tensors='pt').to(DEV)\n",
- " total_len = min(context_len, max_gen_len + tokenized.shape[1])\n",
- "\n",
- " tokens = torch.full((1, total_len), mask_id).to(DEV)\n",
- " tokens[0, :tokenized.shape[1]] = tokenized[0]\n",
- "\n",
- " output_past_key_values = None\n",
- "\n",
- " if past_key_values and return_past_key_values:\n",
- " output_past_key_values = past_key_values\n",
- "\n",
- " for cur_id in range(tokenized.shape[1], total_len):\n",
- " #print(cur_id - tokenized.shape[1])\n",
- " if past_key_values:\n",
- " output = model(tokens[:, cur_id-1:cur_id], past_key_values=past_key_values, use_cache=True)\n",
- " else:\n",
- " output = model(tokens[:, :cur_id], use_cache=True)\n",
- "\n",
- " if not past_key_values:\n",
- " logits = output.logits[:, cur_id-1, :]\n",
- " if return_past_key_values:\n",
- " output_past_key_values = output.past_key_values\n",
- " else:\n",
- " logits = output.logits[:, 0, :]\n",
- " \n",
- " past_key_values = output.past_key_values\n",
- " input_ids = tokens[:, cur_id-1:cur_id]\n",
- "\n",
- " # Apply samplers - do greedy sampling if temperature is 0.\n",
- " if temperature > 0:\n",
- " next_token_scores = sample_top_p_actual(input_ids, logits,\n",
- " top_p)\n",
- " next_token_scores = sample_tail_free(input_ids,\n",
- " next_token_scores, tfs)\n",
- " next_token_scores = sample_typical(input_ids, next_token_scores,\n",
- " typical)\n",
- " next_token_scores = sample_temperature(input_ids,\n",
- " next_token_scores,\n",
- " temperature)\n",
- " next_token_scores = sample_advanced_repetition_penalty(input_ids,\n",
- " next_token_scores,\n",
- " penalty_range,\n",
- " penalty_slope,\n",
- " penalty)\n",
- "\n",
- " next_token_scores = torch.nn.functional.softmax(next_token_scores,\n",
- " dim=-1)\n",
- "\n",
- " next_token = torch.multinomial(next_token_scores,\n",
- " num_samples=1).squeeze(1)\n",
- " else:\n",
- " next_token = torch.argmax(logits, axis=-1)[0]\n",
- "\n",
- " tokens[0, cur_id] = next_token\n",
- " yield next_token, None\n",
- "\n",
- " if next_token.item() == tokenizer.eos_token_id:\n",
- " yield None, output_past_key_values\n",
- " return\n",
- " \n",
- " yield None, output_past_key_values\n",
- " return\n",
- "\n",
- "# taken from Kobold and transformers so this stuff is AGPL I guess\n",
- "def sample_temperature(input_ids, scores, tempt):\n",
- " scores = scores / tempt\n",
- " return scores\n",
- "\n",
- "def sample_typical(input_ids, scores, typical, filter_value = -float(\"Inf\"),\n",
- " min_tokens_to_keep = 1):\n",
- " if filter_value >= 1.0:\n",
- " return scores\n",
- "\n",
- " probs = scores.softmax(dim=-1)\n",
- " log_probs = probs.log()\n",
- "\n",
- " neg_entropy = (probs * log_probs).nansum(dim=-1, keepdim=True)\n",
- "\n",
- " entropy_deviation = (neg_entropy - log_probs).abs()\n",
- "\n",
- " _, sorted_indices = torch.sort(entropy_deviation)\n",
- " sorted_logits = probs.gather(-1, sorted_indices)\n",
- " sorted_indices_to_remove = sorted_logits.cumsum(dim=-1) >= typical\n",
- " sorted_indices_to_remove = sorted_indices_to_remove.roll(1, dims=-1)\n",
- "\n",
- " min_tokens_to_keep = max(min_tokens_to_keep, 1)\n",
- " # Keep at least min_tokens_to_keep\n",
- " sorted_indices_to_remove[..., : min_tokens_to_keep] = 0\n",
- "\n",
- " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)\n",
- " scores = scores.masked_fill(indices_to_remove, filter_value)\n",
- " return scores \n",
- "\n",
- "def sample_top_p_actual(input_ids, scores, top_p, filter_value = -float(\"Inf\"),\n",
- " min_tokens_to_keep = 1):\n",
- " sorted_logits, sorted_indices = torch.sort(scores, descending=False)\n",
- " cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)\n",
- "\n",
- " # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)\n",
- " sorted_indices_to_remove = cumulative_probs <= (1 - top_p)\n",
- " if min_tokens_to_keep > 1:\n",
- " # Keep at least min_tokens_to_keep\n",
- " sorted_indices_to_remove[..., -min_tokens_to_keep :] = 0\n",
- "\n",
- " # scatter sorted tensors to original indexing\n",
- " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,\n",
- " sorted_indices_to_remove)\n",
- " scores = scores.masked_fill(indices_to_remove, filter_value)\n",
- " return scores\n",
- "\n",
- "def sample_advanced_repetition_penalty(input_ids, scores, penalty_range,\n",
- " penalty_slope, penalty):\n",
- " penalty_range = int(penalty_range)\n",
- " clipped_penalty_range = min(input_ids.shape[-1], penalty_range)\n",
- "\n",
- " if penalty != 1.0:\n",
- " if penalty_range > 0:\n",
- " if clipped_penalty_range < input_ids.shape[1]:\n",
- " input_ids = input_ids[..., -clipped_penalty_range:]\n",
- "\n",
- " if penalty_slope != 0:\n",
- " _penalty = (torch.arange(penalty_range, dtype=scores.dtype,\n",
- " device=scores.device)/(penalty_range - 1)) * 2. - 1\n",
- " _penalty = (penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (penalty_slope - 1))\n",
- " _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)\n",
- " penalty = _penalty[..., -clipped_penalty_range:]\n",
- "\n",
- " score = torch.gather(scores, 1, input_ids)\n",
- " score = torch.where(score <= 0, score * penalty, score / penalty)\n",
- " scores.scatter_(1, input_ids, score)\n",
- "\n",
- " return scores \n",
- "\n",
- "def sample_top_a(input_ids, scores, top_a, filter_value = -float(\"Inf\"),\n",
- " min_tokens_to_keep = 1):\n",
- " if filter_value >= 1.0:\n",
- " return scores\n",
- "\n",
- " sorted_logits, sorted_indices = torch.sort(scores, descending=True)\n",
- " probs = sorted_logits.softmax(dim=-1)\n",
- "\n",
- " # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)\n",
- " probs_max = probs[..., 0, None]\n",
- " sorted_indices_to_remove = probs < probs_max * probs_max * top_a\n",
- "\n",
- " if min_tokens_to_keep > 1:\n",
- " # Keep at least min_tokens_to_keep\n",
- " sorted_indices_to_remove[..., : min_tokens_to_keep] = 0\n",
- "\n",
- " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,\n",
- " sorted_indices_to_remove)\n",
- " scores = scores.masked_fill(indices_to_remove, filter_value)\n",
- " return scores \n",
- "\n",
- "def sample_tail_free(input_ids, scores, tfs, filter_value = -float(\"Inf\"),\n",
- " min_tokens_to_keep = 1):\n",
- " if filter_value >= 1.0:\n",
- " return scores\n",
- " sorted_logits, sorted_indices = torch.sort(scores, descending=True)\n",
- " probs = sorted_logits.softmax(dim=-1)\n",
- "\n",
- " # Compute second derivative normalized CDF\n",
- " d2 = probs.diff().diff().abs()\n",
- " normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)\n",
- " normalized_d2_cdf = normalized_d2.cumsum(dim=-1)\n",
- "\n",
- " # Remove tokens with CDF value above the threshold (token with 0 are kept)\n",
- " sorted_indices_to_remove = normalized_d2_cdf > tfs\n",
- "\n",
- " # Centre the distribution around the cutoff as in the original implementation of the algorithm\n",
- " sorted_indices_to_remove = torch.cat(\n",
- " (\n",
- " torch.zeros(scores.shape[0], 1, dtype=torch.bool,\n",
- " device=scores.device),\n",
- " sorted_indices_to_remove,\n",
- " torch.ones(scores.shape[0], 1, dtype=torch.bool,\n",
- " device=scores.device),\n",
- " ),\n",
- " dim=-1,\n",
- " )\n",
- "\n",
- " if min_tokens_to_keep > 1:\n",
- " # Keep at least min_tokens_to_keep\n",
- " sorted_indices_to_remove[..., : min_tokens_to_keep] = 0\n",
- "\n",
- " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices,\n",
- " sorted_indices_to_remove)\n",
- " scores = scores.masked_fill(indices_to_remove, filter_value)\n",
- " return scores"
- ],
- "metadata": {
- "id": "1mmAK2zhi6Ak"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "Main GUI. Includes sampler parameters, memory, and lorebook.\n",
- "\n",
- "Lorebook tip: multiple keys can be associated with an entry by separating them via commas."
- ],
- "metadata": {
- "id": "KZcyrZdreNGX"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "import ipywidgets as widgets\n",
- "from IPython.display import display\n",
- "import time\n",
- "from enum import Enum\n",
- "\n",
- "context_size = 2048 #@param {type:\"number\"}\n",
- "max_gen_len = 160 #@param {type:\"number\"}\n",
- "temperature = 1.0 #@param {type:\"number\"}\n",
- "top_p = 0.95 #@param {type:\"number\"}\n",
- "tfs = 1.0 #@param {type:\"number\"}\n",
- "typical = 1.0 #@param {type:\"number\"}\n",
- "penalty_range = 2048 #@param {type:\"number\"}\n",
- "penalty_slope = 0.7 #@param {type:\"number\"}\n",
- "penalty = 1.1 #@param {type:\"number\"}\n",
- "lorebook_token_search_range = 1000 #@param {type:\"number\"}\n",
- "kv_cache_size = 0 #@param {type:\"number\"}\n",
- "\n",
- "input_text_area = widgets.Textarea(placeholder='Enter a prompt...',\n",
- " layout=widgets.Layout(width='900px',\n",
- " height='600px'))\n",
- "model.seqlen = context_size\n",
- "send_button = widgets.Button(description='Send')\n",
- "undo_button = widgets.Button(description='Undo')\n",
- "redo_button = widgets.Button(description='Redo')\n",
- "retry_button = widgets.Button(description='Retry')\n",
- "prev_retry_button = widgets.Button(description='Previous Retry')\n",
- "memory_button = widgets.ToggleButton(description='Memory')\n",
- "context_button = widgets.ToggleButton(description='Context')\n",
- "\n",
- "side_panel = widgets.VBox([send_button, undo_button, redo_button,\n",
- " retry_button, prev_retry_button,\n",
- " memory_button, context_button])\n",
- "\n",
- "lorebook_dropdown = widgets.Dropdown(options=['(Empty)'],\n",
- " layout=widgets.Layout(width='200px'))\n",
- "lorebook_add_button = widgets.Button(description='Add New')\n",
- "lorebook_remove_button = widgets.Button(description='Remove')\n",
- "lorebook_apply_changes_button = widgets.Button(description='Apply Changes')\n",
- "lorebook_key_field = widgets.Text(description='Key')\n",
- "lorebook_value_field = widgets.Textarea(description='Value',\n",
- " layout=widgets.Layout(height='100px'))\n",
- "\n",
- "lorebook_panel = widgets.VBox([widgets.HBox([lorebook_dropdown,\n",
- " lorebook_key_field,\n",
- " lorebook_value_field]),\n",
- " widgets.HBox([lorebook_add_button,\n",
- " lorebook_remove_button,\n",
- " lorebook_apply_changes_button])])\n",
- "\n",
- "lorebook_accordion = widgets.Accordion(children=[lorebook_panel],\n",
- " titles=('Lorebook',),\n",
- " selected_index=None)\n",
- "lorebook_accordion.set_title(0, 'Lorebook')\n",
- "\n",
- "main_panel = widgets.HBox([widgets.VBox([input_text_area, lorebook_accordion]),\n",
- " side_panel])\n",
- "output = widgets.Output()\n",
- "\n",
- "Mode = Enum('Mode', ['INPUT', 'MEMORY', 'GENERATING', 'CONTEXT'])\n",
- "\n",
- "class State:\n",
- " def __init__(self, pos, mode):\n",
- " self.pos = pos\n",
- " self.mode = mode\n",
- " self.mem = ''\n",
- " self.saved_input = ''\n",
- " self.kv_dict = {}\n",
- " self.kv_queue = []\n",
- " self.lorebook_dict = {}\n",
- " self.lorebook_entry_index = 0\n",
- " self.lorebook_prev_key = ''\n",
- " self.lorebook_prev_value = ''\n",
- " \n",
- " def get_kv(self, id):\n",
- " if id in self.kv_dict:\n",
- " return self.kv_dict[id]\n",
- " else:\n",
- " return None\n",
- " \n",
- " def delete_kv(self, id):\n",
- " if id in self.kv_dict:\n",
- " self.kv_queue.remove(id)\n",
- " del self.kv_dict[id]\n",
- "\n",
- " def add_kv(self, id, kv):\n",
- " if id in self.kv_dict:\n",
- " self.kv_queue.remove(id)\n",
- " self.kv_queue.insert(0, id)\n",
- " elif self.get_len_kv() < kv_cache_size:\n",
- " self.kv_queue.insert(0, id)\n",
- " self.kv_dict[id] = kv\n",
- " \n",
- " def remove_last_kv(self):\n",
- " if self.kv_queue:\n",
- " old_id = self.kv_queue.pop()\n",
- " del self.kv_dict[old_id]\n",
- " \n",
- " def get_len_kv(self):\n",
- " return len(self.kv_queue)\n",
- " \n",
- " def clear_kv(self):\n",
- " for id in self.kv_queue:\n",
- " del self.kv_dict[id]\n",
- " self.kv_queue = []\n",
- "\n",
- "class Position:\n",
- " cur_id = 0\n",
- " def __init__(self):\n",
- " self.id = Position.cur_id\n",
- " self.pred = None\n",
- " self.succs = []\n",
- " self.succ_idx = -1\n",
- " self.text = ''\n",
- " Position.cur_id += 1\n",
- "\n",
- "init_pos = Position()\n",
- "cur_state = State(init_pos, Mode.INPUT)\n",
- "\n",
- "def build_context():\n",
- " # When creating the context, first, place the full memory followed by a\n",
- " # newline.\n",
- " #\n",
- " # Next, taking the last (max_seq_len-1-max_gen_len-len(mem)) tokens,\n",
- " # place these tokens in the context.\n",
- " \n",
- " inp_tokenized = tokenizer.encode(input_text_area.value, return_tensors='pt', add_special_tokens=False)[0].tolist()\n",
- "\n",
- " # Search for keys in lorebook, append results to memory.\n",
- "\n",
- " lorebook_search_text = tokenizer.decode(inp_tokenized[-lorebook_token_search_range:])\n",
- " lorebook_values = []\n",
- " for key_str, value in cur_state.lorebook_dict.items():\n",
- " if not key_str:\n",
- " continue\n",
- " \n",
- " keys = key_str.split(',')\n",
- "\n",
- " for key in keys:\n",
- " if key.strip() in lorebook_search_text:\n",
- " lorebook_values.append(value)\n",
- " break\n",
- "\n",
- " memory = '\\n'.join([cur_state.mem, *lorebook_values])\n",
- "\n",
- " if memory:\n",
- " mem_tokenized = tokenizer.encode(memory + '\\n', return_tensors='pt', add_special_tokens=False)[0].tolist()\n",
- " else:\n",
- " mem_tokenized = []\n",
- "\n",
- " num_inp_tokens = max(model.seqlen-1-max_gen_len-len(mem_tokenized), 0)\n",
- "\n",
- " if num_inp_tokens > 0:\n",
- " tokenized = mem_tokenized + inp_tokenized[-num_inp_tokens:]\n",
- " elif len(mem_tokenized) > 0:\n",
- " num_mem_tokens = model.seqlen-1-max_gen_len\n",
- " tokenized = mem_tokenized[-num_mem_tokens:]\n",
- " else:\n",
- " tokenized = []\n",
- "\n",
- " tokenized.insert(0, tokenizer.bos_token_id)\n",
- " detokenized = tokenizer.decode(tokenized)\n",
- " return detokenized\n",
- "\n",
- "def generate():\n",
- " # Create the context and send it to the model, update the text area.\n",
- " \n",
- " gen_context = build_context()\n",
- " retokenized = tokenizer.encode(gen_context, return_tensors='pt').to(DEV)\n",
- " prev_num_tokens = len(retokenized[0])\n",
- "\n",
- " output = ''\n",
- " past_key_values = None\n",
- " num_characters = 0\n",
- "\n",
- " with torch.no_grad():\n",
- " out_tokens = retokenized[0].tolist()\n",
- " gen = stm_next_tokens(model, tokenizer, retokenized, model.seqlen,\n",
- " max_gen_len, 1, temperature=temperature, top_p=top_p, tfs=tfs,\n",
- " typical=typical, penalty_range=penalty_range,\n",
- " penalty_slope=penalty_slope, penalty=penalty,\n",
- " past_key_values=cur_state.get_kv(cur_state.pos.id),\n",
- " return_past_key_values=kv_cache_size != 0)\n",
- " for tkn, pkv in gen:\n",
- " if tkn is None:\n",
- " past_key_values = pkv\n",
- " else:\n",
- " out_tokens.append(tkn.item())\n",
- " output = tokenizer.decode(out_tokens, skip_special_tokens=True)\n",
- " num_characters = len(output) - len(gen_context)\n",
- " input_text_area.value = cur_state.pos.text + output[-num_characters:]\n",
- " torch.cuda.empty_cache()\n",
- " return output[-num_characters:], past_key_values\n",
- "\n",
- "def on_update_input_text_area(change):\n",
- " # Input mode: Destroy all successors in the node list.\n",
- " #\n",
- " # Memory mode: n/a.\n",
- " #\n",
- " # Action allowed criterion: state.mode == 'input' or state.mode == 'memory'.\n",
- "\n",
- " if cur_state.mode == Mode.INPUT and (cur_state.pos.succs or cur_state.get_kv(cur_state.pos.id)) and cur_state.pos.text != input_text_area.value:\n",
- " if cur_state.pos.succs:\n",
- " del cur_state.pos.succs\n",
- " cur_state.pos.succs = []\n",
- " cur_state.pos.succ_idx = -1\n",
- " update_buttons_visible()\n",
- " if cur_state.get_kv(cur_state.pos.id):\n",
- " cur_state.delete_kv(cur_state.pos.id)\n",
- "\n",
- "def send():\n",
- " cur_state.pos.text = input_text_area.value\n",
- " cur_state.mode = Mode.GENERATING\n",
- " update_buttons_visible()\n",
- "\n",
- " if cur_state.get_len_kv() == kv_cache_size and not cur_state.get_kv(cur_state.pos.id):\n",
- " cur_state.remove_last_kv()\n",
- "\n",
- " generation, past_key_values = generate()\n",
- "\n",
- " new_succ = Position()\n",
- " new_succ.pred = cur_state.pos\n",
- " #new_succ.text = input_text_area.value + generation\n",
- " new_succ.text = input_text_area.value\n",
- " cur_state.pos.succs.append(new_succ)\n",
- " cur_state.pos.succ_idx = len(cur_state.pos.succs) - 1\n",
- " if past_key_values is not None:\n",
- " cur_state.add_kv(cur_state.pos.id, past_key_values)\n",
- " \n",
- " jump_to(new_succ)\n",
- "\n",
- " cur_state.mode = Mode.INPUT\n",
- " update_buttons_visible()\n",
- "\n",
- "def send_button_clicked(b):\n",
- " # Set text in current node to whatever is in the input area, generate text\n",
- " # (setting mode to 'generating' in the meantime), create a new successor at\n",
- " # head of list with text, set successor position to it, jump to it.\n",
- " #\n",
- " # Action allowed criterion: state.mode == 'input'.\n",
- "\n",
- " send()\n",
- " \n",
- "def undo_button_clicked(b):\n",
- " # Jump to predecessor.\n",
- " #\n",
- " # Action allowed criterion: state.mode == 'input', state.predecessor != nil.\n",
- "\n",
- " jump_to(cur_state.pos.pred)\n",
- "\n",
- "def redo_button_clicked(b):\n",
- " # Jump to current successor.\n",
- " #\n",
- " # Action allowed criterion: state.mode == 'input', state.successor_list !=\n",
- " # nil.\n",
- "\n",
- " jump_to(cur_state.pos.succs[cur_state.pos.succ_idx])\n",
- "\n",
- "def retry_button_clicked(b):\n",
- " # Jump to predecessor, then set successor position to next in the list if\n",
- " # it exists and jump to it, otherwise send_button_clicked().\n",
- " #\n",
- " # Action allowed criterion: state.mode == 'input', state.predecessor != nil.\n",
- "\n",
- " jump_to(cur_state.pos.pred)\n",
- "\n",
- " if cur_state.pos.succ_idx < len(cur_state.pos.succs) - 1:\n",
- " cur_state.pos.succ_idx += 1\n",
- " jump_to(cur_state.pos.succs[cur_state.pos.succ_idx])\n",
- " else:\n",
- " send()\n",
- "\n",
- "def prev_retry_button_clicked(b):\n",
- " # Jump to predecessor, then set successor position to prev in the list and\n",
- " # jump to it.\n",
- " #\n",
- " # Action allowed criterion: state.mode == 'input', state.predecessor != nil,\n",
- " # state.predecessor.succ_idx > 0.\n",
- "\n",
- " jump_to(cur_state.pos.pred)\n",
- " cur_state.pos.succ_idx -= 1\n",
- " jump_to(cur_state.pos.succs[cur_state.pos.succ_idx])\n",
- "\n",
- "def memory_button_clicked(b):\n",
- " # Input mode: switch modes to 'memory', save current state.\n",
- " #\n",
- " # Memory mode: switch modes to 'input', save memory, restore current state.\n",
- " #\n",
- " # Action allowed criterion: state.mode == 'input' or state.mode == 'memory'.\n",
- "\n",
- " if cur_state.mode == Mode.INPUT:\n",
- " cur_state.mode = Mode.MEMORY\n",
- " cur_state.saved_input = input_text_area.value\n",
- " input_text_area.value = cur_state.mem\n",
- " update_buttons_visible()\n",
- " elif cur_state.mode == Mode.MEMORY:\n",
- " if cur_state.mem != input_text_area.value:\n",
- " cur_state.clear_kv()\n",
- " cur_state.mode = Mode.INPUT\n",
- " cur_state.mem = input_text_area.value\n",
- " input_text_area.value = cur_state.saved_input\n",
- " update_buttons_visible()\n",
- "\n",
- "def context_button_clicked(b):\n",
- " # Input mode: switch modes to 'context', save current state.\n",
- " #\n",
- " # Context mode: switch mode to 'input', restore current state.\n",
- " #\n",
- " # Action allowed criterion: state.mode == 'input' or state.mode == 'context'.\n",
- "\n",
- " if cur_state.mode == Mode.INPUT:\n",
- " cur_state.mode = Mode.CONTEXT\n",
- " cur_state.saved_input = input_text_area.value\n",
- " input_text_area.value = build_context()\n",
- " update_buttons_visible()\n",
- " elif cur_state.mode == Mode.CONTEXT:\n",
- " cur_state.mode = Mode.INPUT\n",
- " input_text_area.value = cur_state.saved_input\n",
- " update_buttons_visible()\n",
- "\n",
- "def jump_to(pos):\n",
- " cur_state.pos = pos\n",
- " input_text_area.value = pos.text\n",
- " update_buttons_visible()\n",
- "\n",
- "def update_buttons_visible():\n",
- " send_button.disabled = cur_state.mode != Mode.INPUT\n",
- " undo_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.pred\n",
- " redo_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.succs\n",
- " retry_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.pred\n",
- " prev_retry_button.disabled = cur_state.mode != Mode.INPUT or not cur_state.pos.pred or not cur_state.pos.pred.succ_idx > 0\n",
- " memory_button.disabled = cur_state.mode != Mode.INPUT and cur_state.mode != Mode.MEMORY\n",
- " context_button.disabled = cur_state.mode != Mode.INPUT and cur_state.mode != Mode.CONTEXT\n",
- " input_text_area.disabled = cur_state.mode == Mode.GENERATING or cur_state.mode == Mode.CONTEXT\n",
- " lorebook_key_field.disabled = cur_state.mode == Mode.GENERATING or len(cur_state.lorebook_dict) == 0\n",
- " lorebook_value_field.disabled = cur_state.mode == Mode.GENERATING or len(cur_state.lorebook_dict) == 0\n",
- " lorebook_add_button.disabled = cur_state.mode == Mode.GENERATING\n",
- " lorebook_remove_button.disabled = cur_state.mode == Mode.GENERATING or len(cur_state.lorebook_dict) == 0\n",
- " lorebook_apply_changes_button.disabled = cur_state.mode == Mode.GENERATING or len(cur_state.lorebook_dict) == 0 or (lorebook_key_field.value == cur_state.lorebook_prev_key and lorebook_value_field.value == cur_state.lorebook_prev_value)\n",
- " lorebook_dropdown.disabled = cur_state.mode == Mode.GENERATING\n",
- "\n",
- "def apply_changes_to_entry(set_prev_key_to_cur):\n",
- " lorebook_cur_key = lorebook_key_field.value\n",
- " lorebook_cur_value = lorebook_value_field.value\n",
- " if not cur_state.lorebook_prev_key or (cur_state.lorebook_prev_key == lorebook_cur_key and cur_state.lorebook_prev_value == lorebook_cur_value) or cur_state.lorebook_prev_key not in cur_state.lorebook_dict:\n",
- " return\n",
- "\n",
- " if cur_state.lorebook_prev_key != lorebook_cur_key:\n",
- " del cur_state.lorebook_dict[cur_state.lorebook_prev_key]\n",
- " cur_state.lorebook_dict[lorebook_cur_key] = lorebook_cur_value\n",
- " new_options = list(lorebook_dropdown.options)\n",
- " lorebook_key_idx = new_options.index(cur_state.lorebook_prev_key)\n",
- " new_options[lorebook_key_idx] = lorebook_cur_key\n",
- " lorebook_dropdown.options = tuple(new_options)\n",
- " lorebook_dropdown.value = lorebook_cur_key\n",
- " elif cur_state.lorebook_prev_value != lorebook_cur_value:\n",
- " cur_state.lorebook_dict[lorebook_cur_key] = lorebook_cur_value\n",
- " \n",
- " if set_prev_key_to_cur:\n",
- " cur_state.lorebook_prev_key = lorebook_cur_key\n",
- " cur_state.lorebook_prev_value = lorebook_cur_value\n",
- "\n",
- "def add_new_entry():\n",
- " new_key = 'Key{}'.format(cur_state.lorebook_entry_index)\n",
- " new_val = 'Value{}'.format(cur_state.lorebook_entry_index)\n",
- " cur_state.lorebook_dict[new_key] = new_val\n",
- " cur_state.lorebook_entry_index += 1\n",
- "\n",
- " if len(cur_state.lorebook_dict) == 1:\n",
- " lorebook_dropdown.options = (new_key,)\n",
- " else:\n",
- " new_options = list(lorebook_dropdown.options)\n",
- " new_options.append(new_key)\n",
- " lorebook_dropdown.options = tuple(new_options)\n",
- " lorebook_dropdown.value = new_key\n",
- "\n",
- "def remove_entry():\n",
- " removed_key = cur_state.lorebook_prev_key\n",
- "\n",
- " del cur_state.lorebook_dict[cur_state.lorebook_prev_key]\n",
- " cur_state.lorebook_prev_key = ''\n",
- " cur_state.lorebook_prev_val = ''\n",
- "\n",
- " if len(cur_state.lorebook_dict) == 0:\n",
- " lorebook_dropdown.options = ('(Empty)',)\n",
- " else:\n",
- " new_options = list(lorebook_dropdown.options)\n",
- " removed_idx = new_options.index(removed_key)\n",
- " new_options.remove(removed_key)\n",
- " lorebook_dropdown.options = tuple(new_options)\n",
- " lorebook_dropdown.value = lorebook_dropdown.options[max(removed_idx-1, 0)]\n",
- "\n",
- "def lorebook_add_clicked(b):\n",
- " add_new_entry()\n",
- " update_buttons_visible()\n",
- " cur_state.clear_kv()\n",
- "\n",
- "def lorebook_remove_clicked(b):\n",
- " remove_entry()\n",
- " update_buttons_visible()\n",
- " cur_state.clear_kv()\n",
- "\n",
- "def lorebook_apply_changes_clicked(b):\n",
- " apply_changes_to_entry(True)\n",
- " update_buttons_visible()\n",
- " cur_state.clear_kv()\n",
- "\n",
- "def on_lorebook_dropdown_select(b):\n",
- " apply_changes_to_entry(False)\n",
- " if len(cur_state.lorebook_dict) > 0:\n",
- " lorebook_key_field.value = lorebook_dropdown.value\n",
- " lorebook_value_field.value = cur_state.lorebook_dict[lorebook_key_field.value]\n",
- " cur_state.lorebook_prev_key = lorebook_key_field.value\n",
- " cur_state.lorebook_prev_value = lorebook_value_field.value\n",
- " else:\n",
- " lorebook_key_field.value = ''\n",
- " lorebook_value_field.value = ''\n",
- " update_buttons_visible()\n",
- " cur_state.clear_kv()\n",
- "\n",
- "def on_update_lorebook_key_field(b):\n",
- " if lorebook_apply_changes_button.disabled:\n",
- " lorebook_apply_changes_button.disabled = False\n",
- "\n",
- "def on_update_lorebook_value_field(b):\n",
- " if lorebook_apply_changes_button.disabled:\n",
- " lorebook_apply_changes_button.disabled = False\n",
- "\n",
- "send_button.on_click(send_button_clicked)\n",
- "undo_button.on_click(undo_button_clicked)\n",
- "redo_button.on_click(redo_button_clicked)\n",
- "retry_button.on_click(retry_button_clicked)\n",
- "prev_retry_button.on_click(prev_retry_button_clicked)\n",
- "memory_button.observe(memory_button_clicked, names='value')\n",
- "context_button.observe(context_button_clicked, names='value')\n",
- "input_text_area.observe(on_update_input_text_area, names='value')\n",
- "lorebook_add_button.on_click(lorebook_add_clicked)\n",
- "lorebook_remove_button.on_click(lorebook_remove_clicked)\n",
- "lorebook_apply_changes_button.on_click(lorebook_apply_changes_clicked)\n",
- "lorebook_dropdown.observe(on_lorebook_dropdown_select, names='value')\n",
- "lorebook_key_field.observe(on_update_lorebook_key_field, names='value')\n",
- "lorebook_value_field.observe(on_update_lorebook_value_field, names='value')\n",
- "\n",
- "update_buttons_visible()\n",
- "\n",
- "display(main_panel, output)"
- ],
- "metadata": {
- "id": "JNrRMsUkoX99"
- },
- "execution_count": null,
- "outputs": []
- }
- ]
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement