Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import inspect
- from collections import OrderedDict
- from json import loads, dumps
- from pprint import pprint
- from typing import Type, Union
- from django.core.exceptions import FieldError
- from django.test import TestCase
- from rest_framework.relations import RelatedField, ManyRelatedField
- from rest_framework.serializers import ModelSerializer, Serializer, BaseSerializer, ListSerializer
- from rest_framework.test import APIRequestFactory
- from rest_framework.utils.serializer_helpers import ReturnList
- #This snippet automatically prefetches all related objects a Serializer needs.
- #Example usage in your ViewSets get_queryset method:
- def get_queryset():
- qs = YOUR_MODEL.objects.all() # Or whatever queryset you want to use
- qs = prefetch(self.get_serializer_class(), queryset) # This line prefetches the related model depending on the serializer.
- return qs
- def prefetch(queryset, serializer: Type[ModelSerializer]):
- select_related, prefetch_related = _prefetch(serializer)
- return queryset.select_related(*select_related).prefetch_related(*prefetch_related)
- def _prefetch(serializer: Union[Type[BaseSerializer], BaseSerializer], path=None, indentation=0):
- """
- Returns prefetch_related, select_related
- :param serializer:
- :return:
- """
- prepend = f'{path}__' if path is not None else ''
- class_name = getattr(serializer, '__name__', serializer.__class__.__name__)
- print(f'{" " * indentation}LOOKING AT SERIALIZER:', class_name, 'from path: ', prepend)
- select_related = set()
- prefetch_related = set()
- print()
- if inspect.isclass(serializer):
- print('serializer is a class')
- serializer_instance = serializer()
- else:
- serializer_instance = serializer
- try:
- fields = getattr(serializer_instance, 'child', serializer_instance).fields.fields.items()
- except AttributeError:
- # This can happen if there's no further fields, e.g. if we're passed a PrimaryKeyRelatedField
- # as the nested representation of a ManyToManyField
- return (set(), set())
- for name, field_instance in fields:
- field_type_name = field_instance.__class__.__name__
- print(f'{" " * indentation} Field "{name}", type: {field_type_name}, src: "{field_instance.source}"')
- # We potentially need to recurse deeper
- if isinstance(field_instance, (BaseSerializer, RelatedField, ManyRelatedField)):
- print(f'{" " * indentation}Found: {field_type_name} ({type(field_instance)}) - recursing deeper')
- field_path = f'{prepend}{field_instance.source}'
- # Fields where the field name *is* the model.
- if isinstance(field_instance, RelatedField):
- print(f'{" " * indentation} Found related field: ', field_type_name)
- select_related.add(f'{prepend}{name}')
- """
- If we have multiple entities, we need to use prefetch_related instead of select_related
- We also need to do this for all further calls
- """
- elif isinstance(field_instance, (ListSerializer, ManyRelatedField)):
- print(f'{" " * indentation} Found *:m relation: ', field_type_name)
- prefetch_related.add(field_path)
- # If it's a ManyRelatedField, we can only get the actual underlying field by querying child_relation
- nested_field = getattr(field_instance, 'child_relation', field_instance)
- select, prefetch = _prefetch(nested_field, field_path, indentation + 4)
- prefetch_related |= select
- prefetch_related |= prefetch
- else:
- print(f'{" " * indentation} Found *:1 relation: ', field_type_name)
- select_related.add(field_path)
- select, prefetch = _prefetch(field_instance, field_path, indentation + 4)
- select_related |= select
- prefetch_related |= prefetch
- return (select_related, prefetch_related)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement