Refactor crud module

This commit is contained in:
lvrossem
2023-03-31 07:13:13 -06:00
parent 49f8d7d713
commit 032a6ed543
6 changed files with 205 additions and 202 deletions

View File

@@ -1,176 +0,0 @@
from fastapi import HTTPException
from sqlalchemy import desc
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
import jwt
from enums import CourseEnum, MinigameEnum
from models import CourseProgress, HighScore, User
from schemas.courseprogress import CourseProgressBase
from schemas.highscores import HighScoreCreate
from schemas.users import UserCreate, UserHighScore
from passlib.context import CryptContext
DEFAULT_NR_HIGH_SCORES = 10
# JWT authentication setup
jwt_secret = "secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_user_by_id(db: Session, user_id: int):
"""Fetches a User from the database by their id"""
return db.query(User).filter(User.user_id == user_id).first()
def get_user_by_username(db: Session, username: str):
"""Fetches a User from the database by their username"""
return db.query(User).filter(User.username == username).first()
def get_users(db: Session):
"""Fetch a list of all users"""
return db.query(User).all()
def authenticate_user(db: Session, user: UserCreate):
db_user = get_user_by_username(db=db, username=user.username)
if not db_user:
return False
hashed_password = db_user.hashed_password
if not hashed_password or not pwd_context.verify(user.password, hashed_password):
return False
return db_user
def register(db: Session, username: str, password: str):
"""Register a new user"""
db_user = get_user_by_username(db, username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
db_user = User(username=username, hashed_password=pwd_context.hash(password))
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def login(db: Session, user: UserCreate):
user = authenticate_user(db, user)
if not user:
raise HTTPException(status_code=401, detail="Invalid username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token_payload = {
"sub": user.username,
"exp": datetime.utcnow() + access_token_expires,
}
access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM)
return {"access_token": access_token}
def patch_user(db: Session, username: str, user: UserCreate):
db_user = get_user_by_username(db, username)
potential_duplicate = get_user_by_username(db, user.username)
if potential_duplicate:
if potential_duplicate.user_id != db_user.user_id:
raise HTTPException(status_code=400, detail="Username already registered")
db_user.username = user.username
db_user.hashed_password = pwd_context.hash(user.password)
db.commit()
def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int):
"""Get the n highest scores of a given minigame"""
if nr_highest < 1:
raise HTTPException(status_code=400, detail="Invalid number of high scores")
user_high_scores = []
if not nr_highest:
nr_highest = DEFAULT_NR_HIGH_SCORES
if not minigame:
minigame = MinigameEnum.SpellingBee
high_scores = (
db.query(HighScore)
.filter(HighScore.minigame == minigame)
.order_by(desc(HighScore.score_value))
.limit(nr_highest)
.all()
)
for high_score in high_scores:
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
user_high_scores.append(
UserHighScore(username=owner.username, score_value=high_score.score_value)
)
return user_high_scores
def create_high_score(db: Session, high_score: HighScoreCreate):
"""Create a new high score for a given minigame"""
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
if not owner:
raise HTTPException(status_code=400, detail="User does not exist")
old_high_score = (
db.query(HighScore)
.filter(
HighScore.owner_id == high_score.owner_id,
HighScore.minigame == high_score.minigame,
)
.first()
)
if old_high_score:
if old_high_score.score_value < high_score.score_value:
db_high_score = HighScore(
score_value=high_score.score_value,
minigame=high_score.minigame,
owner_id=high_score.owner_id,
)
db.delete(old_high_score)
db.add(db_high_score)
db.commit()
db.refresh(db_high_score)
return db_high_score
else:
return old_high_score
else:
db_high_score = HighScore(
score_value=high_score.score_value,
minigame=high_score.minigame,
owner_id=high_score.owner_id,
)
db.add(db_high_score)
db.commit()
db.refresh(db_high_score)
return db_high_score
def get_course_progress(db: Session, username: str, course: CourseEnum):
"""Get the progress a user has for a certain course"""
user = get_user_by_username(db, username)
if course != CourseEnum.All:
course_progress = (
db.query(CourseProgress)
.filter(
CourseProgress.owner_id == user.user_id, CourseProgress.course == course
)
.first()
)
if course_progress:
return [
CourseProgressBase(
progress_value=course_progress.progress_value,
course=course_progress.course,
)
]
else:
return [CourseProgressBase(progress_value=0, course=course)]
return []

View File

@@ -0,0 +1,52 @@
from datetime import datetime, timedelta
import jwt
from fastapi import HTTPException
from sqlalchemy.orm import Session
from crud.users import get_user_by_username, pwd_context
from models import User
from schemas.users import UserCreate
DEFAULT_NR_HIGH_SCORES = 10
# JWT authentication setup
jwt_secret = "secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month
def authenticate_user(db: Session, user: UserCreate):
"""Checks whether the provided credentials match with an existing User"""
db_user = get_user_by_username(db=db, username=user.username)
if not db_user:
return False
hashed_password = db_user.hashed_password
if not hashed_password or not pwd_context.verify(user.password, hashed_password):
return False
return db_user
def register(db: Session, username: str, password: str):
"""Register a new user"""
db_user = get_user_by_username(db, username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
db_user = User(username=username, hashed_password=pwd_context.hash(password))
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def login(db: Session, user: UserCreate):
user = authenticate_user(db, user)
if not user:
raise HTTPException(status_code=401, detail="Invalid username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token_payload = {
"sub": user.username,
"exp": datetime.utcnow() + access_token_expires,
}
access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM)
return {"access_token": access_token}

View File

@@ -0,0 +1,27 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session
from enums import CourseEnum
def get_course_progress(db: Session, username: str, course: CourseEnum):
"""Get the progress a user has for a certain course"""
user = get_user_by_username(db, username)
if course != CourseEnum.All:
course_progress = (
db.query(CourseProgress)
.filter(
CourseProgress.owner_id == user.user_id, CourseProgress.course == course
)
.first()
)
if course_progress:
return [
CourseProgressBase(
progress_value=course_progress.progress_value,
course=course_progress.course,
)
]
else:
return [CourseProgressBase(progress_value=0, course=course)]
return []

75
src/crud/highscores.py Normal file
View File

@@ -0,0 +1,75 @@
from fastapi import HTTPException
from sqlalchemy import desc
from sqlalchemy.orm import Session
from enums import MinigameEnum
from models import HighScore, User
from schemas.highscores import HighScoreCreate
from schemas.users import UserHighScore
def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int):
"""Get the n highest scores of a given minigame"""
if nr_highest < 1:
raise HTTPException(status_code=400, detail="Invalid number of high scores")
user_high_scores = []
if not nr_highest:
nr_highest = DEFAULT_NR_HIGH_SCORES
if not minigame:
minigame = MinigameEnum.SpellingBee
high_scores = (
db.query(HighScore)
.filter(HighScore.minigame == minigame)
.order_by(desc(HighScore.score_value))
.limit(nr_highest)
.all()
)
for high_score in high_scores:
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
user_high_scores.append(
UserHighScore(username=owner.username, score_value=high_score.score_value)
)
return user_high_scores
def create_high_score(db: Session, high_score: HighScoreCreate):
"""Create a new high score for a given minigame"""
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
if not owner:
raise HTTPException(status_code=400, detail="User does not exist")
old_high_score = (
db.query(HighScore)
.filter(
HighScore.owner_id == high_score.owner_id,
HighScore.minigame == high_score.minigame,
)
.first()
)
if old_high_score:
if old_high_score.score_value < high_score.score_value:
db_high_score = HighScore(
score_value=high_score.score_value,
minigame=high_score.minigame,
owner_id=high_score.owner_id,
)
db.delete(old_high_score)
db.add(db_high_score)
db.commit()
db.refresh(db_high_score)
return db_high_score
else:
return old_high_score
else:
db_high_score = HighScore(
score_value=high_score.score_value,
minigame=high_score.minigame,
owner_id=high_score.owner_id,
)
db.add(db_high_score)
db.commit()
db.refresh(db_high_score)
return db_high_score

31
src/crud/users.py Normal file
View File

@@ -0,0 +1,31 @@
from fastapi import HTTPException
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from models import User
from schemas.users import UserCreate
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def patch_user(db: Session, username: str, user: UserCreate):
"""Changes the username and/or the password of a User"""
db_user = get_user_by_username(db, username)
potential_duplicate = get_user_by_username(db, user.username)
if potential_duplicate:
if potential_duplicate.user_id != db_user.user_id:
raise HTTPException(status_code=400, detail="Username already registered")
db_user.username = user.username
db_user.hashed_password = pwd_context.hash(user.password)
db.commit()
def get_user_by_username(db: Session, username: str):
"""Fetches a User from the database by their username"""
return db.query(User).filter(User.username == username).first()
def get_users(db: Session):
"""Fetch a list of all users"""
return db.query(User).all()

View File

@@ -1,13 +1,14 @@
from datetime import datetime, timedelta
from typing import List, Optional
import jwt
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext
from sqlalchemy.orm import Session
import crud
from crud import authentication as crud_authentication
from crud import courseprogress as crud_courseprogress
from crud import highscores as crud_highscores
from crud import users as crud_users
from database import SessionLocal, engine
from enums import CourseEnum, MinigameEnum
from models import Base
@@ -28,11 +29,16 @@ def get_db():
finally:
db.close()
def get_current_user_name(
token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
):
try:
payload = jwt.decode(token.credentials, crud.jwt_secret, algorithms=[crud.ALGORITHM])
payload = jwt.decode(
token.credentials,
crud_authentication.jwt_secret,
algorithms=[crud_authentication.ALGORITHM],
)
username = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid JWT token")
@@ -40,6 +46,7 @@ def get_current_user_name(
except jwt.exceptions.DecodeError:
raise HTTPException(status_code=401, detail="Invalid JWT token")
@app.get("/")
async def root():
return {"message": "Hello world!"}
@@ -47,45 +54,42 @@ async def root():
@app.get("/users", response_model=List[users.User])
async def read_users(db: Session = Depends(get_db)):
users = crud.get_users(db)
return users
return crud_users.get_users(db)
@app.patch("/users")
async def patch_current_user(
user: users.UserCreate,
current_user_name = Depends(get_current_user_name),
current_user_name=Depends(get_current_user_name),
db: Session = Depends(get_db),
):
crud.patch_user(db, current_user_name, user)
crud_users.patch_user(db, current_user_name, user)
@app.post("/register", response_model=users.User)
async def register(user: users.UserCreate, db: Session = Depends(get_db)):
return crud.register(
db, user.username, user.password
)
return crud_authentication.register(db, user.username, user.password)
@app.post("/login")
async def login(user: users.UserCreate, db: Session = Depends(get_db)):
return crud.login(db, user)
return crud_authentication.login(db, user)
@app.get("/highscores", response_model=List[users.UserHighScore])
async def read_high_scores(
async def get_high_scores(
db: Session = Depends(get_db),
minigame: Optional[MinigameEnum] = None,
nr_highest: Optional[int] = None,
):
return crud.get_high_scores(db, minigame, nr_highest)
return crud_highscores.get_high_scores(db, minigame, nr_highest)
@app.post("/highscores", response_model=highscores.HighScore)
async def create_high_score(
high_score: highscores.HighScoreCreate, db: Session = Depends(get_db)
):
return crud.create_high_score(db=db, high_score=high_score)
return crud_highscores.create_high_score(db=db, high_score=high_score)
#### TESTING!! DELETE LATER
@@ -96,20 +100,10 @@ async def protected_route(current_user_name=Depends(get_current_user_name)):
return {"message": f"Hello, {current_user_name}!"}
def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)):
db_user = crud.get_user_by_username(db=db, username=user.username)
if not db_user:
return False
hashed_password = db_user.hashed_password
if not hashed_password or not pwd_context.verify(user.password, hashed_password):
return False
return db_user
@app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase])
async def get_course_progress(
course: Optional[CourseEnum] = CourseEnum.All,
current_user_name=Depends(get_current_user_name),
db: Session = Depends(get_db),
):
return crud.get_course_progress(db, current_user_name, course)
return crud_courseprogress.get_course_progress(db, current_user_name, course)