Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # %%
- import random
- import openai
- import nest_asyncio
- import asyncio
- from pathlib import Path
- import json
- nest_asyncio.apply()
- # %%
- def get_list():
- return [random.randint(0, 10) for _ in range(10)]
- k = 32
- shot_template = """
- Unsorted list: {unsorted_list}
- Sorted list: {sorted_list}"""
- few_shot_examples = [get_list() for _ in range(k)]
- few_shot_prompt = "\n".join(
- [shot_template.format(unsorted_list=shot, sorted_list=sorted(shot)) for shot in few_shot_examples]
- )
- few_shot_template = (
- few_shot_prompt
- + """
- Unsorted list: {unsorted_list}
- Sorted list:"""
- )
- nothing_template = """
- Unsorted list: {unsorted_list}
- Sorted list:"""
- python_template = """The sort function can be used to sort a list in ascending, descending or user defined
- order.
- To sort the list in ascending order, simply call list.sort(). This will sort a list
- of integers in ascending order so that the smallest integer will be first in the list
- and the largest integer will be the last.
- For example:
- list = {unsorted_list}
- list.sort() ="""
- async def run_prompt(prompt: str, model: str = "davinci-002"):
- for _ in range(20):
- try:
- await asyncio.sleep(random.random() * 5)
- response = await openai.Completion.acreate(
- engine=model,
- prompt=prompt,
- max_tokens=40,
- n=1,
- temperature=0.0,
- )
- return response.choices[0].text.split("\n")[0].strip()
- except Exception as e:
- # print(e)
- continue
- print(f"Failed to get response for {prompt}")
- return ""
- templates = {
- "Nothing": nothing_template,
- "Few-shot": few_shot_template,
- "Python": python_template,
- }
- async def run(model="davinci-002"):
- lists = [get_list() for _ in range(200)]
- for name, template in templates.items():
- answers = await asyncio.gather(*[run_prompt(template.format(unsorted_list=shot), model) for shot in lists])
- correctness = [str(sorted(shot)) == ans for shot, ans in zip(lists, answers)]
- acc = sum(correctness) / len(correctness)
- print(f"{name} accuracy: {acc:.2f} +- {1.96 * (acc * (1 - acc) / len(correctness)) ** 0.5:.2f}")
- asyncio.run(run())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement