Sort of fix StrEnum issue

This commit is contained in:
lvrossem
2023-03-31 11:43:07 -06:00
parent edd50b9ecb
commit 65d1a2a6e4
8 changed files with 65 additions and 60 deletions

View File

@@ -27,12 +27,12 @@ def authenticate_user(db: Session, username: str, password: str):
return db_user
def register(db: Session, username: str, password: str):
def register(db: Session, username: str, password: str, avatar: str):
"""Register a new user"""
db_user = get_user_by_username(db, username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
db_user = User(username=username, hashed_password=pwd_context.hash(password))
db_user = User(username = username, hashed_password = pwd_context.hash(password), avatar = avatar)
db.add(db_user)
db.commit()
db.refresh(db_user)

View File

@@ -1,12 +1,12 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session
from enums import CourseEnum
from enums import CourseEnum, course_enum_list
from models import User, CourseProgress
def get_course_progress(db: Session, username: str, course: CourseEnum):
def get_course_progress(db: Session, user: User, course: CourseEnum):
"""Get the progress a user has for a certain course"""
user = get_user_by_username(db, username)
if course != CourseEnum.All:
course_progress = (
db.query(CourseProgress)
@@ -18,10 +18,18 @@ def get_course_progress(db: Session, username: str, course: CourseEnum):
if course_progress:
return [
CourseProgressBase(
progress_value=course_progress.progress_value,
course=course_progress.course,
progress_value = course_progress.progress_value,
course = course_progress.course,
)
]
else:
return [CourseProgressBase(progress_value=0, course=course)]
return [CourseProgressBase(progress_value = 0, course=course)]
return []
def initialize_user(db: Session, user: User):
for course in course_enum_list:
db.add(CourseProgress(progress_value = 0.0, course = course, owner_id = user.user_id))
db.commit()

View File

@@ -4,7 +4,7 @@ from sqlalchemy.orm import Session
from enums import MinigameEnum
from models import HighScore, User
from schemas.highscores import HighScoreCreate
from schemas.highscores import HighScoreBase
from schemas.users import UserHighScore
@@ -31,45 +31,38 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int):
for high_score in high_scores:
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
user_high_scores.append(
UserHighScore(username=owner.username, score_value=high_score.score_value)
UserHighScore(username = owner.username, score_value = high_score.score_value, avatar = owner.avatar)
)
return user_high_scores
def create_high_score(db: Session, high_score: HighScoreCreate):
def create_high_score(db: Session, user: User, high_score: HighScoreBase):
"""Create a new high score for a given minigame"""
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
if not owner:
raise HTTPException(status_code=400, detail="User does not exist")
def add_to_db():
"""Helper function that adds new score to database; prevents code duplication"""
db_high_score = HighScore(
score_value = high_score.score_value,
minigame = high_score.minigame,
owner_id = user.user_id,
)
db.add(db_high_score)
db.commit()
db.refresh(db_high_score)
return db_high_score
old_high_score = (
db.query(HighScore)
.filter(
HighScore.owner_id == high_score.owner_id,
HighScore.owner_id == user.user_id,
HighScore.minigame == high_score.minigame,
)
.first()
)
if old_high_score:
if old_high_score.score_value < high_score.score_value:
db_high_score = HighScore(
score_value=high_score.score_value,
minigame=high_score.minigame,
owner_id=high_score.owner_id,
)
db.delete(old_high_score)
db.add(db_high_score)
db.commit()
db.refresh(db_high_score)
return db_high_score
return add_to_db()
else:
return old_high_score
else:
db_high_score = HighScore(
score_value=high_score.score_value,
minigame=high_score.minigame,
owner_id=high_score.owner_id,
)
db.add(db_high_score)
db.commit()
db.refresh(db_high_score)
return db_high_score
return add_to_db()

View File

@@ -28,9 +28,20 @@ class MinigameEnum(StrEnum):
class CourseEnum(StrEnum):
Fingerspelling = "Fingerspelling"
Basics = "Basics"
#Basics = "Basics"
Hobbies = "Hobbies"
Animals = "Animals"
Colors = "Colors"
FruitsVegetables = "FruitsVegetables"
All = "All"
# This is needed because for some reason iterating over an enum doesn't work properly...
course_enum_list = [
CourseEnum.Fingerspelling,
#CourseEnum.Basics,
CourseEnum.Hobbies,
CourseEnum.Animals,
CourseEnum.Colors,
CourseEnum.FruitsVegetables,
CourseEnum.All
]

View File

@@ -15,10 +15,7 @@ from models import Base
from schemas import courseprogress, highscores, users
app = FastAPI()
Base.metadata.create_all(bind=engine)
bearer_scheme = HTTPBearer()
@@ -61,14 +58,17 @@ async def read_users(db: Session = Depends(get_db)):
async def patch_current_user(
user: users.UserCreate,
current_user_name=Depends(get_current_user_name),
db: Session = Depends(get_db),
db: Session = Depends(get_db)
):
crud_users.patch_user(db, current_user_name, user)
@app.post("/register")
async def register(user: users.UserCreate, db: Session = Depends(get_db)):
return crud_authentication.register(db, user.username, user.password)
access_token = crud_authentication.register(db, user.username, user.password, user.avatar)
user = crud_users.get_user_by_username(db, user.username)
crud_courseprogress.initialize_user(db, user)
return access_token
@app.post("/login")
@@ -78,18 +78,21 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)):
@app.get("/highscores", response_model=List[users.UserHighScore])
async def get_high_scores(
db: Session = Depends(get_db),
minigame: Optional[MinigameEnum] = None,
nr_highest: Optional[int] = None,
db: Session = Depends(get_db)
):
return crud_highscores.get_high_scores(db, minigame, nr_highest)
@app.post("/highscores", response_model=highscores.HighScore)
async def create_high_score(
high_score: highscores.HighScoreCreate, db: Session = Depends(get_db)
high_score: highscores.HighScoreBase,
current_user_name = Depends(get_current_user_name),
db: Session = Depends(get_db)
):
return crud_highscores.create_high_score(db=db, high_score=high_score)
current_user = crud_users.get_user_by_username(db, current_user_name)
return crud_highscores.create_high_score(db, current_user, high_score)
#### TESTING!! DELETE LATER
@@ -104,6 +107,7 @@ async def protected_route(current_user_name=Depends(get_current_user_name)):
async def get_course_progress(
course: Optional[CourseEnum] = CourseEnum.All,
current_user_name=Depends(get_current_user_name),
db: Session = Depends(get_db),
db: Session = Depends(get_db)
):
return crud_courseprogress.get_course_progress(db, current_user_name, course)
current_user = crud_users.get_user_by_username(db, current_user_name)
return crud_courseprogress.get_course_progress(db, current_user, course)

View File

@@ -12,6 +12,7 @@ class User(Base):
user_id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
avatar = Column(String, nullable=False)
high_scores = relationship(
"HighScore", back_populates="owner", cascade="all, delete", lazy="dynamic"
@@ -20,19 +21,13 @@ class User(Base):
"CourseProgress", back_populates="owner", cascade="all, delete", lazy="dynamic"
)
# add a new column to store the high_score IDs
high_score_ids = Column(ARRAY(Integer), default=[])
# add a new column to store the course_progress IDs
course_progress_ids = Column(ARRAY(Integer), default=[])
class HighScore(Base):
__tablename__ = "high_scores"
high_score_id = Column(Integer, primary_key=True, index=True)
score_value = Column(Float, nullable=False)
minigame = Column(StrEnumType(MinigameEnum), nullable=False)
minigame = Column(String, nullable=False)
owner_id = Column(Integer, ForeignKey("users.user_id"))
owner = relationship("User", back_populates="high_scores")
@@ -42,6 +37,6 @@ class CourseProgress(Base):
course_progress_id = Column(Integer, primary_key=True, index=True)
progress_value = Column(Float, nullable=False)
course = Column(StrEnumType(CourseEnum), nullable=False)
course = Column(String, nullable=False)
owner_id = Column(Integer, ForeignKey("users.user_id"))
owner = relationship("User", back_populates="course_progress")

View File

@@ -6,15 +6,11 @@ from enums import MinigameEnum
class HighScoreBase(BaseModel):
score_value: float
minigame: MinigameEnum
owner_id: int
class HighScoreCreate(HighScoreBase):
pass
class HighScore(HighScoreBase):
high_score_id: int
owner_id: int
class Config:
orm_mode = True

View File

@@ -5,15 +5,13 @@ from pydantic import BaseModel
class UserBase(BaseModel):
username: str
avatar: str
class User(UserBase):
user_id: int
hashed_password: str
high_score_ids: List[int] = []
course_progress_ids: List[int] = []
class Config:
orm_mode = True