Advertisement
nocapmeta

Untitled

Aug 30th, 2023
44
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.28 KB | None | 0 0
  1. import json
  2. import boto3
  3. import os
  4. import time
  5. from multiprocessing import Process
  6. from io import BytesIO
  7.  
  8. # import openai
  9. import sagemaker
  10. from sagemaker import Session
  11.  
  12.  
  13. print(json.__version__)
  14. sagemaker_runtime = boto3.client('sagemaker-runtime')
  15.  
  16. openai.api_key = "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"
  17.  
  18. stops = ["<EOD>", "USER:", "{user}", "{character}"]
  19.  
  20. def openai_api(messages):
  21.  
  22. chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages)
  23.  
  24. return (chat_completion.choices[0].message.content)
  25.  
  26. def build_llama2_prompt(messages):
  27. startPrompt = "<s>[INST] "
  28. endPrompt = " [/INST]"
  29. conversation = []
  30. for index, message in enumerate(messages):
  31. if message["role"] == "system" and index == 0:
  32. conversation.append(f"<<SYS>>\n{message['content']}\n<</SYS>>\n\n")
  33. elif message["role"] == "user":
  34. conversation.append(message["content"].strip())
  35. else:
  36. conversation.append(f" [/INST] {message['content'].strip()}</s><s>[INST] ")
  37. return startPrompt + "".join(conversation) + endPrompt
  38.  
  39. def genPrompt(messages):
  40. prompt = ""
  41. for index, r in enumerate(messages):
  42. if index==0 and r['role'] == "system":
  43. prompt += r['content']
  44. # print(r['role'])
  45. if r['role'] == "user":
  46. prompt += 'USER: '+r['content']
  47. if r['role'] == "assistant":
  48. prompt += ' ASSISTANT: '+r['content']+' <EOD> '
  49. prompt += ' ASSISTANT: '
  50. return prompt
  51.  
  52. def cleanStops(result):
  53. for stop in stops:
  54. result = result.strip(stop).strip()
  55. return result
  56.  
  57. def local_llm(endpoint_name, messages):
  58. client = boto3.client('runtime.sagemaker')
  59. prompt = genPrompt(messages)
  60. parameters = {
  61. "max_new_tokens": 512,
  62. "temperature": 0.6,
  63. "top_k": 1,
  64. "repeatation_penalty": 1.17,
  65. "stop": stops
  66. }
  67. payload = {"inputs":prompt, "parameters": parameters, "options": {'use_cache': False}}
  68. payload = json.dumps(payload)
  69. try:
  70. response = sagemaker_runtime.invoke_endpoint(
  71. EndpointName=endpoint_name,
  72. Body=payload,
  73. ContentType="application/json"
  74. )
  75. result = response['Body'].read().decode('utf-8')
  76. result_arr_str = json.loads(result)[0]["generated_text"]
  77. result_arr = result_arr_str.split("ASSISTANT: ")
  78. result = result_arr[len(result_arr)-1].replace('<EOD>', '').strip()
  79. result = cleanStops(result)
  80. return result
  81. except Exception as e:
  82. print(e)
  83. return
  84.  
  85.  
  86. def lambda_handler(event, context):
  87.  
  88. question = event['text']
  89. # print('prompts: ', text)
  90.  
  91. endpoint_name = "aynaa-llm-image-2023-08-29-12-13-47-852"
  92. local_llm_flag = True
  93.  
  94. ai_name = "Esther"
  95. pre_prompt = open('Esther.txt', 'r').read()
  96. messages = [
  97. {"role": "system","content": pre_prompt}
  98. ]
  99. question_row = {"role": "user", "content": question}
  100. messages.append(question_row)
  101. if local_llm_flag == True:
  102. answer = local_llm(endpoint_name, messages)
  103. else:
  104. answer = openai_api(messages)
  105.  
  106. return {
  107. 'statusCode': 200,
  108. 'msg': answer,
  109. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement