fix rate limit being global instead of per-IP
This commit is contained in:
@@ -4,7 +4,7 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pyrate_limiter import Duration, Limiter, Rate
|
||||
from pyrate_limiter import AbstractBucket, BucketFactory, Duration, InMemoryBucket, Limiter, MonotonicClock, Rate, RateItem
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
from fastapi import Depends
|
||||
import asyncmy
|
||||
@@ -60,8 +60,27 @@ async def connect_db(app: FastAPI):
|
||||
app.state.pool.close()
|
||||
await app.state.pool.wait_closed()
|
||||
|
||||
class MultiBucketFactory(BucketFactory):
|
||||
def __init__(self, rates, clock):
|
||||
self.clock = clock
|
||||
self.rates = rates
|
||||
self.buckets = {}
|
||||
|
||||
def wrap_item(self, name: str, weight: int = 1) -> RateItem:
|
||||
"""Time-stamping item, return a RateItem"""
|
||||
now = self.clock.now()
|
||||
return RateItem(name, now, weight=weight)
|
||||
|
||||
def get(self, item: RateItem) -> AbstractBucket:
|
||||
if item.name not in self.buckets:
|
||||
new_bucket = self.create(InMemoryBucket, self.rates)
|
||||
self.buckets.update({item.name: new_bucket})
|
||||
|
||||
return self.buckets[item.name]
|
||||
|
||||
app = FastAPI(lifespan=connect_db)
|
||||
limiter = Limiter(Rate(50, Duration.MINUTE))
|
||||
rates = [Rate(50, Duration.MINUTE)]
|
||||
limiter = Limiter(MultiBucketFactory(rates,MonotonicClock()))
|
||||
|
||||
async def create_tables(pool):
|
||||
async with pool.acquire() as conn:
|
||||
|
||||
Reference in New Issue
Block a user