Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- %pip install transformer_lens matplotlib openai python-dotenv
- from threading import Thread
- import torch
- from transformer_lens import HookedTransformer
- import matplotlib.pyplot as plt
- import numpy as np
- from typing import Dict, Union, List
- from openai import OpenAI, RateLimitError
- from tqdm import tqdm
- from dotenv import load_dotenv
- import time
- import os
- from datasets import load_dataset
- import json
- import random
- load_dotenv()
- client = OpenAI(api_key = os.getenv("OPENAI_API_KEY"))
- def ask_yes_no(prompt: str) -> bool:
- done = False
- backoff = 4
- while not done:
- try:
- response = client.chat.completions.create(
- model="gpt-4o-mini",
- messages=[
- {
- "role": "system",
- "content": "You are a helpful assistant that answers yes or no questions. Respond with 'yes' or 'no' only."
- },
- {
- "role": "user",
- "content": prompt
- }
- ],
- max_tokens=3,
- temperature=0,
- )
- return "yes" in response.choices[0].message.content.lower()
- except RateLimitError as e:
- print(f"Rate limit error: {e}")
- time.sleep(backoff)
- backoff *= 2
- if backoff > 128:
- raise e
- def batch_yes_no(batch: List[str], question_fn, progress=True, num_workers=20) -> List[bool]:
- from concurrent.futures import ThreadPoolExecutor
- # Process in parallel batches of num_workers
- with ThreadPoolExecutor(max_workers=num_workers) as executor:
- futures = [executor.submit(question_fn, text) for text in batch]
- results = list(tqdm(
- (f.result() for f in futures),
- total=len(batch),
- desc="Evaluating batch",
- disable=not progress
- ))
- return results
- def related(big_batch: List[str], topic: str, progress=True, num_workers=20) -> List[bool]:
- question_fn = lambda text: ask_yes_no(f"Is this text related to {topic}? Text:\n\n{text}")
- return batch_yes_no(big_batch, question_fn, progress, num_workers)
- def coherent(batch: List[str], progress=True, num_workers=20) -> List[bool]:
- question_fn = lambda text: ask_yes_no(f"Is this text extremely incoherent? Text:\n\n{text}")
- return batch_yes_no(batch, question_fn, progress, num_workers)
- model_name = "gpt2-xl"
- device = "cuda:0"
- model = HookedTransformer.from_pretrained(model_name, device=device)
- model.eval()
- if torch.cuda.is_available():
- model.to(device)
- SEED = 0
- sampling_kwargs = dict(temperature=1.0, top_p=0.3, freq_penalty=1.0)
- tlen = lambda prompt: model.to_tokens(prompt).shape[1]
- pad_right = lambda prompt, length: prompt + " " * (length - tlen(prompt))
- def pad_both(p_add, p_sub):
- l = max(tlen(p_add), tlen(p_sub))
- return pad_right(p_add, l), pad_right(p_sub, l)
- def get_resid_pre(prompt: str, layer: int):
- name = f"blocks.{layer}.hook_resid_pre"
- cache, caching_hooks, _ = model.get_caching_hooks(lambda n: n == name)
- with model.hooks(fwd_hooks=caching_hooks):
- _ = model(prompt)
- return cache[name]
- def get_act_diff(prompt_add: str, prompt_sub: str, layer: int):
- act_add = get_resid_pre(prompt_add, layer)
- act_sub = get_resid_pre(prompt_sub, layer)
- return act_add - act_sub # if this errors you forgot to pad
- from functools import partial
- def ave_hook(act_diff, resid_pre, hook):
- if resid_pre.shape[1] == 1:
- return # caching in model.generate for new tokens
- # We only add to the prompt (first call), not the generated tokens.
- ppos, apos = resid_pre.shape[1], act_diff.shape[1]
- assert apos <= ppos, f"More mod tokens ({apos}) then prompt tokens ({ppos})!"
- # add to the beginning (position-wise) of the activations
- resid_pre[:, :apos, :] += act_diff
- def hooked_generate(prompt_batch: List[str], fwd_hooks=[], seed=None, verbose=False, **kwargs):
- if seed is not None:
- torch.manual_seed(seed)
- with model.hooks(fwd_hooks=fwd_hooks):
- tokenized = model.to_tokens(prompt_batch)
- r = model.generate(input=tokenized, do_sample=True, verbose=verbose, **kwargs)
- return r
- def generate_hooked(prompt_batch: List[str], prompt_add: str, prompt_sub: str, act_name: int, coeff: float, verbose=False, max_new_tokens=50):
- prompt_add, prompt_sub = pad_both(prompt_add, prompt_sub)
- act_diff = coeff*get_act_diff(prompt_add, prompt_sub, act_name)
- editing_hooks = [(f"blocks.{act_name}.hook_resid_pre", partial(ave_hook, act_diff))]
- hooked_res = hooked_generate(prompt_batch, editing_hooks, seed=SEED, verbose=verbose, **sampling_kwargs, max_new_tokens=max_new_tokens)
- return hooked_res
- def generate_both(prompt_batch: List[str], prompt_add: str, prompt_sub: str, act_name: int, coeff: float):
- hooked_res = generate_hooked(prompt_batch, prompt_add, prompt_sub, act_name, coeff)
- vanilla_res = hooked_generate(prompt_batch, [], seed=SEED, **sampling_kwargs)
- return hooked_res, vanilla_res
- TOPICS = ["art", "finance", "music", "politics", "science", "weddings"]
- # https://huggingface.co/datasets/stanfordnlp/imdb
- ds = load_dataset("stanfordnlp/imdb")
- imdb_sample = random.sample(ds['test']['text'], 3000)
- imdb_prompts = [x[:256] for x in imdb_sample]
- imdb_prompts = [" ".join(x.split(" ")[:-1]) for x in imdb_prompts]
- imdb_prompts[:10]
- for topic in TOPICS:
- rel = related(imdb_prompts, topic, num_workers=100, progress=True)
- imdb_prompts = [x for x, y in zip(imdb_prompts, rel) if not y]
- print(f"{topic}: removed {sum(rel)} ({np.mean(rel)*100:.1f}%) new len prompts: {len(imdb_prompts)}")
- open('imdb_prompts_filtered.json', 'w').write(json.dumps(imdb_prompts))
- def generate_big_hooked_batch(prompt_batch: List[str], padd: str, psub: str, act_name: int, coeff: float):
- "Generate in cases where processing prompt_batch all at once would make us run out of memory."
- batch_size = 60 # XXX: REDUCE THIS IF UR NOT ON A h100
- res_strs = []
- disable_progress = len(prompt_batch) < batch_size
- for i in tqdm(range(0, len(prompt_batch), batch_size), desc="Prompt batch", disable=disable_progress):
- batch = prompt_batch[i:i+batch_size]
- hooked_res = generate_hooked(batch, padd, psub, act_name, coeff)
- res_strs += model.to_string(hooked_res[:, 1:])
- return [x.replace("<|endoftext|>", "") for x in res_strs]
- COEFFICIENTS = [0.0, 0.5, 1.0, 2.0]
- SEED = 0
- completions = {}
- results = {}
- PROMPT_TEMPLATE = "I talk about {topic} constantly"
- NEG_PROMPT_TEMPLATE = "I do not talk about {topic} constantly"
- # Initialize results
- for topic in TOPICS:
- results[topic] = []
- threads = []
- # Generate completions and spawn related tasks
- for topic in tqdm(TOPICS, desc="topics"):
- prompt_add, prompt_sub = pad_both(PROMPT_TEMPLATE.format(topic=topic),
- NEG_PROMPT_TEMPLATE.format(topic=topic))
- completions[topic] = {}
- for coeff in COEFFICIENTS:
- hooked_strs = generate_big_hooked_batch(PROMPT_BATCH, prompt_add, prompt_sub,
- act_name, coeff)
- completions[topic][coeff] = hooked_strs
- def run_related(t=topic, c=coeff, h=hooked_strs):
- rel = related(h, t, num_workers=100, progress=False)
- score = np.mean(rel)
- results[t].append(score)
- print(f"topic {t} coeff {c} -- rel: {score}")
- thread = Thread(target=run_related)
- thread.start()
- threads.append(thread)
- # Wait for all threads
- print("Waiting for threads to finish...")
- for thread in threads:
- thread.join()
- import time
- fname = f"results/{int(time.time())}.json"
- with open(fname, "w") as f:
- json.dump({
- "results": results,
- "completions": completions,
- "meta": {
- "coeffs": COEFFICIENTS,
- "topics": TOPICS,
- "act_name": act_name,
- "seed": SEED,
- "prompt_template": PROMPT_TEMPLATE,
- "neg_prompt_template": NEG_PROMPT_TEMPLATE,
- },
- "prompts": PROMPT_BATCH
- }, f)
- print(f"Saved results to {fname}")
- # Create the topic steering plot WITH baseline included
- plt.figure(figsize=(12, 6))
- x = np.arange(len(TOPICS))
- width = 0.2
- subtract = True
- for i, coeff in enumerate(COEFFICIENTS):
- relevance_scores = [results[topic][i]*100 for topic in TOPICS]
- plt.bar(x + i * width, relevance_scores, width, label=f'c={coeff}')
- plt.xlabel('Topic')
- plt.ylabel('% evaluated as relevant')
- plt.title('gpt-4o-mini scored relevance of ActAdd completions on generic topics')
- plt.xticks(x + width * (len(COEFFICIENTS) - 1) / 2, TOPICS)
- plt.legend()
- plt.tight_layout()
- plt.show()
- # Create the visualization WITHOUT baseline included (show diffs)
- plt.figure(figsize=(12, 6))
- x = np.arange(len(TOPICS))
- width = 0.2
- subtract = True
- for i, coeff in enumerate(COEFFICIENTS[1:]):
- relevance_scores = [results[topic][i]*100 for topic in TOPICS]
- plt.bar(x + i * width, relevance_scores, width, label=f'c={coeff}')
- plt.xlabel('Topic')
- # plt.ylabel('% evaluated as relevant minus no-editing baseline')
- plt.ylabel('Δ% Topic Relevance (vs Unedited)')
- plt.title('gpt-4o-mini scored relevance of ActAdd completions on generic topics')
- plt.xticks(x + width * (len(COEFFICIENTS[1:]) - 1) / 2, TOPICS)
- plt.legend()
- plt.tight_layout()
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment