From 65d1a2a6e41a43146bf705b0e6cd7769d1bd729c Mon Sep 17 00:00:00 2001 From: lvrossem Date: Fri, 31 Mar 2023 11:43:07 -0600 Subject: [PATCH] Sort of fix StrEnum issue --- src/crud/authentication.py | 4 ++-- src/crud/courseprogress.py | 20 ++++++++++++------ src/crud/highscores.py | 43 ++++++++++++++++---------------------- src/enums.py | 13 +++++++++++- src/main.py | 24 ++++++++++++--------- src/models.py | 11 +++------- src/schemas/highscores.py | 6 +----- src/schemas/users.py | 4 +--- 8 files changed, 65 insertions(+), 60 deletions(-) diff --git a/src/crud/authentication.py b/src/crud/authentication.py index 0f7a9cf..5986a57 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -27,12 +27,12 @@ def authenticate_user(db: Session, username: str, password: str): return db_user -def register(db: Session, username: str, password: str): +def register(db: Session, username: str, password: str, avatar: str): """Register a new user""" db_user = get_user_by_username(db, username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") - db_user = User(username=username, hashed_password=pwd_context.hash(password)) + db_user = User(username = username, hashed_password = pwd_context.hash(password), avatar = avatar) db.add(db_user) db.commit() db.refresh(db_user) diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index c96c7db..6a2f17e 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -1,12 +1,12 @@ from fastapi import HTTPException from sqlalchemy.orm import Session -from enums import CourseEnum +from enums import CourseEnum, course_enum_list +from models import User, CourseProgress -def get_course_progress(db: Session, username: str, course: CourseEnum): +def get_course_progress(db: Session, user: User, course: CourseEnum): """Get the progress a user has for a certain course""" - user = get_user_by_username(db, username) if course != CourseEnum.All: course_progress = ( db.query(CourseProgress) @@ -18,10 +18,18 @@ def get_course_progress(db: Session, username: str, course: CourseEnum): if course_progress: return [ CourseProgressBase( - progress_value=course_progress.progress_value, - course=course_progress.course, + progress_value = course_progress.progress_value, + course = course_progress.course, ) ] else: - return [CourseProgressBase(progress_value=0, course=course)] + return [CourseProgressBase(progress_value = 0, course=course)] return [] + + +def initialize_user(db: Session, user: User): + for course in course_enum_list: + db.add(CourseProgress(progress_value = 0.0, course = course, owner_id = user.user_id)) + db.commit() + + diff --git a/src/crud/highscores.py b/src/crud/highscores.py index 0d21396..6471c34 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session from enums import MinigameEnum from models import HighScore, User -from schemas.highscores import HighScoreCreate +from schemas.highscores import HighScoreBase from schemas.users import UserHighScore @@ -31,45 +31,38 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): for high_score in high_scores: owner = db.query(User).filter(User.user_id == high_score.owner_id).first() user_high_scores.append( - UserHighScore(username=owner.username, score_value=high_score.score_value) + UserHighScore(username = owner.username, score_value = high_score.score_value, avatar = owner.avatar) ) return user_high_scores -def create_high_score(db: Session, high_score: HighScoreCreate): +def create_high_score(db: Session, user: User, high_score: HighScoreBase): """Create a new high score for a given minigame""" - owner = db.query(User).filter(User.user_id == high_score.owner_id).first() - if not owner: - raise HTTPException(status_code=400, detail="User does not exist") + def add_to_db(): + """Helper function that adds new score to database; prevents code duplication""" + db_high_score = HighScore( + score_value = high_score.score_value, + minigame = high_score.minigame, + owner_id = user.user_id, + ) + db.add(db_high_score) + db.commit() + db.refresh(db_high_score) + return db_high_score + old_high_score = ( db.query(HighScore) .filter( - HighScore.owner_id == high_score.owner_id, + HighScore.owner_id == user.user_id, HighScore.minigame == high_score.minigame, ) .first() ) if old_high_score: if old_high_score.score_value < high_score.score_value: - db_high_score = HighScore( - score_value=high_score.score_value, - minigame=high_score.minigame, - owner_id=high_score.owner_id, - ) db.delete(old_high_score) - db.add(db_high_score) - db.commit() - db.refresh(db_high_score) - return db_high_score + return add_to_db() else: return old_high_score else: - db_high_score = HighScore( - score_value=high_score.score_value, - minigame=high_score.minigame, - owner_id=high_score.owner_id, - ) - db.add(db_high_score) - db.commit() - db.refresh(db_high_score) - return db_high_score + return add_to_db() diff --git a/src/enums.py b/src/enums.py index 3ec4085..088d795 100644 --- a/src/enums.py +++ b/src/enums.py @@ -28,9 +28,20 @@ class MinigameEnum(StrEnum): class CourseEnum(StrEnum): Fingerspelling = "Fingerspelling" - Basics = "Basics" + #Basics = "Basics" Hobbies = "Hobbies" Animals = "Animals" Colors = "Colors" FruitsVegetables = "FruitsVegetables" All = "All" + +# This is needed because for some reason iterating over an enum doesn't work properly... +course_enum_list = [ + CourseEnum.Fingerspelling, + #CourseEnum.Basics, + CourseEnum.Hobbies, + CourseEnum.Animals, + CourseEnum.Colors, + CourseEnum.FruitsVegetables, + CourseEnum.All +] diff --git a/src/main.py b/src/main.py index fb0855d..5a9f139 100644 --- a/src/main.py +++ b/src/main.py @@ -15,10 +15,7 @@ from models import Base from schemas import courseprogress, highscores, users app = FastAPI() - - Base.metadata.create_all(bind=engine) - bearer_scheme = HTTPBearer() @@ -61,14 +58,17 @@ async def read_users(db: Session = Depends(get_db)): async def patch_current_user( user: users.UserCreate, current_user_name=Depends(get_current_user_name), - db: Session = Depends(get_db), + db: Session = Depends(get_db) ): crud_users.patch_user(db, current_user_name, user) @app.post("/register") async def register(user: users.UserCreate, db: Session = Depends(get_db)): - return crud_authentication.register(db, user.username, user.password) + access_token = crud_authentication.register(db, user.username, user.password, user.avatar) + user = crud_users.get_user_by_username(db, user.username) + crud_courseprogress.initialize_user(db, user) + return access_token @app.post("/login") @@ -78,18 +78,21 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): @app.get("/highscores", response_model=List[users.UserHighScore]) async def get_high_scores( - db: Session = Depends(get_db), minigame: Optional[MinigameEnum] = None, nr_highest: Optional[int] = None, + db: Session = Depends(get_db) ): return crud_highscores.get_high_scores(db, minigame, nr_highest) @app.post("/highscores", response_model=highscores.HighScore) async def create_high_score( - high_score: highscores.HighScoreCreate, db: Session = Depends(get_db) + high_score: highscores.HighScoreBase, + current_user_name = Depends(get_current_user_name), + db: Session = Depends(get_db) ): - return crud_highscores.create_high_score(db=db, high_score=high_score) + current_user = crud_users.get_user_by_username(db, current_user_name) + return crud_highscores.create_high_score(db, current_user, high_score) #### TESTING!! DELETE LATER @@ -104,6 +107,7 @@ async def protected_route(current_user_name=Depends(get_current_user_name)): async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, current_user_name=Depends(get_current_user_name), - db: Session = Depends(get_db), + db: Session = Depends(get_db) ): - return crud_courseprogress.get_course_progress(db, current_user_name, course) + current_user = crud_users.get_user_by_username(db, current_user_name) + return crud_courseprogress.get_course_progress(db, current_user, course) diff --git a/src/models.py b/src/models.py index 6c7f3c6..948f6b8 100644 --- a/src/models.py +++ b/src/models.py @@ -12,6 +12,7 @@ class User(Base): user_id = Column(Integer, primary_key=True, index=True) username = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) + avatar = Column(String, nullable=False) high_scores = relationship( "HighScore", back_populates="owner", cascade="all, delete", lazy="dynamic" @@ -20,19 +21,13 @@ class User(Base): "CourseProgress", back_populates="owner", cascade="all, delete", lazy="dynamic" ) - # add a new column to store the high_score IDs - high_score_ids = Column(ARRAY(Integer), default=[]) - - # add a new column to store the course_progress IDs - course_progress_ids = Column(ARRAY(Integer), default=[]) - class HighScore(Base): __tablename__ = "high_scores" high_score_id = Column(Integer, primary_key=True, index=True) score_value = Column(Float, nullable=False) - minigame = Column(StrEnumType(MinigameEnum), nullable=False) + minigame = Column(String, nullable=False) owner_id = Column(Integer, ForeignKey("users.user_id")) owner = relationship("User", back_populates="high_scores") @@ -42,6 +37,6 @@ class CourseProgress(Base): course_progress_id = Column(Integer, primary_key=True, index=True) progress_value = Column(Float, nullable=False) - course = Column(StrEnumType(CourseEnum), nullable=False) + course = Column(String, nullable=False) owner_id = Column(Integer, ForeignKey("users.user_id")) owner = relationship("User", back_populates="course_progress") diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index afec43f..3f1879a 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -6,15 +6,11 @@ from enums import MinigameEnum class HighScoreBase(BaseModel): score_value: float minigame: MinigameEnum - owner_id: int - - -class HighScoreCreate(HighScoreBase): - pass class HighScore(HighScoreBase): high_score_id: int + owner_id: int class Config: orm_mode = True diff --git a/src/schemas/users.py b/src/schemas/users.py index eed0d7d..a14c4a3 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -5,15 +5,13 @@ from pydantic import BaseModel class UserBase(BaseModel): username: str + avatar: str class User(UserBase): user_id: int hashed_password: str - high_score_ids: List[int] = [] - course_progress_ids: List[int] = [] - class Config: orm_mode = True