Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- @csrf_protect
- @api_view(['POST'])
- @permission_classes([IsAuthenticated])
- async def generate_report_batch(request):
- """
- Handles the creation and async generation of multiple reports.
- """
- logger.debug(f"Async batch generation request received with data: {request.data}")
- student_reports_data = request.data.get('reports')
- word_count = request.data.get('word_count')
- if not isinstance(student_reports_data, dict) or not word_count:
- return Response({'error': 'Invalid payload format. "reports" and "word_count" are required.'}, status=status.HTTP_400_BAD_REQUEST)
- report_list_data = list(student_reports_data.values())
- batch_size = len(report_list_data)
- if not request.user.institution.can_generate() or request.user.institution.remaining_reports < batch_size:
- return Response({'error': 'Not enough report tokens to generate the entire batch.'}, status=status.HTTP_402_PAYMENT_REQUIRED)
- # Report configuration entries
- created_reports = []
- all_serializers_valid = True
- errors = []
- for report_data in report_list_data:
- serializer = ReportSerializer(data=report_data, context={'request': request})
- if serializer.is_valid():
- try:
- student = await Student.objects.aget(pk=report_data.get('student'))
- if student.institution != request.user.institution:
- raise Http404
- report_instance = await serializer.asave(user=request.user)
- created_reports.append(report_instance)
- errors.append({})
- except (Student.DoesNotExist, Http404):
- all_serializers_valid = False
- errors.append({'student': 'Student not found or does not belong to your institution.'})
- else:
- all_serializers_valid = False
- errors.append(serializer.errors)
- if not all_serializers_valid:
- await Report.objects.filter(id__in=[r.id for r in created_reports]).adelete()
- return Response({'error': 'Validation failed for one or more reports.', 'details': errors}, status=status.HTTP_400_BAD_REQUEST)
- # Prepare and run CONCURRENT LLM calls
- tasks = []
- for report in created_reports:
- student_details = report.create_prompt()
- # Creating a coroutine for each API call and adding it to task list
- tasks.append(prompt.call_ai_async(student_details, word_count))
- # return_exceptions=True so that one failure won't stop the others.
- ai_results = await asyncio.gather(*tasks, return_exceptions=True)
- # Process results and save GeneratedReport objects
- final_response_data = {}
- successful_reports = 0
- total_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0, "total_spending_usd": Decimal(0), "total_spending_gbp": Decimal(0)}
- for i, result in enumerate(ai_results):
- report_instance = created_reports[i] # Matching result to the original report by index
- if isinstance(result, Exception):
- logger.error(f"Error generating report for report_id {report_instance.id}: {result}")
- continue # move to the next
- # Process successful result
- usage = prompt.get_token_usage_and_cost(result)
- for key in total_usage: # Aggregate usage stats
- total_usage[key] += usage[key]
- student_report_text = result.choices[0].message.content.strip()
- de_anon_report = report_instance.de_anon_report(student_report_text)
- report_word_count = len(de_anon_report.split())
- report_char_count = len(de_anon_report)
- gen_report_data = {
- 'student_id': report_instance.student.id,
- 'user': request.user.id,
- 'report': report_instance.id,
- 'report_content': de_anon_report,
- 'word_count': report_word_count,
- 'char_count': report_char_count,
- 'batch_mode': True
- }
- gen_report_serializer = GeneratedReportSerializer(data=gen_report_data, context={'request': request})
- if gen_report_serializer.is_valid():
- saved_gen_report = await gen_report_serializer.asave()
- # Must serialize the saved instance to include all fields (like student object)
- final_response_data[report_instance.id] = GeneratedReportSerializer(saved_gen_report, context={'request': request}).data
- successful_reports += 1
- else:
- logger.error(f"Failed to serialize generated report for report_id {report_instance.id}: {gen_report_serializer.errors}")
- # Adjusting usage counts based on successful reports
- if successful_reports > 0:
- user = request.user
- await user.institution.aincrement_reports_used(count=successful_reports)
- await user.aincrement_generated_count(count=successful_reports)
- await User.objects.filter(pk=user.pk).aupdate(
- total_input_tokens=F('total_input_tokens') + total_usage["input_tokens"],
- total_output_tokens=F('total_output_tokens') + total_usage["output_tokens"],
- total_tokens=F('total_tokens') + total_usage["total_tokens"],
- total_spending_usd=Coalesce(F('total_spending_usd'), 0, output_field=DecimalField()) + total_usage["total_spending_usd"],
- total_spending_gbp=Coalesce(F('total_spending_gbp'), 0, output_field=DecimalField()) + total_usage["total_spending_gbp"]
- )
- return Response(final_response_data, status=status.HTTP_201_CREATED)
Advertisement
Add Comment
Please, Sign In to add comment