Minor stuff

This commit is contained in:
lvrossem 2023-04-01 10:03:18 -06:00
parent 65d1a2a6e4
commit d2933a95ba
9 changed files with 91 additions and 54 deletions

View File

@ -1,7 +1,8 @@
from datetime import datetime, timedelta
import jwt
from fastapi import HTTPException
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from crud.users import get_user_by_username, pwd_context
@ -15,6 +16,25 @@ jwt_secret = "secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month
bearer_scheme = HTTPBearer()
def get_current_user_name(
token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
):
try:
payload = jwt.decode(
token.credentials,
jwt_secret,
algorithms=[ALGORITHM],
)
username = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid JWT token")
return username
except jwt.exceptions.DecodeError:
raise HTTPException(status_code=401, detail="Invalid JWT token")
def authenticate_user(db: Session, username: str, password: str):
"""Checks whether the provided credentials match with an existing User"""
@ -29,10 +49,14 @@ def authenticate_user(db: Session, username: str, password: str):
def register(db: Session, username: str, password: str, avatar: str):
"""Register a new user"""
if avatar == "":
raise HTTPException(status_code=400, detail="No avatar was provided")
db_user = get_user_by_username(db, username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
db_user = User(username = username, hashed_password = pwd_context.hash(password), avatar = avatar)
db_user = User(
username=username, hashed_password=pwd_context.hash(password), avatar=avatar
)
db.add(db_user)
db.commit()
db.refresh(db_user)

View File

@ -2,7 +2,8 @@ from fastapi import HTTPException
from sqlalchemy.orm import Session
from enums import CourseEnum, course_enum_list
from models import User, CourseProgress
from models import CourseProgress, User
from schemas.courseprogress import CourseProgressBase
def get_course_progress(db: Session, user: User, course: CourseEnum):
@ -15,6 +16,7 @@ def get_course_progress(db: Session, user: User, course: CourseEnum):
)
.first()
)
if course_progress:
return [
CourseProgressBase(
@ -23,13 +25,16 @@ def get_course_progress(db: Session, user: User, course: CourseEnum):
)
]
else:
return [CourseProgressBase(progress_value = 0, course=course)]
db.add(
CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)
)
db.commit()
return [CourseProgressBase(progress_value=0.0, course=course)]
return []
def initialize_user(db: Session, user: User):
"""Create CourseProgress records with a value of 0 for a new user"""
for course in course_enum_list:
db.add(CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id))
db.commit()

View File

@ -31,13 +31,18 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int):
for high_score in high_scores:
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
user_high_scores.append(
UserHighScore(username = owner.username, score_value = high_score.score_value, avatar = owner.avatar)
UserHighScore(
username=owner.username,
score_value=high_score.score_value,
avatar=owner.avatar,
)
)
return user_high_scores
def create_high_score(db: Session, user: User, high_score: HighScoreBase):
"""Create a new high score for a given minigame"""
def add_to_db():
"""Helper function that adds new score to database; prevents code duplication"""
db_high_score = HighScore(

View File

@ -18,6 +18,7 @@ def patch_user(db: Session, username: str, user: UserCreate):
db_user.username = user.username
db_user.hashed_password = pwd_context.hash(user.password)
db_user.avatar = user.avatar
db.commit()

View File

@ -9,3 +9,11 @@ engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@ -28,13 +28,14 @@ class MinigameEnum(StrEnum):
class CourseEnum(StrEnum):
Fingerspelling = "Fingerspelling"
#Basics = "Basics"
Basics = "Basics"
Hobbies = "Hobbies"
Animals = "Animals"
Colors = "Colors"
FruitsVegetables = "FruitsVegetables"
All = "All"
# This is needed because for some reason iterating over an enum doesn't work properly...
course_enum_list = [
CourseEnum.Fingerspelling,
@ -43,5 +44,5 @@ course_enum_list = [
CourseEnum.Animals,
CourseEnum.Colors,
CourseEnum.FruitsVegetables,
CourseEnum.All
CourseEnum.All,
]

View File

@ -2,46 +2,19 @@ from typing import List, Optional
import jwt
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from crud import authentication as crud_authentication
from crud import courseprogress as crud_courseprogress
from crud import highscores as crud_highscores
from crud import users as crud_users
from database import SessionLocal, engine
from database import SessionLocal, engine, get_db
from enums import CourseEnum, MinigameEnum
from models import Base
from schemas import courseprogress, highscores, users
app = FastAPI()
Base.metadata.create_all(bind=engine)
bearer_scheme = HTTPBearer()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_current_user_name(
token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
):
try:
payload = jwt.decode(
token.credentials,
crud_authentication.jwt_secret,
algorithms=[crud_authentication.ALGORITHM],
)
username = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid JWT token")
return username
except jwt.exceptions.DecodeError:
raise HTTPException(status_code=401, detail="Invalid JWT token")
@app.get("/")
@ -57,15 +30,17 @@ async def read_users(db: Session = Depends(get_db)):
@app.patch("/users")
async def patch_current_user(
user: users.UserCreate,
current_user_name=Depends(get_current_user_name),
db: Session = Depends(get_db)
current_user_name=Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
crud_users.patch_user(db, current_user_name, user)
@app.post("/register")
async def register(user: users.UserCreate, db: Session = Depends(get_db)):
access_token = crud_authentication.register(db, user.username, user.password, user.avatar)
access_token = crud_authentication.register(
db, user.username, user.password, user.avatar
)
user = crud_users.get_user_by_username(db, user.username)
crud_courseprogress.initialize_user(db, user)
return access_token
@ -80,7 +55,7 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)):
async def get_high_scores(
minigame: Optional[MinigameEnum] = None,
nr_highest: Optional[int] = None,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
return crud_highscores.get_high_scores(db, minigame, nr_highest)
@ -88,8 +63,8 @@ async def get_high_scores(
@app.post("/highscores", response_model=highscores.HighScore)
async def create_high_score(
high_score: highscores.HighScoreBase,
current_user_name = Depends(get_current_user_name),
db: Session = Depends(get_db)
current_user_name=Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
current_user = crud_users.get_user_by_username(db, current_user_name)
return crud_highscores.create_high_score(db, current_user, high_score)
@ -99,15 +74,27 @@ async def create_high_score(
@app.get("/protected")
async def protected_route(current_user_name=Depends(get_current_user_name)):
async def protected_route(
current_user_name: str = Depends(crud_authentication.get_current_user_name),
):
return {"message": f"Hello, {current_user_name}!"}
@app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase])
async def get_course_progress(
course: Optional[CourseEnum] = CourseEnum.All,
current_user_name=Depends(get_current_user_name),
db: Session = Depends(get_db)
current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
current_user = crud_users.get_user_by_username(db, current_user_name)
return crud_courseprogress.get_course_progress(db, current_user, course)
@app.patch("/courseprogress")
async def get_course_progress(
current_user_name: str = Depends(crud_authentication.get_current_user_name),
course: Optional[CourseEnum] = CourseEnum.All,
db: Session = Depends(get_db),
):
current_user = crud_users.get_user_by_username(db, current_user_name)
return crud_courseprogress.get_course_progress(db, current_user, course)

View File

@ -7,6 +7,8 @@ from enums import CourseEnum, MinigameEnum, StrEnumType
class User(Base):
"""The database model for users"""
__tablename__ = "users"
user_id = Column(Integer, primary_key=True, index=True)
@ -23,6 +25,8 @@ class User(Base):
class HighScore(Base):
"""The database model for high scores"""
__tablename__ = "high_scores"
high_score_id = Column(Integer, primary_key=True, index=True)
@ -33,6 +37,8 @@ class HighScore(Base):
class CourseProgress(Base):
"""The database model for course progress"""
__tablename__ = "course_progress"
course_progress_id = Column(Integer, primary_key=True, index=True)

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel
class UserBase(BaseModel):
username: str
avatar: str
avatar: str = ""
class User(UserBase):