Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import json
- import boto3
- import os
- import time
- from multiprocessing import Process
- from io import BytesIO
- # import openai
- import sagemaker
- from sagemaker import Session
- print(json.__version__)
- sagemaker_runtime = boto3.client('sagemaker-runtime')
- openai.api_key = "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"
- stops = ["<EOD>", "USER:", "{user}", "{character}"]
- def openai_api(messages):
- chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages)
- return (chat_completion.choices[0].message.content)
- def build_llama2_prompt(messages):
- startPrompt = "<s>[INST] "
- endPrompt = " [/INST]"
- conversation = []
- for index, message in enumerate(messages):
- if message["role"] == "system" and index == 0:
- conversation.append(f"<<SYS>>\n{message['content']}\n<</SYS>>\n\n")
- elif message["role"] == "user":
- conversation.append(message["content"].strip())
- else:
- conversation.append(f" [/INST] {message['content'].strip()}</s><s>[INST] ")
- return startPrompt + "".join(conversation) + endPrompt
- def genPrompt(messages):
- prompt = ""
- for index, r in enumerate(messages):
- if index==0 and r['role'] == "system":
- prompt += r['content']
- # print(r['role'])
- if r['role'] == "user":
- prompt += 'USER: '+r['content']
- if r['role'] == "assistant":
- prompt += ' ASSISTANT: '+r['content']+' <EOD> '
- prompt += ' ASSISTANT: '
- return prompt
- def cleanStops(result):
- for stop in stops:
- result = result.strip(stop).strip()
- return result
- def local_llm(endpoint_name, messages):
- client = boto3.client('runtime.sagemaker')
- prompt = genPrompt(messages)
- parameters = {
- "max_new_tokens": 512,
- "temperature": 0.6,
- "top_k": 1,
- "repeatation_penalty": 1.17,
- "stop": stops
- }
- payload = {"inputs":prompt, "parameters": parameters, "options": {'use_cache': False}}
- payload = json.dumps(payload)
- try:
- response = sagemaker_runtime.invoke_endpoint(
- EndpointName=endpoint_name,
- Body=payload,
- ContentType="application/json"
- )
- result = response['Body'].read().decode('utf-8')
- result_arr_str = json.loads(result)[0]["generated_text"]
- result_arr = result_arr_str.split("ASSISTANT: ")
- result = result_arr[len(result_arr)-1].replace('<EOD>', '').strip()
- result = cleanStops(result)
- return result
- except Exception as e:
- print(e)
- return
- def lambda_handler(event, context):
- question = event['text']
- # print('prompts: ', text)
- endpoint_name = "aynaa-llm-image-2023-08-29-12-13-47-852"
- local_llm_flag = True
- ai_name = "Esther"
- pre_prompt = open('Esther.txt', 'r').read()
- messages = [
- {"role": "system","content": pre_prompt}
- ]
- question_row = {"role": "user", "content": question}
- messages.append(question_row)
- if local_llm_flag == True:
- answer = local_llm(endpoint_name, messages)
- else:
- answer = openai_api(messages)
- return {
- 'statusCode': 200,
- 'msg': answer,
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement