Advertisement
tastypear

llama bluelm hack

Nov 20th, 2023 (edited)
1,238
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 43.10 KB | None | 0 0
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3.  
  4. import json
  5. import os
  6. import shutil
  7. import struct
  8. import sys
  9. import tempfile
  10. from enum import Enum, IntEnum, auto
  11. from io import BufferedWriter
  12. from pathlib import Path
  13. from typing import IO, Any, BinaryIO, Callable, Sequence
  14.  
  15. import numpy as np
  16.  
  17. #
  18. # constants
  19. #
  20.  
  21. GGUF_MAGIC             = 0x46554747
  22. GGUF_VERSION           = 3
  23. GGUF_DEFAULT_ALIGNMENT = 32
  24.  
  25.  
  26. # general
  27. KEY_GENERAL_ARCHITECTURE         = "general.architecture"
  28. KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version"
  29. KEY_GENERAL_ALIGNMENT            = "general.alignment"
  30. KEY_GENERAL_NAME                 = "general.name"
  31. KEY_GENERAL_AUTHOR               = "general.author"
  32. KEY_GENERAL_URL                  = "general.url"
  33. KEY_GENERAL_DESCRIPTION          = "general.description"
  34. KEY_GENERAL_LICENSE              = "general.license"
  35. KEY_GENERAL_SOURCE_URL           = "general.source.url"
  36. KEY_GENERAL_SOURCE_HF_REPO       = "general.source.huggingface.repository"
  37. KEY_GENERAL_FILE_TYPE            = "general.file_type"
  38.  
  39. # LLM
  40. KEY_CONTEXT_LENGTH        = "{arch}.context_length"
  41. KEY_EMBEDDING_LENGTH      = "{arch}.embedding_length"
  42. KEY_BLOCK_COUNT           = "{arch}.block_count"
  43. KEY_FEED_FORWARD_LENGTH   = "{arch}.feed_forward_length"
  44. KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
  45. KEY_TENSOR_DATA_LAYOUT    = "{arch}.tensor_data_layout"
  46.  
  47. # attention
  48. KEY_ATTENTION_HEAD_COUNT        = "{arch}.attention.head_count"
  49. KEY_ATTENTION_HEAD_COUNT_KV     = "{arch}.attention.head_count_kv"
  50. KEY_ATTENTION_MAX_ALIBI_BIAS    = "{arch}.attention.max_alibi_bias"
  51. KEY_ATTENTION_CLAMP_KQV         = "{arch}.attention.clamp_kqv"
  52. KEY_ATTENTION_LAYERNORM_EPS     = "{arch}.attention.layer_norm_epsilon"
  53. KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
  54.  
  55. # RoPE
  56. KEY_ROPE_DIMENSION_COUNT         = "{arch}.rope.dimension_count"
  57. KEY_ROPE_FREQ_BASE               = "{arch}.rope.freq_base"
  58. KEY_ROPE_SCALING_TYPE            = "{arch}.rope.scaling.type"
  59. KEY_ROPE_SCALING_FACTOR          = "{arch}.rope.scaling.factor"
  60. KEY_ROPE_SCALING_ORIG_CTX_LEN    = "{arch}.rope.scaling.original_context_length"
  61. KEY_ROPE_SCALING_FINETUNED       = "{arch}.rope.scaling.finetuned"
  62.  
  63. # tokenization
  64. KEY_TOKENIZER_MODEL      = "tokenizer.ggml.model"
  65. KEY_TOKENIZER_LIST       = "tokenizer.ggml.tokens"
  66. KEY_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type"
  67. KEY_TOKENIZER_SCORES     = "tokenizer.ggml.scores"
  68. KEY_TOKENIZER_MERGES     = "tokenizer.ggml.merges"
  69. KEY_TOKENIZER_BOS_ID     = "tokenizer.ggml.bos_token_id"
  70. KEY_TOKENIZER_EOS_ID     = "tokenizer.ggml.eos_token_id"
  71. KEY_TOKENIZER_UNK_ID     = "tokenizer.ggml.unknown_token_id"
  72. KEY_TOKENIZER_SEP_ID     = "tokenizer.ggml.seperator_token_id"
  73. KEY_TOKENIZER_PAD_ID     = "tokenizer.ggml.padding_token_id"
  74. KEY_TOKENIZER_HF_JSON    = "tokenizer.huggingface.json"
  75. KEY_TOKENIZER_RWKV       = "tokenizer.rwkv.world"
  76.  
  77.  
  78. #
  79. # recommended mapping of model tensor names for storage in gguf
  80. #
  81.  
  82.  
  83. class MODEL_ARCH(IntEnum):
  84.     LLAMA         : int = auto()
  85.     FALCON        : int = auto()
  86.     BAICHUAN      : int = auto()
  87.     GPT2          : int = auto()
  88.     GPTJ          : int = auto()
  89.     GPTNEOX       : int = auto()
  90.     MPT           : int = auto()
  91.     STARCODER     : int = auto()
  92.     PERSIMMON     : int = auto()
  93.     REFACT        : int = auto()
  94.     BERT          : int = auto()
  95.     BLOOM         : int = auto()
  96.  
  97.  
  98. class MODEL_TENSOR(IntEnum):
  99.     TOKEN_EMBD      : int = auto()
  100.     TOKEN_EMBD_NORM : int = auto()
  101.     TOKEN_TYPES     : int = auto()
  102.     POS_EMBD        : int = auto()
  103.     OUTPUT          : int = auto()
  104.     OUTPUT_NORM     : int = auto()
  105.     ROPE_FREQS      : int = auto()
  106.     ATTN_Q          : int = auto()
  107.     ATTN_K          : int = auto()
  108.     ATTN_V          : int = auto()
  109.     ATTN_QKV        : int = auto()
  110.     ATTN_OUT        : int = auto()
  111.     ATTN_NORM       : int = auto()
  112.     ATTN_NORM_2     : int = auto()
  113.     ATTN_ROT_EMBD   : int = auto()
  114.     FFN_GATE        : int = auto()
  115.     FFN_DOWN        : int = auto()
  116.     FFN_UP          : int = auto()
  117.     FFN_NORM        : int = auto()
  118.     ATTN_Q_NORM     : int = auto()
  119.     ATTN_K_NORM     : int = auto()
  120.  
  121.  
  122. MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
  123.     MODEL_ARCH.LLAMA:          "llama",
  124.     MODEL_ARCH.FALCON:         "falcon",
  125.     MODEL_ARCH.BAICHUAN:       "baichuan",
  126.     MODEL_ARCH.GPT2:           "gpt2",
  127.     MODEL_ARCH.GPTJ:           "gptj",
  128.     MODEL_ARCH.GPTNEOX:        "gptneox",
  129.     MODEL_ARCH.MPT:            "mpt",
  130.     MODEL_ARCH.STARCODER:      "starcoder",
  131.     MODEL_ARCH.PERSIMMON:      "persimmon",
  132.     MODEL_ARCH.REFACT:         "refact",
  133.     MODEL_ARCH.BERT:           "bert",
  134.     MODEL_ARCH.BLOOM:          "bloom",
  135. }
  136.  
  137. TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
  138.     MODEL_TENSOR.TOKEN_EMBD:      "token_embd",
  139.     MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
  140.     MODEL_TENSOR.TOKEN_TYPES:     "token_types",
  141.     MODEL_TENSOR.POS_EMBD:        "position_embd",
  142.     MODEL_TENSOR.OUTPUT_NORM:     "output_norm",
  143.     MODEL_TENSOR.OUTPUT:          "output",
  144.     MODEL_TENSOR.ROPE_FREQS:      "rope_freqs",
  145.     MODEL_TENSOR.ATTN_NORM:       "blk.{bid}.attn_norm",
  146.     MODEL_TENSOR.ATTN_NORM_2:     "blk.{bid}.attn_norm_2",
  147.     MODEL_TENSOR.ATTN_QKV:        "blk.{bid}.attn_qkv",
  148.     MODEL_TENSOR.ATTN_Q:          "blk.{bid}.attn_q",
  149.     MODEL_TENSOR.ATTN_K:          "blk.{bid}.attn_k",
  150.     MODEL_TENSOR.ATTN_V:          "blk.{bid}.attn_v",
  151.     MODEL_TENSOR.ATTN_OUT:        "blk.{bid}.attn_output",
  152.     MODEL_TENSOR.ATTN_ROT_EMBD:   "blk.{bid}.attn_rot_embd",
  153.     MODEL_TENSOR.ATTN_Q_NORM:     "blk.{bid}.attn_q_norm",
  154.     MODEL_TENSOR.ATTN_K_NORM:     "blk.{bid}.attn_k_norm",
  155.     MODEL_TENSOR.FFN_NORM:        "blk.{bid}.ffn_norm",
  156.     MODEL_TENSOR.FFN_GATE:        "blk.{bid}.ffn_gate",
  157.     MODEL_TENSOR.FFN_DOWN:        "blk.{bid}.ffn_down",
  158.     MODEL_TENSOR.FFN_UP:          "blk.{bid}.ffn_up",
  159. }
  160.  
  161. MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
  162.     MODEL_ARCH.LLAMA: [
  163.         MODEL_TENSOR.TOKEN_EMBD,
  164.         MODEL_TENSOR.OUTPUT_NORM,
  165.         MODEL_TENSOR.OUTPUT,
  166.         MODEL_TENSOR.ROPE_FREQS,
  167.         MODEL_TENSOR.ATTN_NORM,
  168.         MODEL_TENSOR.ATTN_Q,
  169.         MODEL_TENSOR.ATTN_K,
  170.         MODEL_TENSOR.ATTN_V,
  171.         MODEL_TENSOR.ATTN_OUT,
  172.         MODEL_TENSOR.ATTN_ROT_EMBD,
  173.         MODEL_TENSOR.FFN_NORM,
  174.         MODEL_TENSOR.FFN_GATE,
  175.         MODEL_TENSOR.FFN_DOWN,
  176.         MODEL_TENSOR.FFN_UP,
  177.     ],
  178.     MODEL_ARCH.GPTNEOX: [
  179.         MODEL_TENSOR.TOKEN_EMBD,
  180.         MODEL_TENSOR.OUTPUT_NORM,
  181.         MODEL_TENSOR.OUTPUT,
  182.         MODEL_TENSOR.ATTN_NORM,
  183.         MODEL_TENSOR.ATTN_QKV,
  184.         MODEL_TENSOR.ATTN_OUT,
  185.         MODEL_TENSOR.FFN_NORM,
  186.         MODEL_TENSOR.FFN_DOWN,
  187.         MODEL_TENSOR.FFN_UP,
  188.     ],
  189.     MODEL_ARCH.FALCON: [
  190.         MODEL_TENSOR.TOKEN_EMBD,
  191.         MODEL_TENSOR.OUTPUT_NORM,
  192.         MODEL_TENSOR.OUTPUT,
  193.         MODEL_TENSOR.ATTN_NORM,
  194.         MODEL_TENSOR.ATTN_NORM_2,
  195.         MODEL_TENSOR.ATTN_QKV,
  196.         MODEL_TENSOR.ATTN_OUT,
  197.         MODEL_TENSOR.FFN_DOWN,
  198.         MODEL_TENSOR.FFN_UP,
  199.     ],
  200.     MODEL_ARCH.BAICHUAN: [
  201.         MODEL_TENSOR.TOKEN_EMBD,
  202.         MODEL_TENSOR.OUTPUT_NORM,
  203.         MODEL_TENSOR.OUTPUT,
  204.         MODEL_TENSOR.ROPE_FREQS,
  205.         MODEL_TENSOR.ATTN_NORM,
  206.         MODEL_TENSOR.ATTN_Q,
  207.         MODEL_TENSOR.ATTN_K,
  208.         MODEL_TENSOR.ATTN_V,
  209.         MODEL_TENSOR.ATTN_OUT,
  210.         MODEL_TENSOR.ATTN_ROT_EMBD,
  211.         MODEL_TENSOR.FFN_NORM,
  212.         MODEL_TENSOR.FFN_GATE,
  213.         MODEL_TENSOR.FFN_DOWN,
  214.         MODEL_TENSOR.FFN_UP,
  215.     ],
  216.     MODEL_ARCH.STARCODER: [
  217.         MODEL_TENSOR.TOKEN_EMBD,
  218.         MODEL_TENSOR.POS_EMBD,
  219.         MODEL_TENSOR.OUTPUT_NORM,
  220.         MODEL_TENSOR.OUTPUT,
  221.         MODEL_TENSOR.ATTN_NORM,
  222.         MODEL_TENSOR.ATTN_QKV,
  223.         MODEL_TENSOR.ATTN_OUT,
  224.         MODEL_TENSOR.FFN_NORM,
  225.         MODEL_TENSOR.FFN_DOWN,
  226.         MODEL_TENSOR.FFN_UP,
  227.     ],
  228.     MODEL_ARCH.BERT: [
  229.         MODEL_TENSOR.TOKEN_EMBD,
  230.         MODEL_TENSOR.TOKEN_TYPES,
  231.         MODEL_TENSOR.POS_EMBD,
  232.         MODEL_TENSOR.OUTPUT_NORM,
  233.         MODEL_TENSOR.ATTN_NORM,
  234.         MODEL_TENSOR.ATTN_Q,
  235.         MODEL_TENSOR.ATTN_K,
  236.         MODEL_TENSOR.ATTN_V,
  237.         MODEL_TENSOR.ATTN_OUT,
  238.         MODEL_TENSOR.FFN_NORM,
  239.         MODEL_TENSOR.FFN_DOWN,
  240.         MODEL_TENSOR.FFN_UP,
  241.     ],
  242.     MODEL_ARCH.MPT: [
  243.         MODEL_TENSOR.TOKEN_EMBD,
  244.         MODEL_TENSOR.OUTPUT_NORM,
  245.         MODEL_TENSOR.OUTPUT,
  246.         MODEL_TENSOR.ATTN_NORM,
  247.         MODEL_TENSOR.ATTN_QKV,
  248.         MODEL_TENSOR.ATTN_OUT,
  249.         MODEL_TENSOR.FFN_NORM,
  250.         MODEL_TENSOR.FFN_DOWN,
  251.         MODEL_TENSOR.FFN_UP,
  252.     ],
  253.     MODEL_ARCH.GPTJ: [
  254.         MODEL_TENSOR.TOKEN_EMBD,
  255.         MODEL_TENSOR.OUTPUT_NORM,
  256.         MODEL_TENSOR.OUTPUT,
  257.         MODEL_TENSOR.ATTN_NORM,
  258.         MODEL_TENSOR.ATTN_Q,
  259.         MODEL_TENSOR.ATTN_K,
  260.         MODEL_TENSOR.ATTN_V,
  261.         MODEL_TENSOR.ATTN_OUT,
  262.         MODEL_TENSOR.FFN_DOWN,
  263.         MODEL_TENSOR.FFN_UP,
  264.     ],
  265.     MODEL_ARCH.PERSIMMON: [
  266.         MODEL_TENSOR.TOKEN_EMBD,
  267.         MODEL_TENSOR.OUTPUT,
  268.         MODEL_TENSOR.OUTPUT_NORM,
  269.         MODEL_TENSOR.ATTN_NORM,
  270.         MODEL_TENSOR.ATTN_QKV,
  271.         MODEL_TENSOR.ATTN_OUT,
  272.         MODEL_TENSOR.FFN_NORM,
  273.         MODEL_TENSOR.FFN_DOWN,
  274.         MODEL_TENSOR.FFN_UP,
  275.         MODEL_TENSOR.ATTN_Q_NORM,
  276.         MODEL_TENSOR.ATTN_K_NORM,
  277.         MODEL_TENSOR.ATTN_ROT_EMBD,
  278.     ],
  279.     MODEL_ARCH.REFACT: [
  280.         MODEL_TENSOR.TOKEN_EMBD,
  281.         MODEL_TENSOR.OUTPUT_NORM,
  282.         MODEL_TENSOR.OUTPUT,
  283.         MODEL_TENSOR.ATTN_NORM,
  284.         MODEL_TENSOR.ATTN_Q,
  285.         MODEL_TENSOR.ATTN_K,
  286.         MODEL_TENSOR.ATTN_V,
  287.         MODEL_TENSOR.ATTN_OUT,
  288.         MODEL_TENSOR.FFN_NORM,
  289.         MODEL_TENSOR.FFN_GATE,
  290.         MODEL_TENSOR.FFN_DOWN,
  291.         MODEL_TENSOR.FFN_UP,
  292.     ],
  293.     MODEL_ARCH.BLOOM: [
  294.         MODEL_TENSOR.TOKEN_EMBD,
  295.         MODEL_TENSOR.TOKEN_EMBD_NORM,
  296.         MODEL_TENSOR.OUTPUT_NORM,
  297.         MODEL_TENSOR.OUTPUT,
  298.         MODEL_TENSOR.ATTN_NORM,
  299.         MODEL_TENSOR.ATTN_QKV,
  300.         MODEL_TENSOR.ATTN_OUT,
  301.         MODEL_TENSOR.FFN_NORM,
  302.         MODEL_TENSOR.FFN_DOWN,
  303.         MODEL_TENSOR.FFN_UP,
  304.     ],
  305.     MODEL_ARCH.GPT2: [
  306.         # TODO
  307.     ],
  308.     # TODO
  309. }
  310.  
  311. # tensors that will not be serialized
  312. MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
  313.     MODEL_ARCH.LLAMA: [
  314.         MODEL_TENSOR.ROPE_FREQS,
  315.         MODEL_TENSOR.ATTN_ROT_EMBD,
  316.     ],
  317.     MODEL_ARCH.BAICHUAN: [
  318.         MODEL_TENSOR.ROPE_FREQS,
  319.         MODEL_TENSOR.ATTN_ROT_EMBD,
  320.     ],
  321.     MODEL_ARCH.PERSIMMON: [
  322.         MODEL_TENSOR.ROPE_FREQS,
  323.     ]
  324. }
  325.  
  326.  
  327. class TensorNameMap:
  328.     mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
  329.         # Token embeddings
  330.         MODEL_TENSOR.TOKEN_EMBD: (
  331.             "gpt_neox.embed_in",                        # gptneox
  332.             "transformer.wte",                          # gpt2 gpt-j mpt refact
  333.             "transformer.word_embeddings",              # falcon
  334.             "word_embeddings",                          # bloom
  335.             "model.embed_tokens",                       # llama-hf
  336.             "tok_embeddings",                           # llama-pth
  337.             "embeddings.word_embeddings",               # bert
  338.             "language_model.embedding.word_embeddings", # persimmon
  339.         ),
  340.  
  341.         # Token type embeddings
  342.         MODEL_TENSOR.TOKEN_TYPES: (
  343.             "embeddings.token_type_embeddings",  # bert
  344.         ),
  345.  
  346.         # Normalization of token embeddings
  347.         MODEL_TENSOR.TOKEN_EMBD_NORM: (
  348.             "word_embeddings_layernorm",  # bloom
  349.             # "model.embed_layer_norm",  # bluelm
  350.         ),
  351.  
  352.         # Position embeddings
  353.         MODEL_TENSOR.POS_EMBD: (
  354.             "transformer.wpe",                 # gpt2
  355.             "embeddings.position_embeddings",  # bert
  356.         ),
  357.  
  358.         # Output
  359.         MODEL_TENSOR.OUTPUT: (
  360.             "embed_out",                # gptneox
  361.             "lm_head",                  # gpt2 mpt falcon llama-hf baichuan
  362.             "output",                   # llama-pth bloom
  363.             "word_embeddings_for_head", # persimmon
  364.         ),
  365.  
  366.         # Output norm
  367.         MODEL_TENSOR.OUTPUT_NORM: (
  368.             "gpt_neox.final_layer_norm",              # gptneox
  369.             #"model.embed_layer_norm",                  # BlueLM
  370.             "transformer.ln_f",                       # gpt2 gpt-j falcon
  371.             "model.norm",                             # llama-hf baichuan
  372.             "norm",                                   # llama-pth
  373.             "embeddings.LayerNorm",                   # bert
  374.             "transformer.norm_f",                     # mpt
  375.             "ln_f",                                   # refact bloom
  376.             "language_model.encoder.final_layernorm", # persimmon
  377.         ),
  378.  
  379.         # Rope frequencies
  380.         MODEL_TENSOR.ROPE_FREQS: (
  381.             "rope.freqs", # llama-pth
  382.         ),
  383.     }
  384.  
  385.     block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
  386.         # Attention norm
  387.         MODEL_TENSOR.ATTN_NORM: (
  388.             "gpt_neox.layers.{bid}.input_layernorm",               # gptneox
  389.             "transformer.h.{bid}.ln_1",                            # gpt2 gpt-j refact
  390.             "transformer.blocks.{bid}.norm_1",                     # mpt
  391.             "transformer.h.{bid}.input_layernorm",                 # falcon7b
  392.             "h.{bid}.input_layernorm",                             # bloom
  393.             "transformer.h.{bid}.ln_mlp",                          # falcon40b
  394.             "model.layers.{bid}.input_layernorm",                  # llama-hf
  395.             "layers.{bid}.attention_norm",                         # llama-pth
  396.             "encoder.layer.{bid}.attention.output.LayerNorm",      # bert
  397.             "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
  398.             "model.layers.{bid}.ln1",                              # yi
  399.         ),
  400.  
  401.         # Attention norm 2
  402.         MODEL_TENSOR.ATTN_NORM_2: (
  403.             "transformer.h.{bid}.ln_attn", # falcon40b
  404.         ),
  405.  
  406.         # Attention query-key-value
  407.         MODEL_TENSOR.ATTN_QKV: (
  408.             "gpt_neox.layers.{bid}.attention.query_key_value",                    # gptneox
  409.             "transformer.h.{bid}.attn.c_attn",                                    # gpt2
  410.             "transformer.blocks.{bid}.attn.Wqkv",                                 # mpt
  411.             "transformer.h.{bid}.self_attention.query_key_value",                 # falcon
  412.             "h.{bid}.self_attention.query_key_value",                             # bloom
  413.             "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
  414.         ),
  415.  
  416.         # Attention query
  417.         MODEL_TENSOR.ATTN_Q: (
  418.             "model.layers.{bid}.self_attn.q_proj",       # llama-hf
  419.             "layers.{bid}.attention.wq",                 # llama-pth
  420.             "encoder.layer.{bid}.attention.self.query",  # bert
  421.             "transformer.h.{bid}.attn.q_proj",           # gpt-j
  422.         ),
  423.  
  424.         # Attention key
  425.         MODEL_TENSOR.ATTN_K: (
  426.             "model.layers.{bid}.self_attn.k_proj",     # llama-hf
  427.             "layers.{bid}.attention.wk",               # llama-pth
  428.             "encoder.layer.{bid}.attention.self.key",  # bert
  429.             "transformer.h.{bid}.attn.k_proj",         # gpt-j
  430.         ),
  431.  
  432.         # Attention value
  433.         MODEL_TENSOR.ATTN_V: (
  434.             "model.layers.{bid}.self_attn.v_proj",       # llama-hf
  435.             "layers.{bid}.attention.wv",                 # llama-pth
  436.             "encoder.layer.{bid}.attention.self.value",  # bert
  437.             "transformer.h.{bid}.attn.v_proj",           # gpt-j
  438.         ),
  439.  
  440.         # Attention output
  441.         MODEL_TENSOR.ATTN_OUT: (
  442.             "gpt_neox.layers.{bid}.attention.dense",                   # gptneox
  443.             "transformer.h.{bid}.attn.c_proj",                         # gpt2 refact
  444.             "transformer.blocks.{bid}.attn.out_proj",                  # mpt
  445.             "transformer.h.{bid}.self_attention.dense",                # falcon
  446.             "h.{bid}.self_attention.dense",                            # bloom
  447.             "model.layers.{bid}.self_attn.o_proj",                     # llama-hf
  448.             "layers.{bid}.attention.wo",                               # llama-pth
  449.             "encoder.layer.{bid}.attention.output.dense",              # bert
  450.             "transformer.h.{bid}.attn.out_proj",                       # gpt-j
  451.             "language_model.encoder.layers.{bid}.self_attention.dense" # persimmon
  452.         ),
  453.  
  454.         # Rotary embeddings
  455.         MODEL_TENSOR.ATTN_ROT_EMBD: (
  456.             "model.layers.{bid}.self_attn.rotary_emb.inv_freq",  # llama-hf
  457.             "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
  458.         ),
  459.  
  460.         # Feed-forward norm
  461.         MODEL_TENSOR.FFN_NORM: (
  462.             "gpt_neox.layers.{bid}.post_attention_layernorm",               # gptneox
  463.             "transformer.h.{bid}.ln_2",                                     # gpt2 refact
  464.             "h.{bid}.post_attention_layernorm",                             # bloom
  465.             "transformer.blocks.{bid}.norm_2",                              # mpt
  466.             "model.layers.{bid}.post_attention_layernorm",                  # llama-hf
  467.             "layers.{bid}.ffn_norm",                                        # llama-pth
  468.             "encoder.layer.{bid}.output.LayerNorm",                         # bert
  469.             "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
  470.             "model.layers.{bid}.ln2",                                       # yi
  471.         ),
  472.  
  473.         # Feed-forward up
  474.         MODEL_TENSOR.FFN_UP: (
  475.             "gpt_neox.layers.{bid}.mlp.dense_h_to_4h",               # gptneox
  476.             "transformer.h.{bid}.mlp.c_fc",                          # gpt2
  477.             "transformer.blocks.{bid}.ffn.up_proj",                  # mpt
  478.             "transformer.h.{bid}.mlp.dense_h_to_4h",                 # falcon
  479.             "h.{bid}.mlp.dense_h_to_4h",                             # bloom
  480.             "model.layers.{bid}.mlp.up_proj",                        # llama-hf refact
  481.             "layers.{bid}.feed_forward.w3",                          # llama-pth
  482.             "encoder.layer.{bid}.intermediate.dense",                # bert
  483.             "transformer.h.{bid}.mlp.fc_in",                         # gpt-j
  484.             "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
  485.         ),
  486.  
  487.         # Feed-forward gate
  488.         MODEL_TENSOR.FFN_GATE: (
  489.             "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
  490.             "layers.{bid}.feed_forward.w1",     # llama-pth
  491.         ),
  492.  
  493.         # Feed-forward down
  494.         MODEL_TENSOR.FFN_DOWN: (
  495.             "gpt_neox.layers.{bid}.mlp.dense_4h_to_h",               # gptneox
  496.             "transformer.h.{bid}.mlp.c_proj",                        # gpt2 refact
  497.             "transformer.blocks.{bid}.ffn.down_proj",                # mpt
  498.             "transformer.h.{bid}.mlp.dense_4h_to_h",                 # falcon
  499.             "h.{bid}.mlp.dense_4h_to_h",                             # bloom
  500.             "model.layers.{bid}.mlp.down_proj",                      # llama-hf
  501.             "layers.{bid}.feed_forward.w2",                          # llama-pth
  502.             "encoder.layer.{bid}.output.dense",                      # bert
  503.             "transformer.h.{bid}.mlp.fc_out",                        # gpt-j
  504.             "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
  505.         ),
  506.  
  507.         MODEL_TENSOR.ATTN_Q_NORM: (
  508.             "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
  509.         ),
  510.  
  511.         MODEL_TENSOR.ATTN_K_NORM: (
  512.             "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
  513.         ),
  514.  
  515.         MODEL_TENSOR.ROPE_FREQS: (
  516.             "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
  517.         )
  518.     }
  519.  
  520.     mapping: dict[str, tuple[MODEL_TENSOR, str]]
  521.  
  522.     def __init__(self, arch: MODEL_ARCH, n_blocks: int):
  523.         self.mapping = {}
  524.         for tensor, keys in self.mappings_cfg.items():
  525.             if tensor not in MODEL_TENSORS[arch]:
  526.                 continue
  527.             tensor_name = TENSOR_NAMES[tensor]
  528.             self.mapping[tensor_name] = (tensor, tensor_name)
  529.             for key in keys:
  530.                 self.mapping[key] = (tensor, tensor_name)
  531.         for bid in range(n_blocks):
  532.             for tensor, keys in self.block_mappings_cfg.items():
  533.                 if tensor not in MODEL_TENSORS[arch]:
  534.                     continue
  535.                 tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
  536.                 self.mapping[tensor_name] = (tensor, tensor_name)
  537.                 for key in keys:
  538.                     key = key.format(bid = bid)
  539.                     self.mapping[key] = (tensor, tensor_name)
  540.  
  541.     def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
  542.         result = self.mapping.get(key)
  543.         if result is not None:
  544.             return result
  545.         for suffix in try_suffixes:
  546.             if key.endswith(suffix):
  547.                 result = self.mapping.get(key[:-len(suffix)])
  548.                 if result is not None:
  549.                     return (result[0], result[1] + suffix)
  550.         return None
  551.  
  552.     def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
  553.         result = self.get_type_and_name(key, try_suffixes = try_suffixes)
  554.         if result is None:
  555.             return None
  556.         return result[1]
  557.  
  558.     def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
  559.         result = self.get_type_and_name(key, try_suffixes = try_suffixes)
  560.         if result is None:
  561.             return None
  562.         return result[0]
  563.  
  564.     def __getitem__(self, key: str) -> str:
  565.         try:
  566.             return self.mapping[key][1]
  567.         except KeyError:
  568.             raise KeyError(key)
  569.  
  570.     def __contains__(self, key: str) -> bool:
  571.         return key in self.mapping
  572.  
  573.     def __repr__(self) -> str:
  574.         return repr(self.mapping)
  575.  
  576. def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
  577.     return TensorNameMap(arch, n_blocks)
  578.  
  579. class TokenType(IntEnum):
  580.     NORMAL       = 1
  581.     UNKNOWN      = 2
  582.     CONTROL      = 3
  583.     USER_DEFINED = 4
  584.     UNUSED       = 5
  585.     BYTE         = 6
  586.  
  587. class RopeScalingType(Enum):
  588.     NONE   = 'none'
  589.     LINEAR = 'linear'
  590.     YARN   = 'yarn'
  591.  
  592. #
  593. # implementation
  594. #
  595.  
  596.  
  597. class GGMLQuantizationType(IntEnum):
  598.     F32  = 0
  599.     F16  = 1
  600.     Q4_0 = 2
  601.     Q4_1 = 3
  602.     Q5_0 = 6
  603.     Q5_1 = 7
  604.     Q8_0 = 8
  605.     Q8_1 = 9
  606.     Q2_K = 10
  607.     Q3_K = 11
  608.     Q4_K = 12
  609.     Q5_K = 13
  610.     Q6_K = 14
  611.     Q8_K = 15
  612.  
  613. class GGUFEndian(IntEnum):
  614.     LITTLE = 0
  615.     BIG = 1
  616.  
  617.  
  618. class GGUFValueType(IntEnum):
  619.     UINT8   = 0
  620.     INT8    = 1
  621.     UINT16  = 2
  622.     INT16   = 3
  623.     UINT32  = 4
  624.     INT32   = 5
  625.     FLOAT32 = 6
  626.     BOOL    = 7
  627.     STRING  = 8
  628.     ARRAY   = 9
  629.     UINT64  = 10
  630.     INT64   = 11
  631.     FLOAT64 = 12
  632.  
  633.     @staticmethod
  634.     def get_type(val):
  635.         if isinstance(val, str) or isinstance(val, bytes) or isinstance(val, bytearray):
  636.             return GGUFValueType.STRING
  637.         elif isinstance(val, list):
  638.             return GGUFValueType.ARRAY
  639.         elif isinstance(val, float):
  640.             return GGUFValueType.FLOAT32
  641.         elif isinstance(val, bool):
  642.             return GGUFValueType.BOOL
  643.         elif isinstance(val, int):
  644.             return GGUFValueType.INT32
  645.         # TODO: need help with 64-bit types in Python
  646.         else:
  647.             print("Unknown type: "+str(type(val)))
  648.             sys.exit()
  649.  
  650.  
  651. class GGUFWriter:
  652.     fout: BufferedWriter
  653.     arch: str
  654.     offset_tensor = 0
  655.     data_alignment = GGUF_DEFAULT_ALIGNMENT
  656.     kv_data = b""
  657.     kv_data_count = 0
  658.     ti_data = b""
  659.     ti_data_count = 0
  660.     use_temp_file: bool
  661.     temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
  662.     tensors: list[tuple[np.ndarray[Any, Any], int]]
  663.  
  664.     @property
  665.     def pack_prefix(self):
  666.         if self.endianess==GGUFEndian.LITTLE:
  667.             return "<"
  668.         else:
  669.             return ">"
  670.  
  671.     def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file = True, endianess=GGUFEndian.LITTLE):
  672.         self.fout = open(path, "wb")
  673.         self.arch = arch
  674.         self.endianess = endianess
  675.         self._simple_value_packing = {
  676.             GGUFValueType.UINT8:   f"{self.pack_prefix}B",
  677.             GGUFValueType.INT8:    f"{self.pack_prefix}b",
  678.             GGUFValueType.UINT16:  f"{self.pack_prefix}H",
  679.             GGUFValueType.INT16:   f"{self.pack_prefix}h",
  680.             GGUFValueType.UINT32:  f"{self.pack_prefix}I",
  681.             GGUFValueType.INT32:   f"{self.pack_prefix}i",
  682.             GGUFValueType.FLOAT32: f"{self.pack_prefix}f",
  683.             GGUFValueType.UINT64:  f"{self.pack_prefix}Q",
  684.             GGUFValueType.INT64:   f"{self.pack_prefix}q",
  685.             GGUFValueType.FLOAT64: f"{self.pack_prefix}d",
  686.             GGUFValueType.BOOL:    "?" ,
  687.         }
  688.         self.add_architecture()
  689.         self.use_temp_file = use_temp_file
  690.         self.tensors = []
  691.         endianess_str = "Big Endian" if self.endianess == GGUFEndian.BIG else "Little Endian"
  692.         print(f"This gguf file is for {endianess_str} only")
  693.  
  694.     def write_header_to_file(self):
  695.         self.fout.write(struct.pack("<I", GGUF_MAGIC))
  696.         self.fout.write(struct.pack(f"{self.pack_prefix}I", GGUF_VERSION))
  697.         self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.ti_data_count))
  698.         self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.kv_data_count))
  699.         self.flush()
  700. #        print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
  701.  
  702.     def write_kv_data_to_file(self):
  703.         self.fout.write(self.kv_data)
  704.         self.flush()
  705.  
  706.     def write_ti_data_to_file(self):
  707.         self.fout.write(self.ti_data)
  708.         self.flush()
  709.  
  710.     def add_key(self, key: str):
  711.         self.add_val(key, GGUFValueType.STRING, add_vtype=False)
  712.  
  713.     def add_uint8(self, key: str, val: int):
  714.         self.add_key(key)
  715.         self.add_val(val, GGUFValueType.UINT8)
  716.  
  717.     def add_int8(self, key: str, val: int):
  718.         self.add_key(key)
  719.         self.add_val(val, GGUFValueType.INT8)
  720.  
  721.     def add_uint16(self, key: str, val: int):
  722.         self.add_key(key)
  723.         self.add_val(val, GGUFValueType.UINT16)
  724.  
  725.     def add_int16(self, key: str, val: int):
  726.         self.add_key(key)
  727.         self.add_val(val, GGUFValueType.INT16)
  728.  
  729.     def add_uint32(self, key: str, val: int):
  730.         self.add_key(key)
  731.         self.add_val(val, GGUFValueType.UINT32)
  732.  
  733.     def add_int32(self, key: str, val: int):
  734.         self.add_key(key)
  735.         self.add_val(val, GGUFValueType.INT32)
  736.  
  737.     def add_float32(self, key: str, val: float):
  738.         self.add_key(key)
  739.         self.add_val(val, GGUFValueType.FLOAT32)
  740.  
  741.     def add_uint64(self, key: str, val: int):
  742.         self.add_key(key)
  743.         self.add_val(val, GGUFValueType.UINT64)
  744.  
  745.     def add_int64(self, key: str, val: int):
  746.         self.add_key(key)
  747.         self.add_val(val, GGUFValueType.INT64)
  748.  
  749.     def add_float64(self, key: str, val: float):
  750.         self.add_key(key)
  751.         self.add_val(val, GGUFValueType.FLOAT64)
  752.  
  753.     def add_bool(self, key: str, val: bool):
  754.         self.add_key(key)
  755.         self.add_val(val, GGUFValueType.BOOL)
  756.  
  757.     def add_string(self, key: str, val: str):
  758.         if len(val) == 0:
  759.             return
  760.         self.add_key(key)
  761.         self.add_val(val, GGUFValueType.STRING)
  762.  
  763.     def add_array(self, key: str, val: Sequence[Any]):
  764.         if not isinstance(val, Sequence):
  765.             raise ValueError("Value must be a sequence for array type")
  766.  
  767.         self.add_key(key)
  768.         self.add_val(val, GGUFValueType.ARRAY)
  769.  
  770.     def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True):
  771.         if vtype is None:
  772.             vtype = GGUFValueType.get_type(val)
  773.  
  774.         if add_vtype:
  775.             self.kv_data += struct.pack(f"{self.pack_prefix}I", vtype)
  776.             self.kv_data_count += 1
  777.  
  778.         pack_fmt = self._simple_value_packing.get(vtype)
  779.         if pack_fmt is not None:
  780.             self.kv_data += struct.pack(pack_fmt, val)
  781.         elif vtype == GGUFValueType.STRING:
  782.             encoded_val = val.encode("utf8") if isinstance(val, str) else val
  783.             self.kv_data += struct.pack(f"{self.pack_prefix}Q", len(encoded_val))
  784.             self.kv_data += encoded_val
  785.         elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0:
  786.             ltype = GGUFValueType.get_type(val[0])
  787.             if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
  788.                 raise ValueError("All items in a GGUF array should be of the same type")
  789.             self.kv_data += struct.pack(f"{self.pack_prefix}I", ltype)
  790.             self.kv_data += struct.pack(f"{self.pack_prefix}Q", len(val))
  791.             for item in val:
  792.                 self.add_val(item, add_vtype=False)
  793.         else:
  794.             raise ValueError("Invalid GGUF metadata value type or value")
  795.  
  796.     @staticmethod
  797.     def ggml_pad(x: int, n: int) -> int:
  798.         return ((x + n - 1) // n) * n
  799.  
  800.     def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None):
  801.         assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
  802.  
  803.         encoded_name = name.encode("utf8")
  804.         self.ti_data += struct.pack(f"{self.pack_prefix}Q", len(encoded_name))
  805.         self.ti_data += encoded_name
  806.         n_dims = len(tensor_shape)
  807.         self.ti_data += struct.pack(f"{self.pack_prefix}I", n_dims)
  808.         for i in range(n_dims):
  809.             self.ti_data += struct.pack(f"{self.pack_prefix}Q", tensor_shape[n_dims - 1 - i])
  810.         if raw_dtype is None:
  811.             dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
  812.         else:
  813.             dtype = raw_dtype
  814.         self.ti_data += struct.pack(f"{self.pack_prefix}I", dtype)
  815.         self.ti_data += struct.pack(f"{self.pack_prefix}Q", self.offset_tensor)
  816.         self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
  817.         self.ti_data_count += 1
  818.  
  819.     def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None):
  820.         if self.endianess == GGUFEndian.BIG:
  821.             tensor.byteswap(inplace=True)
  822.         if self.use_temp_file and self.temp_file is None:
  823.             fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
  824.             fp.seek(0)
  825.             self.temp_file = fp
  826.  
  827.         shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
  828.         self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
  829.  
  830.         pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
  831.  
  832.         if  self.temp_file is None:
  833.             self.tensors.append((tensor, pad))
  834.             return
  835.  
  836.         tensor.tofile(self.temp_file)
  837.  
  838.         if pad != 0:
  839.             self.temp_file.write(bytes([0] * pad))
  840.  
  841.     def write_padding(self, fp: BinaryIO, n: int, align: int | None = None):
  842.         pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
  843.         if pad != 0:
  844.             fp.write(bytes([0] * pad))
  845.  
  846.     def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
  847.         if self.endianess==GGUFEndian.BIG:
  848.             tensor.byteswap(inplace=True)
  849.         self.write_padding(self.fout, self.fout.tell())
  850.         tensor.tofile(self.fout)
  851.         self.write_padding(self.fout, tensor.nbytes)
  852.  
  853.     def write_tensors_to_file(self):
  854.         self.write_ti_data_to_file()
  855.  
  856.         self.write_padding(self.fout, self.fout.tell())
  857.  
  858.         if self.temp_file is None:
  859.             for (currtensor, currpad) in self.tensors:
  860.                 currtensor.tofile(self.fout)
  861.                 if currpad != 0:
  862.                     self.fout.write(bytes([0] * currpad))
  863.             return
  864.  
  865.         self.temp_file.seek(0)
  866.  
  867.         shutil.copyfileobj(self.temp_file, self.fout)
  868.         self.flush()
  869.         self.temp_file.close()
  870.  
  871.     def flush(self):
  872.         self.fout.flush()
  873.  
  874.     def close(self):
  875.         self.fout.close()
  876.  
  877.     def add_architecture(self):
  878.         self.add_string(KEY_GENERAL_ARCHITECTURE, self.arch)
  879.  
  880.     def add_author(self, author: str):
  881.         self.add_string(KEY_GENERAL_AUTHOR, author)
  882.  
  883.     def add_tensor_data_layout(self, layout: str):
  884.         self.add_string(KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
  885.  
  886.     def add_url(self, url: str):
  887.         self.add_string(KEY_GENERAL_URL, url)
  888.  
  889.     def add_description(self, description: str):
  890.         self.add_string(KEY_GENERAL_DESCRIPTION, description)
  891.  
  892.     def add_source_url(self, url: str):
  893.         self.add_string(KEY_GENERAL_SOURCE_URL, url)
  894.  
  895.     def add_source_hf_repo(self, repo: str):
  896.         self.add_string(KEY_GENERAL_SOURCE_HF_REPO, repo)
  897.  
  898.     def add_file_type(self, ftype: int):
  899.         self.add_uint32(KEY_GENERAL_FILE_TYPE, ftype)
  900.  
  901.     def add_name(self, name: str):
  902.         self.add_string(KEY_GENERAL_NAME, name)
  903.  
  904.     def add_quantization_version(self, quantization_version: GGMLQuantizationType):
  905.         self.add_uint32(
  906.             KEY_GENERAL_QUANTIZATION_VERSION, quantization_version)
  907.  
  908.     def add_custom_alignment(self, alignment: int):
  909.         self.data_alignment = alignment
  910.         self.add_uint32(KEY_GENERAL_ALIGNMENT, alignment)
  911.  
  912.     def add_context_length(self, length: int):
  913.         self.add_uint32(
  914.             KEY_CONTEXT_LENGTH.format(arch=self.arch), length)
  915.  
  916.     def add_embedding_length(self, length: int):
  917.         self.add_uint32(
  918.             KEY_EMBEDDING_LENGTH.format(arch=self.arch), length)
  919.  
  920.     def add_block_count(self, length: int):
  921.         self.add_uint32(
  922.             KEY_BLOCK_COUNT.format(arch=self.arch), length)
  923.  
  924.     def add_feed_forward_length(self, length: int):
  925.         self.add_uint32(
  926.             KEY_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
  927.  
  928.     def add_parallel_residual(self, use: bool):
  929.         self.add_bool(
  930.             KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
  931.  
  932.     def add_head_count(self, count: int):
  933.         self.add_uint32(
  934.             KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count)
  935.  
  936.     def add_head_count_kv(self, count: int):
  937.         self.add_uint32(
  938.             KEY_ATTENTION_HEAD_COUNT_KV.format(arch=self.arch), count)
  939.  
  940.     def add_max_alibi_bias(self, bias: float):
  941.         self.add_float32(
  942.             KEY_ATTENTION_MAX_ALIBI_BIAS.format(arch=self.arch), bias)
  943.  
  944.     def add_clamp_kqv(self, value: float):
  945.         self.add_float32(
  946.             KEY_ATTENTION_CLAMP_KQV.format(arch=self.arch), value)
  947.  
  948.     def add_layer_norm_eps(self, value: float):
  949.         self.add_float32(
  950.             KEY_ATTENTION_LAYERNORM_EPS.format(arch=self.arch), value)
  951.  
  952.     def add_layer_norm_rms_eps(self, value: float):
  953.         self.add_float32(
  954.             KEY_ATTENTION_LAYERNORM_RMS_EPS.format(arch=self.arch), value)
  955.  
  956.     def add_rope_dimension_count(self, count: int):
  957.         self.add_uint32(
  958.             KEY_ROPE_DIMENSION_COUNT.format(arch=self.arch), count)
  959.  
  960.     def add_rope_freq_base(self, value: float):
  961.         self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
  962.  
  963.     def add_rope_scaling_type(self, value: RopeScalingType):
  964.         self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
  965.  
  966.     def add_rope_scaling_factor(self, value: float):
  967.         self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
  968.  
  969.     def add_rope_scaling_orig_ctx_len(self, value: int):
  970.         self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
  971.  
  972.     def add_rope_scaling_finetuned(self, value: bool):
  973.         self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value)
  974.  
  975.     def add_tokenizer_model(self, model: str):
  976.         self.add_string(KEY_TOKENIZER_MODEL, model)
  977.  
  978.     def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
  979.         self.add_array(KEY_TOKENIZER_LIST, tokens)
  980.  
  981.     def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
  982.         self.add_array(KEY_TOKENIZER_MERGES, merges)
  983.  
  984.     def add_token_types(self, types: Sequence[TokenType] | Sequence[int]):
  985.         self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
  986.  
  987.     def add_token_scores(self, scores: Sequence[float]):
  988.         self.add_array(KEY_TOKENIZER_SCORES, scores)
  989.  
  990.     def add_bos_token_id(self, id: int):
  991.         self.add_uint32(KEY_TOKENIZER_BOS_ID, id)
  992.  
  993.     def add_eos_token_id(self, id: int):
  994.         self.add_uint32(KEY_TOKENIZER_EOS_ID, id)
  995.  
  996.     def add_unk_token_id(self, id: int):
  997.         self.add_uint32(KEY_TOKENIZER_UNK_ID, id)
  998.  
  999.     def add_sep_token_id(self, id: int):
  1000.         self.add_uint32(KEY_TOKENIZER_SEP_ID, id)
  1001.  
  1002.     def add_pad_token_id(self, id: int):
  1003.         self.add_uint32(KEY_TOKENIZER_PAD_ID, id)
  1004.  
  1005.  
  1006. class SpecialVocab:
  1007.     load_merges: bool = False
  1008.     merges: list[str] = []
  1009.     special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
  1010.     special_token_ids: dict[str, int] = {}
  1011.     n_vocab: int | None = None
  1012.  
  1013.     def __init__(
  1014.         self, path: str | os.PathLike[str], load_merges: bool = False,
  1015.         special_token_types: tuple[str, ...] | None = None,
  1016.         n_vocab: int | None = None,
  1017.     ):
  1018.         self.special_token_ids = {}
  1019.         self.n_vocab = n_vocab
  1020.         self.load_merges = load_merges
  1021.         if special_token_types is not None:
  1022.             self.special_token_types = special_token_types
  1023.         self._load(Path(path))
  1024.  
  1025.     def _load(self, path: Path) -> None:
  1026.         if not self._try_load_from_tokenizer_json(path):
  1027.             self._try_load_from_config_json(path)
  1028.         if self.load_merges and len(self.merges) == 0:
  1029.             self._try_load_merges_txt(path)
  1030.  
  1031.     def _try_load_merges_txt(self, path: Path) -> bool:
  1032.         merges_file = path / 'merges.txt'
  1033.         if not merges_file.is_file():
  1034.             return False
  1035.         with open(merges_file, 'r') as fp:
  1036.             first_line = next(fp, '').strip()
  1037.             if not first_line.startswith('#'):
  1038.                 fp.seek(0)
  1039.                 line_num = 0
  1040.             else:
  1041.                 line_num = 1
  1042.             merges = []
  1043.             for line in fp:
  1044.                 line_num += 1
  1045.                 line = line.strip()
  1046.                 if len(line) == 0:
  1047.                     continue
  1048.                 parts = line.split(None, 3)
  1049.                 if len(parts) != 2:
  1050.                     print(f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring',
  1051.                         file = sys.stderr)
  1052.                     continue
  1053.                 merges.append(f'{parts[0]} {parts[1]}')
  1054.         self.merges = merges
  1055.         return True
  1056.  
  1057.  
  1058.     def _set_special_token(self, typ: str, tid: Any):
  1059.         if not isinstance(tid, int) or tid < 0:
  1060.             return
  1061.         if self.n_vocab is None or tid < self.n_vocab:
  1062.             self.special_token_ids[typ] = tid
  1063.             return
  1064.         print(f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
  1065.             file = sys.stderr)
  1066.  
  1067.  
  1068.     def _try_load_from_tokenizer_json(self, path: Path) -> bool:
  1069.         tokenizer_file = path / 'tokenizer.json'
  1070.         if not tokenizer_file.is_file():
  1071.             return False
  1072.         with open(tokenizer_file, encoding = 'utf-8') as f:
  1073.             tokenizer = json.load(f)
  1074.         if self.load_merges:
  1075.             merges = tokenizer.get('model', {}).get('merges')
  1076.             if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
  1077.                 self.merges = merges
  1078.         tokenizer_config_file = path / 'tokenizer_config.json'
  1079.         added_tokens = tokenizer.get('added_tokens')
  1080.         if added_tokens is None or not tokenizer_config_file.is_file():
  1081.             return True
  1082.         with open(tokenizer_config_file, encoding = 'utf-8') as f:
  1083.             tokenizer_config = json.load(f)
  1084.         for typ in self.special_token_types:
  1085.             entry = tokenizer_config.get(f'{typ}_token')
  1086.             if isinstance(entry, str):
  1087.                 tc_content = entry
  1088.             elif isinstance(entry, dict):
  1089.                 entry_content = entry.get('content')
  1090.                 if not isinstance(entry_content, str):
  1091.                     continue
  1092.                 tc_content = entry_content
  1093.             else:
  1094.                 continue
  1095.             # We only need the first match here.
  1096.             maybe_token_id = next((
  1097.                 atok.get('id') for atok in added_tokens
  1098.                 if atok.get('content') == tc_content), None)
  1099.             self._set_special_token(typ, maybe_token_id)
  1100.         return True
  1101.  
  1102.     def _try_load_from_config_json(self, path: Path) -> bool:
  1103.         config_file = path / 'config.json'
  1104.         if not config_file.is_file():
  1105.             return False
  1106.         with open(config_file, encoding = 'utf-8') as f:
  1107.             config = json.load(f)
  1108.         for typ in self.special_token_types:
  1109.             self._set_special_token(typ, config.get(f'{typ}_token_id'))
  1110.         return True
  1111.  
  1112.     def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
  1113.         if len(self.merges) > 0:
  1114.             if not quiet:
  1115.                 print(f'gguf: Adding {len(self.merges)} merge(s).')
  1116.             gw.add_token_merges(self.merges)
  1117.         elif self.load_merges:
  1118.             print('gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.',
  1119.                 file = sys.stderr)
  1120.         for typ, tokid in self.special_token_ids.items():
  1121.             handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
  1122.             if handler is None:
  1123.                 print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping', file = sys.stderr)
  1124.                 continue
  1125.             if not quiet:
  1126.                 print(f'gguf: Setting special token type {typ} to {tokid}')
  1127.             handler(tokid)
  1128.  
  1129.     def __repr__(self) -> str:
  1130.         return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids or "unset"}>'
  1131.  
  1132.  
  1133. # Example usage:
  1134. if __name__ == "__main__":
  1135.     # Example usage with a file
  1136.     gguf_writer = GGUFWriter("example.gguf", "llama")
  1137.  
  1138.     gguf_writer.add_architecture()
  1139.     gguf_writer.add_block_count(12)
  1140.     gguf_writer.add_uint32("answer", 42)  # Write a 32-bit integer
  1141.     gguf_writer.add_float32("answer_in_float", 42.0)  # Write a 32-bit float
  1142.     gguf_writer.add_custom_alignment(64)
  1143.  
  1144.     tensor1 = np.ones((32,), dtype=np.float32) * 100.0
  1145.     tensor2 = np.ones((64,), dtype=np.float32) * 101.0
  1146.     tensor3 = np.ones((96,), dtype=np.float32) * 102.0
  1147.  
  1148.     gguf_writer.add_tensor("tensor1", tensor1)
  1149.     gguf_writer.add_tensor("tensor2", tensor2)
  1150.     gguf_writer.add_tensor("tensor3", tensor3)
  1151.  
  1152.     gguf_writer.write_header_to_file()
  1153.     gguf_writer.write_kv_data_to_file()
  1154.     gguf_writer.write_tensors_to_file()
  1155.  
  1156.     gguf_writer.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement