Skip to content

#233: SQLAlchemy and FastAPI

Until now we kept our data in a variable. While that worked with an example application, the data vanishes as soon as we restart our API. To get a more realistic application, we need to persist data for a longer time. Let us explore how we can integrate SQLAlchemy with FastAPI.

The extended to-do application

The current state of the to-do application is a good basic example for FastAPI. I copied the application as it is after the introduction of the router into a new folder named extended_todo. Here we can add the new features for the rest of this series on FastAPI.

Install SQLAlchemy

I covered SQLAlchemy in great detail a few years back. Since then, version 2 was released, and we need a few changes in our examples. We can install (or update) SQLAlchemy with this command:

pip install -U SQLAlchemy

How can we integrate SQLAlchemy and FastAPI?

When it comes to the integration of SQLAlchemy in FastAPI, we get a lot of flexibility. For the models and entities, we can choose between these 2 options:

  1. Dedicated models for Pydantic and SQLAlchemy.
  2. One combined model for Pydantic and SQLAlchemy (for example with SQLModel)

When we know how to split the models, we need to decide how we want to access the database:

  1. Through a repository that hides the SQLAlchemy commands from our application.
  2. Our models interact directly with the database (like the Active record pattern).
  3. SQLAlchemy commands in our API endpoints.

All options have their pros and cons, so it is up to the specific application to make the trade-off.

I like the in-memory data store I currently have with its separation of concerns and the two Pydantic models to split the input from the output. I will continue with this idea and create a separate entity Task that is a SQLAlchemy entity that maps to the table. My datastore will continue to use the Pydantic models as input and output but works with the Task entity behind the scenes.

The SQLAlchemy configuration

We need to tell SQLAlchemy what kind of database we want to use and create some setup code to wire everything together. For that we add the database.py file inside the data folder with this code:

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .__all_models import *
from .entitybase import EntityBase

def create_session_factory(db_file: str) -> sessionmaker:
    engine = create_engine(
        'sqlite:///' + db_file, connect_args={"check_same_thread": False}
    )
    create_tables(engine)
    factory = sessionmaker(autocommit=False, autoflush=False, bind=engine)
    return factory

def create_tables(engine):
    EntityBase.metadata.create_all(engine)

The create_session_factory() method sets everything up and creates a session factory for us, that we will use as our entry into the database. This is similar to the official tutorial, but I prefer to set the file name for the database from the outside.

The database models

To create our classes that match our tables, we need a base class that creates the SQLAlchemy magic behind the scenes. We can put that code into entitybase.py in the data folder:

1
2
3
from sqlalchemy import orm

EntityBase = orm.declarative_base()

For our Task entity (the class that matches the table), we can create the file entities.py in the data folder and put it there:

from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String

from .entitybase import EntityBase


class Task(EntityBase):
    __tablename__ = "tasks"

    id = Column(Integer, primary_key=True)
    name =  Column(String)
    priority = Column(Integer)
    due_date = Column(Date)
    done = Column(Boolean, default=False)
    created_at = Column(DateTime)

As a final step we create the file data/__all_models.py to have one place that contains all our entities – that way SQLAlchemy can create them when it initialises the database:

from .entities import Task

The tests for the new data store

We continue our test first approach and start with the tests for the database store we want to use with SQLAlchemy. As always, start with a test, implement enough code to make it work and then repeat. To shorten this post, here are all the tests that we need:

import os

from ..data.database import create_session_factory
from ..data.datastore_db import DataStoreDb
from ..models.todo import TaskOutput, TaskInput
from datetime import date, datetime, timedelta
import pytest


@pytest.fixture(scope="session")
def with_db():
    db_file = os.path.join(
        os.path.dirname(__file__),
        '..',
        'db',
        'test_db.sqlite')

    factory = create_session_factory(db_file)
    session = factory()

    try:
        yield session
    finally:
        session.close()


def test_can_add_entry(with_db):
    current_time = datetime.now()
    entry = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False)
    store = DataStoreDb(with_db)

    data = store.add(entry)

    assert data.name == "a simple task"
    assert data.priority == 1
    assert data.due_date == date.today()
    assert data.done == False
    assert data.created_at == date.today()
    assert data.id >= 1


def test_can_add_multiple_entries(with_db):
    entry_a = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False)
    entry_b = TaskInput(name="b simple task", priority=2, due_date=date.today(), done=False)
    store = DataStoreDb(with_db)

    data_a = store.add(entry_a)
    data_b = store.add(entry_b)

    assert data_a.id < data_b.id


def test_can_get_specific_entry_back(with_db):
    entry_a = TaskInput(name="Find a specific task", priority=1, due_date=date.today(), done=False)
    store = DataStoreDb(with_db)
    saved = store.add(entry_a)

    entry = store.get(saved.id)

    assert saved == entry


def test_missing_entry_gets_None_back(with_db):
    store = DataStoreDb(with_db)

    entry = store.get(-1)

    assert entry == None


def test_can_get_all_entrries_back(with_db):
    entry_a = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False)
    entry_b = TaskInput(name="b simple task", priority=2, due_date=date.today(), done=False)
    entry_c = TaskInput(name="b simple task", priority=2, due_date=date.today(), done=False)
    store = DataStoreDb(with_db)
    store.add(entry_a)
    store.add(entry_b)
    store.add(entry_c)

    entries = store.all()

    assert len(entries) >= 3


def test_can_delete_entry(with_db):
    entry_a = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False)
    store = DataStoreDb(with_db)
    saved = store.add(entry_a)

    store.delete(saved.id)

    result = store.get(saved.id)
    assert result == None


def test_can_update_entry(with_db):
    old = TaskInput(name="a simple task", priority=1, due_date=date.today(), done=False)
    store = DataStoreDb(with_db)
    old_saved = store.add(old)

    new = TaskInput(name="b simple task", priority=2, due_date=date.today() + timedelta(days=2), done=True)
    store.update(old_saved.id, new)

    entry = store.get(old_saved.id)
    assert entry.name == "b simple task"
    assert entry.priority == 2
    assert entry.due_date == date.today() + timedelta(days=2)
    assert entry.done == True


def test_non_existing_entry_cannot_be_updated(with_db):
    store = DataStoreDb(with_db)

    new = TaskInput(name="b simple task", priority=2, due_date=date.today() + timedelta(days=2), done=True)
    with pytest.raises(ValueError) as e_info:
        store.update(-123, new)
    assert str(e_info.value) == "no taks known with id '-123'"

The Pytest fixture with_db() configures the database using our above created factory.

Make sure that you have a db folder next to the data and test folders.

Implementing the data store

We can use our tests and implement the new DataStoreDb class inside data/datastore_db.py:

from sqlalchemy.orm import Session
from datetime import date, datetime

from ..models.todo import TaskInput, TaskOutput
from .entities import Task

class DataStoreDb:
    def __init__(self, db: Session):
        self.db = db


    def add(self, entry: TaskInput) -> TaskOutput:
        task = Task(created_at=datetime.now(), **dict(entry))        
        self.db.add(task)
        self.db.commit()

        return self.__to_output(task) 


    def get(self, id: int) -> TaskOutput:
        result = self.db.query(Task) \
            .filter(Task.id == id) \
            .first()

        if result:
            return self.__to_output(result)
        else:
            return None


    def all(self):
        entries = self.db.query(Task).all()
        results = []

        for entry in entries:
            results.append(self.__to_output(entry))

        return results


    def delete(self, id: int) -> None:
        entry = self.db.query(Task) \
            .filter(Task.id == id) \
            .first()

        if entry:
            self.db.delete(entry)
            self.db.commit()


    def update(self, id: int, update: TaskInput) -> TaskOutput:
        entry = self.db.query(Task) \
            .filter(Task.id == id) \
            .first()

        if entry: 
            entry.name = update.name
            entry.priority = update.priority
            entry.due_date = update.due_date
            entry.done = update.done
            self.db.commit()

            return self.__to_output(entry)
        else: 
            raise ValueError(f"no taks known with id '{id}'")


    def __to_output(self, entity: Task) -> TaskOutput:
        return TaskOutput(id=entity.id, 
                          name=entity.name, 
                          priority=entity.priority, 
                          due_date=entity.due_date,
                          done=entity.done, 
                          created_at=date.today())

The private method __to_output() turns our Task into a TaskOutput. The different data types of created_at are there by design so that we can see how we can create a converter to map from the database objects to the ones we use in FastAPI.

With this code in place, we can run our tests. Everything should pass, including the old tests.

Use the data store in the FastAPI application

Since we moved the /todo endpoints into a router file, we need to open routers/todo.py, wire up the database and replace the datastore:

from ..data.database import create_session_factory

router = APIRouter()


async def get_db():
    """
    Creates the datastore 
    """
    db_file = os.path.join(
        os.path.dirname(__file__),
        '..',
        'db',
        'todo_api.sqlite')

    factory = create_session_factory(db_file)
    session = factory()
    db = DataStoreDb(session)
    try:
        yield db
    finally:
        session.close()

# db = DataStoreDb(session)

We need to change all endpoint methods and add the dependency for our get_db() method:

@router.get("/")
async def show_all_tasks(filter: Annotated[dict, Depends(filter_parameters)], 
                         db: DataStoreDb = Depends(get_db)) -> List[TaskOutput]:
    ...


@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_task(task: TaskInput, 
                      request: Request, 
                      db: DataStoreDb = Depends(get_db)) -> TaskOutput:
    ...


@router.get("/{id}")
async def show_task(id: int, 
                    db: DataStoreDb = Depends(get_db)) -> TaskOutput:
    ...


@router.put("/{id}")
async def update_task(id: int, 
                      task: TaskInput, 
                      db: DataStoreDb = Depends(get_db)) -> TaskOutput:
    ...


@router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_task(id: int, 
                      db: DataStoreDb = Depends(get_db)) -> None:
    ...

The rest of the code stays the same. Even better, our tests for our endpoints need not to change at all.

Next

We moved from the in-memory data store to SQLAlchemy, and our API works as before. We have the tests to check that this is not only a claim, but reality.

However, there is something that could be a problem. Our endpoint tests write into the database we defined in the router and not into the same place as the datastore tests. Next week we optimise the database tests and fix that problem.