Guest User

Untitled

a guest
Jan 28th, 2025
52
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 10.25 KB | None | 0 0
  1. #include "ggml.h"
  2. #include "gguf.h"
  3. #include "common.h"
  4.  
  5. #include <algorithm>
  6. #include <cinttypes>
  7. #include <cstdio>
  8. #include <cstdlib>
  9. #include <stdexcept>
  10. #include <cstring>
  11. #include <fstream>
  12. #include <string>
  13. #include <vector>
  14.  
  15. // Helper function to check if a string ends with another string
  16. bool ends_with(const std::string &str, const std::string &suffix) {
  17.     if (str.length() < suffix.length()) return false;
  18.     return str.compare(str.length() - suffix.length(), suffix.length(), suffix) == 0;
  19. }
  20.  
  21. // Helper function to replace all occurrences of a substring
  22. std::string replace_string(const std::string &str, const std::string &from, const std::string &to) {
  23.     std::string result = str;
  24.     size_t start_pos = 0;
  25.     while ((start_pos = result.find(from, start_pos)) != std::string::npos) {
  26.         result.replace(start_pos, from.length(), to);
  27.         start_pos += to.length();
  28.     }
  29.     return result;
  30. }
  31.  
  32. // Write zeros for padding
  33. void zeros(std::ofstream &file, size_t n) {
  34.     char zero = 0;
  35.     for (size_t i = 0; i < n; ++i) {
  36.         file.write(&zero, 1);
  37.     }
  38. }
  39.  
  40. int main(int argc, const char **argv) {
  41.     if (argc != 3) {
  42.         fprintf(stderr, "Usage: %s <input.gguf> <output.gguf>\n", argv[0]);
  43.         return EXIT_FAILURE;
  44.     }
  45.  
  46.     const std::string input_path = argv[1];
  47.     const std::string output_path = argv[2];
  48.  
  49.     // Load the original GGUF file
  50.     struct ggml_context *ctx_meta = nullptr;
  51.     struct gguf_init_params params = {
  52.         /*.no_alloc = */ true,
  53.         /*.ctx      = */ &ctx_meta,
  54.     };
  55.     struct gguf_context *original_ctx = gguf_init_from_file(input_path.c_str(), params);
  56.     if (!original_ctx) {
  57.         fprintf(stderr, "Failed to load input GGUF file: %s\n", input_path.c_str());
  58.         return EXIT_FAILURE;
  59.     }
  60.  
  61.     // Create a new GGUF context
  62.     struct gguf_context *new_ctx = gguf_init_empty();
  63.     if (!new_ctx) {
  64.         fprintf(stderr, "Failed to initialize new GGUF context\n");
  65.         gguf_free(original_ctx);
  66.         return EXIT_FAILURE;
  67.     }
  68.  
  69.     // Copy metadata from original to new context (excluding tensors)
  70.     const int n_kv = gguf_get_n_kv(original_ctx);
  71.     for (int i = 0; i < n_kv; ++i) {
  72.         const char *key = gguf_get_key(original_ctx, i);
  73.         const enum gguf_type type = gguf_get_kv_type(original_ctx, i);
  74.         switch (type) {
  75.             case GGUF_TYPE_UINT8:
  76.                 gguf_set_val_u8(new_ctx, key, gguf_get_val_u8(original_ctx, i));
  77.                 break;
  78.             case GGUF_TYPE_INT8:
  79.                 gguf_set_val_i8(new_ctx, key, gguf_get_val_i8(original_ctx, i));
  80.                 break;
  81.             case GGUF_TYPE_UINT16:
  82.                 gguf_set_val_u16(new_ctx, key, gguf_get_val_u16(original_ctx, i));
  83.                 break;
  84.             case GGUF_TYPE_INT16:
  85.                 gguf_set_val_i16(new_ctx, key, gguf_get_val_i16(original_ctx, i));
  86.                 break;
  87.             case GGUF_TYPE_UINT32:
  88.                 gguf_set_val_u32(new_ctx, key, gguf_get_val_u32(original_ctx, i));
  89.                 break;
  90.             case GGUF_TYPE_INT32:
  91.                 gguf_set_val_i32(new_ctx, key, gguf_get_val_i32(original_ctx, i));
  92.                 break;
  93.             case GGUF_TYPE_FLOAT32:
  94.                 gguf_set_val_f32(new_ctx, key, gguf_get_val_f32(original_ctx, i));
  95.                 break;
  96.             case GGUF_TYPE_UINT64:
  97.                 gguf_set_val_u64(new_ctx, key, gguf_get_val_u64(original_ctx, i));
  98.                 break;
  99.             case GGUF_TYPE_INT64:
  100.                 gguf_set_val_i64(new_ctx, key, gguf_get_val_i64(original_ctx, i));
  101.                 break;
  102.             case GGUF_TYPE_FLOAT64:
  103.                 gguf_set_val_f64(new_ctx, key, gguf_get_val_f64(original_ctx, i));
  104.                 break;
  105.             case GGUF_TYPE_BOOL:
  106.                 gguf_set_val_bool(new_ctx, key, gguf_get_val_bool(original_ctx, i));
  107.                 break;
  108.             case GGUF_TYPE_STRING:
  109.                 gguf_set_val_str(new_ctx, key, gguf_get_val_str(original_ctx, i));
  110.                 break;
  111.             case GGUF_TYPE_ARRAY:
  112.                 // Handle arrays if needed (not typically required for this conversion)
  113.                 break;
  114.             default:
  115.                 fprintf(stderr, "Unhandled metadata type: %d\n", type);
  116.                 break;
  117.         }
  118.     }
  119.  
  120.     std::vector<std::vector<uint8_t>> tensor_data_buffers;
  121.     std::ifstream f_input(input_path, std::ios::binary);
  122.     if (!f_input.is_open()) {
  123.         fprintf(stderr, "Failed to open input file for reading: %s\n", input_path.c_str());
  124.         gguf_free(original_ctx);
  125.         gguf_free(new_ctx);
  126.         return EXIT_FAILURE;
  127.     }
  128.  
  129.     // Retrieve hyperparameters from metadata (replace keys if necessary)
  130.     int num_key_value_heads = 128;   // TODO: Don't hardcode
  131.     int v_head_dim = 128;           // TODO: Don't hardcode
  132.     int qk_nope_head_dim = 128;     // TODO: Don't hardcode
  133.  
  134.     const int n_tensors = gguf_get_n_tensors(original_ctx);
  135.     int split_count = 0;
  136.  
  137.     for (int i = 0; i < n_tensors; ++i) {
  138.         const char *tensor_name = gguf_get_tensor_name(original_ctx, i);
  139.         fprintf(stderr, "Processing tensor_name: %s\n", tensor_name);
  140.         struct ggml_tensor *tensor = ggml_get_tensor(ctx_meta, tensor_name);
  141.  
  142.         if (ends_with(tensor_name, "kv_b.weight")) {
  143.             fprintf(stderr, "GOING TO SPLIT %s\n", tensor_name);
  144.             // Read tensor data
  145.             const size_t data_size = ggml_nbytes(tensor);
  146.             std::vector<uint8_t> data(data_size);
  147.             const size_t offset = gguf_get_data_offset(original_ctx) + gguf_get_tensor_offset(original_ctx, i);
  148.             f_input.seekg(offset);
  149.             f_input.read(reinterpret_cast<char*>(data.data()), data_size);
  150.  
  151.             // Assuming F32 data type
  152.             float *f_data = reinterpret_cast<float*>(data.data());
  153.             const int a = tensor->ne[0];  // rows
  154.             const int b = tensor->ne[1];  // columns
  155.  
  156.             // Validate hyperparameters
  157.             if (a != num_key_value_heads * (v_head_dim + qk_nope_head_dim)) {
  158.                 fprintf(stderr, "Tensor shape does not match hyperparameters, EXPECTED (num_key_value_heads * (v_head_dim + qk_nope_head_dim)):%d, ACTUAL:%d\n", num_key_value_heads * (v_head_dim + qk_nope_head_dim), a);
  159.                 gguf_free(original_ctx);
  160.                 gguf_free(new_ctx);
  161.                 return EXIT_FAILURE;
  162.             }
  163.  
  164.             const int n_head_kv = num_key_value_heads;
  165.             const int qkn = qk_nope_head_dim;
  166.             const int vhd = v_head_dim;
  167.  
  168.             // Prepare new tensors' data
  169.             std::vector<float> k_data(n_head_kv * b * qkn, 0.0f);
  170.             std::vector<float> v_data(n_head_kv * vhd * b, 0.0f);
  171.  
  172.             for (int h = 0; h < n_head_kv; ++h) {
  173.                 // Process k part
  174.                 for (int q = 0; q < qkn; ++q) {
  175.                     for (int c = 0; c < b; ++c) {
  176.                         const size_t original_idx = h * (vhd + qkn) * b + q * b + c;
  177.                         const size_t new_idx = (h * b + c) * qkn + q;
  178.                         k_data[new_idx] = f_data[original_idx];
  179.                     }
  180.                 }
  181.                 // Process v part
  182.                 for (int v_row = 0; v_row < vhd; ++v_row) {
  183.                     const size_t original_start = h * (vhd + qkn) * b + (qkn + v_row) * b;
  184.                     const size_t new_start = (h * vhd + v_row) * b;
  185.                     memcpy(&v_data[new_start], &f_data[original_start], b * sizeof(float));
  186.                 }
  187.             }
  188.  
  189.             // Create new tensor names
  190.             std::string k_name = replace_string(tensor_name, "kv_b", "k_b");
  191.             std::string v_name = replace_string(tensor_name, "kv_b", "v_b");
  192.  
  193.             // Add new tensors to the context
  194.             struct ggml_tensor *k_tensor = ggml_new_tensor_2d(ctx_meta, GGML_TYPE_F32, qkn, n_head_kv * b);
  195.             gguf_add_tensor(new_ctx, k_tensor);
  196.             struct ggml_tensor *v_tensor = ggml_new_tensor_2d(ctx_meta, GGML_TYPE_F32, b, n_head_kv * vhd);
  197.             gguf_add_tensor(new_ctx, v_tensor);
  198.  
  199.             // Store data buffers
  200.             std::vector<uint8_t> k_buffer(reinterpret_cast<uint8_t*>(k_data.data()), reinterpret_cast<uint8_t*>(k_data.data() + k_data.size()));
  201.             std::vector<uint8_t> v_buffer(reinterpret_cast<uint8_t*>(v_data.data()), reinterpret_cast<uint8_t*>(v_data.data() + v_data.size()));
  202.             tensor_data_buffers.push_back(k_buffer);
  203.             tensor_data_buffers.push_back(v_buffer);
  204.  
  205.             split_count++;
  206.         }
  207.             // Add original tensor to new context
  208.             gguf_add_tensor(new_ctx, tensor);
  209.  
  210.             // Read data
  211.             const size_t data_size = ggml_nbytes(tensor);
  212.             std::vector<uint8_t> data(data_size);
  213.             const size_t offset = gguf_get_data_offset(original_ctx) + gguf_get_tensor_offset(original_ctx, i);
  214.             f_input.seekg(offset);
  215.             f_input.read(reinterpret_cast<char*>(data.data()), data_size);
  216.             tensor_data_buffers.push_back(data);
  217.  
  218.     }
  219.     fprintf(stderr, "Finished Processing tensors, will now write output>\n");
  220.  
  221.     // Write the new GGUF file
  222.     std::ofstream f_output(output_path, std::ios::binary);
  223.     if (!f_output.is_open()) {
  224.         fprintf(stderr, "Failed to open output file: %s\n", output_path.c_str());
  225.         gguf_free(original_ctx);
  226.         gguf_free(new_ctx);
  227.         return EXIT_FAILURE;
  228.     }
  229.  
  230.     // Write metadata
  231.     const size_t meta_size = gguf_get_meta_size(new_ctx);
  232.     std::vector<uint8_t> meta_buffer(meta_size);
  233.     gguf_get_meta_data(new_ctx, meta_buffer.data());
  234.     f_output.write(reinterpret_cast<const char*>(meta_buffer.data()), meta_size);
  235.  
  236.     // Write tensor data
  237.     for (const auto &data : tensor_data_buffers) {
  238.         f_output.write(reinterpret_cast<const char*>(data.data()), data.size());
  239.         const size_t pad_size = GGML_PAD(data.size(), GGUF_DEFAULT_ALIGNMENT) - data.size();
  240.         zeros(f_output, pad_size);
  241.     }
  242.  
  243.     // Cleanup
  244.     gguf_free(original_ctx);
  245.     gguf_free(new_ctx);
  246.     f_input.close();
  247.     f_output.close();
  248.  
  249.     fprintf(stderr, "Successfully converted %d tensors. Output written to %s\n", split_count, output_path.c_str());
  250.     return EXIT_SUCCESS;
  251. }
  252.  
Advertisement
Add Comment
Please, Sign In to add comment