FastAPI 自定义中间件
概要: 利用中间件 middleware 处理FastAPI请求和响应
创建时间: 2023.11.01 22:50:45
更新时间: 2023.11.02 00:08:39
中间件是什么
在 FastAPI 中,中间件,即 middleware,用于预处理 API endpoint 的请求 request 和响应 response。
- 对于 API 请求,中间件在到达 API 的 endpoint 之前生效
- 对于 API 响应,中间件在返回客户端前生效

可以看出,通过实现自定义的中间件,我们可以在 API endpoint 到达前进行预处理,实现特定的目的,如数据校验和请求拦截,提高代码的可读性和复用度。
FastAPI 自带的中间件
- 跨域中间件 CORSMiddleware,用于处理跨域请求
- 可信主机中间件 TrustedHostMiddleware,用于校验请求的主机 header 信息,避免被 HTTP Host Header 攻击
- 会话中间件 SessionMiddleware,当会话数据为只读时,实现对基于 cookie 的 HTTP 会话签名
- GZip中间件 GZip Middleware,通过压缩响应数据的负载提高数据传输速度
自定义中间件
基础实例
Python |
---|
| # app.py
from fastapi import FastAPI
app = FastAPI()
@app.get("/api/info")
async def hello():
return {"message": "Hello, World!"}
@app.get("/apiv2/info")
async def hellov2():
return {"message": "Hello, World from V2"}
|
执行uvicorn app:app --reload --port 9999
可以在 http://127.0.0.1:9999/info 上访问我们预设的 API endpoint
Bash |
---|
| ❯ uvicorn app:app --reload --port 9999
INFO: Will watch for changes in these directories: ['/Users/lzwang/projects/PyPlayground/fastapi_tests/middle_ware']
INFO: Uvicorn running on http://127.0.0.1:9999 (Press CTRL+C to quit)
INFO: Started reloader process [53964] using statreload
INFO: Started server process [53966]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: 127.0.0.1:58937 - "GET /api/info HTTP/1.1" 200 OK
|
基于函数的中间件
基于函数中间件,通过装饰器 @app.middleware("http")
实现,它接受下面两个参数
request
: 即HTTP请求的对象
call_next
: 通过回调此函数,可以无缝衔接响应
下面通过添加中间件,将上面 FastAPI 的 api
自动转到 apiv2
上
Python |
---|
| # app.py
# at the import level
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
# ......
# after the `app` variable
@app.middleware("http")
async def modify_request_response_middleware(request: Request, call_next):
# Intercept and modify the incoming request
request.scope["path"] = str(request.url.path).replace("api", "apiv2")
# Process the modified request
response = await call_next(request)
# Transform the outgoing response
if isinstance(response, StreamingResponse):
response.headers["X-Custom-Header"] = "Modified"
return response
|
通过 FastAPI 的日志可以看到,已经被转到了 apiv2 上
Bash |
---|
| ❯ uvicorn app:app --reload --port 9999
INFO: Will watch for changes in these directories: ['/Users/lzwang/projects/PyPlayground/fastapi_tests/middle_ware']
INFO: Uvicorn running on http://127.0.0.1:9999 (Press CTRL+C to quit)
INFO: Started reloader process [54189] using statreload
INFO: Started server process [54191]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: 127.0.0.1:63141 - "GET /apiv2/info HTTP/1.1" 200 OK
|

基于类的中间件
基于类的中间件,通过继承 BaseHTTPMiddleware
并覆写 async def dispatch(request, call_next)
方法实现,下面通过基于类的中间件实现一个 FastAPI 的限流功能,同个 IP 在 1 分钟内最多请求 3 次,如果多余 3 次,就返回 429 Too Many Requests 错误
Python |
---|
| # app.py
# at the import level
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from datetime import datetime, timedelta
# immediately after imports
class RateLimitingMiddleware(BaseHTTPMiddleware):
# Rate limiting configurations
RATE_LIMIT_DURATION = timedelta(minutes=1)
RATE_LIMIT_REQUESTS = 3
def __init__(self, app):
super().__init__(app)
# Dictionary to store request counts for each IP
self.request_counts = {}
async def dispatch(self, request, call_next):
# Get the client's IP address
client_ip = request.client.host
# Check if IP is already present in request_counts
request_count, last_request = self.request_counts.get(client_ip, (0, datetime.min))
# Calculate the time elapsed since the last request
elapsed_time = datetime.now() - last_request
if elapsed_time > self.RATE_LIMIT_DURATION:
# If the elapsed time is greater than the rate limit duration, reset the count
request_count = 1
else:
if request_count >= self.RATE_LIMIT_REQUESTS:
# If the request count exceeds the rate limit, return a JSON response with an error message
return JSONResponse(
status_code=429,
content={"message": "Rate limit exceeded. Please try again later."}
)
request_count += 1
# Update the request count and last request timestamp for the IP
self.request_counts[client_ip] = (request_count, datetime.now())
# Proceed with the request
response = await call_next(request)
return response
app = FastAPI()
app.add_middleware(RateLimitingMiddleware)
|
在上面的代码中,我们通过覆写 dispatch
方法,实现了基于 IP 限流的中间件
Bash |
---|
| ❯ uvicorn app:app --reload --port 9999
INFO: Will watch for changes in these directories: ['/Users/lzwang/projects/PyPlayground/fastapi_tests/middle_ware']
INFO: Uvicorn running on http://127.0.0.1:9999 (Press CTRL+C to quit)
INFO: Started reloader process [55824] using statreload
INFO: Started server process [55826]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: 127.0.0.1:55880 - "GET /api/info HTTP/1.1" 200 OK
INFO: 127.0.0.1:55880 - "GET /api/info HTTP/1.1" 200 OK
INFO: 127.0.0.1:55880 - "GET /api/info HTTP/1.1" 200 OK
INFO: 127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
INFO: 127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
INFO: 127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
INFO: 127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
INFO: 127.0.0.1:55880 - "GET /api/info HTTP/1.1" 429 Too Many Requests
|
中间件调用顺序
简单来说,中间件调用顺序为它们注册顺序的逆序
- 对于基于函数的中间件,按照代码解释顺序的反序,即从下到上
- 对于基于类的中间件,那么 app 添加类中间的顺序的反序
- 如果同时存在基于函数的中间件和基于类的中间件,那么先执行基于函数的再执行基于类的中间件
单元测试
Python |
---|
| # app.py
# import level
import time
from fastapi.testclient import TestClient
# ......
# at the end of the file
client = TestClient(app)
def test_modify_request_response_middleware():
# Send a GET request to the hello endpoint
response = client.get("/api/info")
# Assert the response status code is 200
assert response.status_code == 200
# Assert the middleware has been applied
assert response.headers.get("X-Custom-Header") == "Modified"
# Assert the response content
assert response.json() == {"message": "Hello, World from V2"}
def test_rate_limiting_middleware():
time.sleep(1)
response = client.get("/api/info")
# Assert the response status code is 200
assert response.status_code == 200
time.sleep(1)
response = client.get("/api/info")
# Assert the response status code is 200
assert response.status_code == 200
time.sleep(1)
response = client.get("/api/info")
# Assert the response status code is 200
assert response.status_code == 429
|
最佳实践
- 轻量化中间件:由于中间件会在每次请求中执行,过多的计算将会导致服务器响应变慢,因此中间件必须保持轻量化
- 优化中间件顺序:减少重复操作并优化中间件执行逻辑,比如用于授权和验证的中间件放在最前面执行
- 利用缓存和记忆机制:减少重复计算和连接外部数据库等操作
- 基准测试:通过使用 cProfile 等工具,定量分析新引入的中间件对服务的影响大小
- 安全方面:利用 pydantic 进行输入数据校验;利用 JWT 等方法进行授权和验证;利用基于角色的权限控制(Role-Based Access Control, RBAC)限制敏感的 endpoint 访问
参考