Guest User

Untitled

a guest
Sep 26th, 2024
1,066
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.31 KB | None | 0 0
  1. from flask import Flask, request, jsonify, Response
  2. import requests
  3. import threading
  4. import logging
  5. import json
  6.  
  7. app = Flask(__name__)
  8.  
  9. # Configure logging
  10. logging.basicConfig(level=logging.INFO)
  11.  
  12. # Target server configuration
  13. TARGET_SERVER = 'http://localhost:8080'
  14.  
  15. # Modify this function to add max_tokens to the request body
  16. def modify_request_body(original_body):
  17. try:
  18. data = json.loads(original_body)
  19. # Add or overwrite the max_tokens field
  20. data['max_tokens'] = 8192
  21. modified_body = json.dumps(data)
  22. return modified_body
  23. except json.JSONDecodeError:
  24. logging.error("Failed to decode JSON from the request body.")
  25. return original_body
  26.  
  27. @app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'])
  28. def proxy(path):
  29. # Construct the full URL to the target server
  30. url = f"{TARGET_SERVER}/{path}"
  31.  
  32. # Log the incoming request
  33. logging.info(f"Incoming request to: {request.method} {request.url}")
  34.  
  35. # Headers to send to the target server
  36. headers = {key: value for key, value in request.headers if key != 'Host'}
  37.  
  38. # Get the request data
  39. if request.data:
  40. # Modify the request body to add max_tokens
  41. modified_data = modify_request_body(request.data)
  42. else:
  43. modified_data = None
  44.  
  45. # Send the request to the target server
  46. response = requests.request(
  47. method=request.method,
  48. url=url,
  49. headers=headers,
  50. data=modified_data,
  51. params=request.args,
  52. stream=True
  53. )
  54.  
  55. # Prepare the response to relay back to the client
  56. excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection']
  57. headers = [(name, value) for (name, value) in response.raw.headers.items()
  58. if name.lower() not in excluded_headers]
  59.  
  60. def generate():
  61. for line in response.iter_lines(decode_unicode=True):
  62. # Each line corresponds to a complete SSE event
  63. if line.startswith('data: '):
  64. data = line[len('data: '):]
  65. if data == '[DONE]':
  66. yield f'data: {data}\n\n'
  67. else:
  68. try:
  69. # Parse the JSON data
  70. json_data = json.loads(data)
  71. # Modify the finish_reason if it is 'length'
  72. choices = json_data.get('choices', [])
  73. for choice in choices:
  74. if choice.get('finish_reason') == 'length':
  75. choice['finish_reason'] = 'stop'
  76. # Serialize back to JSON
  77. modified_data = json.dumps(json_data)
  78. # Yield the modified event
  79. yield f'data: {modified_data}\n\n'
  80. except json.JSONDecodeError:
  81. # If parsing fails, yield the line as is
  82. yield line + '\n\n'
  83. else:
  84. # Yield non-data lines as is
  85. yield line + '\n\n'
  86.  
  87. return Response(generate(), status=response.status_code, headers=headers)
  88.  
  89. if __name__ == '__main__':
  90. # Run the proxy server on port 8090
  91. app.run(host='0.0.0.0', port=8090, threaded=True)
Advertisement
Add Comment
Please, Sign In to add comment