Guest User

gpt2-xl_n1000_l6

a guest
Nov 25th, 2024
161
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.47 KB | None | 0 0
  1. %pip install transformer_lens matplotlib openai python-dotenv
  2.  
  3. from threading import Thread
  4. import torch
  5. from transformer_lens import HookedTransformer
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from typing import Dict, Union, List
  9. from openai import OpenAI, RateLimitError
  10. from tqdm import tqdm
  11. from dotenv import load_dotenv
  12. import time
  13. import os
  14. from datasets import load_dataset
  15. import json
  16. import random
  17.  
  18. load_dotenv()
  19.  
  20.  
  21. client = OpenAI(api_key = os.getenv("OPENAI_API_KEY"))
  22.  
  23. def ask_yes_no(prompt: str) -> bool:
  24. done = False
  25. backoff = 4
  26. while not done:
  27. try:
  28. response = client.chat.completions.create(
  29. model="gpt-4o-mini",
  30. messages=[
  31. {
  32. "role": "system",
  33. "content": "You are a helpful assistant that answers yes or no questions. Respond with 'yes' or 'no' only."
  34. },
  35. {
  36. "role": "user",
  37. "content": prompt
  38. }
  39. ],
  40. max_tokens=3,
  41. temperature=0,
  42. )
  43. return "yes" in response.choices[0].message.content.lower()
  44. except RateLimitError as e:
  45. print(f"Rate limit error: {e}")
  46. time.sleep(backoff)
  47. backoff *= 2
  48. if backoff > 128:
  49. raise e
  50.  
  51. def batch_yes_no(batch: List[str], question_fn, progress=True, num_workers=20) -> List[bool]:
  52. from concurrent.futures import ThreadPoolExecutor
  53.  
  54. # Process in parallel batches of num_workers
  55. with ThreadPoolExecutor(max_workers=num_workers) as executor:
  56. futures = [executor.submit(question_fn, text) for text in batch]
  57. results = list(tqdm(
  58. (f.result() for f in futures),
  59. total=len(batch),
  60. desc="Evaluating batch",
  61. disable=not progress
  62. ))
  63.  
  64. return results
  65.  
  66. def related(big_batch: List[str], topic: str, progress=True, num_workers=20) -> List[bool]:
  67. question_fn = lambda text: ask_yes_no(f"Is this text related to {topic}? Text:\n\n{text}")
  68. return batch_yes_no(big_batch, question_fn, progress, num_workers)
  69.  
  70. def coherent(batch: List[str], progress=True, num_workers=20) -> List[bool]:
  71. question_fn = lambda text: ask_yes_no(f"Is this text extremely incoherent? Text:\n\n{text}")
  72. return batch_yes_no(batch, question_fn, progress, num_workers)
  73.  
  74. model_name = "gpt2-xl"
  75. device = "cuda:0"
  76. model = HookedTransformer.from_pretrained(model_name, device=device)
  77. model.eval()
  78. if torch.cuda.is_available():
  79. model.to(device)
  80.  
  81.  
  82. SEED = 0
  83. sampling_kwargs = dict(temperature=1.0, top_p=0.3, freq_penalty=1.0)
  84. tlen = lambda prompt: model.to_tokens(prompt).shape[1]
  85. pad_right = lambda prompt, length: prompt + " " * (length - tlen(prompt))
  86.  
  87. def pad_both(p_add, p_sub):
  88. l = max(tlen(p_add), tlen(p_sub))
  89. return pad_right(p_add, l), pad_right(p_sub, l)
  90.  
  91. def get_resid_pre(prompt: str, layer: int):
  92. name = f"blocks.{layer}.hook_resid_pre"
  93. cache, caching_hooks, _ = model.get_caching_hooks(lambda n: n == name)
  94. with model.hooks(fwd_hooks=caching_hooks):
  95. _ = model(prompt)
  96. return cache[name]
  97.  
  98.  
  99. def get_act_diff(prompt_add: str, prompt_sub: str, layer: int):
  100. act_add = get_resid_pre(prompt_add, layer)
  101. act_sub = get_resid_pre(prompt_sub, layer)
  102. return act_add - act_sub # if this errors you forgot to pad
  103.  
  104.  
  105. from functools import partial
  106.  
  107. def ave_hook(act_diff, resid_pre, hook):
  108. if resid_pre.shape[1] == 1:
  109. return # caching in model.generate for new tokens
  110.  
  111. # We only add to the prompt (first call), not the generated tokens.
  112. ppos, apos = resid_pre.shape[1], act_diff.shape[1]
  113. assert apos <= ppos, f"More mod tokens ({apos}) then prompt tokens ({ppos})!"
  114.  
  115. # add to the beginning (position-wise) of the activations
  116. resid_pre[:, :apos, :] += act_diff
  117.  
  118.  
  119. def hooked_generate(prompt_batch: List[str], fwd_hooks=[], seed=None, verbose=False, **kwargs):
  120. if seed is not None:
  121. torch.manual_seed(seed)
  122.  
  123. with model.hooks(fwd_hooks=fwd_hooks):
  124. tokenized = model.to_tokens(prompt_batch)
  125. r = model.generate(input=tokenized, do_sample=True, verbose=verbose, **kwargs)
  126. return r
  127.  
  128.  
  129. def generate_hooked(prompt_batch: List[str], prompt_add: str, prompt_sub: str, act_name: int, coeff: float, verbose=False, max_new_tokens=50):
  130. prompt_add, prompt_sub = pad_both(prompt_add, prompt_sub)
  131. act_diff = coeff*get_act_diff(prompt_add, prompt_sub, act_name)
  132. editing_hooks = [(f"blocks.{act_name}.hook_resid_pre", partial(ave_hook, act_diff))]
  133. hooked_res = hooked_generate(prompt_batch, editing_hooks, seed=SEED, verbose=verbose, **sampling_kwargs, max_new_tokens=max_new_tokens)
  134. return hooked_res
  135.  
  136.  
  137. def generate_both(prompt_batch: List[str], prompt_add: str, prompt_sub: str, act_name: int, coeff: float):
  138. hooked_res = generate_hooked(prompt_batch, prompt_add, prompt_sub, act_name, coeff)
  139. vanilla_res = hooked_generate(prompt_batch, [], seed=SEED, **sampling_kwargs)
  140. return hooked_res, vanilla_res
  141.  
  142.  
  143. TOPICS = ["art", "finance", "music", "politics", "science", "weddings"]
  144.  
  145. # https://huggingface.co/datasets/stanfordnlp/imdb
  146. ds = load_dataset("stanfordnlp/imdb")
  147.  
  148. imdb_sample = random.sample(ds['test']['text'], 3000)
  149. imdb_prompts = [x[:256] for x in imdb_sample]
  150. imdb_prompts = [" ".join(x.split(" ")[:-1]) for x in imdb_prompts]
  151. imdb_prompts[:10]
  152.  
  153. for topic in TOPICS:
  154. rel = related(imdb_prompts, topic, num_workers=100, progress=True)
  155. imdb_prompts = [x for x, y in zip(imdb_prompts, rel) if not y]
  156. print(f"{topic}: removed {sum(rel)} ({np.mean(rel)*100:.1f}%) new len prompts: {len(imdb_prompts)}")
  157. open('imdb_prompts_filtered.json', 'w').write(json.dumps(imdb_prompts))
  158.  
  159.  
  160. def generate_big_hooked_batch(prompt_batch: List[str], padd: str, psub: str, act_name: int, coeff: float):
  161. "Generate in cases where processing prompt_batch all at once would make us run out of memory."
  162.  
  163. batch_size = 60 # XXX: REDUCE THIS IF UR NOT ON A h100
  164.  
  165. res_strs = []
  166. disable_progress = len(prompt_batch) < batch_size
  167. for i in tqdm(range(0, len(prompt_batch), batch_size), desc="Prompt batch", disable=disable_progress):
  168. batch = prompt_batch[i:i+batch_size]
  169. hooked_res = generate_hooked(batch, padd, psub, act_name, coeff)
  170. res_strs += model.to_string(hooked_res[:, 1:])
  171. return [x.replace("<|endoftext|>", "") for x in res_strs]
  172.  
  173.  
  174. COEFFICIENTS = [0.0, 0.5, 1.0, 2.0]
  175. SEED = 0
  176. completions = {}
  177. results = {}
  178.  
  179. PROMPT_TEMPLATE = "I talk about {topic} constantly"
  180. NEG_PROMPT_TEMPLATE = "I do not talk about {topic} constantly"
  181.  
  182. # Initialize results
  183. for topic in TOPICS:
  184. results[topic] = []
  185.  
  186. threads = []
  187.  
  188. # Generate completions and spawn related tasks
  189. for topic in tqdm(TOPICS, desc="topics"):
  190. prompt_add, prompt_sub = pad_both(PROMPT_TEMPLATE.format(topic=topic),
  191. NEG_PROMPT_TEMPLATE.format(topic=topic))
  192. completions[topic] = {}
  193.  
  194. for coeff in COEFFICIENTS:
  195. hooked_strs = generate_big_hooked_batch(PROMPT_BATCH, prompt_add, prompt_sub,
  196. act_name, coeff)
  197. completions[topic][coeff] = hooked_strs
  198.  
  199. def run_related(t=topic, c=coeff, h=hooked_strs):
  200. rel = related(h, t, num_workers=100, progress=False)
  201. score = np.mean(rel)
  202. results[t].append(score)
  203. print(f"topic {t} coeff {c} -- rel: {score}")
  204.  
  205. thread = Thread(target=run_related)
  206. thread.start()
  207. threads.append(thread)
  208.  
  209. # Wait for all threads
  210. print("Waiting for threads to finish...")
  211. for thread in threads:
  212. thread.join()
  213.  
  214. import time
  215.  
  216. fname = f"results/{int(time.time())}.json"
  217. with open(fname, "w") as f:
  218. json.dump({
  219. "results": results,
  220. "completions": completions,
  221. "meta": {
  222. "coeffs": COEFFICIENTS,
  223. "topics": TOPICS,
  224. "act_name": act_name,
  225. "seed": SEED,
  226. "prompt_template": PROMPT_TEMPLATE,
  227. "neg_prompt_template": NEG_PROMPT_TEMPLATE,
  228. },
  229. "prompts": PROMPT_BATCH
  230. }, f)
  231.  
  232. print(f"Saved results to {fname}")
  233.  
  234. # Create the topic steering plot WITH baseline included
  235. plt.figure(figsize=(12, 6))
  236. x = np.arange(len(TOPICS))
  237. width = 0.2
  238.  
  239. subtract = True
  240. for i, coeff in enumerate(COEFFICIENTS):
  241. relevance_scores = [results[topic][i]*100 for topic in TOPICS]
  242. plt.bar(x + i * width, relevance_scores, width, label=f'c={coeff}')
  243.  
  244. plt.xlabel('Topic')
  245. plt.ylabel('% evaluated as relevant')
  246. plt.title('gpt-4o-mini scored relevance of ActAdd completions on generic topics')
  247. plt.xticks(x + width * (len(COEFFICIENTS) - 1) / 2, TOPICS)
  248.  
  249. plt.legend()
  250. plt.tight_layout()
  251. plt.show()
  252.  
  253.  
  254. # Create the visualization WITHOUT baseline included (show diffs)
  255. plt.figure(figsize=(12, 6))
  256. x = np.arange(len(TOPICS))
  257. width = 0.2
  258.  
  259. subtract = True
  260. for i, coeff in enumerate(COEFFICIENTS[1:]):
  261. relevance_scores = [results[topic][i]*100 for topic in TOPICS]
  262. plt.bar(x + i * width, relevance_scores, width, label=f'c={coeff}')
  263.  
  264. plt.xlabel('Topic')
  265. # plt.ylabel('% evaluated as relevant minus no-editing baseline')
  266. plt.ylabel('Δ% Topic Relevance (vs Unedited)')
  267. plt.title('gpt-4o-mini scored relevance of ActAdd completions on generic topics')
  268. plt.xticks(x + width * (len(COEFFICIENTS[1:]) - 1) / 2, TOPICS)
  269.  
  270. plt.legend()
  271. plt.tight_layout()
  272. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment