Advertisement
Guest User

Untitled

a guest
Mar 25th, 2023
256
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.50 KB | Software | 0 0
  1. import openai
  2. import transformers
  3. import math
  4. import random
  5. import collections
  6.  
  7. openai.api_key = '<your_api_key>'
  8. # you can get this from http://mattmahoney.net/dc/enwik8.zip
  9. enwikxmlfile = '/path/to/enwik8.xml'
  10.  
  11. enwikxml = open(enwikxmlfile, 'r').read()
  12. tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
  13. enwikxmltoks = tokenizer.encode(enwikxml)
  14.  
  15. def build_markov(tokens, depth):
  16.     counts = {}
  17.     path = [None for i in range(depth + 1)]
  18.     for token in tokens:
  19.         curr = counts
  20.         path = (path + [token])[1:]
  21.         if None in path:
  22.             continue
  23.         for x in path[:-1]:
  24.             if x not in curr:
  25.                 curr[x] = {}
  26.             curr = curr[x]
  27.         if token not in curr:
  28.             curr[token] = 0
  29.         curr[token] += 1
  30.     return counts
  31.  
  32. def compute_markov_logprobs(model, depth, seq, last_n):
  33.     logprobs = []
  34.     path = seq[:depth+1]
  35.     for token in seq[depth+1:]:
  36.         path = (path + [token])[1:]
  37.         curr = model
  38.         for x in path[:-1]:
  39.             curr = curr[x]
  40.         tot_count = sum(curr.values()) + 1 # somthing something laplace
  41.         logprobs.append(math.log(curr[token] / tot_count))
  42.     return logprobs[-last_n:]
  43.  
  44. def compute_oa_logprobs(model_name, tokens, last_n):
  45.     completion = openai.Completion.create(
  46.         model=model_name,
  47.         logprobs=1,
  48.         max_tokens=0,
  49.         echo=True,
  50.         prompt=tokenizer.decode(tokens)
  51.     )
  52.     return completion.choices[0].logprobs.token_logprobs[-last_n:]
  53.  
  54. markov_chars_0 = build_markov(enwikxml, 0)
  55. markov_chars_1 = build_markov(enwikxml, 1)
  56. markov_chars_2 = build_markov(enwikxml, 2)
  57. markov_chars_3 = build_markov(enwikxml, 3)
  58.  
  59. markov_toks_0 = build_markov(enwikxmltoks, 0)
  60. markov_toks_1 = build_markov(enwikxmltoks, 1)
  61. markov_toks_2 = build_markov(enwikxmltoks, 2)
  62. markov_toks_3 = build_markov(enwikxmltoks, 3)
  63.  
  64. def record_trial(bits_per_char, model_name, get_tot_logprob, txt_len):
  65.     try:
  66.         tot_logprob = get_tot_logprob()
  67.         bit_cost_per_char = -tot_logprob / math.log(2) / txt_len
  68.     except Exception as e:
  69.         print(e)
  70.         bit_cost_per_char = None
  71.     if model_name not in bits_per_char:
  72.         bits_per_char[model_name] = []
  73.     bits_per_char[model_name].append(bit_cost_per_char)
  74.  
  75. oa_model_names = [
  76.     'ada',
  77.     'babbage',
  78.     'curie',
  79.     'davinci',
  80.     'text-ada-001',
  81.     'text-babbage-001',
  82.     'text-curie-001',
  83.     'text-davinci-002',
  84.     'text-davinci-003',
  85. ]
  86.  
  87. # results of 10 trials of looking at the bits-per-char for each model
  88.  
  89. bits_per_char = {
  90. }
  91. samples = []
  92. for i in range(3):
  93.     sample_length = 2048
  94.     sample_offset = int(random.random() * (len(enwikxmltoks) - sample_length))
  95.     sample_toks   = enwikxmltoks[sample_offset:sample_offset+sample_length]
  96.     sample_text   = tokenizer.decode(sample_toks)
  97.     txt_len = len(tokenizer.decode(sample_toks[-1024:]))
  98.     samples.append(sample_text)
  99.     record_trial(bits_per_char, 'markov_chars_0', lambda: sum(compute_markov_logprobs(markov_chars_0, 0, list(sample_text), txt_len)), txt_len)
  100.     record_trial(bits_per_char, 'markov_chars_1', lambda: sum(compute_markov_logprobs(markov_chars_1, 1, list(sample_text), txt_len)), txt_len)
  101.     record_trial(bits_per_char, 'markov_chars_2', lambda: sum(compute_markov_logprobs(markov_chars_2, 2, list(sample_text), txt_len)), txt_len)
  102.     record_trial(bits_per_char, 'markov_chars_3', lambda: sum(compute_markov_logprobs(markov_chars_3, 3, list(sample_text), txt_len)), txt_len)
  103.     record_trial(bits_per_char, 'markov_toks_0', lambda: sum(compute_markov_logprobs(markov_toks_0, 0, sample_toks, 1024)), txt_len)
  104.     record_trial(bits_per_char, 'markov_toks_1', lambda: sum(compute_markov_logprobs(markov_toks_1, 1, sample_toks, 1024)), txt_len)
  105.     record_trial(bits_per_char, 'markov_toks_2', lambda: sum(compute_markov_logprobs(markov_toks_2, 2, sample_toks, 1024)), txt_len)
  106.     record_trial(bits_per_char, 'markov_toks_3', lambda: sum(compute_markov_logprobs(markov_toks_3, 3, sample_toks, 1024)), txt_len)
  107.     for model_name in oa_model_names:
  108.         record_trial(bits_per_char, 'openai:' + model_name, lambda: sum(compute_oa_logprobs(model_name, sample_toks, 1024)), txt_len)
  109.  
  110.  
  111. print(f'    {"Model Name:":24s} Bits per char (stddev)')
  112. for model_name, results in bits_per_char.items():
  113.     mean = sum(results) / len(results)
  114.     variance = sum([(x - mean)**2 for x in results])
  115.     stddev = math.sqrt(variance / len(results))
  116.     print(f'    {model_name+":":24s}: {mean:.2f} ({stddev:.2f})')
  117.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement