116 lines
3.5 KiB
Python
116 lines
3.5 KiB
Python
import inspect
|
|
import re
|
|
from datetime import timedelta
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.openapi.utils import get_openapi
|
|
from fastapi.routing import APIRoute
|
|
from fastapi_jwt_auth import AuthJWT
|
|
from fastapi_jwt_auth.exceptions import AuthJWTException
|
|
from pydantic import BaseModel
|
|
|
|
import src.settings as settings
|
|
from src.database.database import init_db
|
|
from src.exceptions.base_exception import BaseException
|
|
from src.exceptions.exception_handlers import (auth_exception_handler,
|
|
base_exception_handler)
|
|
|
|
app = FastAPI(title="SignDataCollector", version="0.0.1")
|
|
|
|
@app.on_event("startup")
|
|
async def on_startup():
|
|
await init_db()
|
|
|
|
|
|
class AuthSettings(BaseModel):
|
|
"""AuthSettings model"""
|
|
|
|
authjwt_secret_key: str = settings.JWT_SECRET_KEY
|
|
# authjwt_denylist_enabled: bool = True
|
|
authjwt_denylist_enabled: bool = False
|
|
authjwt_denylist_token_checks: dict = {"access", "refresh"}
|
|
access_expires: timedelta = timedelta(seconds=settings.ACCESS_EXPIRES)
|
|
refresh_expires: timedelta = timedelta(seconds=settings.REFRESH_EXPIRES)
|
|
authjwt_cookie_csrf_protect: bool = True
|
|
authjwt_token_location: dict = {"headers"}
|
|
authjwt_cookie_samesite: str = "lax"
|
|
|
|
|
|
@AuthJWT.load_config
|
|
def auth_config():
|
|
return AuthSettings()
|
|
|
|
# Add middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# include the routers
|
|
from .routers import auth_router, signs_router, signvideo_router
|
|
|
|
app.include_router(auth_router)
|
|
app.include_router(signs_router)
|
|
app.include_router(signvideo_router)
|
|
|
|
# Add the exception handlers
|
|
app.add_exception_handler(BaseException, base_exception_handler)
|
|
app.add_exception_handler(AuthJWTException, auth_exception_handler)
|
|
|
|
def custom_openapi():
|
|
"""custom_openapi generate the custom swagger api documentation
|
|
|
|
:return: custom openapi_schema
|
|
:rtype: openapi_schema
|
|
"""
|
|
|
|
if app.openapi_schema:
|
|
return app.openapi_schema
|
|
|
|
openapi_schema = get_openapi(
|
|
title="My Auth API",
|
|
version="1.0",
|
|
description="An API with an Authorize Button",
|
|
routes=app.routes,
|
|
# servers=[{"url": config.api_path}],
|
|
)
|
|
|
|
openapi_schema["components"]["securitySchemes"] = {
|
|
"Bearer Auth": {
|
|
"type": "apiKey",
|
|
"in": "header",
|
|
"name": "Authorization",
|
|
"description": "Enter: **'Bearer <JWT>'**, where JWT is the access token",
|
|
}
|
|
}
|
|
|
|
# Get all routes where jwt_optional() or jwt_required
|
|
api_router = [route for route in app.routes if isinstance(route, APIRoute)]
|
|
|
|
for route in api_router:
|
|
path = getattr(route, "path")
|
|
endpoint = getattr(route, "endpoint")
|
|
methods = [method.lower() for method in getattr(route, "methods")]
|
|
|
|
for method in methods:
|
|
# access_token
|
|
if (
|
|
re.search("RoleChecker", inspect.getsource(endpoint))
|
|
or re.search("jwt_required", inspect.getsource(endpoint))
|
|
or re.search("fresh_jwt_required", inspect.getsource(endpoint))
|
|
or re.search("jwt_optional", inspect.getsource(endpoint))
|
|
):
|
|
openapi_schema["paths"][path][method]["security"] = [
|
|
{"Bearer Auth": []}
|
|
]
|
|
|
|
app.openapi_schema = openapi_schema
|
|
return app.openapi_schema
|
|
|
|
|
|
app.openapi = custom_openapi
|