import inspect import re from datetime import timedelta from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi from fastapi.routing import APIRoute from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException from pydantic import BaseModel import src.settings as settings from src.database.database import init_db from src.exceptions.base_exception import BaseException from src.exceptions.exception_handlers import (auth_exception_handler, base_exception_handler) app = FastAPI(title="SignDataCollector", version="0.0.1") @app.on_event("startup") async def on_startup(): await init_db() class AuthSettings(BaseModel): """AuthSettings model""" authjwt_secret_key: str = settings.JWT_SECRET_KEY # authjwt_denylist_enabled: bool = True authjwt_denylist_enabled: bool = False authjwt_denylist_token_checks: dict = {"access", "refresh"} access_expires: timedelta = timedelta(seconds=settings.ACCESS_EXPIRES) refresh_expires: timedelta = timedelta(seconds=settings.REFRESH_EXPIRES) authjwt_cookie_csrf_protect: bool = True authjwt_token_location: dict = {"headers"} authjwt_cookie_samesite: str = "lax" @AuthJWT.load_config def auth_config(): return AuthSettings() # Add middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # include the routers from .routers import auth_router, signs_router, signvideo_router app.include_router(auth_router) app.include_router(signs_router) app.include_router(signvideo_router) # Add the exception handlers app.add_exception_handler(BaseException, base_exception_handler) app.add_exception_handler(AuthJWTException, auth_exception_handler) def custom_openapi(): """custom_openapi generate the custom swagger api documentation :return: custom openapi_schema :rtype: openapi_schema """ if app.openapi_schema: return app.openapi_schema openapi_schema = get_openapi( title="My Auth API", version="1.0", description="An API with an Authorize Button", routes=app.routes, # servers=[{"url": config.api_path}], ) openapi_schema["components"]["securitySchemes"] = { "Bearer Auth": { "type": "apiKey", "in": "header", "name": "Authorization", "description": "Enter: **'Bearer <JWT>'**, where JWT is the access token", } } # Get all routes where jwt_optional() or jwt_required api_router = [route for route in app.routes if isinstance(route, APIRoute)] for route in api_router: path = getattr(route, "path") endpoint = getattr(route, "endpoint") methods = [method.lower() for method in getattr(route, "methods")] for method in methods: # access_token if ( re.search("RoleChecker", inspect.getsource(endpoint)) or re.search("jwt_required", inspect.getsource(endpoint)) or re.search("fresh_jwt_required", inspect.getsource(endpoint)) or re.search("jwt_optional", inspect.getsource(endpoint)) ): openapi_schema["paths"][path][method]["security"] = [ {"Bearer Auth": []} ] app.openapi_schema = openapi_schema return app.openapi_schema app.openapi = custom_openapi