diff --git a/src/crud.py b/src/crud.py index 3e2ff10..6e1e59f 100644 --- a/src/crud.py +++ b/src/crud.py @@ -12,18 +12,22 @@ DEFAULT_NR_HIGH_SCORES = 10 def get_user_by_id(db: Session, user_id: int): + """ Fetches a User from the database by their id """ return db.query(User).filter(User.user_id == user_id).first() def get_user_by_username(db: Session, username: str): + """ Fetches a User from the database by their username """ return db.query(User).filter(User.username == username).first() def get_users(db: Session): + """ Fetch a list of all users """ return db.query(User).all() def create_user(db: Session, username: str, hashed_password: str): + """ Create a new user """ db_user = User(username=username, hashed_password=hashed_password) db.add(db_user) db.commit() @@ -32,6 +36,7 @@ def create_user(db: Session, username: str, hashed_password: str): def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): + """ Get the n highest scores of a given minigame """ user_high_scores = [] if not nr_highest: @@ -56,6 +61,7 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): def create_high_score(db: Session, high_score: HighScoreCreate): + """ Create a new high score for a given minigame """ owner = db.query(User).filter(User.user_id == high_score.owner_id).first() if not owner: raise HTTPException(status_code=400, detail="User does not exist") @@ -93,6 +99,7 @@ def create_high_score(db: Session, high_score: HighScoreCreate): return db_high_score def get_course_progress(db: Session, user: User, course: CourseEnum): + """ Get the progress a user has for a certain course """ if course != CourseEnum.All: course_progress = db.query(CourseProgress).filter(CourseProgress.owner_id == user.user_id, CourseProgress.course == course).first() if course_progress: diff --git a/src/main.py b/src/main.py index 55558ed..bb88ecd 100644 --- a/src/main.py +++ b/src/main.py @@ -23,7 +23,7 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # JWT authentication setup jwt_secret = "secret_key" ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 +ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month bearer_scheme = HTTPBearer() @@ -36,6 +36,18 @@ def get_db(): db.close() +def get_current_user( + token: HTTPAuthorizationCredentials = Depends(bearer_scheme), +): + try: + payload = jwt.decode(token.credentials, jwt_secret, algorithms=[ALGORITHM]) + username = payload.get("sub") + if username is None: + raise HTTPException(status_code=401, detail="Invalid JWT token") + return username + except jwt.exceptions.DecodeError: + raise HTTPException(status_code=401, detail="Invalid JWT token") + @app.get("/") async def root(): return {"message": "Hello world!"} @@ -47,8 +59,14 @@ async def read_users(db: Session = Depends(get_db)): return users -@app.post("/users", response_model=users.User) -async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): +@app.patch("/users") +async def patch_current_user(user: users.UserCreate, current_user = Depends(get_current_user), db: Session = Depends(get_db)): + db_user = crud.get_us + return users + + +@app.post("/register", response_model=users.User) +async def register(user: users.UserCreate, db: Session = Depends(get_db)): db_user = crud.get_user_by_username(db, username=user.username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") @@ -57,12 +75,28 @@ async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): ) +@app.post("/login") +async def login(user: users.UserCreate, db: Session = Depends(get_db)): + user = authenticate_user(user, db) + if not user: + raise HTTPException(status_code=401, detail="Invalid username or password") + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_payload = { + "sub": user.username, + "exp": datetime.utcnow() + access_token_expires, + } + access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) + return {"access_token": access_token} + + @app.get("/highscores", response_model=List[users.UserHighScore]) async def read_high_scores( db: Session = Depends(get_db), minigame: Optional[MinigameEnum] = None, n_highest: Optional[int] = None, ): + if n_highest < 1: + raise HTTPException(status_code=400, detail="Invalid number of high scores") high_scores = crud.get_high_scores(db, minigame, n_highest) return high_scores @@ -105,18 +139,7 @@ def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)): return db_user -@app.post("/login") -async def login(user: users.UserCreate, db: Session = Depends(get_db)): - user = authenticate_user(user, db) - if not user: - raise HTTPException(status_code=401, detail="Invalid username or password") - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token_payload = { - "sub": user.username, - "exp": datetime.utcnow() + access_token_expires, - } - access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) - return {"access_token": access_token} + @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase])