Guest User

rough draft 2

a guest
Aug 7th, 2025
18
0
122 days
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Diff 19.03 KB | None | 0 0
  1. diff --git a/tools/server/server.cpp b/tools/server/server.cpp
  2. index b23e35d3..53bdfd75 100644
  3. --- a/tools/server/server.cpp
  4. +++ b/tools/server/server.cpp
  5. @@ -4228,57 +4228,16 @@ int main(int argc, char ** argv) {
  6.              // TODO: this log can become very long, put it behind a flag or think about a more compact format
  7.              //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
  8.  
  9. -            // process files
  10. -            mtmd::bitmaps bitmaps;
  11. -            const bool has_mtmd = ctx_server.mctx != nullptr;
  12. -            {
  13. -                if (!has_mtmd && !files.empty()) {
  14. -                    throw std::runtime_error("This server does not support multimodal");
  15. -                }
  16. -                for (auto & file : files) {
  17. -                    mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, file.data(), file.size()));
  18. -                    if (!bmp.ptr) {
  19. -                        throw std::runtime_error("Failed to load image or audio file");
  20. -                    }
  21. -                    // calculate bitmap hash (for KV caching)
  22. -                    std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
  23. -                    bmp.set_id(hash.c_str());
  24. -                    bitmaps.entries.push_back(std::move(bmp));
  25. -                }
  26. -            }
  27. -
  28.              // process prompt
  29.              std::vector<server_tokens> inputs;
  30.  
  31. -            if (has_mtmd) {
  32. -                // multimodal
  33. -                std::string prompt_str = prompt.get<std::string>();
  34. -                mtmd_input_text inp_txt = {
  35. -                    prompt_str.c_str(),
  36. -                    /* add_special */   true,
  37. -                    /* parse_special */ true,
  38. -                };
  39. -                mtmd::input_chunks chunks(mtmd_input_chunks_init());
  40. -                auto bitmaps_c_ptr = bitmaps.c_ptr();
  41. -                int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
  42. -                                                    chunks.ptr.get(),
  43. -                                                    &inp_txt,
  44. -                                                    bitmaps_c_ptr.data(),
  45. -                                                    bitmaps_c_ptr.size());
  46. -                if (tokenized != 0) {
  47. -                    throw std::runtime_error("Failed to tokenize prompt");
  48. -                }
  49. -
  50. -                server_tokens tmp(chunks, true);
  51. -                inputs.push_back(std::move(tmp));
  52. -            } else {
  53. -                // non-multimodal version
  54. -                auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
  55. -                for (auto & p : tokenized_prompts) {
  56. -                    auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
  57. -                    inputs.push_back(std::move(tmp));
  58. -                }
  59. -            }
  60. +            if (oaicompat && ctx_server.mctx != nullptr) {
  61. +                   // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
  62. +                   inputs.push_back(std::move(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files)));
  63. +           } else {
  64. +                   // Everything else, including multimodal completions.
  65. +                   inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
  66. +           }
  67.  
  68.              tasks.reserve(inputs.size());
  69.              for (size_t i = 0; i < inputs.size(); i++) {
  70. @@ -4451,7 +4410,7 @@ int main(int argc, char ** argv) {
  71.          data["input_extra"] = input_extra; // default to empty array if it's not exist
  72.  
  73.          std::string prompt = json_value(data, "prompt", std::string());
  74. -        std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
  75. +        std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
  76.          SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
  77.          data["prompt"] = format_infill(
  78.              ctx_server.vocab,
  79. @@ -4462,7 +4421,7 @@ int main(int argc, char ** argv) {
  80.              ctx_server.params_base.n_predict,
  81.              ctx_server.slots[0].n_ctx, // TODO: there should be a better way
  82.              ctx_server.params_base.spm_infill,
  83. -            tokenized_prompts[0]
  84. +            tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
  85.          );
  86.  
  87.          std::vector<raw_buffer> files; // dummy
  88. @@ -4640,7 +4599,7 @@ int main(int argc, char ** argv) {
  89.              }
  90.          }
  91.  
  92. -        auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
  93. +        auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
  94.          for (const auto & tokens : tokenized_prompts) {
  95.              // this check is necessary for models that do not add BOS token to the input
  96.              if (tokens.empty()) {
  97. @@ -4668,7 +4627,7 @@ int main(int argc, char ** argv) {
  98.  
  99.                  task.id            = ctx_server.queue_tasks.get_new_id();
  100.                  task.index         = i;
  101. -                task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
  102. +                task.prompt_tokens = std::move(tokenized_prompts[i]);
  103.  
  104.                  // OAI-compat
  105.                  task.params.oaicompat = oaicompat;
  106. @@ -4755,7 +4714,7 @@ int main(int argc, char ** argv) {
  107.              return;
  108.          }
  109.  
  110. -        llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
  111. +        server_tokens tokenized_query = std::move(tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true)[0]);
  112.  
  113.          // create and queue the task
  114.          json responses = json::array();
  115. @@ -4763,14 +4722,14 @@ int main(int argc, char ** argv) {
  116.          std::unordered_set<int> task_ids;
  117.          {
  118.              std::vector<server_task> tasks;
  119. -            auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
  120. +            auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
  121.              tasks.reserve(tokenized_docs.size());
  122.              for (size_t i = 0; i < tokenized_docs.size(); i++) {
  123.                  auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
  124.                  server_task task   = server_task(SERVER_TASK_TYPE_RERANK);
  125.                  task.id            = ctx_server.queue_tasks.get_new_id();
  126.                  task.index         = i;
  127. -                task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
  128. +                task.prompt_tokens = std::move(tmp);
  129.                  tasks.push_back(std::move(task));
  130.              }
  131.  
  132. diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py
  133. index be3a0052..c3317140 100644
  134. --- a/tools/server/tests/unit/test_completion.py
  135. +++ b/tools/server/tests/unit/test_completion.py
  136. @@ -231,6 +231,27 @@ def test_nocache_long_input_prompt():
  137.      })
  138.      assert res.status_code == 200
  139.  
  140. +def test_nocache_json_prompt():
  141. +    global server
  142. +    server.start()
  143. +    res = server.make_request("POST", "/completion", data={
  144. +        "prompt": { "prompt": "I believe the meaning of life is" },
  145. +        "seed": 42,
  146. +        "temperature": 1.0,
  147. +        "cache_prompt": False,
  148. +    })
  149. +    assert res.status_code == 200
  150. +
  151. +def test_nocache_multimodal_prompt():
  152. +    global server
  153. +    server.start()
  154. +    res = server.make_request("POST", "/completion", data={
  155. +        "prompt": { "prompt": "I believe the meaning of life is", "multimodal_data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
  156. +        "seed": 42,
  157. +        "temperature": 1.0,
  158. +        "cache_prompt": False,
  159. +    })
  160. +    assert res.status_code == 200
  161.  
  162.  def test_completion_with_tokens_input():
  163.      global server
  164. @@ -269,6 +290,15 @@ def test_completion_with_tokens_input():
  165.      assert len(res.body) == 2
  166.      assert res.body[0]["content"] == res.body[1]["content"]
  167.  
  168. +    # mixed multimodal and tokens
  169. +    res = server.make_request("POST", "/completion", data={
  170. +        "prompt": [tokens, { "prompt": "My name is ", "multimodal_data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" }],
  171. +    })
  172. +    assert res.status_code == 200
  173. +    assert type(res.body) == list
  174. +    assert len(res.body) == 2
  175. +    assert res.body[0]["content"] == res.body[1]["content"]
  176. +
  177.      # mixed string and tokens in one sequence
  178.      res = server.make_request("POST", "/completion", data={
  179.          "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
  180. diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp
  181. index f3dfc822..2865d977 100644
  182. --- a/tools/server/utils.hpp
  183. +++ b/tools/server/utils.hpp
  184. @@ -186,48 +186,6 @@ static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_
  185.      return prompt_tokens;
  186.  }
  187.  
  188. -/**
  189. - * break the input "prompt" object into multiple prompt if needed, then tokenize them
  190. - * this supports these cases:
  191. - * - "prompt": "string"
  192. - * - "prompt": [12, 34, 56]
  193. - * - "prompt": [12, 34, "string", 56, 78]
  194. - * and multiple prompts (multi-tasks):
  195. - * - "prompt": ["string1", "string2"]
  196. - * - "prompt": ["string1", [12, 34, 56]]
  197. - * - "prompt": [[12, 34, 56], [78, 90, 12]]
  198. - * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
  199. - */
  200. -static std::vector<llama_tokens> tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
  201. -    std::vector<llama_tokens> result;
  202. -    if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
  203. -        // string or mixed
  204. -        result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special));
  205. -    } else if (json_is_array_of_numbers(json_prompt)) {
  206. -        // array of tokens
  207. -        result.push_back(json_prompt.get<llama_tokens>());
  208. -    } else if (json_prompt.is_array()) {
  209. -        // array of prompts
  210. -        result.reserve(json_prompt.size());
  211. -        for (const auto & p : json_prompt) {
  212. -            if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
  213. -                result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
  214. -            } else if (json_is_array_of_numbers(p)) {
  215. -                // array of tokens
  216. -                result.push_back(p.get<llama_tokens>());
  217. -            } else {
  218. -                throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
  219. -            }
  220. -        }
  221. -    } else {
  222. -        throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
  223. -    }
  224. -    if (result.empty()) {
  225. -        throw std::runtime_error("\"prompt\" must not be empty");
  226. -    }
  227. -    return result;
  228. -}
  229. -
  230.  // return the last index of character that can form a valid string
  231.  // if the last character is potentially cut in half, return the index before the cut
  232.  // if validate_utf8(text) == text.size(), then the whole text is valid utf8
  233. @@ -262,35 +220,6 @@ static size_t validate_utf8(const std::string& text) {
  234.  // template utils
  235.  //
  236.  
  237. -// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
  238. -static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
  239. -    llama_tokens result;
  240. -
  241. -    // Get EOS token - use SEP token as fallback if EOS is not available
  242. -    llama_token eos_token = llama_vocab_eos(vocab);
  243. -    if (eos_token == LLAMA_TOKEN_NULL) {
  244. -        eos_token = llama_vocab_sep(vocab);
  245. -    }
  246. -
  247. -    result.reserve(doc.size() + query.size() + 4);
  248. -    if (llama_vocab_get_add_bos(vocab)) {
  249. -        result.push_back(llama_vocab_bos(vocab));
  250. -    }
  251. -    result.insert(result.end(), query.begin(), query.end());
  252. -    if (llama_vocab_get_add_eos(vocab)) {
  253. -        result.push_back(eos_token);
  254. -    }
  255. -    if (llama_vocab_get_add_sep(vocab)) {
  256. -        result.push_back(llama_vocab_sep(vocab));
  257. -    }
  258. -    result.insert(result.end(), doc.begin(), doc.end());
  259. -    if (llama_vocab_get_add_eos(vocab)) {
  260. -        result.push_back(eos_token);
  261. -    }
  262. -
  263. -    return result;
  264. -}
  265. -
  266.  // format infill task
  267.  static llama_tokens format_infill(
  268.          const llama_vocab * vocab,
  269. @@ -1186,6 +1115,18 @@ public:
  270.          }
  271.      }
  272.  
  273. +    // appends server tokens, updates the media map. destroys server tokens.
  274. +    void push_back(const server_tokens & tokens) {
  275. +           size_t start_size = tokens.size();
  276. +           for (size_t i = 0; i < start_size; i++) {
  277. +                   push_back(tokens[i]);
  278. +           }
  279. +           // TODO, currently this breaks multimodal document ranking!
  280. +           //for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
  281. +           //      map_pos_to_media[start_size+it->first]=std::move(it->second);
  282. +           //}
  283. +    }
  284. +
  285.      // for compatibility with context shift and prompt truncation
  286.      void insert(const llama_tokens & inp_tokens) {
  287.          GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
  288. @@ -1356,3 +1297,146 @@ static std::string fnv_hash(const uint8_t * data, size_t len) {
  289.      }
  290.      return std::to_string(hash);
  291.  }
  292. +
  293. +
  294. +// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
  295. +static server_tokens format_rerank(const struct llama_vocab * vocab, const server_tokens & query, const server_tokens & doc) {
  296. +       server_tokens result = {};
  297. +
  298. +       // Get EOS token - use SEP token as fallback if EOS is not available
  299. +       llama_token eos_token = llama_vocab_eos(vocab);
  300. +       if (eos_token == LLAMA_TOKEN_NULL) {
  301. +               eos_token = llama_vocab_sep(vocab);
  302. +       }
  303. +       if (llama_vocab_get_add_bos(vocab)) {
  304. +               result.push_back(llama_vocab_bos(vocab));
  305. +       }
  306. +       result.push_back(query);
  307. +       if (llama_vocab_get_add_eos(vocab)) {
  308. +               result.push_back(eos_token);
  309. +       }
  310. +       if (llama_vocab_get_add_sep(vocab)) {
  311. +               result.push_back(llama_vocab_sep(vocab));
  312. +       }
  313. +       result.push_back(doc);
  314. +       if (llama_vocab_get_add_eos(vocab)) {
  315. +               result.push_back(eos_token);
  316. +       }
  317. +       return result;
  318. +}
  319. +
  320. +
  321. +static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
  322. +       mtmd::bitmaps bitmaps;
  323. +       for (auto & file : files) {
  324. +               mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size()));
  325. +               if (!bmp.ptr) {
  326. +                       throw std::runtime_error("Failed to load image or audio file");
  327. +               }
  328. +               // calculate bitmap hash (for KV caching)
  329. +               std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
  330. +               bmp.set_id(hash.c_str());
  331. +               bitmaps.entries.push_back(std::move(bmp));
  332. +       }
  333. +       // process prompt
  334. +       std::vector<server_tokens> inputs;
  335. +       // multimodal
  336. +       mtmd_input_text inp_txt = {
  337. +               prompt.c_str(),
  338. +               /* add_special */   true,
  339. +               /* parse_special */ true,
  340. +       };
  341. +       mtmd::input_chunks chunks(mtmd_input_chunks_init());
  342. +       auto bitmaps_c_ptr = bitmaps.c_ptr();
  343. +       int32_t tokenized = mtmd_tokenize(mctx,
  344. +                                         chunks.ptr.get(),
  345. +                                         &inp_txt,
  346. +                                         bitmaps_c_ptr.data(),
  347. +                                         bitmaps_c_ptr.size());
  348. +       if (tokenized != 0) {
  349. +               throw std::runtime_error("Failed to tokenize prompt");
  350. +       }
  351. +       auto result = server_tokens(chunks,true);
  352. +       return result;
  353. +}
  354. +
  355. +/**
  356. + * break the input "prompt" object into multiple prompt if needed, then tokenize them
  357. + * this supports these cases:
  358. + * - "prompt": "string"
  359. + * - "prompt": [12, 34, 56]
  360. + * - "prompt": [12, 34, "string", 56, 78]
  361. + * and multiple prompts (multi-tasks):
  362. + * - "prompt": ["string1", "string2"]
  363. + * - "prompt": ["string1", [12, 34, 56]]
  364. + * - "prompt": [[12, 34, 56], [78, 90, 12]]
  365. + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
  366. + */
  367. +static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab* vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) {
  368. +       std::vector<server_tokens> result;
  369. +       bool has_mtmd = mctx != nullptr;
  370. +       if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
  371. +               // string or mixed
  372. +               llama_tokens toks = tokenize_mixed(vocab, json_prompt, add_special, parse_special);
  373. +               auto tmp = server_tokens(toks, false);
  374. +               result.push_back(std::move(tmp));
  375. +       } else if (json_is_array_of_numbers(json_prompt)) {
  376. +               // array of tokens
  377. +               llama_tokens toks = json_prompt.get<llama_tokens>();
  378. +               auto tmp = server_tokens(toks, false);
  379. +               result.push_back(std::move(tmp));
  380. +       } else if (json_prompt.find("prompt") != json_prompt.end()) {
  381. +               // JSON object with prompt key.
  382. +               if (has_mtmd && json_prompt.find("multimodal_data") != json_prompt.end()) {
  383. +                       // JSON object with prompt and multimodal key.
  384. +                       std::vector<raw_buffer> files;
  385. +                       for (const auto& entry : json_prompt.at("multimodal_data")) {
  386. +                               files.push_back(base64_decode(entry));
  387. +                       }
  388. +                       result.push_back(std::move(process_mtmd_prompt(mctx, json_prompt.at("prompt"), files)));
  389. +               } else {
  390. +                       // Not multimodal, but contains a subobject.
  391. +                       llama_tokens toks = tokenize_mixed(vocab, json_prompt.at("prompt"), add_special, parse_special);
  392. +                       auto tmp = server_tokens(toks, false);
  393. +                       result.push_back(std::move(tmp));
  394. +               }
  395. +       } else if (json_prompt.is_array()) {
  396. +               // array of prompts
  397. +               result.reserve(json_prompt.size());
  398. +               for (const auto & p : json_prompt) {
  399. +                       if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
  400. +                               llama_tokens toks = tokenize_mixed(vocab, p, add_special, parse_special);
  401. +                               auto tmp = server_tokens(toks, false);
  402. +                               result.push_back(std::move(tmp));
  403. +                       } else if (json_is_array_of_numbers(p)) {
  404. +                               // array of tokens
  405. +                               llama_tokens toks = p.get<llama_tokens>();
  406. +                               auto tmp = server_tokens(toks,false);
  407. +                               result.push_back(std::move(tmp));
  408. +                       } else if (has_mtmd && p.find("prompt") != p.end()) {
  409. +                               if (p.find("multimodal_data") != p.end()) {
  410. +                                       // Multimodal JSON object.
  411. +                                       std::vector<raw_buffer> files;
  412. +                                       for (const auto& entry : p.at("multimodal_data")) {
  413. +                                               files.push_back(base64_decode(entry));
  414. +                                       }
  415. +                                       result.push_back(process_mtmd_prompt(mctx, json_prompt.at("prompt"), files));
  416. +                               } else {
  417. +                                       // Non-multimodal JSON object.
  418. +                                       llama_tokens toks = tokenize_mixed(vocab, p, add_special, parse_special);
  419. +                                       auto tmp = server_tokens(toks, false);
  420. +                                       result.push_back(std::move(tmp));
  421. +                               }
  422. +                       } else {
  423. +                               throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
  424. +                       }
  425. +               }
  426. +       } else {
  427. +               throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
  428. +       }
  429. +       if (result.empty()) {
  430. +               throw std::runtime_error("\"prompt\" must not be empty");
  431. +       }
  432. +       return result;
  433. +}
  434. +
  435.  
Add Comment
Please, Sign In to add comment