Mivik

llm_lv2.py

Nov 9th, 2024
195
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.73 KB | None | 0 0
  1. from llama_cpp import Llama
  2. import llama_cpp.llama_chat_format as llama_chat_format
  3.  
  4. # 1. Assert sha256 of the model file
  5. model_path = '../qwen2.5-3b-instruct-q8_0.gguf'
  6.  
  7.  
  8. llm = Llama(
  9.     model_path=model_path,
  10.     n_ctx=1024,
  11.     seed=200,
  12. )
  13. eos_token_id = llm.token_eos()
  14. bos_token_id = llm.token_bos()
  15. eos_token = (
  16.     llm._model.token_get_text(eos_token_id) if eos_token_id != -1 else ''
  17. )
  18. bos_token = (
  19.     llm._model.token_get_text(bos_token_id) if bos_token_id != -1 else ''
  20. )
  21. formatter = llama_chat_format.Jinja2ChatFormatter(
  22.     template=llm.metadata['tokenizer.chat_template'],
  23.     eos_token=eos_token,
  24.     bos_token=bos_token,
  25.     stop_token_ids=[eos_token_id],
  26. )
  27. result = formatter(
  28.     messages=[
  29.         {'role': 'system', 'content': 'You are a professional CTF player.'},
  30.         {
  31.             'role': 'user',
  32.             'content': 'Write a short article for Hackergame 2024 (中国科学技术大学 (University of Science and Technology of China) 第十一届信息安全大赛) in English. The more funny and unreal the better. About 500 words.',
  33.         },
  34.     ]
  35. )
  36.  
  37. words = {}
  38. for i in range(llm.n_vocab()):
  39.     t = llm._model.token_get_text(i)
  40.     words.setdefault(len(t), {})[t] = i
  41.  
  42.  
  43. prompt = result.prompt
  44. prompt = llm.tokenize(
  45.     prompt.encode('utf-8'),
  46.     add_bos=not result.added_special,
  47.     special=True,
  48. )
  49. appended = []
  50. with open('after.txt', 'rb') as f:
  51.     s = f.read()
  52.  
  53. alters = []
  54. alter_cur = []
  55.  
  56.  
  57. def logits_processor(
  58.     input_ids,
  59.     scores,
  60. ):
  61.     global alters
  62.  
  63.     wow = scores.argsort()
  64.     for i in wow[-1:-20:-1]:
  65.         ss = llm.detokenize(appended + [i])
  66.         if all(
  67.             (a == ord('x') and b in map(ord, 'hackergame of ustc')) or a == b
  68.             for a, b in zip(s[: len(ss)], ss)
  69.         ):
  70.             alters.append(i)
  71.  
  72.     if alter_cur[-1] >= len(alters):
  73.         # run out of options, clear alters to notify the caller
  74.         alters = []
  75.     else:
  76.         scores[alters[alter_cur[-1]]] = 100
  77.  
  78.     return scores
  79.  
  80.  
  81. i = 0
  82. while True:
  83.     gen = llm.generate(
  84.         prompt,
  85.         logits_processor=[logits_processor],
  86.     )
  87.     alter_cur.append(0)
  88.     while True:
  89.         alters.clear()
  90.         logic = next(gen)
  91.  
  92.         if not alters:
  93.             print('# Oh no!! Going back!!!')
  94.             alter_cur.pop()
  95.             appended.pop()
  96.             alter_cur[-1] += 1
  97.             gen = llm.generate(
  98.                 prompt + appended,
  99.                 logits_processor=[logits_processor],
  100.             )
  101.             continue
  102.  
  103.         appended.append(logic)
  104.         output = llm.detokenize(appended)
  105.         print('=====\n' + output.decode())
  106.         if len(output) >= len(s):
  107.             quit()
  108.         alter_cur.append(0)
  109.  
Advertisement
Add Comment
Please, Sign In to add comment