跳转至

FastAPI 自定义中间件

概要: 利用中间件 middleware 处理FastAPI请求和响应

创建时间: 2023.11.01 22:50:45

更新时间: 2023.11.02 00:08:39

中间件是什么

在 FastAPI 中,中间件,即 middleware,用于预处理 API endpoint 的请求 request 和响应 response。

  • 对于 API 请求,中间件在到达 API 的 endpoint 之前生效
  • 对于 API 响应,中间件在返回客户端前生效

image.png
可以看出,通过实现自定义的中间件,我们可以在 API endpoint 到达前进行预处理,实现特定的目的,如数据校验和请求拦截,提高代码的可读性和复用度。

FastAPI 自带的中间件

  • 跨域中间件 CORSMiddleware,用于处理跨域请求
  • 可信主机中间件 TrustedHostMiddleware,用于校验请求的主机 header 信息,避免被 HTTP Host Header 攻击
  • 会话中间件 SessionMiddleware,当会话数据为只读时,实现对基于 cookie 的 HTTP 会话签名
  • GZip中间件 GZip Middleware,通过压缩响应数据的负载提高数据传输速度

自定义中间件

基础实例

Python
# app.py
from fastapi import FastAPI

app = FastAPI()


@app.get("/api/info")
async def hello():
    return {"message": "Hello, World!"}


@app.get("/apiv2/info")
async def hellov2():
    return {"message": "Hello, World from V2"}
执行uvicorn app:app --reload --port 9999可以在 http://127.0.0.1:9999/info 上访问我们预设的 API endpoint
Bash
1
2
3
4
5
6
7
8
 uvicorn app:app --reload --port 9999
INFO:     Will watch for changes in these directories: ['/Users/lzwang/projects/PyPlayground/fastapi_tests/middle_ware']
INFO:     Uvicorn running on http://127.0.0.1:9999 (Press CTRL+C to quit)
INFO:     Started reloader process [53964] using statreload
INFO:     Started server process [53966]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     127.0.0.1:58937 - "GET /api/info HTTP/1.1" 200 OK
image.png

基于函数的中间件

基于函数中间件,通过装饰器 @app.middleware("http") 实现,它接受下面两个参数

  1. request: 即HTTP请求的对象
  2. call_next: 通过回调此函数,可以无缝衔接响应

下面通过添加中间件,将上面 FastAPI 的 api 自动转到 apiv2

Python
# app.py
# at the import level
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse

# ......

# after the `app` variable
@app.middleware("http")
async def modify_request_response_middleware(request: Request, call_next):
    # Intercept and modify the incoming request
    request.scope["path"] = str(request.url.path).replace("api", "apiv2")
    # Process the modified request
    response = await call_next(request)
    # Transform the outgoing response
    if isinstance(response, StreamingResponse):
        response.headers["X-Custom-Header"] = "Modified"
    return response
通过 FastAPI 的日志可以看到,已经被转到了 apiv2 上
Bash
1
2
3
4
5
6
7
8
 uvicorn app:app --reload --port 9999
INFO:     Will watch for changes in these directories: ['/Users/lzwang/projects/PyPlayground/fastapi_tests/middle_ware']
INFO:     Uvicorn running on http://127.0.0.1:9999 (Press CTRL+C to quit)
INFO:     Started reloader process [54189] using statreload
INFO:     Started server process [54191]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     127.0.0.1:63141 - "GET /apiv2/info HTTP/1.1" 200 OK

image.png

基于类的中间件

基于类的中间件,通过继承 BaseHTTPMiddleware 并覆写 async def dispatch(request, call_next) 方法实现,下面通过基于类的中间件实现一个 FastAPI 的限流功能,同个 IP 在 1 分钟内最多请求 3 次,如果多余 3 次,就返回 429 Too Many Requests 错误

Python
# app.py
# at the import level
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from datetime import datetime, timedelta

# immediately after imports
class RateLimitingMiddleware(BaseHTTPMiddleware):
    # Rate limiting configurations
    RATE_LIMIT_DURATION = timedelta(minutes=1)
    RATE_LIMIT_REQUESTS = 3

    def __init__(self, app):
        super().__init__(app)
        # Dictionary to store request counts for each IP
        self.request_counts = {}

    async def dispatch(self, request, call_next):
        # Get the client's IP address
        client_ip = request.client.host

        # Check if IP is already present in request_counts
        request_count, last_request = self.request_counts.get(client_ip, (0, datetime.min))

        # Calculate the time elapsed since the last request
        elapsed_time = datetime.now() - last_request

        if elapsed_time > self.RATE_LIMIT_DURATION:
            # If the elapsed time is greater than the rate limit duration, reset the count
            request_count = 1
        else:
            if request_count >= self.RATE_LIMIT_REQUESTS:
                # If the request count exceeds the rate limit, return a JSON response with an error message
                return JSONResponse(
                    status_code=429,
                    content={"message": "Rate limit exceeded. Please try again later."}
                )
            request_count += 1

        # Update the request count and last request timestamp for the IP
        self.request_counts[client_ip] = (request_count, datetime.now())

        # Proceed with the request
        response = await call_next(request)
        return response

app = FastAPI()
app.add_middleware(RateLimitingMiddleware)
在上面的代码中,我们通过覆写 dispatch 方法,实现了基于 IP 限流的中间件
Bash
 uvicorn app:app --reload --port 9999
INFO:     Will watch for changes in these directories: ['/Users/lzwang/projects/PyPlayground/fastapi_tests/middle_ware']
INFO:     Uvicorn running on http://127.0.0.1:9999 (Press CTRL+C to quit)
INFO:     Started reloader process [55824] using statreload
INFO:     Started server process [55826]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     127.0.0.1:55880 - "GET /api/info HTTP/1.1" 200 OK
INFO:     127.0.0.1:55880 - "GET /api/info HTTP/1.1" 200 OK
INFO:     127.0.0.1:55880 - "GET /api/info HTTP/1.1" 200 OK
INFO:     127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
INFO:     127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
INFO:     127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
INFO:     127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
INFO:     127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
FastAPI-class-middleware.gif

中间件调用顺序

简单来说,中间件调用顺序为它们注册顺序的逆序

  • 对于基于函数的中间件,按照代码解释顺序的反序,即从下到上
  • 对于基于类的中间件,那么 app 添加类中间的顺序的反序
  • 如果同时存在基于函数的中间件和基于类的中间件,那么先执行基于函数的再执行基于类的中间件

单元测试

Python
# app.py
# import level
import time
from fastapi.testclient import TestClient

# ......

# at the end of the file
client = TestClient(app)


def test_modify_request_response_middleware():
    # Send a GET request to the hello endpoint
    response = client.get("/api/info")
    # Assert the response status code is 200
    assert response.status_code == 200
    # Assert the middleware has been applied
    assert response.headers.get("X-Custom-Header") == "Modified"
    # Assert the response content
    assert response.json() == {"message": "Hello, World from V2"}


def test_rate_limiting_middleware():
    time.sleep(1)
    response = client.get("/api/info")
    # Assert the response status code is 200
    assert response.status_code == 200

    time.sleep(1)
    response = client.get("/api/info")
    # Assert the response status code is 200
    assert response.status_code == 200

    time.sleep(1)
    response = client.get("/api/info")
    # Assert the response status code is 200
    assert response.status_code == 429

最佳实践

  • 轻量化中间件:由于中间件会在每次请求中执行,过多的计算将会导致服务器响应变慢,因此中间件必须保持轻量化
  • 优化中间件顺序:减少重复操作并优化中间件执行逻辑,比如用于授权和验证的中间件放在最前面执行
  • 利用缓存和记忆机制:减少重复计算和连接外部数据库等操作
  • 基准测试:通过使用 cProfile 等工具,定量分析新引入的中间件对服务的影响大小
  • 安全方面:利用 pydantic 进行输入数据校验;利用 JWT 等方法进行授权和验证;利用基于角色的权限控制(Role-Based Access Control, RBAC)限制敏感的 endpoint 访问

参考