Back to blog

Dependency Injection in FastAPI: Complete Guide

pythonfastapidependency-injectiontestingarchitecture
Dependency Injection in FastAPI: Complete Guide

Dependency Injection (DI) is one of FastAPI's most powerful features, yet it's often misunderstood. FastAPI's DI system allows you to write clean, modular, and testable code effortlessly. Let's master it from the ground up.

What is Dependency Injection?

Dependency Injection is a design pattern where components receive their dependencies from external sources rather than creating them internally.

# ❌ Without Dependency Injection
def get_user(user_id: int):
    db = Database()  # Creating dependency internally
    return db.query(User).filter(User.id == user_id).first()
 
# ✅ With Dependency Injection
def get_user(user_id: int, db: Database = Depends(get_db)):
    return db.query(User).filter(User.id == user_id).first()

Benefits

  1. Testability - Easy to mock dependencies
  2. Reusability - Share code across endpoints
  3. Separation of Concerns - Clear component boundaries
  4. Maintainability - Changes in one place
  5. Type Safety - Full editor support

Basic Dependencies

Simple Function Dependency

from fastapi import Depends, FastAPI
 
app = FastAPI()
 
def common_parameters(q: str = None, skip: int = 0, limit: int = 100):
    """Common query parameters."""
    return {"q": q, "skip": skip, "limit": limit}
 
@app.get("/items/")
async def read_items(commons: dict = Depends(common_parameters)):
    """Use common parameters dependency."""
    return commons
 
@app.get("/users/")
async def read_users(commons: dict = Depends(common_parameters)):
    """Reuse the same dependency."""
    return commons

When you call /items/?q=test&skip=10&limit=50, FastAPI:

  1. Calls common_parameters(q="test", skip=10, limit=50)
  2. Returns the result to commons
  3. Your function receives the processed data

Database Session Dependency

from sqlalchemy.orm import Session
from database import SessionLocal
 
def get_db():
    """Database session dependency."""
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()
 
@app.get("/users/{user_id}")
async def get_user(user_id: int, db: Session = Depends(get_db)):
    """Get user with automatic DB session management."""
    user = db.query(User).filter(User.id == user_id).first()
    if not user:
        raise HTTPException(status_code=404, detail="User not found")
    return user

Key Points:

  • yield makes it a generator (context manager)
  • Code after yield runs after the response
  • Perfect for cleanup (closing connections, files, etc.)

Class-Based Dependencies

from typing import Optional
 
class CommonQueryParams:
    """Reusable query parameters as a class."""
    
    def __init__(
        self,
        q: Optional[str] = None,
        skip: int = 0,
        limit: int = 100
    ):
        self.q = q
        self.skip = skip
        self.limit = limit
 
@app.get("/items/")
async def read_items(commons: CommonQueryParams = Depends(CommonQueryParams)):
    """Use class-based dependency."""
    response = {"skip": commons.skip, "limit": commons.limit}
    if commons.q:
        response["q"] = commons.q
    return response
 
# Shorthand syntax (same result)
@app.get("/items/")
async def read_items(commons: CommonQueryParams = Depends()):
    """FastAPI infers the dependency from the type hint."""
    return commons

Sub-Dependencies (Dependencies of Dependencies)

from fastapi import Header, HTTPException
 
def verify_token(x_token: str = Header(...)):
    """Verify authentication token."""
    if x_token != "fake-super-secret-token":
        raise HTTPException(status_code=400, detail="Invalid X-Token header")
    return x_token
 
def verify_key(x_key: str = Header(...)):
    """Verify API key."""
    if x_key != "fake-super-secret-key":
        raise HTTPException(status_code=400, detail="Invalid X-Key header")
    return x_key
 
def get_current_user(
    token: str = Depends(verify_token),
    key: str = Depends(verify_key),
    db: Session = Depends(get_db)
):
    """
    Get current user - depends on verify_token, verify_key, and get_db.
    FastAPI resolves all sub-dependencies automatically.
    """
    # Both token and key are verified before this runs
    user = db.query(User).filter(User.token == token).first()
    if not user:
        raise HTTPException(status_code=404, detail="User not found")
    return user
 
@app.get("/users/me")
async def read_user_me(current_user: User = Depends(get_current_user)):
    """
    This endpoint requires:
    1. Valid X-Token header (via verify_token)
    2. Valid X-Key header (via verify_key)
    3. Database session (via get_db)
    4. User exists in database (via get_current_user)
    
    All checked automatically by FastAPI!
    """
    return current_user

Dependency Tree:

read_user_me
    └── get_current_user
            ├── verify_token
            ├── verify_key
            └── get_db

Dependencies in Path Operation Decorators

from fastapi import Depends
 
# Apply dependency to all routes in the decorator
@app.get("/items/", dependencies=[Depends(verify_token), Depends(verify_key)])
async def read_items():
    """This route requires token and key, but doesn't need their return values."""
    return [{"item_id": "Foo"}]
 
# Multiple dependencies
@app.get(
    "/users/",
    dependencies=[
        Depends(verify_token),
        Depends(verify_key),
        Depends(check_subscription)
    ]
)
async def read_users():
    """Multiple dependencies for access control."""
    return [{"username": "Rick"}, {"username": "Morty"}]

Use this when:

  • You need the dependency to run (e.g., validation)
  • You don't need the return value
  • You want cleaner function signatures

Global Dependencies

from fastapi import FastAPI, Depends
 
app = FastAPI(
    dependencies=[Depends(verify_token)]
)
 
# Now ALL routes require the token
@app.get("/items/")
async def read_items():
    return [{"item_id": "Foo"}]
 
@app.get("/users/")
async def read_users():
    return [{"username": "Rick"}]

Router-level dependencies:

from fastapi import APIRouter
 
router = APIRouter(
    prefix="/admin",
    dependencies=[Depends(verify_admin)]
)
 
@router.get("/users")
async def admin_get_users():
    """Only admins can access this."""
    return users
 
@router.delete("/users/{user_id}")
async def admin_delete_user(user_id: int):
    """Only admins can access this."""
    return {"deleted": user_id}

Yield Dependencies (Context Managers)

Database with Transactions

from sqlalchemy.orm import Session
 
def get_db_with_transaction():
    """Database with automatic transaction management."""
    db = SessionLocal()
    try:
        yield db
        db.commit()  # Commit on success
    except Exception:
        db.rollback()  # Rollback on error
        raise
    finally:
        db.close()  # Always close
 
@app.post("/users/")
async def create_user(
    user: UserCreate,
    db: Session = Depends(get_db_with_transaction)
):
    """User is created in a transaction."""
    new_user = User(**user.dict())
    db.add(new_user)
    # Transaction commits automatically if no exception
    return new_user

Resource Cleanup

import time
from contextlib import contextmanager
 
class Timer:
    """Timer context manager for performance monitoring."""
    
    def __enter__(self):
        self.start = time.time()
        return self
    
    def __exit__(self, *args):
        self.end = time.time()
        self.elapsed = self.end - self.start
 
def get_timer():
    """Timer dependency with automatic cleanup."""
    timer = Timer()
    try:
        with timer:
            yield timer
    finally:
        print(f"Request took {timer.elapsed:.2f} seconds")
 
@app.get("/slow-operation")
async def slow_operation(timer: Timer = Depends(get_timer)):
    """Automatically timed."""
    # Simulate slow operation
    time.sleep(2)
    return {"message": "Done"}

File Handling

from pathlib import Path
 
def get_temp_file():
    """Temporary file with automatic cleanup."""
    temp_file = Path("/tmp/temp_file.txt")
    temp_file.touch()
    try:
        yield temp_file
    finally:
        temp_file.unlink(missing_ok=True)
 
@app.post("/process-file/")
async def process_file(
    content: str,
    temp_file: Path = Depends(get_temp_file)
):
    """Process content with temporary file."""
    temp_file.write_text(content)
    # Process file...
    result = temp_file.read_text()
    # File automatically deleted after response
    return {"processed": result}

Dependencies with Parameters

from typing import Optional
 
def pagination(page: int = 1, size: int = 50):
    """Pagination parameters with validation."""
    if page < 1:
        raise HTTPException(status_code=400, detail="Page must be >= 1")
    if size < 1 or size > 100:
        raise HTTPException(status_code=400, detail="Size must be 1-100")
    
    skip = (page - 1) * size
    return {"skip": skip, "limit": size}
 
@app.get("/items/")
async def list_items(
    paging: dict = Depends(pagination),
    db: Session = Depends(get_db)
):
    """List items with pagination."""
    items = db.query(Item).offset(paging["skip"]).limit(paging["limit"]).all()
    return items

Parameterized Dependencies

from typing import Optional
 
class RateLimiter:
    """Rate limiting dependency."""
    
    def __init__(self, max_calls: int = 10, period: int = 60):
        self.max_calls = max_calls
        self.period = period
        self.calls = {}
    
    def __call__(self, request: Request):
        """Check rate limit for IP."""
        client_ip = request.client.host
        now = time.time()
        
        # Clean old entries
        self.calls = {
            ip: times 
            for ip, times in self.calls.items() 
            if times[-1] > now - self.period
        }
        
        # Check limit
        if client_ip in self.calls:
            if len(self.calls[client_ip]) >= self.max_calls:
                raise HTTPException(
                    status_code=429,
                    detail="Rate limit exceeded"
                )
            self.calls[client_ip].append(now)
        else:
            self.calls[client_ip] = [now]
        
        return True
 
# Create rate limiter instances
rate_limit_strict = RateLimiter(max_calls=5, period=60)
rate_limit_relaxed = RateLimiter(max_calls=100, period=60)
 
@app.get("/api/strict", dependencies=[Depends(rate_limit_strict)])
async def strict_endpoint():
    """Maximum 5 calls per minute."""
    return {"message": "Success"}
 
@app.get("/api/relaxed", dependencies=[Depends(rate_limit_relaxed)])
async def relaxed_endpoint():
    """Maximum 100 calls per minute."""
    return {"message": "Success"}

Caching Dependencies

from functools import lru_cache
from pydantic_settings import BaseSettings
 
class Settings(BaseSettings):
    """Application settings."""
    app_name: str = "FastAPI App"
    database_url: str
    secret_key: str
    
    class Config:
        env_file = ".env"
 
@lru_cache()
def get_settings():
    """
    Cached settings - loaded only once.
    Perfect for configuration that doesn't change during runtime.
    """
    return Settings()
 
@app.get("/info")
async def info(settings: Settings = Depends(get_settings)):
    """Get app info - settings loaded from cache."""
    return {
        "app_name": settings.app_name,
        "database": settings.database_url
    }

Important:

  • Use @lru_cache() for expensive operations
  • Only cache dependencies that don't change
  • Don't cache database sessions or request-specific data

Dependency Overrides (Testing)

# main.py
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()
 
@app.get("/users/{user_id}")
async def get_user(user_id: int, db: Session = Depends(get_db)):
    return db.query(User).filter(User.id == user_id).first()
# test_main.py
from fastapi.testclient import TestClient
from main import app, get_db
 
# Mock database
def override_get_db():
    """Override database with test database."""
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()
 
# Override the dependency
app.dependency_overrides[get_db] = override_get_db
 
client = TestClient(app)
 
def test_get_user():
    response = client.get("/users/1")
    assert response.status_code == 200
    # Uses test database instead of production database!

Benefits:

  • Test with mock data
  • Use test database
  • Override authentication
  • Speed up tests

Real-World Examples

1. Current User Dependency

from jose import JWTError, jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
 
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 
async def get_current_user(
    token: str = Depends(oauth2_scheme),
    db: Session = Depends(get_db)
) -> User:
    """Extract and validate user from JWT token."""
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    
    user = db.query(User).filter(User.username == username).first()
    if user is None:
        raise credentials_exception
    
    return user
 
async def get_current_active_user(
    current_user: User = Depends(get_current_user)
) -> User:
    """Ensure user is active."""
    if not current_user.is_active:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user
 
async def get_current_admin_user(
    current_user: User = Depends(get_current_active_user)
) -> User:
    """Ensure user is admin."""
    if not current_user.is_admin:
        raise HTTPException(status_code=403, detail="Not enough permissions")
    return current_user
 
# Use in routes
@app.get("/users/me")
async def read_users_me(current_user: User = Depends(get_current_active_user)):
    return current_user
 
@app.get("/admin/users")
async def read_all_users(
    admin_user: User = Depends(get_current_admin_user),
    db: Session = Depends(get_db)
):
    return db.query(User).all()

2. Pagination Dependency

from typing import Optional
from pydantic import BaseModel
 
class PaginationParams(BaseModel):
    """Pagination parameters with validation."""
    skip: int = 0
    limit: int = 100
    
    def __init__(self, page: int = 1, page_size: int = 50):
        if page < 1:
            raise HTTPException(status_code=400, detail="Page must be >= 1")
        if page_size < 1 or page_size > 100:
            raise HTTPException(status_code=400, detail="Page size must be 1-100")
        
        skip = (page - 1) * page_size
        super().__init__(skip=skip, limit=page_size)
 
@app.get("/items/")
async def list_items(
    pagination: PaginationParams = Depends(),
    db: Session = Depends(get_db)
):
    items = (
        db.query(Item)
        .offset(pagination.skip)
        .limit(pagination.limit)
        .all()
    )
    total = db.query(Item).count()
    
    return {
        "items": items,
        "total": total,
        "page": (pagination.skip // pagination.limit) + 1,
        "page_size": pagination.limit
    }

3. Request Validation

from fastapi import Request
 
async def validate_content_type(request: Request):
    """Ensure request has correct content type."""
    content_type = request.headers.get("content-type")
    if content_type != "application/json":
        raise HTTPException(
            status_code=400,
            detail="Content-Type must be application/json"
        )
    return content_type
 
async def validate_request_size(request: Request):
    """Ensure request body is not too large."""
    content_length = request.headers.get("content-length")
    if content_length and int(content_length) > 1_000_000:  # 1MB
        raise HTTPException(
            status_code=413,
            detail="Request body too large"
        )
    return True
 
@app.post("/upload/", dependencies=[
    Depends(validate_content_type),
    Depends(validate_request_size)
])
async def upload_data(data: dict):
    """Upload with validation."""
    return {"received": data}

4. Service Layer Pattern

class UserService:
    """User business logic."""
    
    def __init__(self, db: Session = Depends(get_db)):
        self.db = db
    
    def get_user(self, user_id: int) -> User:
        user = self.db.query(User).filter(User.id == user_id).first()
        if not user:
            raise HTTPException(status_code=404, detail="User not found")
        return user
    
    def create_user(self, user_data: UserCreate) -> User:
        # Check if user exists
        existing = self.db.query(User).filter(
            User.email == user_data.email
        ).first()
        if existing:
            raise HTTPException(status_code=400, detail="Email already exists")
        
        # Create user
        user = User(**user_data.dict())
        self.db.add(user)
        self.db.commit()
        self.db.refresh(user)
        return user
    
    def update_user(self, user_id: int, user_data: UserUpdate) -> User:
        user = self.get_user(user_id)
        for field, value in user_data.dict(exclude_unset=True).items():
            setattr(user, field, value)
        self.db.commit()
        self.db.refresh(user)
        return user
 
# Use in routes
@app.get("/users/{user_id}")
async def get_user(
    user_id: int,
    user_service: UserService = Depends()
):
    return user_service.get_user(user_id)
 
@app.post("/users/")
async def create_user(
    user_data: UserCreate,
    user_service: UserService = Depends()
):
    return user_service.create_user(user_data)

Advanced Patterns

1. Dependency with State

class ConnectionPool:
    """Connection pool as a dependency."""
    
    def __init__(self):
        self.connections = []
        self.max_size = 10
    
    async def get_connection(self):
        """Get connection from pool."""
        if not self.connections:
            # Create new connection
            return await create_connection()
        return self.connections.pop()
    
    async def return_connection(self, conn):
        """Return connection to pool."""
        if len(self.connections) < self.max_size:
            self.connections.append(conn)
        else:
            await conn.close()
 
# Single instance for the app
pool = ConnectionPool()
 
async def get_connection():
    """Get connection from pool with automatic return."""
    conn = await pool.get_connection()
    try:
        yield conn
    finally:
        await pool.return_connection(conn)
 
@app.get("/data")
async def get_data(conn = Depends(get_connection)):
    return await conn.fetch("SELECT * FROM data")

2. Conditional Dependencies

from typing import Optional
 
def get_optional_user(
    token: Optional[str] = Header(None),
    db: Session = Depends(get_db)
) -> Optional[User]:
    """Get user if token provided, None otherwise."""
    if not token:
        return None
    
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username = payload.get("sub")
        return db.query(User).filter(User.username == username).first()
    except JWTError:
        return None
 
@app.get("/items/")
async def list_items(
    user: Optional[User] = Depends(get_optional_user),
    db: Session = Depends(get_db)
):
    """List items - authenticated users see private items."""
    query = db.query(Item)
    
    if user:
        # Show all items for authenticated users
        query = query.filter(
            (Item.is_public == True) | (Item.owner_id == user.id)
        )
    else:
        # Show only public items
        query = query.filter(Item.is_public == True)
    
    return query.all()

3. Dependency Composition

from dataclasses import dataclass
 
@dataclass
class RequestContext:
    """Aggregated request context."""
    user: User
    db: Session
    settings: Settings
    
async def get_request_context(
    user: User = Depends(get_current_active_user),
    db: Session = Depends(get_db),
    settings: Settings = Depends(get_settings)
) -> RequestContext:
    """Combine multiple dependencies into one context."""
    return RequestContext(user=user, db=db, settings=settings)
 
@app.get("/profile")
async def get_profile(ctx: RequestContext = Depends(get_request_context)):
    """Use aggregated context."""
    # Access all dependencies through ctx
    user_data = ctx.db.query(UserProfile).filter(
        UserProfile.user_id == ctx.user.id
    ).first()
    
    return {
        "user": ctx.user,
        "profile": user_data,
        "app_name": ctx.settings.app_name
    }

Best Practices

1. Keep Dependencies Pure

# ✅ Good: Pure dependency
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()
 
# ❌ Bad: Dependency with side effects
def get_db():
    db = SessionLocal()
    log_access()  # Side effect!
    yield db
    db.close()

2. Use Type Hints

# ✅ Good: Clear types
def get_user(user_id: int, db: Session = Depends(get_db)) -> User:
    return db.query(User).filter(User.id == user_id).first()
 
# ❌ Bad: No type hints
def get_user(user_id, db = Depends(get_db)):
    return db.query(User).filter(User.id == user_id).first()

3. Name Dependencies Clearly

# ✅ Good: Clear names
get_db
get_current_user
verify_token
require_admin
 
# ❌ Bad: Vague names
dep1
check
validate
helper

4. Don't Overuse Dependencies

# ❌ Bad: Dependency for simple parameter
def get_user_id(user_id: int = Path(...)):
    return user_id
 
@app.get("/users/{user_id}")
async def get_user(uid: int = Depends(get_user_id)):
    # Just use user_id directly!
    pass
 
# ✅ Good: Use parameter directly
@app.get("/users/{user_id}")
async def get_user(user_id: int):
    # Simple and clear
    pass

5. Use Sub-Dependencies for Composition

# ✅ Good: Composed dependencies
def get_current_user(
    token: str = Depends(oauth2_scheme),
    db: Session = Depends(get_db)
) -> User:
    # Uses sub-dependencies
    pass
 
# ❌ Bad: Everything in one dependency
def get_current_user(
    authorization: str = Header(...),
) -> User:
    # Parsing token, connecting to DB, querying - all in one function
    # Hard to test and reuse
    pass

Testing with Dependencies

# conftest.py
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
 
from main import app, get_db
from database import Base
 
# Test database
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
 
def override_get_db():
    """Test database dependency."""
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()
 
@pytest.fixture
def client():
    """Test client with overridden dependencies."""
    Base.metadata.create_all(bind=engine)
    app.dependency_overrides[get_db] = override_get_db
    
    with TestClient(app) as c:
        yield c
    
    Base.metadata.drop_all(bind=engine)
    app.dependency_overrides.clear()
 
# test_users.py
def test_create_user(client):
    response = client.post("/users/", json={
        "username": "test",
        "email": "test@example.com"
    })
    assert response.status_code == 201
    # Uses test database!

Conclusion

Dependency Injection in FastAPI is powerful yet intuitive:

Clean code - No global state or tight coupling
Testability - Easy to override for testing
Reusability - Share logic across endpoints
Type safety - Full editor support
Automatic execution - FastAPI handles the rest
Sub-dependencies - Compose complex logic
Context managers - Automatic cleanup with yield

Key Takeaways

  1. Use Depends() for injecting dependencies
  2. Yield for cleanup - Perfect for resources like DB sessions
  3. Sub-dependencies compose well
  4. Classes work great as dependencies
  5. Override in tests for mocking
  6. Cache expensive operations with @lru_cache()
  7. Keep dependencies pure - Avoid side effects
  8. Type hints help everyone

Next Steps

  • Implement service layer pattern with DI
  • Create custom dependency decorators
  • Build middleware using dependencies
  • Optimize with caching strategies
  • Add monitoring with dependency wrappers
  • Create dependency validation layers

Master Dependency Injection and your FastAPI code will be cleaner, more testable, and easier to maintain! 🚀

📬 Subscribe to Newsletter

Get the latest blog posts delivered to your inbox every week. No spam, unsubscribe anytime.

We respect your privacy. Unsubscribe at any time.

💬 Comments

Sign in to leave a comment

We'll never post without your permission.