Image

GraphQL-like features in Django Rest Framework

Dani Hodovic May 23, 2024 6 min read
Responsive image

I was recently working for a client that had built their backend built with GraphQL and Graphene. They were asking me how to streamline building features in Django because their lead engineers had no prior Django experience.

Having worked with a number of REST and specifically DRF backends I have to admit that I don't like GraphQL. Or perhaps I don't like how it's implemented in Django land. It feels clunkly and doesn't seem to fit the design philosophy of Django.

  • I don't like how GraphQL breaks from traditional HTTP patterns. HTTP as a protocol has stood the test of time across a myriad of changes in web technology. From Java servlets, to Rails, microservices and single page apps. If I had to make a bet I'd say HTTP will still be here 10 years from now
  • as a consequence I don't like how GraphQL breaks traditional scaling and caching patterns that we've built on top of HTTP and REST
  • REST has a knowledge and talent pool that's much deeper to draw from which has been standardized across many languages and frameworks. It makes Stackoverflow and ChatGPT driven development easier.
  • lastly and importantly I have no experience in GraphQL. I don't like re-learning things that work and generally only learn new things if they multiply productivity many times over.

In particular I don't like Graphene and Django-Graphene because:

  • it's not batteries included. Basic features like auth and caching are scattered across a myriad of repositories most of which are unmaintained.
  • it doesn't automatically generate schema and methods, like DRF serializers and viewsets do
  • lacks boilerplate for CRUD, making me spend a substantial amount of time on code that should be automatic or generated
  • it swallows errors, making debugging a mess
  • doesn't provide built in docs like DRF and DRF Spectacular do
  • doesn't provide built-in auth like DRF does
  • doesn't integrate with django-debug-toolbar

On the other hand frontend engineers tend to like GraphQL because it allows them to fetch data across many tables in a single query. This is great. GraphQL is also touted as a performance tool for low-bandwidth devices, but I'm skeptical if this is worth dumping REST for.

Could we combine the benefits of DRF and steal some inspiration from GraphQL...?

drf-flex-fields to the rescue!

Robert Singer has written a drf-flex-fields, a package that emulates GraphQL like query features in REST. It allows you to:

  • ask for what fields to you want to see, e.g users/?fields=id,first_name
  • remove the fields you want to omit, e.g users/?omit=id,first_name
  • return related resources in the same request, e.g users/?expand=organization.owner.roles

The last feature, returning related resources in a nested graph, is probably the most useful and requested feature by frontend engineers. They typically don't want to issue multiple requests and await all responses to return, either because of complexity or because of bandwidth limitations (mobile devices).

Below is an example of an app with a user model and a many-to-many relation to an organization model (a user could be a member of many organizations). I've created an example repository to demo the implementation, but below is a snippet.

from django.contrib.auth import get_user_model
from django.db import models

User = get_user_model()

class Organization(models.Model):
    name = models.CharField(max_length=50)
    users = models.ManyToManyField(to=User)

Let's request a list of users and return the related organization_set model:

curl http://localhost:8000/api/v1/users/?expand=organization_set
HTTP 200 OK
Allow: GET, HEAD, OPTIONS
Content-Type: application/json

{
    "count": 1,
    "results": [
        {
            "id": 1,
            "name": "Dani Hodovic",
            "created_at": "2024-04-22T18:20:46.263286Z",
            "organization_set": [
                {
                    "id": 1,
                    "name": "My Great Organization",
                    "created_by": 1
                }
            ]
        }
    ]
}

I could go one step further and expand the nested created_by property of the organization which refers to the parent model.

curl http://localhost:8000/api/v1/users/?expand=organization_set.created_by
HTTP 200 OK
Allow: GET, HEAD, OPTIONS
Content-Type: application/json

{
    "count": 1,
    "results": [
        {
            "id": 1,
            "name": "Dani Hodovic",
            "created_at": "2024-04-22T18:20:46.263286Z",
            "organization_set": [
                {
                    "id": 1,
                    "name": "My Great Organization",
                    "created_by": {
                        "id": 1,
                        "name": "Dani Hodovic",
                        "created_at": "2024-04-22T18:20:46.263286Z",
                    },
                }
            ]
        }
    ]
}

The magic works by extending your serializer class from rest_flex_fields.FlexFieldsModelSerializer and setting the meta property expandable_fields.

from rest_flex_fields import FlexFieldsModelSerializer

class OrganizationSerializer(FlexFieldsModelSerializer):
    class Meta:
        model = Organization
        fields = ["name", "users"]
        expandable_fields = {"users": (UserSerializer, {"many": True})}

The FlexFieldsModelSerializer class takes care of the rest!

I recommend curious frontend engineers to use the star (*) value to discover what fields can be expanded. This will attempt to expand any nested payload that inherits from FlexFieldsModelSerializer.

curl http://localhost:8000/api/v1/users/?expand=*

Be careful however, the nested expansion performs no limiting or pagination of the results, you could be returning a gigantic array if you're scanning a related table.

The Django settings for the project allow you to change the query parameters used and limit the maximum expansion depth.

Magically discovering expandable_fields

I realized that we could make the expansion magic more interesting. Given that the Django model metadata exposes the relations to other models I guessed that we could discover all of the expandable_fields in the graph, no matter the distance (level of nesting) in the model hierarchy. At the bottom of the page is a "magic" snippet that extends the FlexFieldsModelSerializer further by automatically setting the expandable_fields property to all related fields.

The result is that we can inherit from the MagicFlexFieldsModelSerializer and all related properties will be expandable by default. We don't have to statically specify expandable_fields.

from myapp.api.serializers.utils import MagicFlexFieldsModelSerializer

class UserSerializer(MagicFlexFieldsModelSerializer):
    class Meta:
        model = User
        fields = [
            "id",
            "name",
        ]

This is great in case you're running a private API that only your frontend engineers have access to, or if you're early stage and favor development speed over stability. It could be a little dangerous if you're exposing a large public API because it would allow attackers to flood your servers with traffic.

Below is the source code for the MagicFlexFieldsModelSerializer:

import difflib
import importlib
import logging
from functools import lru_cache

from django.db.models import ForeignKey, ManyToManyField
from rest_flex_fields.serializers import FlexFieldsModelSerializer


SERIALIZERS_DIR = "myapp.api.serializers"


class MagicFlexFieldsModelSerializer(FlexFieldsModelSerializer):
    """
    Magically expands all related fields by inspecting related fields for a model.
    """

    def __init__(self, *args, **kwargs):
        self.expandable_fields = discover_expandable_fields(self.Meta.model)
        super().__init__(*args, **kwargs)


def discover_expandable_fields(model):
    """
    Magically assigns `expandable_fields` from all related model fields.
    Discovers the serializer class based on the related model name. Assigns
    expandable_fields:

    The end result looks like this:
    ```
    {
        'created_by': backend.api.serializers.user_serializer.UserSerializer,
        'workspace': backend.api.serializers.workspace_serializer.WorkspaceSerializer,
        'users': (backend.api.serializers.user_serializer.UserSerializer, {'many': True}),
        'archived_by': (backend.api.serializers.user_serializer.UserSerializer, {'many': True}),
        'notified': (backend.api.serializers.user_serializer.UserSerializer, {'many': True}),
        'conversation_messages': (backend.api.serializers.message_serializer.MessageSerializer, {'many': True})
    }
    ```

    See the example at: https://github.com/rsinger86/drf-flex-fields?tab=readme-ov-file#quick-start
    """
    expandable_fields = {}
    serializer_classes = find_serializer_classes()
    for field in model._meta.get_fields():
        if isinstance(field, (ForeignKey, ManyToManyField)):
            related_model_name = field.related_model.__name__
            serializer_cls = find_closest_matching_serializer_by_name(
                related_model_name, serializer_classes
            )
            if field.related_model != serializer_cls.Meta.model:
                logging.debug(
                    f"Unable to find the correct serializer class for: {field.related_model}"
                )
                continue

            if isinstance(field, ForeignKey):
                expandable_fields[field.name] = serializer_cls
            else:
                expandable_fields[field.name] = (
                    serializer_cls,
                    {"many": True},
                )

    for field in model._meta.related_objects:
        related_model_name = field.related_model.__name__
        serializer_cls = find_closest_matching_serializer_by_name(
            related_model_name, serializer_classes
        )
        if field.related_model != serializer_cls.Meta.model:
            logging.debug(
                f"Unable to find the correct serializer class for: {field.related_model}"
            )
            continue

        accessor_name = field.get_accessor_name()
        expandable_fields[accessor_name] = (serializer_cls, {"many": True})

    return expandable_fields


def find_closest_matching_serializer_by_name(input_str, serializer_classes):
    matcher = difflib.SequenceMatcher(None, input_str, str(serializer_classes[0]))
    closest_match = serializer_classes[0]
    max_similarity = matcher.ratio()
    for cls in serializer_classes[1:]:
        matcher = difflib.SequenceMatcher(None, input_str, str(cls))
        similarity = matcher.ratio()
        if similarity > max_similarity:
            closest_match = cls
            max_similarity = similarity

    return closest_match


@lru_cache(maxsize=None)
def find_serializer_classes():
    """
    Technically you could get this from the DRF router object via the registered views.
    """
    init_module = importlib.import_module(SERIALIZERS_DIR)
    serializer_classes = [
        getattr(init_module, name)
        for name in dir(init_module)
        if callable(getattr(init_module, name)) and "Serializer" in name
    ]

    return serializer_classes