Files
Signlanguage_Datacollector/backend/src/app.py

116 lines
3.5 KiB
Python

import inspect
import re
from datetime import timedelta
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.routing import APIRoute
from fastapi_jwt_auth import AuthJWT
from fastapi_jwt_auth.exceptions import AuthJWTException
from pydantic import BaseModel
import src.settings as settings
from src.database.database import init_db
from src.exceptions.base_exception import BaseException
from src.exceptions.exception_handlers import (auth_exception_handler,
base_exception_handler)
app = FastAPI(title="SignDataCollector", version="0.0.1")
@app.on_event("startup")
async def on_startup():
await init_db()
class AuthSettings(BaseModel):
"""AuthSettings model"""
authjwt_secret_key: str = settings.JWT_SECRET_KEY
# authjwt_denylist_enabled: bool = True
authjwt_denylist_enabled: bool = False
authjwt_denylist_token_checks: dict = {"access", "refresh"}
access_expires: timedelta = timedelta(seconds=settings.ACCESS_EXPIRES)
refresh_expires: timedelta = timedelta(seconds=settings.REFRESH_EXPIRES)
authjwt_cookie_csrf_protect: bool = True
authjwt_token_location: dict = {"headers"}
authjwt_cookie_samesite: str = "lax"
@AuthJWT.load_config
def auth_config():
return AuthSettings()
# Add middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# include the routers
from .routers import auth_router, signs_router, signvideo_router
app.include_router(auth_router)
app.include_router(signs_router)
app.include_router(signvideo_router)
# Add the exception handlers
app.add_exception_handler(BaseException, base_exception_handler)
app.add_exception_handler(AuthJWTException, auth_exception_handler)
def custom_openapi():
"""custom_openapi generate the custom swagger api documentation
:return: custom openapi_schema
:rtype: openapi_schema
"""
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="My Auth API",
version="1.0",
description="An API with an Authorize Button",
routes=app.routes,
# servers=[{"url": config.api_path}],
)
openapi_schema["components"]["securitySchemes"] = {
"Bearer Auth": {
"type": "apiKey",
"in": "header",
"name": "Authorization",
"description": "Enter: **'Bearer <JWT>'**, where JWT is the access token",
}
}
# Get all routes where jwt_optional() or jwt_required
api_router = [route for route in app.routes if isinstance(route, APIRoute)]
for route in api_router:
path = getattr(route, "path")
endpoint = getattr(route, "endpoint")
methods = [method.lower() for method in getattr(route, "methods")]
for method in methods:
# access_token
if (
re.search("RoleChecker", inspect.getsource(endpoint))
or re.search("jwt_required", inspect.getsource(endpoint))
or re.search("fresh_jwt_required", inspect.getsource(endpoint))
or re.search("jwt_optional", inspect.getsource(endpoint))
):
openapi_schema["paths"][path][method]["security"] = [
{"Bearer Auth": []}
]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi