From 51c27f8bc62209712ad162efaa0f593afcf40794 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Mon, 27 Mar 2023 15:38:13 +0200 Subject: [PATCH 01/37] Repo setup --- .gitignore | 3 +++ requirements.txt | 4 ++++ src/main.py | 8 ++++++++ 3 files changed, 15 insertions(+) create mode 100644 .gitignore create mode 100644 requirements.txt create mode 100644 src/main.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..804f5ce --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +venv/* +venv +src/__pycache__* \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b4352ef --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +fastapi +pydantic +sqlalchemy +uvicorn[standard] \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..d786432 --- /dev/null +++ b/src/main.py @@ -0,0 +1,8 @@ +from fastapi import FastAPI + +app = FastAPI() + + +@app.get("/") +async def root(): + return {"message": "Hello World"} \ No newline at end of file From b16198a816d97cdf0f957988ad60b2018343d194 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Mon, 27 Mar 2023 17:24:33 +0200 Subject: [PATCH 02/37] Start working on database models --- src/database.py | 12 ++++++++++++ src/enums.py | 11 +++++++++++ src/models.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+) create mode 100644 src/database.py create mode 100644 src/enums.py create mode 100644 src/models.py diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000..16e67b5 --- /dev/null +++ b/src/database.py @@ -0,0 +1,12 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +# SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" +SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db" + +engine = create_engine(SQLALCHEMY_DATABASE_URL) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() \ No newline at end of file diff --git a/src/enums.py b/src/enums.py new file mode 100644 index 0000000..036aabd --- /dev/null +++ b/src/enums.py @@ -0,0 +1,11 @@ +from enum import Enum + +class MinigameEnum(str, Enum): + SpellingBee = 'SpellingBee' + Hangman = 'Hangman' + JustSign = 'JustSign' + + +class CourseEnum(str, Enum): + Fingerspelling = 'Fingerspelling' + Animals = 'Animals' \ No newline at end of file diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000..7982cc6 --- /dev/null +++ b/src/models.py @@ -0,0 +1,34 @@ +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + +from .enums import MinigameEnum, CourseEnum +from .database import Base + + +class User(Base): + __tablename__ = "users" + + user_id = Column(Integer, primary_key=True, index=True) + username = Column(String, unique=True, index=True, nullable=False) + hashed_password = Column(String, nullable=False) + + high_scores = relationship("HighScore", back_populates="owner") + course_progresses = relationship("CourseProgress", back_populates="owner") + + +class HighScore(Base): + __tablename__ = "high_scores" + + high_score_id = Column(Integer, primary_key=True, index=True) + score_value = Column(Integer, nullable=False) + minigame = Column(Enum(MinigameEnum), nullable=False) + owner = Column(Integer, ForeignKey("users.user_id")) + + +class CourseProgress(Base): + __tablename__ = "course_progress" + + course_progress_id = Column(Integer, primary_key=True, index=True) + course = Column(Enum(CourseEnum), nullable=False) + owner = Column(Integer, ForeignKey("users.user_id")) + \ No newline at end of file From e9b2ea4188eae16b20c26b9a01f27a7609437ceb Mon Sep 17 00:00:00 2001 From: lvrossem Date: Mon, 27 Mar 2023 19:48:48 +0200 Subject: [PATCH 03/37] Add CRUD functions and Pydantic schemas --- src/crud.py | 25 +++++++++++++++++++++++++ src/database.py | 2 +- src/models.py | 5 +++-- src/schemas.py | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 src/crud.py create mode 100644 src/schemas.py diff --git a/src/crud.py b/src/crud.py new file mode 100644 index 0000000..77c25aa --- /dev/null +++ b/src/crud.py @@ -0,0 +1,25 @@ +from sqlalchemy.orm import Session + +from models import User, HighScore, CourseProgress +from . import models, schemas + + +def get_user(db: Session, user_id: int): + return db.query(models.User).filter(models.User.id == user_id).first() + + +def get_user_by_username(db: Session, username: str): + return db.query(User).filter(User.email == email).first() + + +def get_users(db: Session, skip: int = 0, limit: int = 100): + return db.query(User).all() + + +def create_user(db: Session, username: str, hashed_password: str): + db_user = models.User(username=username, hashed_password=hashed_password) + db.add(db_user) + db.commit() + db.refresh(db_user) + return db_user + diff --git a/src/database.py b/src/database.py index 16e67b5..e7d7dbe 100644 --- a/src/database.py +++ b/src/database.py @@ -3,7 +3,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker # SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" -SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db" +SQLALCHEMY_DATABASE_URL = 'postgresql://user:password@postgresserver/wesign-dev' engine = create_engine(SQLALCHEMY_DATABASE_URL) diff --git a/src/models.py b/src/models.py index 7982cc6..4728139 100644 --- a/src/models.py +++ b/src/models.py @@ -1,4 +1,4 @@ -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float from sqlalchemy.orm import relationship from .enums import MinigameEnum, CourseEnum @@ -20,7 +20,7 @@ class HighScore(Base): __tablename__ = "high_scores" high_score_id = Column(Integer, primary_key=True, index=True) - score_value = Column(Integer, nullable=False) + score_value = Column(Float, nullable=False) minigame = Column(Enum(MinigameEnum), nullable=False) owner = Column(Integer, ForeignKey("users.user_id")) @@ -29,6 +29,7 @@ class CourseProgress(Base): __tablename__ = "course_progress" course_progress_id = Column(Integer, primary_key=True, index=True) + progress_value = Column(Float, nullable=False) course = Column(Enum(CourseEnum), nullable=False) owner = Column(Integer, ForeignKey("users.user_id")) \ No newline at end of file diff --git a/src/schemas.py b/src/schemas.py new file mode 100644 index 0000000..9472883 --- /dev/null +++ b/src/schemas.py @@ -0,0 +1,33 @@ +from pydantic import BaseModel +from enums import MinigameEnum, CourseEnum + +class User(BaseModel): + user_id: int + username: str + hashed_password: str + + high_scores: list[HighScore] = [] + course_progresses: list[CourseProgress] = [] + + class Config: + orm_mode = True + + +class HighScore(BaseModel): + high_score_id: int + score_value: float + minigame: MinigameEnum + owner: User + + class Config: + orm_mode = True + + +class CourseProgress(BaseModel): + course_progress_id: int + progress_value: float + course: CourseEnum + owner: User + + class Config: + orm_mode = True \ No newline at end of file From 3de642cfdc01fc35e41f8d4155fc94cdfcf6f63c Mon Sep 17 00:00:00 2001 From: lvrossem Date: Mon, 27 Mar 2023 22:16:23 +0200 Subject: [PATCH 04/37] Database issues, for now --- src/main.py | 24 +++++++++++++++++++++++- src/models.py | 4 ++-- src/schemas.py | 5 +++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/main.py b/src/main.py index d786432..813da20 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,29 @@ -from fastapi import FastAPI +from fastapi import Depends, FastAPI, HTTPException +from sqlalchemy.orm import Session +from models import Base +from database import SessionLocal, engine +from schemas import UserCreate +import crud app = FastAPI() +Base.metadata.create_all(bind=engine) + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + +@app.post("/users/", response_model=schemas.User) +def create_user(user: UserCreate, db: Session = Depends(get_db)): + db_user = crud.get_user_by_email(db, email=user.email) + if db_user: + raise HTTPException(status_code=400, detail="Email already registered") + return crud.create_user(db=db, user=user) + @app.get("/") async def root(): diff --git a/src/models.py b/src/models.py index 4728139..d3aaf93 100644 --- a/src/models.py +++ b/src/models.py @@ -1,8 +1,8 @@ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float from sqlalchemy.orm import relationship -from .enums import MinigameEnum, CourseEnum -from .database import Base +from enums import MinigameEnum, CourseEnum +from database import Base class User(Base): diff --git a/src/schemas.py b/src/schemas.py index 9472883..0f084f2 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -13,6 +13,11 @@ class User(BaseModel): orm_mode = True +class UserCreate(BaseModel): + username: str + password: str + + class HighScore(BaseModel): high_score_id: int score_value: float From 765f3e9befddaed3a1737161fc248f6bed554858 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Mon, 27 Mar 2023 14:28:30 -0600 Subject: [PATCH 05/37] Dependency updates --- requirements.txt | 4 +++- src/enums.py | 8 ++++---- src/models.py | 3 ++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index b4352ef..36b6aad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ fastapi pydantic sqlalchemy -uvicorn[standard] \ No newline at end of file +uvicorn[standard] +psycopg2-binary +fastapi_utils diff --git a/src/enums.py b/src/enums.py index 036aabd..b9e9f43 100644 --- a/src/enums.py +++ b/src/enums.py @@ -1,11 +1,11 @@ -from enum import Enum +from fastapi_utils.enums import StrEnum -class MinigameEnum(str, Enum): +class MinigameEnum(StrEnum): SpellingBee = 'SpellingBee' Hangman = 'Hangman' JustSign = 'JustSign' -class CourseEnum(str, Enum): +class CourseEnum(StrEnum): Fingerspelling = 'Fingerspelling' - Animals = 'Animals' \ No newline at end of file + Animals = 'Animals' diff --git a/src/models.py b/src/models.py index d3aaf93..522af50 100644 --- a/src/models.py +++ b/src/models.py @@ -1,5 +1,6 @@ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float from sqlalchemy.orm import relationship +from enum import Enum from enums import MinigameEnum, CourseEnum from database import Base @@ -32,4 +33,4 @@ class CourseProgress(Base): progress_value = Column(Float, nullable=False) course = Column(Enum(CourseEnum), nullable=False) owner = Column(Integer, ForeignKey("users.user_id")) - \ No newline at end of file + From fc63176642e246d3dd3065f391dfa366d3e557fe Mon Sep 17 00:00:00 2001 From: lvrossem Date: Mon, 27 Mar 2023 23:19:59 +0200 Subject: [PATCH 06/37] Enough for today --- .gitignore | 2 +- src/crud.py | 15 +++++++------- src/enums.py | 20 ++++++++++++++++++ src/main.py | 4 ++-- src/models.py | 6 +++--- src/schemas.py | 38 ----------------------------------- src/schemas/courseprogress.py | 12 +++++++++++ src/schemas/highscores.py | 17 ++++++++++++++++ src/schemas/users.py | 18 +++++++++++++++++ 9 files changed, 80 insertions(+), 52 deletions(-) delete mode 100644 src/schemas.py create mode 100644 src/schemas/courseprogress.py create mode 100644 src/schemas/highscores.py create mode 100644 src/schemas/users.py diff --git a/.gitignore b/.gitignore index 804f5ce..30fbb48 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ venv/* venv -src/__pycache__* \ No newline at end of file +*__pycache__* \ No newline at end of file diff --git a/src/crud.py b/src/crud.py index 77c25aa..ae879e3 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,23 +1,22 @@ from sqlalchemy.orm import Session from models import User, HighScore, CourseProgress -from . import models, schemas - +import schemas, models def get_user(db: Session, user_id: int): - return db.query(models.User).filter(models.User.id == user_id).first() + return db.query(models.User).filter(models.User.user_id == user_id).first() def get_user_by_username(db: Session, username: str): - return db.query(User).filter(User.email == email).first() + return db.query(User).filter(models.User.username == username).first() -def get_users(db: Session, skip: int = 0, limit: int = 100): - return db.query(User).all() +def get_users(db: Session): + return db.query(models.User).all() -def create_user(db: Session, username: str, hashed_password: str): - db_user = models.User(username=username, hashed_password=hashed_password) +def create_user(db: Session, user: schemas.users.UserCreate): + db_user = models.User(username=user.username, hashed_password=user.hashed_password) db.add(db_user) db.commit() db.refresh(db_user) diff --git a/src/enums.py b/src/enums.py index b9e9f43..01e323e 100644 --- a/src/enums.py +++ b/src/enums.py @@ -1,4 +1,24 @@ from fastapi_utils.enums import StrEnum +from sqlalchemy.types import TypeDecorator, Enum + + +class StrEnumType(TypeDecorator): + impl = Enum + + def __init__(self, enum_class, **kw): + self.enum_class = enum_class + super().__init__(enum_class, **kw) + + def process_bind_param(self, value, dialect): + if value is None: + return None + return value.value + + def process_result_value(self, value, dialect): + if value is None: + return None + return self.enum_class(value) + class MinigameEnum(StrEnum): SpellingBee = 'SpellingBee' diff --git a/src/main.py b/src/main.py index 813da20..0df2cce 100644 --- a/src/main.py +++ b/src/main.py @@ -2,7 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException from sqlalchemy.orm import Session from models import Base from database import SessionLocal, engine -from schemas import UserCreate +from schemas.users import UserCreate, User import crud app = FastAPI() @@ -17,7 +17,7 @@ def get_db(): db.close() -@app.post("/users/", response_model=schemas.User) +@app.post("/users/", response_model=User) def create_user(user: UserCreate, db: Session = Depends(get_db)): db_user = crud.get_user_by_email(db, email=user.email) if db_user: diff --git a/src/models.py b/src/models.py index 522af50..ffc5f60 100644 --- a/src/models.py +++ b/src/models.py @@ -2,7 +2,7 @@ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float from sqlalchemy.orm import relationship from enum import Enum -from enums import MinigameEnum, CourseEnum +from enums import MinigameEnum, CourseEnum, StrEnumType from database import Base @@ -22,7 +22,7 @@ class HighScore(Base): high_score_id = Column(Integer, primary_key=True, index=True) score_value = Column(Float, nullable=False) - minigame = Column(Enum(MinigameEnum), nullable=False) + minigame = Column(StrEnumType(MinigameEnum), nullable=False) owner = Column(Integer, ForeignKey("users.user_id")) @@ -31,6 +31,6 @@ class CourseProgress(Base): course_progress_id = Column(Integer, primary_key=True, index=True) progress_value = Column(Float, nullable=False) - course = Column(Enum(CourseEnum), nullable=False) + course = Column(StrEnumType(CourseEnum), nullable=False) owner = Column(Integer, ForeignKey("users.user_id")) diff --git a/src/schemas.py b/src/schemas.py deleted file mode 100644 index 0f084f2..0000000 --- a/src/schemas.py +++ /dev/null @@ -1,38 +0,0 @@ -from pydantic import BaseModel -from enums import MinigameEnum, CourseEnum - -class User(BaseModel): - user_id: int - username: str - hashed_password: str - - high_scores: list[HighScore] = [] - course_progresses: list[CourseProgress] = [] - - class Config: - orm_mode = True - - -class UserCreate(BaseModel): - username: str - password: str - - -class HighScore(BaseModel): - high_score_id: int - score_value: float - minigame: MinigameEnum - owner: User - - class Config: - orm_mode = True - - -class CourseProgress(BaseModel): - course_progress_id: int - progress_value: float - course: CourseEnum - owner: User - - class Config: - orm_mode = True \ No newline at end of file diff --git a/src/schemas/courseprogress.py b/src/schemas/courseprogress.py new file mode 100644 index 0000000..8f42a4a --- /dev/null +++ b/src/schemas/courseprogress.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel +from enums import CourseEnum + + +class CourseProgress(BaseModel): + course_progress_id: int + progress_value: float + course: CourseEnum + owner: int + + class Config: + orm_mode = True diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py new file mode 100644 index 0000000..3cf35af --- /dev/null +++ b/src/schemas/highscores.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel +from src.enums import MinigameEnum + + +class HighScore(BaseModel): + high_score_id: int + score_value: float + minigame: MinigameEnum + owner_id: "User" + + class Config: + orm_mode = True + + +# It's ugly, but I have no choice +from users import User +HighScore.update_forward_refs() \ No newline at end of file diff --git a/src/schemas/users.py b/src/schemas/users.py new file mode 100644 index 0000000..7ae799b --- /dev/null +++ b/src/schemas/users.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + + +class User(BaseModel): + user_id: int + username: str + hashed_password: str + + high_scores: list[int] = [] + course_progresses: list[int] = [] + + class Config: + orm_mode = True + + +class UserCreate(BaseModel): + username: str + password: str From 3c49985a83bedd9a1231dfd5f80c7bffabe12523 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Tue, 28 Mar 2023 13:25:02 +0200 Subject: [PATCH 07/37] Fix type error --- src/schemas/users.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/schemas/users.py b/src/schemas/users.py index 7ae799b..6a020d3 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -6,8 +6,8 @@ class User(BaseModel): username: str hashed_password: str - high_scores: list[int] = [] - course_progresses: list[int] = [] + high_scores: list[] = [] + course_progresses: list[] = [] class Config: orm_mode = True From 3a42c13026531a5c7f5836b29c450fa58f3dc67e Mon Sep 17 00:00:00 2001 From: lvrossem Date: Tue, 28 Mar 2023 05:31:52 -0600 Subject: [PATCH 08/37] Use List from typing module instead of list --- src/schemas/users.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/schemas/users.py b/src/schemas/users.py index 6a020d3..3322c43 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -1,4 +1,5 @@ from pydantic import BaseModel +from typing import List class User(BaseModel): @@ -6,8 +7,8 @@ class User(BaseModel): username: str hashed_password: str - high_scores: list[] = [] - course_progresses: list[] = [] + high_scores: List[int] = [] + course_progresses: List[int] = [] class Config: orm_mode = True From b57cbe52a2a5207567aa74c7a99d7d8a239782a7 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Tue, 28 Mar 2023 17:26:50 +0200 Subject: [PATCH 09/37] Fix email issue --- src/main.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/main.py b/src/main.py index 0df2cce..85468b6 100644 --- a/src/main.py +++ b/src/main.py @@ -19,12 +19,18 @@ def get_db(): @app.post("/users/", response_model=User) def create_user(user: UserCreate, db: Session = Depends(get_db)): - db_user = crud.get_user_by_email(db, email=user.email) + db_user = crud.get_user_by_username(db, username=user.username) if db_user: raise HTTPException(status_code=400, detail="Email already registered") return crud.create_user(db=db, user=user) +@app.get("/users/", response_model=list[User]) +def read_users(db: Session = Depends(get_db)): + users = crud.get_users(db) + return users + + @app.get("/") async def root(): return {"message": "Hello World"} \ No newline at end of file From 252d844446c98a2623666fcc69ce715f8d4c8ed6 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Tue, 28 Mar 2023 10:44:39 -0600 Subject: [PATCH 10/37] Fix request path issues --- src/crud.py | 6 +++--- src/database.py | 4 ++-- src/main.py | 21 +++++++++++++-------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/crud.py b/src/crud.py index ae879e3..8255e2d 100644 --- a/src/crud.py +++ b/src/crud.py @@ -4,15 +4,15 @@ from models import User, HighScore, CourseProgress import schemas, models def get_user(db: Session, user_id: int): - return db.query(models.User).filter(models.User.user_id == user_id).first() + return db.query(User).filter(User.user_id == user_id).first() def get_user_by_username(db: Session, username: str): - return db.query(User).filter(models.User.username == username).first() + return db.query(User).filter(User.username == username).first() def get_users(db: Session): - return db.query(models.User).all() + return db.query(User).all() def create_user(db: Session, user: schemas.users.UserCreate): diff --git a/src/database.py b/src/database.py index e7d7dbe..9be44c1 100644 --- a/src/database.py +++ b/src/database.py @@ -3,10 +3,10 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker # SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" -SQLALCHEMY_DATABASE_URL = 'postgresql://user:password@postgresserver/wesign-dev' +SQLALCHEMY_DATABASE_URL = 'postgresql://admin:WeSign123!@localhost/wesigndev' engine = create_engine(SQLALCHEMY_DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() \ No newline at end of file +Base = declarative_base() diff --git a/src/main.py b/src/main.py index 85468b6..2ebd2e3 100644 --- a/src/main.py +++ b/src/main.py @@ -3,6 +3,7 @@ from sqlalchemy.orm import Session from models import Base from database import SessionLocal, engine from schemas.users import UserCreate, User +from typing import List import crud app = FastAPI() @@ -12,25 +13,29 @@ Base.metadata.create_all(bind=engine) def get_db(): db = SessionLocal() try: + print("Yield") yield db finally: + print("Close") db.close() -@app.post("/users/", response_model=User) -def create_user(user: UserCreate, db: Session = Depends(get_db)): +@app.get("/") +async def root(): + print("Hello world") + return {"message": "Hello world!"} + +@app.post("/users", response_model=User) +async def create_user(user: UserCreate, db: Session = Depends(get_db)): db_user = crud.get_user_by_username(db, username=user.username) if db_user: raise HTTPException(status_code=400, detail="Email already registered") return crud.create_user(db=db, user=user) -@app.get("/users/", response_model=list[User]) -def read_users(db: Session = Depends(get_db)): +@app.get("/users")#, response_model=List[User]) +async def read_users(db: Session = Depends(get_db)): + print("here") users = crud.get_users(db) return users - -@app.get("/") -async def root(): - return {"message": "Hello World"} \ No newline at end of file From b7f41e85961916a4525a36de5d470069f474f97a Mon Sep 17 00:00:00 2001 From: lvrossem Date: Tue, 28 Mar 2023 13:01:20 -0600 Subject: [PATCH 11/37] Working get request for users --- src/models.py | 20 +++++++++++++++----- src/schemas/highscores.py | 7 +------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/models.py b/src/models.py index ffc5f60..898c2cd 100644 --- a/src/models.py +++ b/src/models.py @@ -1,5 +1,6 @@ from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float from sqlalchemy.orm import relationship +from sqlalchemy.dialects.postgresql import ARRAY from enum import Enum from enums import MinigameEnum, CourseEnum, StrEnumType @@ -13,8 +14,14 @@ class User(Base): username = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) - high_scores = relationship("HighScore", back_populates="owner") - course_progresses = relationship("CourseProgress", back_populates="owner") + high_scores = relationship("HighScore", back_populates="owner", cascade="all, delete", lazy="dynamic") + course_progress = relationship("CourseProgress", back_populates="owner", cascade="all, delete", lazy="dynamic") + + # add a new column to store the high_score IDs + high_score_ids = Column(ARRAY(Integer), default=[]) + + # add a new column to store the course_progress IDs + course_progress_ids = Column(ARRAY(Integer), default=[]) class HighScore(Base): @@ -23,7 +30,9 @@ class HighScore(Base): high_score_id = Column(Integer, primary_key=True, index=True) score_value = Column(Float, nullable=False) minigame = Column(StrEnumType(MinigameEnum), nullable=False) - owner = Column(Integer, ForeignKey("users.user_id")) + owner_id = Column(Integer, ForeignKey("users.user_id")) + owner = relationship("User", back_populates="high_scores") + class CourseProgress(Base): @@ -32,5 +41,6 @@ class CourseProgress(Base): course_progress_id = Column(Integer, primary_key=True, index=True) progress_value = Column(Float, nullable=False) course = Column(StrEnumType(CourseEnum), nullable=False) - owner = Column(Integer, ForeignKey("users.user_id")) - + owner_id = Column(Integer, ForeignKey("users.user_id")) + owner = relationship("User", back_populates="course_progress") + diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index 3cf35af..726e647 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -6,12 +6,7 @@ class HighScore(BaseModel): high_score_id: int score_value: float minigame: MinigameEnum - owner_id: "User" + owner_id: int class Config: orm_mode = True - - -# It's ugly, but I have no choice -from users import User -HighScore.update_forward_refs() \ No newline at end of file From 5eba17f793b2efbf536a0e7db71723a281ffc802 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Tue, 28 Mar 2023 13:41:58 -0600 Subject: [PATCH 12/37] For real this time, it really works --- src/crud.py | 2 +- src/main.py | 9 +++++---- src/schemas/users.py | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/crud.py b/src/crud.py index 8255e2d..6be16ce 100644 --- a/src/crud.py +++ b/src/crud.py @@ -16,7 +16,7 @@ def get_users(db: Session): def create_user(db: Session, user: schemas.users.UserCreate): - db_user = models.User(username=user.username, hashed_password=user.hashed_password) + db_user = models.User(username=user.username, hashed_password=user.password) db.add(db_user) db.commit() db.refresh(db_user) diff --git a/src/main.py b/src/main.py index 2ebd2e3..a5a10a0 100644 --- a/src/main.py +++ b/src/main.py @@ -2,7 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException from sqlalchemy.orm import Session from models import Base from database import SessionLocal, engine -from schemas.users import UserCreate, User +from schemas import users from typing import List import crud @@ -25,17 +25,18 @@ async def root(): print("Hello world") return {"message": "Hello world!"} -@app.post("/users", response_model=User) -async def create_user(user: UserCreate, db: Session = Depends(get_db)): +@app.post("/users", response_model=users.User) +async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): db_user = crud.get_user_by_username(db, username=user.username) if db_user: raise HTTPException(status_code=400, detail="Email already registered") return crud.create_user(db=db, user=user) -@app.get("/users")#, response_model=List[User]) +@app.get("/users", response_model=List[users.User]) async def read_users(db: Session = Depends(get_db)): print("here") users = crud.get_users(db) + print(users[0].high_scores) return users diff --git a/src/schemas/users.py b/src/schemas/users.py index 3322c43..cbcaa6d 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -7,8 +7,8 @@ class User(BaseModel): username: str hashed_password: str - high_scores: List[int] = [] - course_progresses: List[int] = [] + high_score_ids: List[int] = [] + course_progress_ids: List[int] = [] class Config: orm_mode = True From 6e39852f497ba98ff8d82861ccafc6e1abd3658c Mon Sep 17 00:00:00 2001 From: lvrossem Date: Tue, 28 Mar 2023 14:19:42 -0600 Subject: [PATCH 13/37] Add formatting tools --- format.sh | 3 +++ requirements.txt | 4 ++++ src/crud.py | 7 ++++--- src/database.py | 2 +- src/enums.py | 14 +++++++------- src/main.py | 15 ++++++++------- src/models.py | 21 ++++++++++++--------- src/schemas/courseprogress.py | 1 + src/schemas/highscores.py | 1 + src/schemas/users.py | 3 ++- 10 files changed, 43 insertions(+), 28 deletions(-) create mode 100644 format.sh diff --git a/format.sh b/format.sh new file mode 100644 index 0000000..3c6cc0e --- /dev/null +++ b/format.sh @@ -0,0 +1,3 @@ +flake8 src/* +black src/* +isort src/* diff --git a/requirements.txt b/requirements.txt index 36b6aad..78d704c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,7 @@ sqlalchemy uvicorn[standard] psycopg2-binary fastapi_utils +flake8 +black +isort +interrogate \ No newline at end of file diff --git a/src/crud.py b/src/crud.py index 6be16ce..c6263ca 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,7 +1,9 @@ from sqlalchemy.orm import Session -from models import User, HighScore, CourseProgress -import schemas, models +import models +import schemas +from models import CourseProgress, HighScore, User + def get_user(db: Session, user_id: int): return db.query(User).filter(User.user_id == user_id).first() @@ -21,4 +23,3 @@ def create_user(db: Session, user: schemas.users.UserCreate): db.commit() db.refresh(db_user) return db_user - diff --git a/src/database.py b/src/database.py index 9be44c1..6417662 100644 --- a/src/database.py +++ b/src/database.py @@ -3,7 +3,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker # SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" -SQLALCHEMY_DATABASE_URL = 'postgresql://admin:WeSign123!@localhost/wesigndev' +SQLALCHEMY_DATABASE_URL = "postgresql://admin:WeSign123!@localhost/wesigndev" engine = create_engine(SQLALCHEMY_DATABASE_URL) diff --git a/src/enums.py b/src/enums.py index 01e323e..c15e742 100644 --- a/src/enums.py +++ b/src/enums.py @@ -1,5 +1,5 @@ from fastapi_utils.enums import StrEnum -from sqlalchemy.types import TypeDecorator, Enum +from sqlalchemy.types import Enum, TypeDecorator class StrEnumType(TypeDecorator): @@ -18,14 +18,14 @@ class StrEnumType(TypeDecorator): if value is None: return None return self.enum_class(value) - + class MinigameEnum(StrEnum): - SpellingBee = 'SpellingBee' - Hangman = 'Hangman' - JustSign = 'JustSign' + SpellingBee = "SpellingBee" + Hangman = "Hangman" + JustSign = "JustSign" class CourseEnum(StrEnum): - Fingerspelling = 'Fingerspelling' - Animals = 'Animals' + Fingerspelling = "Fingerspelling" + Animals = "Animals" diff --git a/src/main.py b/src/main.py index a5a10a0..2856f02 100644 --- a/src/main.py +++ b/src/main.py @@ -1,15 +1,18 @@ +from typing import List + from fastapi import Depends, FastAPI, HTTPException from sqlalchemy.orm import Session -from models import Base -from database import SessionLocal, engine -from schemas import users -from typing import List + import crud +from database import SessionLocal, engine +from models import Base +from schemas import users app = FastAPI() Base.metadata.create_all(bind=engine) + def get_db(): db = SessionLocal() try: @@ -25,6 +28,7 @@ async def root(): print("Hello world") return {"message": "Hello world!"} + @app.post("/users", response_model=users.User) async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): db_user = crud.get_user_by_username(db, username=user.username) @@ -35,8 +39,5 @@ async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): @app.get("/users", response_model=List[users.User]) async def read_users(db: Session = Depends(get_db)): - print("here") users = crud.get_users(db) - print(users[0].high_scores) return users - diff --git a/src/models.py b/src/models.py index 898c2cd..9781d5b 100644 --- a/src/models.py +++ b/src/models.py @@ -1,10 +1,11 @@ -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Float -from sqlalchemy.orm import relationship -from sqlalchemy.dialects.postgresql import ARRAY from enum import Enum -from enums import MinigameEnum, CourseEnum, StrEnumType +from sqlalchemy import Boolean, Column, Float, ForeignKey, Integer, String +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.orm import relationship + from database import Base +from enums import CourseEnum, MinigameEnum, StrEnumType class User(Base): @@ -14,9 +15,13 @@ class User(Base): username = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) - high_scores = relationship("HighScore", back_populates="owner", cascade="all, delete", lazy="dynamic") - course_progress = relationship("CourseProgress", back_populates="owner", cascade="all, delete", lazy="dynamic") - + high_scores = relationship( + "HighScore", back_populates="owner", cascade="all, delete", lazy="dynamic" + ) + course_progress = relationship( + "CourseProgress", back_populates="owner", cascade="all, delete", lazy="dynamic" + ) + # add a new column to store the high_score IDs high_score_ids = Column(ARRAY(Integer), default=[]) @@ -34,7 +39,6 @@ class HighScore(Base): owner = relationship("User", back_populates="high_scores") - class CourseProgress(Base): __tablename__ = "course_progress" @@ -43,4 +47,3 @@ class CourseProgress(Base): course = Column(StrEnumType(CourseEnum), nullable=False) owner_id = Column(Integer, ForeignKey("users.user_id")) owner = relationship("User", back_populates="course_progress") - diff --git a/src/schemas/courseprogress.py b/src/schemas/courseprogress.py index 8f42a4a..994001f 100644 --- a/src/schemas/courseprogress.py +++ b/src/schemas/courseprogress.py @@ -1,4 +1,5 @@ from pydantic import BaseModel + from enums import CourseEnum diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index 726e647..efe15de 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -1,4 +1,5 @@ from pydantic import BaseModel + from src.enums import MinigameEnum diff --git a/src/schemas/users.py b/src/schemas/users.py index cbcaa6d..f0cfda7 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel from typing import List +from pydantic import BaseModel + class User(BaseModel): user_id: int From d1c17389172813cf98ea85ef1a80c0e94ef363be Mon Sep 17 00:00:00 2001 From: lvrossem Date: Tue, 28 Mar 2023 15:12:11 -0600 Subject: [PATCH 14/37] Some formatting and cleanup --- src/crud.py | 7 +++---- src/database.py | 1 - src/main.py | 5 +---- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/crud.py b/src/crud.py index c6263ca..6cf1849 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,8 +1,7 @@ from sqlalchemy.orm import Session -import models -import schemas from models import CourseProgress, HighScore, User +from schemas.users import UserCreate def get_user(db: Session, user_id: int): @@ -17,8 +16,8 @@ def get_users(db: Session): return db.query(User).all() -def create_user(db: Session, user: schemas.users.UserCreate): - db_user = models.User(username=user.username, hashed_password=user.password) +def create_user(db: Session, user: UserCreate): + db_user = User(username=user.username, hashed_password=user.password) db.add(db_user) db.commit() db.refresh(db_user) diff --git a/src/database.py b/src/database.py index 6417662..533adf5 100644 --- a/src/database.py +++ b/src/database.py @@ -2,7 +2,6 @@ from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -# SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" SQLALCHEMY_DATABASE_URL = "postgresql://admin:WeSign123!@localhost/wesigndev" engine = create_engine(SQLALCHEMY_DATABASE_URL) diff --git a/src/main.py b/src/main.py index 2856f02..83cc2f1 100644 --- a/src/main.py +++ b/src/main.py @@ -16,16 +16,13 @@ Base.metadata.create_all(bind=engine) def get_db(): db = SessionLocal() try: - print("Yield") yield db finally: - print("Close") db.close() @app.get("/") async def root(): - print("Hello world") return {"message": "Hello world!"} @@ -33,7 +30,7 @@ async def root(): async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): db_user = crud.get_user_by_username(db, username=user.username) if db_user: - raise HTTPException(status_code=400, detail="Email already registered") + raise HTTPException(status_code=400, detail="Username already registered") return crud.create_user(db=db, user=user) From fa543b19e7670f91cc2916663d89eaa8e08cbd24 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Tue, 28 Mar 2023 16:01:15 -0600 Subject: [PATCH 15/37] Get started on high score creation --- src/crud.py | 28 +++++++++++++++++++++++++++- src/main.py | 23 ++++++++++++++++++----- src/schemas/highscores.py | 13 ++++++++++--- src/schemas/users.py | 10 ++++++---- 4 files changed, 61 insertions(+), 13 deletions(-) diff --git a/src/crud.py b/src/crud.py index 6cf1849..104409e 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,10 +1,11 @@ from sqlalchemy.orm import Session from models import CourseProgress, HighScore, User +from schemas.highscores import HighScoreCreate from schemas.users import UserCreate -def get_user(db: Session, user_id: int): +def get_user_by_id(db: Session, user_id: int): return db.query(User).filter(User.user_id == user_id).first() @@ -22,3 +23,28 @@ def create_user(db: Session, user: UserCreate): db.commit() db.refresh(db_user) return db_user + + +def get_high_scores(db: Session): + return db.query(HighScore).all() + + +def create_high_score(db: Session, high_score: HighScoreCreate): + db_high_score = HighScore( + score_value=high_score.score_value, + minigame=high_score.minigame, + owner_id=high_score.owner_id, + ) + db.add(db_high_score) + db.commit() + db.refresh(db_high_score) + owner = get_user_by_id(db, high_score.owner_id) + print("ID IS " + str(db_high_score.high_score_id) + " " + owner.username) + owner.high_score_ids.append(db_high_score.high_score_id) + owner.high_scores.append(db_high_score) + print("LIST OF IDS: " + str(owner.high_score_ids)) + db.commit() + db.refresh(owner) + owner2 = db.query(User).filter(User.user_id == high_score.owner_id).first() + print("LIST OF IDS: " + str(owner2.high_score_ids)) + return db_high_score diff --git a/src/main.py b/src/main.py index 83cc2f1..89a1243 100644 --- a/src/main.py +++ b/src/main.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session import crud from database import SessionLocal, engine from models import Base -from schemas import users +from schemas import highscores, users app = FastAPI() @@ -26,6 +26,12 @@ async def root(): return {"message": "Hello world!"} +@app.get("/users", response_model=List[users.User]) +async def read_users(db: Session = Depends(get_db)): + users = crud.get_users(db) + return users + + @app.post("/users", response_model=users.User) async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): db_user = crud.get_user_by_username(db, username=user.username) @@ -34,7 +40,14 @@ async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): return crud.create_user(db=db, user=user) -@app.get("/users", response_model=List[users.User]) -async def read_users(db: Session = Depends(get_db)): - users = crud.get_users(db) - return users +@app.get("/highscores", response_model=List[highscores.HighScore]) +async def read_high_scores(db: Session = Depends(get_db)): + high_scores = crud.get_high_scores(db) + return high_scores + + +@app.post("/highscores", response_model=highscores.HighScore) +async def create_high_score( + high_score: highscores.HighScoreCreate, db: Session = Depends(get_db) +): + return crud.create_high_score(db=db, high_score=high_score) diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index efe15de..afec43f 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -1,13 +1,20 @@ from pydantic import BaseModel -from src.enums import MinigameEnum +from enums import MinigameEnum -class HighScore(BaseModel): - high_score_id: int +class HighScoreBase(BaseModel): score_value: float minigame: MinigameEnum owner_id: int + +class HighScoreCreate(HighScoreBase): + pass + + +class HighScore(HighScoreBase): + high_score_id: int + class Config: orm_mode = True diff --git a/src/schemas/users.py b/src/schemas/users.py index f0cfda7..7543434 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -3,9 +3,12 @@ from typing import List from pydantic import BaseModel -class User(BaseModel): - user_id: int +class UserBase(BaseModel): username: str + + +class User(UserBase): + user_id: int hashed_password: str high_score_ids: List[int] = [] @@ -15,6 +18,5 @@ class User(BaseModel): orm_mode = True -class UserCreate(BaseModel): - username: str +class UserCreate(UserBase): password: str From 41e6be454a95b0f8cb6a4a357bce8dbcff3f404f Mon Sep 17 00:00:00 2001 From: lvrossem Date: Wed, 29 Mar 2023 11:24:25 -0600 Subject: [PATCH 16/37] Work on high score endpoints --- src/crud.py | 93 ++++++++++++++++++++++++++++++--------- src/main.py | 9 ++-- src/schemas/highscores.py | 2 + src/schemas/users.py | 4 ++ 4 files changed, 83 insertions(+), 25 deletions(-) diff --git a/src/crud.py b/src/crud.py index 104409e..bfb61ca 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,9 +1,13 @@ from sqlalchemy.orm import Session +from sqlalchemy.dialects.postgresql import array from models import CourseProgress, HighScore, User +from enums import MinigameEnum from schemas.highscores import HighScoreCreate -from schemas.users import UserCreate +from schemas.users import UserCreate, UserHighScore +from sqlalchemy import desc +DEFAULT_NR_HIGH_SCORES = 10 def get_user_by_id(db: Session, user_id: int): return db.query(User).filter(User.user_id == user_id).first() @@ -25,26 +29,73 @@ def create_user(db: Session, user: UserCreate): return db_user -def get_high_scores(db: Session): - return db.query(HighScore).all() +def get_high_scores(db: Session, minigame: MinigameEnum, n_highest: int): + + user_high_scores = [] + + if not n_highest: + n_highest = DEFAULT_NR_HIGH_SCORES + + if not minigame: + minigame = MinigameEnum.SpellingBee + + high_scores = db.query(HighScore).filter(HighScore.minigame == minigame).order_by(desc(HighScore.score_value)).limit(n_highest).all() + for high_score in high_scores: + owner = db.query(User).filter(User.user_id == high_score.owner_id).first() + user_high_scores.append(UserHighScore(username = owner.username, score_value = high_score.score_value)) + return user_high_scores def create_high_score(db: Session, high_score: HighScoreCreate): - db_high_score = HighScore( - score_value=high_score.score_value, - minigame=high_score.minigame, - owner_id=high_score.owner_id, - ) - db.add(db_high_score) - db.commit() - db.refresh(db_high_score) - owner = get_user_by_id(db, high_score.owner_id) - print("ID IS " + str(db_high_score.high_score_id) + " " + owner.username) - owner.high_score_ids.append(db_high_score.high_score_id) - owner.high_scores.append(db_high_score) - print("LIST OF IDS: " + str(owner.high_score_ids)) - db.commit() - db.refresh(owner) - owner2 = db.query(User).filter(User.user_id == high_score.owner_id).first() - print("LIST OF IDS: " + str(owner2.high_score_ids)) - return db_high_score + old_high_score = db.query(HighScore).filter(HighScore.owner_id == high_score.owner_id, HighScore.minigame == high_score.minigame).first() + if old_high_score: + print("Older high score found") + print(old_high_score.minigame) + if old_high_score.score_value < high_score.score_value: + print("Older score is lower") + db_high_score = HighScore( + score_value=high_score.score_value, + minigame=high_score.minigame, + owner_id=high_score.owner_id, + ) + db.delete(old_high_score) + db.add(db_high_score) + db.commit() + db.refresh(db_high_score) + return db_high_score + else: + print("Older score is higher") + return old_high_score + else: + db_high_score = HighScore( + score_value=high_score.score_value, + minigame=high_score.minigame, + owner_id=high_score.owner_id, + ) + db.add(db_high_score) + db.commit() + db.refresh(db_high_score) + return db_high_score + +""" +db_high_score = HighScore( + score_value=high_score.score_value, + minigame=high_score.minigame, + owner_id=high_score.owner_id, +) +db.add(db_high_score) +db.commit() +db.refresh(db_high_score) +owner = db.query(User).filter(User.user_id == high_score.owner_id).first() +print("ID IS " + str(db_high_score.high_score_id) + " " + owner.username) +#owner.high_score_ids.append(db_high_score.high_score_id) +#db.add(owner) +#db.commit() +#owner.high_scores.append(db_high_score) +print("LIST OF IDS: " + str(owner.high_score_ids)) +#db.flush() +#db.refresh(owner) +#owner2 = db.query(User).filter(User.user_id == high_score.owner_id).first() +#print("LIST OF IDS: " + str(owner2.high_score_ids)) +return db_high_score +""" diff --git a/src/main.py b/src/main.py index 89a1243..2782dde 100644 --- a/src/main.py +++ b/src/main.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from fastapi import Depends, FastAPI, HTTPException from sqlalchemy.orm import Session @@ -7,6 +7,7 @@ import crud from database import SessionLocal, engine from models import Base from schemas import highscores, users +from enums import MinigameEnum app = FastAPI() @@ -40,9 +41,9 @@ async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): return crud.create_user(db=db, user=user) -@app.get("/highscores", response_model=List[highscores.HighScore]) -async def read_high_scores(db: Session = Depends(get_db)): - high_scores = crud.get_high_scores(db) +@app.get("/highscores", response_model=List[users.UserHighScore]) +async def read_high_scores(db: Session = Depends(get_db), minigame: Optional[MinigameEnum] = None, n_highest: Optional[int] = None): + high_scores = crud.get_high_scores(db, minigame, n_highest) return high_scores diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index afec43f..45a8578 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -18,3 +18,5 @@ class HighScore(HighScoreBase): class Config: orm_mode = True + + diff --git a/src/schemas/users.py b/src/schemas/users.py index 7543434..eed0d7d 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -20,3 +20,7 @@ class User(UserBase): class UserCreate(UserBase): password: str + + +class UserHighScore(UserBase): + score_value: float From 46a3f5858df6853ea54b3bd3889aaa15b15e6707 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Wed, 29 Mar 2023 12:39:32 -0600 Subject: [PATCH 17/37] Throw error at POST /highscores if user does not exist --- src/crud.py | 55 ++++++++++++++++++--------------------- src/main.py | 8 ++++-- src/schemas/highscores.py | 2 -- 3 files changed, 32 insertions(+), 33 deletions(-) diff --git a/src/crud.py b/src/crud.py index bfb61ca..5b0b664 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,14 +1,16 @@ -from sqlalchemy.orm import Session +from sqlalchemy import desc from sqlalchemy.dialects.postgresql import array +from sqlalchemy.orm import Session +from fastapi import HTTPException -from models import CourseProgress, HighScore, User from enums import MinigameEnum +from models import CourseProgress, HighScore, User from schemas.highscores import HighScoreCreate from schemas.users import UserCreate, UserHighScore -from sqlalchemy import desc DEFAULT_NR_HIGH_SCORES = 10 + def get_user_by_id(db: Session, user_id: int): return db.query(User).filter(User.user_id == user_id).first() @@ -30,7 +32,6 @@ def create_user(db: Session, user: UserCreate): def get_high_scores(db: Session, minigame: MinigameEnum, n_highest: int): - user_high_scores = [] if not n_highest: @@ -39,15 +40,33 @@ def get_high_scores(db: Session, minigame: MinigameEnum, n_highest: int): if not minigame: minigame = MinigameEnum.SpellingBee - high_scores = db.query(HighScore).filter(HighScore.minigame == minigame).order_by(desc(HighScore.score_value)).limit(n_highest).all() + high_scores = ( + db.query(HighScore) + .filter(HighScore.minigame == minigame) + .order_by(desc(HighScore.score_value)) + .limit(n_highest) + .all() + ) for high_score in high_scores: owner = db.query(User).filter(User.user_id == high_score.owner_id).first() - user_high_scores.append(UserHighScore(username = owner.username, score_value = high_score.score_value)) + user_high_scores.append( + UserHighScore(username=owner.username, score_value=high_score.score_value) + ) return user_high_scores def create_high_score(db: Session, high_score: HighScoreCreate): - old_high_score = db.query(HighScore).filter(HighScore.owner_id == high_score.owner_id, HighScore.minigame == high_score.minigame).first() + owner = db.query(User).filter(User.user_id == high_score.owner_id).first() + if not owner: + raise HTTPException(status_code=400, detail="User does not exist") + old_high_score = ( + db.query(HighScore) + .filter( + HighScore.owner_id == high_score.owner_id, + HighScore.minigame == high_score.minigame, + ) + .first() + ) if old_high_score: print("Older high score found") print(old_high_score.minigame) @@ -77,25 +96,3 @@ def create_high_score(db: Session, high_score: HighScoreCreate): db.refresh(db_high_score) return db_high_score -""" -db_high_score = HighScore( - score_value=high_score.score_value, - minigame=high_score.minigame, - owner_id=high_score.owner_id, -) -db.add(db_high_score) -db.commit() -db.refresh(db_high_score) -owner = db.query(User).filter(User.user_id == high_score.owner_id).first() -print("ID IS " + str(db_high_score.high_score_id) + " " + owner.username) -#owner.high_score_ids.append(db_high_score.high_score_id) -#db.add(owner) -#db.commit() -#owner.high_scores.append(db_high_score) -print("LIST OF IDS: " + str(owner.high_score_ids)) -#db.flush() -#db.refresh(owner) -#owner2 = db.query(User).filter(User.user_id == high_score.owner_id).first() -#print("LIST OF IDS: " + str(owner2.high_score_ids)) -return db_high_score -""" diff --git a/src/main.py b/src/main.py index 2782dde..c537154 100644 --- a/src/main.py +++ b/src/main.py @@ -5,9 +5,9 @@ from sqlalchemy.orm import Session import crud from database import SessionLocal, engine +from enums import MinigameEnum from models import Base from schemas import highscores, users -from enums import MinigameEnum app = FastAPI() @@ -42,7 +42,11 @@ async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): @app.get("/highscores", response_model=List[users.UserHighScore]) -async def read_high_scores(db: Session = Depends(get_db), minigame: Optional[MinigameEnum] = None, n_highest: Optional[int] = None): +async def read_high_scores( + db: Session = Depends(get_db), + minigame: Optional[MinigameEnum] = None, + n_highest: Optional[int] = None, +): high_scores = crud.get_high_scores(db, minigame, n_highest) return high_scores diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index 45a8578..afec43f 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -18,5 +18,3 @@ class HighScore(HighScoreBase): class Config: orm_mode = True - - From 8ca636c48efe049e3530aa4f80b811555845aa32 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Wed, 29 Mar 2023 14:14:27 -0600 Subject: [PATCH 18/37] Cleanup --- src/crud.py | 17 ++++++----------- src/enums.py | 4 ++++ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/crud.py b/src/crud.py index 5b0b664..fd88a52 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,7 +1,7 @@ +from fastapi import HTTPException from sqlalchemy import desc from sqlalchemy.dialects.postgresql import array from sqlalchemy.orm import Session -from fastapi import HTTPException from enums import MinigameEnum from models import CourseProgress, HighScore, User @@ -31,11 +31,11 @@ def create_user(db: Session, user: UserCreate): return db_user -def get_high_scores(db: Session, minigame: MinigameEnum, n_highest: int): +def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): user_high_scores = [] - if not n_highest: - n_highest = DEFAULT_NR_HIGH_SCORES + if not nr_highest: + nr_highest = DEFAULT_NR_HIGH_SCORES if not minigame: minigame = MinigameEnum.SpellingBee @@ -44,7 +44,7 @@ def get_high_scores(db: Session, minigame: MinigameEnum, n_highest: int): db.query(HighScore) .filter(HighScore.minigame == minigame) .order_by(desc(HighScore.score_value)) - .limit(n_highest) + .limit(nr_highest) .all() ) for high_score in high_scores: @@ -56,7 +56,7 @@ def get_high_scores(db: Session, minigame: MinigameEnum, n_highest: int): def create_high_score(db: Session, high_score: HighScoreCreate): - owner = db.query(User).filter(User.user_id == high_score.owner_id).first() + owner = db.query(User).filter(User.user_id == high_score.owner_id).first() if not owner: raise HTTPException(status_code=400, detail="User does not exist") old_high_score = ( @@ -68,10 +68,7 @@ def create_high_score(db: Session, high_score: HighScoreCreate): .first() ) if old_high_score: - print("Older high score found") - print(old_high_score.minigame) if old_high_score.score_value < high_score.score_value: - print("Older score is lower") db_high_score = HighScore( score_value=high_score.score_value, minigame=high_score.minigame, @@ -83,7 +80,6 @@ def create_high_score(db: Session, high_score: HighScoreCreate): db.refresh(db_high_score) return db_high_score else: - print("Older score is higher") return old_high_score else: db_high_score = HighScore( @@ -95,4 +91,3 @@ def create_high_score(db: Session, high_score: HighScoreCreate): db.commit() db.refresh(db_high_score) return db_high_score - diff --git a/src/enums.py b/src/enums.py index c15e742..11e4da0 100644 --- a/src/enums.py +++ b/src/enums.py @@ -28,4 +28,8 @@ class MinigameEnum(StrEnum): class CourseEnum(StrEnum): Fingerspelling = "Fingerspelling" + Basics = "Basics" + Hobbies = "Hobbies" Animals = "Animals" + Colors = "Colors" + FruitsVegetables = "FruitsVegetables" From dcff7e4f3577448da24a77e9d33bb4e8382a4adc Mon Sep 17 00:00:00 2001 From: lvrossem Date: Wed, 29 Mar 2023 15:03:03 -0600 Subject: [PATCH 19/37] Reformatting --- src/crud.py | 1 - src/models.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/crud.py b/src/crud.py index fd88a52..421c8a0 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,6 +1,5 @@ from fastapi import HTTPException from sqlalchemy import desc -from sqlalchemy.dialects.postgresql import array from sqlalchemy.orm import Session from enums import MinigameEnum diff --git a/src/models.py b/src/models.py index 9781d5b..6c7f3c6 100644 --- a/src/models.py +++ b/src/models.py @@ -1,6 +1,4 @@ -from enum import Enum - -from sqlalchemy import Boolean, Column, Float, ForeignKey, Integer, String +from sqlalchemy import Column, Float, ForeignKey, Integer, String from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.orm import relationship From 3e12125c09c179e88a9378f4dbf175c0dc37c394 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Thu, 30 Mar 2023 07:22:30 -0600 Subject: [PATCH 20/37] First authentication prototype --- requirements.txt | 6 ++++- src/crud.py | 4 +-- src/main.py | 63 +++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 78d704c..7dde8fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,8 @@ fastapi_utils flake8 black isort -interrogate \ No newline at end of file +interrogate +python-jose[cryptography] +passlib +jwt +PyJWT \ No newline at end of file diff --git a/src/crud.py b/src/crud.py index 421c8a0..870fe74 100644 --- a/src/crud.py +++ b/src/crud.py @@ -22,8 +22,8 @@ def get_users(db: Session): return db.query(User).all() -def create_user(db: Session, user: UserCreate): - db_user = User(username=user.username, hashed_password=user.password) +def create_user(db: Session, username: str, hashed_password: str): + db_user = User(username=username, hashed_password=hashed_password) db.add(db_user) db.commit() db.refresh(db_user) diff --git a/src/main.py b/src/main.py index c537154..bb907f1 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,10 @@ +from datetime import datetime, timedelta from typing import List, Optional +import jwt from fastapi import Depends, FastAPI, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from passlib.context import CryptContext from sqlalchemy.orm import Session import crud @@ -14,6 +18,16 @@ app = FastAPI() Base.metadata.create_all(bind=engine) +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# JWT authentication setup +jwt_secret = "secret_key" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +bearer_scheme = HTTPBearer() + + def get_db(): db = SessionLocal() try: @@ -38,7 +52,9 @@ async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): db_user = crud.get_user_by_username(db, username=user.username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") - return crud.create_user(db=db, user=user) + return crud.create_user( + db=db, username=user.username, hashed_password=pwd_context.hash(user.password) + ) @app.get("/highscores", response_model=List[users.UserHighScore]) @@ -56,3 +72,48 @@ async def create_high_score( high_score: highscores.HighScoreCreate, db: Session = Depends(get_db) ): return crud.create_high_score(db=db, high_score=high_score) + + +#### TESTING!! DELETE LATER + + +async def get_current_user( + token: HTTPAuthorizationCredentials = Depends(bearer_scheme), +): + try: + payload = jwt.decode(token.credentials, jwt_secret, algorithms=[ALGORITHM]) + username = payload.get("sub") + if username is None: + raise HTTPException(status_code=401, detail="Invalid JWT token") + return username + except jwt.exceptions.DecodeError: + raise HTTPException(status_code=401, detail="Invalid JWT token") + + +@app.get("/protected") +async def protected_route(current_user=Depends(get_current_user)): + return {"message": f"Hello, {current_user}!"} + + +def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)): + db_user = crud.get_user_by_username(db=db, username=user.username) + if not db_user: + return False + hashed_password = db_user.hashed_password + if not hashed_password or not pwd_context.verify(user.password, hashed_password): + return False + return db_user + + +@app.post("/login") +async def login(user: users.UserCreate, db: Session = Depends(get_db)): + user = authenticate_user(user, db) + if not user: + raise HTTPException(status_code=401, detail="Invalid username or password") + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_payload = { + "sub": user.username, + "exp": datetime.utcnow() + access_token_expires, + } + access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) + return {"access_token": access_token} From 849d2018f9c4b428098e9a2320cf86194932eb33 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Thu, 30 Mar 2023 08:02:46 -0600 Subject: [PATCH 21/37] Start working on course progress endpoints --- src/crud.py | 12 +++++++++++- src/enums.py | 1 + src/main.py | 12 +++++++++--- src/schemas/courseprogress.py | 9 +++++---- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/crud.py b/src/crud.py index 870fe74..3e2ff10 100644 --- a/src/crud.py +++ b/src/crud.py @@ -2,10 +2,11 @@ from fastapi import HTTPException from sqlalchemy import desc from sqlalchemy.orm import Session -from enums import MinigameEnum +from enums import MinigameEnum, CourseEnum from models import CourseProgress, HighScore, User from schemas.highscores import HighScoreCreate from schemas.users import UserCreate, UserHighScore +from schemas.courseprogress import CourseProgressBase DEFAULT_NR_HIGH_SCORES = 10 @@ -90,3 +91,12 @@ def create_high_score(db: Session, high_score: HighScoreCreate): db.commit() db.refresh(db_high_score) return db_high_score + +def get_course_progress(db: Session, user: User, course: CourseEnum): + if course != CourseEnum.All: + course_progress = db.query(CourseProgress).filter(CourseProgress.owner_id == user.user_id, CourseProgress.course == course).first() + if course_progress: + return [CourseProgressBase(progress_value = course_progress.progress_value, course = course_progress.course)] + else: + return [CourseProgressBase(progress_value = 0, course = course)] + return [] diff --git a/src/enums.py b/src/enums.py index 11e4da0..3ec4085 100644 --- a/src/enums.py +++ b/src/enums.py @@ -33,3 +33,4 @@ class CourseEnum(StrEnum): Animals = "Animals" Colors = "Colors" FruitsVegetables = "FruitsVegetables" + All = "All" diff --git a/src/main.py b/src/main.py index bb907f1..55558ed 100644 --- a/src/main.py +++ b/src/main.py @@ -9,9 +9,9 @@ from sqlalchemy.orm import Session import crud from database import SessionLocal, engine -from enums import MinigameEnum +from enums import MinigameEnum, CourseEnum from models import Base -from schemas import highscores, users +from schemas import highscores, users, courseprogress app = FastAPI() @@ -91,7 +91,7 @@ async def get_current_user( @app.get("/protected") -async def protected_route(current_user=Depends(get_current_user)): +async def protected_route(current_user = Depends(get_current_user)): return {"message": f"Hello, {current_user}!"} @@ -117,3 +117,9 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): } access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) return {"access_token": access_token} + + +@app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) +async def get_course_progress(course: Optional[CourseEnum] = CourseEnum.All, current_user = Depends(get_current_user), db: Session = Depends(get_db)): + user = crud.get_user_by_username(db, current_user) + return crud.get_course_progress(db = db, user = user, course = course) diff --git a/src/schemas/courseprogress.py b/src/schemas/courseprogress.py index 994001f..d46486e 100644 --- a/src/schemas/courseprogress.py +++ b/src/schemas/courseprogress.py @@ -2,12 +2,13 @@ from pydantic import BaseModel from enums import CourseEnum - -class CourseProgress(BaseModel): - course_progress_id: int +class CourseProgressBase(BaseModel): progress_value: float course: CourseEnum - owner: int + +class CourseProgress(CourseProgressBase): + course_progress_id: int + owner_id: int class Config: orm_mode = True From 5fe168937f377a743797c923a05163d50989fa3f Mon Sep 17 00:00:00 2001 From: lvrossem Date: Fri, 31 Mar 2023 05:07:07 -0600 Subject: [PATCH 22/37] Add lots of comments --- src/crud.py | 7 +++++++ src/main.py | 53 ++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/crud.py b/src/crud.py index 3e2ff10..6e1e59f 100644 --- a/src/crud.py +++ b/src/crud.py @@ -12,18 +12,22 @@ DEFAULT_NR_HIGH_SCORES = 10 def get_user_by_id(db: Session, user_id: int): + """ Fetches a User from the database by their id """ return db.query(User).filter(User.user_id == user_id).first() def get_user_by_username(db: Session, username: str): + """ Fetches a User from the database by their username """ return db.query(User).filter(User.username == username).first() def get_users(db: Session): + """ Fetch a list of all users """ return db.query(User).all() def create_user(db: Session, username: str, hashed_password: str): + """ Create a new user """ db_user = User(username=username, hashed_password=hashed_password) db.add(db_user) db.commit() @@ -32,6 +36,7 @@ def create_user(db: Session, username: str, hashed_password: str): def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): + """ Get the n highest scores of a given minigame """ user_high_scores = [] if not nr_highest: @@ -56,6 +61,7 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): def create_high_score(db: Session, high_score: HighScoreCreate): + """ Create a new high score for a given minigame """ owner = db.query(User).filter(User.user_id == high_score.owner_id).first() if not owner: raise HTTPException(status_code=400, detail="User does not exist") @@ -93,6 +99,7 @@ def create_high_score(db: Session, high_score: HighScoreCreate): return db_high_score def get_course_progress(db: Session, user: User, course: CourseEnum): + """ Get the progress a user has for a certain course """ if course != CourseEnum.All: course_progress = db.query(CourseProgress).filter(CourseProgress.owner_id == user.user_id, CourseProgress.course == course).first() if course_progress: diff --git a/src/main.py b/src/main.py index 55558ed..bb88ecd 100644 --- a/src/main.py +++ b/src/main.py @@ -23,7 +23,7 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # JWT authentication setup jwt_secret = "secret_key" ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 +ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month bearer_scheme = HTTPBearer() @@ -36,6 +36,18 @@ def get_db(): db.close() +def get_current_user( + token: HTTPAuthorizationCredentials = Depends(bearer_scheme), +): + try: + payload = jwt.decode(token.credentials, jwt_secret, algorithms=[ALGORITHM]) + username = payload.get("sub") + if username is None: + raise HTTPException(status_code=401, detail="Invalid JWT token") + return username + except jwt.exceptions.DecodeError: + raise HTTPException(status_code=401, detail="Invalid JWT token") + @app.get("/") async def root(): return {"message": "Hello world!"} @@ -47,8 +59,14 @@ async def read_users(db: Session = Depends(get_db)): return users -@app.post("/users", response_model=users.User) -async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): +@app.patch("/users") +async def patch_current_user(user: users.UserCreate, current_user = Depends(get_current_user), db: Session = Depends(get_db)): + db_user = crud.get_us + return users + + +@app.post("/register", response_model=users.User) +async def register(user: users.UserCreate, db: Session = Depends(get_db)): db_user = crud.get_user_by_username(db, username=user.username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") @@ -57,12 +75,28 @@ async def create_user(user: users.UserCreate, db: Session = Depends(get_db)): ) +@app.post("/login") +async def login(user: users.UserCreate, db: Session = Depends(get_db)): + user = authenticate_user(user, db) + if not user: + raise HTTPException(status_code=401, detail="Invalid username or password") + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_payload = { + "sub": user.username, + "exp": datetime.utcnow() + access_token_expires, + } + access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) + return {"access_token": access_token} + + @app.get("/highscores", response_model=List[users.UserHighScore]) async def read_high_scores( db: Session = Depends(get_db), minigame: Optional[MinigameEnum] = None, n_highest: Optional[int] = None, ): + if n_highest < 1: + raise HTTPException(status_code=400, detail="Invalid number of high scores") high_scores = crud.get_high_scores(db, minigame, n_highest) return high_scores @@ -105,18 +139,7 @@ def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)): return db_user -@app.post("/login") -async def login(user: users.UserCreate, db: Session = Depends(get_db)): - user = authenticate_user(user, db) - if not user: - raise HTTPException(status_code=401, detail="Invalid username or password") - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token_payload = { - "sub": user.username, - "exp": datetime.utcnow() + access_token_expires, - } - access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) - return {"access_token": access_token} + @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) From 49f8d7d713a9e0aff0a96f0f315dc97edca8e02e Mon Sep 17 00:00:00 2001 From: lvrossem Date: Fri, 31 Mar 2023 06:16:40 -0600 Subject: [PATCH 23/37] BIG refactors --- src/crud.py | 99 +++++++++++++++++++++++++++++------ src/main.py | 81 +++++++++------------------- src/schemas/courseprogress.py | 2 + 3 files changed, 109 insertions(+), 73 deletions(-) diff --git a/src/crud.py b/src/crud.py index 6e1e59f..1321cb2 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,42 +1,96 @@ from fastapi import HTTPException from sqlalchemy import desc -from sqlalchemy.orm import Session +from datetime import datetime, timedelta -from enums import MinigameEnum, CourseEnum +from sqlalchemy.orm import Session +import jwt + + +from enums import CourseEnum, MinigameEnum from models import CourseProgress, HighScore, User +from schemas.courseprogress import CourseProgressBase from schemas.highscores import HighScoreCreate from schemas.users import UserCreate, UserHighScore -from schemas.courseprogress import CourseProgressBase +from passlib.context import CryptContext + DEFAULT_NR_HIGH_SCORES = 10 +# JWT authentication setup +jwt_secret = "secret_key" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + def get_user_by_id(db: Session, user_id: int): - """ Fetches a User from the database by their id """ + """Fetches a User from the database by their id""" return db.query(User).filter(User.user_id == user_id).first() def get_user_by_username(db: Session, username: str): - """ Fetches a User from the database by their username """ + """Fetches a User from the database by their username""" return db.query(User).filter(User.username == username).first() def get_users(db: Session): - """ Fetch a list of all users """ + """Fetch a list of all users""" return db.query(User).all() -def create_user(db: Session, username: str, hashed_password: str): - """ Create a new user """ - db_user = User(username=username, hashed_password=hashed_password) +def authenticate_user(db: Session, user: UserCreate): + db_user = get_user_by_username(db=db, username=user.username) + if not db_user: + return False + hashed_password = db_user.hashed_password + if not hashed_password or not pwd_context.verify(user.password, hashed_password): + return False + return db_user + + +def register(db: Session, username: str, password: str): + """Register a new user""" + db_user = get_user_by_username(db, username) + if db_user: + raise HTTPException(status_code=400, detail="Username already registered") + db_user = User(username=username, hashed_password=pwd_context.hash(password)) db.add(db_user) db.commit() db.refresh(db_user) return db_user +def login(db: Session, user: UserCreate): + user = authenticate_user(db, user) + if not user: + raise HTTPException(status_code=401, detail="Invalid username or password") + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_payload = { + "sub": user.username, + "exp": datetime.utcnow() + access_token_expires, + } + access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) + return {"access_token": access_token} + + +def patch_user(db: Session, username: str, user: UserCreate): + db_user = get_user_by_username(db, username) + potential_duplicate = get_user_by_username(db, user.username) + if potential_duplicate: + if potential_duplicate.user_id != db_user.user_id: + raise HTTPException(status_code=400, detail="Username already registered") + + db_user.username = user.username + db_user.hashed_password = pwd_context.hash(user.password) + db.commit() + + def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): - """ Get the n highest scores of a given minigame """ + """Get the n highest scores of a given minigame""" + if nr_highest < 1: + raise HTTPException(status_code=400, detail="Invalid number of high scores") + user_high_scores = [] if not nr_highest: @@ -61,7 +115,7 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): def create_high_score(db: Session, high_score: HighScoreCreate): - """ Create a new high score for a given minigame """ + """Create a new high score for a given minigame""" owner = db.query(User).filter(User.user_id == high_score.owner_id).first() if not owner: raise HTTPException(status_code=400, detail="User does not exist") @@ -98,12 +152,25 @@ def create_high_score(db: Session, high_score: HighScoreCreate): db.refresh(db_high_score) return db_high_score -def get_course_progress(db: Session, user: User, course: CourseEnum): - """ Get the progress a user has for a certain course """ + +def get_course_progress(db: Session, username: str, course: CourseEnum): + """Get the progress a user has for a certain course""" + user = get_user_by_username(db, username) if course != CourseEnum.All: - course_progress = db.query(CourseProgress).filter(CourseProgress.owner_id == user.user_id, CourseProgress.course == course).first() + course_progress = ( + db.query(CourseProgress) + .filter( + CourseProgress.owner_id == user.user_id, CourseProgress.course == course + ) + .first() + ) if course_progress: - return [CourseProgressBase(progress_value = course_progress.progress_value, course = course_progress.course)] + return [ + CourseProgressBase( + progress_value=course_progress.progress_value, + course=course_progress.course, + ) + ] else: - return [CourseProgressBase(progress_value = 0, course = course)] + return [CourseProgressBase(progress_value=0, course=course)] return [] diff --git a/src/main.py b/src/main.py index bb88ecd..fb0d013 100644 --- a/src/main.py +++ b/src/main.py @@ -9,22 +9,15 @@ from sqlalchemy.orm import Session import crud from database import SessionLocal, engine -from enums import MinigameEnum, CourseEnum +from enums import CourseEnum, MinigameEnum from models import Base -from schemas import highscores, users, courseprogress +from schemas import courseprogress, highscores, users app = FastAPI() + Base.metadata.create_all(bind=engine) - -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -# JWT authentication setup -jwt_secret = "secret_key" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month - bearer_scheme = HTTPBearer() @@ -35,12 +28,11 @@ def get_db(): finally: db.close() - -def get_current_user( +def get_current_user_name( token: HTTPAuthorizationCredentials = Depends(bearer_scheme), ): try: - payload = jwt.decode(token.credentials, jwt_secret, algorithms=[ALGORITHM]) + payload = jwt.decode(token.credentials, crud.jwt_secret, algorithms=[crud.ALGORITHM]) username = payload.get("sub") if username is None: raise HTTPException(status_code=401, detail="Invalid JWT token") @@ -60,45 +52,33 @@ async def read_users(db: Session = Depends(get_db)): @app.patch("/users") -async def patch_current_user(user: users.UserCreate, current_user = Depends(get_current_user), db: Session = Depends(get_db)): - db_user = crud.get_us - return users +async def patch_current_user( + user: users.UserCreate, + current_user_name = Depends(get_current_user_name), + db: Session = Depends(get_db), +): + crud.patch_user(db, current_user_name, user) @app.post("/register", response_model=users.User) async def register(user: users.UserCreate, db: Session = Depends(get_db)): - db_user = crud.get_user_by_username(db, username=user.username) - if db_user: - raise HTTPException(status_code=400, detail="Username already registered") - return crud.create_user( - db=db, username=user.username, hashed_password=pwd_context.hash(user.password) + return crud.register( + db, user.username, user.password ) @app.post("/login") async def login(user: users.UserCreate, db: Session = Depends(get_db)): - user = authenticate_user(user, db) - if not user: - raise HTTPException(status_code=401, detail="Invalid username or password") - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token_payload = { - "sub": user.username, - "exp": datetime.utcnow() + access_token_expires, - } - access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) - return {"access_token": access_token} + return crud.login(db, user) @app.get("/highscores", response_model=List[users.UserHighScore]) async def read_high_scores( db: Session = Depends(get_db), minigame: Optional[MinigameEnum] = None, - n_highest: Optional[int] = None, + nr_highest: Optional[int] = None, ): - if n_highest < 1: - raise HTTPException(status_code=400, detail="Invalid number of high scores") - high_scores = crud.get_high_scores(db, minigame, n_highest) - return high_scores + return crud.get_high_scores(db, minigame, nr_highest) @app.post("/highscores", response_model=highscores.HighScore) @@ -111,22 +91,9 @@ async def create_high_score( #### TESTING!! DELETE LATER -async def get_current_user( - token: HTTPAuthorizationCredentials = Depends(bearer_scheme), -): - try: - payload = jwt.decode(token.credentials, jwt_secret, algorithms=[ALGORITHM]) - username = payload.get("sub") - if username is None: - raise HTTPException(status_code=401, detail="Invalid JWT token") - return username - except jwt.exceptions.DecodeError: - raise HTTPException(status_code=401, detail="Invalid JWT token") - - @app.get("/protected") -async def protected_route(current_user = Depends(get_current_user)): - return {"message": f"Hello, {current_user}!"} +async def protected_route(current_user_name=Depends(get_current_user_name)): + return {"message": f"Hello, {current_user_name}!"} def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)): @@ -139,10 +106,10 @@ def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)): return db_user - - - @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) -async def get_course_progress(course: Optional[CourseEnum] = CourseEnum.All, current_user = Depends(get_current_user), db: Session = Depends(get_db)): - user = crud.get_user_by_username(db, current_user) - return crud.get_course_progress(db = db, user = user, course = course) +async def get_course_progress( + course: Optional[CourseEnum] = CourseEnum.All, + current_user_name=Depends(get_current_user_name), + db: Session = Depends(get_db), +): + return crud.get_course_progress(db, current_user_name, course) diff --git a/src/schemas/courseprogress.py b/src/schemas/courseprogress.py index d46486e..1f071de 100644 --- a/src/schemas/courseprogress.py +++ b/src/schemas/courseprogress.py @@ -2,10 +2,12 @@ from pydantic import BaseModel from enums import CourseEnum + class CourseProgressBase(BaseModel): progress_value: float course: CourseEnum + class CourseProgress(CourseProgressBase): course_progress_id: int owner_id: int From 032a6ed5434b3ac1aa91dc4d160a7672e03fbd55 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Fri, 31 Mar 2023 07:13:13 -0600 Subject: [PATCH 24/37] Refactor crud module --- src/crud.py | 176 ------------------------------------- src/crud/authentication.py | 52 +++++++++++ src/crud/courseprogress.py | 27 ++++++ src/crud/highscores.py | 75 ++++++++++++++++ src/crud/users.py | 31 +++++++ src/main.py | 46 +++++----- 6 files changed, 205 insertions(+), 202 deletions(-) delete mode 100644 src/crud.py create mode 100644 src/crud/authentication.py create mode 100644 src/crud/courseprogress.py create mode 100644 src/crud/highscores.py create mode 100644 src/crud/users.py diff --git a/src/crud.py b/src/crud.py deleted file mode 100644 index 1321cb2..0000000 --- a/src/crud.py +++ /dev/null @@ -1,176 +0,0 @@ -from fastapi import HTTPException -from sqlalchemy import desc -from datetime import datetime, timedelta - -from sqlalchemy.orm import Session -import jwt - - -from enums import CourseEnum, MinigameEnum -from models import CourseProgress, HighScore, User -from schemas.courseprogress import CourseProgressBase -from schemas.highscores import HighScoreCreate -from schemas.users import UserCreate, UserHighScore -from passlib.context import CryptContext - - -DEFAULT_NR_HIGH_SCORES = 10 - -# JWT authentication setup -jwt_secret = "secret_key" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month - -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - - -def get_user_by_id(db: Session, user_id: int): - """Fetches a User from the database by their id""" - return db.query(User).filter(User.user_id == user_id).first() - - -def get_user_by_username(db: Session, username: str): - """Fetches a User from the database by their username""" - return db.query(User).filter(User.username == username).first() - - -def get_users(db: Session): - """Fetch a list of all users""" - return db.query(User).all() - - -def authenticate_user(db: Session, user: UserCreate): - db_user = get_user_by_username(db=db, username=user.username) - if not db_user: - return False - hashed_password = db_user.hashed_password - if not hashed_password or not pwd_context.verify(user.password, hashed_password): - return False - return db_user - - -def register(db: Session, username: str, password: str): - """Register a new user""" - db_user = get_user_by_username(db, username) - if db_user: - raise HTTPException(status_code=400, detail="Username already registered") - db_user = User(username=username, hashed_password=pwd_context.hash(password)) - db.add(db_user) - db.commit() - db.refresh(db_user) - return db_user - - -def login(db: Session, user: UserCreate): - user = authenticate_user(db, user) - if not user: - raise HTTPException(status_code=401, detail="Invalid username or password") - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token_payload = { - "sub": user.username, - "exp": datetime.utcnow() + access_token_expires, - } - access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) - return {"access_token": access_token} - - -def patch_user(db: Session, username: str, user: UserCreate): - db_user = get_user_by_username(db, username) - potential_duplicate = get_user_by_username(db, user.username) - if potential_duplicate: - if potential_duplicate.user_id != db_user.user_id: - raise HTTPException(status_code=400, detail="Username already registered") - - db_user.username = user.username - db_user.hashed_password = pwd_context.hash(user.password) - db.commit() - - -def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): - """Get the n highest scores of a given minigame""" - if nr_highest < 1: - raise HTTPException(status_code=400, detail="Invalid number of high scores") - - user_high_scores = [] - - if not nr_highest: - nr_highest = DEFAULT_NR_HIGH_SCORES - - if not minigame: - minigame = MinigameEnum.SpellingBee - - high_scores = ( - db.query(HighScore) - .filter(HighScore.minigame == minigame) - .order_by(desc(HighScore.score_value)) - .limit(nr_highest) - .all() - ) - for high_score in high_scores: - owner = db.query(User).filter(User.user_id == high_score.owner_id).first() - user_high_scores.append( - UserHighScore(username=owner.username, score_value=high_score.score_value) - ) - return user_high_scores - - -def create_high_score(db: Session, high_score: HighScoreCreate): - """Create a new high score for a given minigame""" - owner = db.query(User).filter(User.user_id == high_score.owner_id).first() - if not owner: - raise HTTPException(status_code=400, detail="User does not exist") - old_high_score = ( - db.query(HighScore) - .filter( - HighScore.owner_id == high_score.owner_id, - HighScore.minigame == high_score.minigame, - ) - .first() - ) - if old_high_score: - if old_high_score.score_value < high_score.score_value: - db_high_score = HighScore( - score_value=high_score.score_value, - minigame=high_score.minigame, - owner_id=high_score.owner_id, - ) - db.delete(old_high_score) - db.add(db_high_score) - db.commit() - db.refresh(db_high_score) - return db_high_score - else: - return old_high_score - else: - db_high_score = HighScore( - score_value=high_score.score_value, - minigame=high_score.minigame, - owner_id=high_score.owner_id, - ) - db.add(db_high_score) - db.commit() - db.refresh(db_high_score) - return db_high_score - - -def get_course_progress(db: Session, username: str, course: CourseEnum): - """Get the progress a user has for a certain course""" - user = get_user_by_username(db, username) - if course != CourseEnum.All: - course_progress = ( - db.query(CourseProgress) - .filter( - CourseProgress.owner_id == user.user_id, CourseProgress.course == course - ) - .first() - ) - if course_progress: - return [ - CourseProgressBase( - progress_value=course_progress.progress_value, - course=course_progress.course, - ) - ] - else: - return [CourseProgressBase(progress_value=0, course=course)] - return [] diff --git a/src/crud/authentication.py b/src/crud/authentication.py new file mode 100644 index 0000000..def202a --- /dev/null +++ b/src/crud/authentication.py @@ -0,0 +1,52 @@ +from datetime import datetime, timedelta + +import jwt +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from crud.users import get_user_by_username, pwd_context +from models import User +from schemas.users import UserCreate + +DEFAULT_NR_HIGH_SCORES = 10 + +# JWT authentication setup +jwt_secret = "secret_key" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month + + +def authenticate_user(db: Session, user: UserCreate): + """Checks whether the provided credentials match with an existing User""" + db_user = get_user_by_username(db=db, username=user.username) + if not db_user: + return False + hashed_password = db_user.hashed_password + if not hashed_password or not pwd_context.verify(user.password, hashed_password): + return False + return db_user + + +def register(db: Session, username: str, password: str): + """Register a new user""" + db_user = get_user_by_username(db, username) + if db_user: + raise HTTPException(status_code=400, detail="Username already registered") + db_user = User(username=username, hashed_password=pwd_context.hash(password)) + db.add(db_user) + db.commit() + db.refresh(db_user) + return db_user + + +def login(db: Session, user: UserCreate): + user = authenticate_user(db, user) + if not user: + raise HTTPException(status_code=401, detail="Invalid username or password") + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_payload = { + "sub": user.username, + "exp": datetime.utcnow() + access_token_expires, + } + access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) + return {"access_token": access_token} diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py new file mode 100644 index 0000000..c96c7db --- /dev/null +++ b/src/crud/courseprogress.py @@ -0,0 +1,27 @@ +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from enums import CourseEnum + + +def get_course_progress(db: Session, username: str, course: CourseEnum): + """Get the progress a user has for a certain course""" + user = get_user_by_username(db, username) + if course != CourseEnum.All: + course_progress = ( + db.query(CourseProgress) + .filter( + CourseProgress.owner_id == user.user_id, CourseProgress.course == course + ) + .first() + ) + if course_progress: + return [ + CourseProgressBase( + progress_value=course_progress.progress_value, + course=course_progress.course, + ) + ] + else: + return [CourseProgressBase(progress_value=0, course=course)] + return [] diff --git a/src/crud/highscores.py b/src/crud/highscores.py new file mode 100644 index 0000000..0d21396 --- /dev/null +++ b/src/crud/highscores.py @@ -0,0 +1,75 @@ +from fastapi import HTTPException +from sqlalchemy import desc +from sqlalchemy.orm import Session + +from enums import MinigameEnum +from models import HighScore, User +from schemas.highscores import HighScoreCreate +from schemas.users import UserHighScore + + +def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): + """Get the n highest scores of a given minigame""" + if nr_highest < 1: + raise HTTPException(status_code=400, detail="Invalid number of high scores") + + user_high_scores = [] + + if not nr_highest: + nr_highest = DEFAULT_NR_HIGH_SCORES + + if not minigame: + minigame = MinigameEnum.SpellingBee + + high_scores = ( + db.query(HighScore) + .filter(HighScore.minigame == minigame) + .order_by(desc(HighScore.score_value)) + .limit(nr_highest) + .all() + ) + for high_score in high_scores: + owner = db.query(User).filter(User.user_id == high_score.owner_id).first() + user_high_scores.append( + UserHighScore(username=owner.username, score_value=high_score.score_value) + ) + return user_high_scores + + +def create_high_score(db: Session, high_score: HighScoreCreate): + """Create a new high score for a given minigame""" + owner = db.query(User).filter(User.user_id == high_score.owner_id).first() + if not owner: + raise HTTPException(status_code=400, detail="User does not exist") + old_high_score = ( + db.query(HighScore) + .filter( + HighScore.owner_id == high_score.owner_id, + HighScore.minigame == high_score.minigame, + ) + .first() + ) + if old_high_score: + if old_high_score.score_value < high_score.score_value: + db_high_score = HighScore( + score_value=high_score.score_value, + minigame=high_score.minigame, + owner_id=high_score.owner_id, + ) + db.delete(old_high_score) + db.add(db_high_score) + db.commit() + db.refresh(db_high_score) + return db_high_score + else: + return old_high_score + else: + db_high_score = HighScore( + score_value=high_score.score_value, + minigame=high_score.minigame, + owner_id=high_score.owner_id, + ) + db.add(db_high_score) + db.commit() + db.refresh(db_high_score) + return db_high_score diff --git a/src/crud/users.py b/src/crud/users.py new file mode 100644 index 0000000..4817d27 --- /dev/null +++ b/src/crud/users.py @@ -0,0 +1,31 @@ +from fastapi import HTTPException +from passlib.context import CryptContext +from sqlalchemy.orm import Session + +from models import User +from schemas.users import UserCreate + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +def patch_user(db: Session, username: str, user: UserCreate): + """Changes the username and/or the password of a User""" + db_user = get_user_by_username(db, username) + potential_duplicate = get_user_by_username(db, user.username) + if potential_duplicate: + if potential_duplicate.user_id != db_user.user_id: + raise HTTPException(status_code=400, detail="Username already registered") + + db_user.username = user.username + db_user.hashed_password = pwd_context.hash(user.password) + db.commit() + + +def get_user_by_username(db: Session, username: str): + """Fetches a User from the database by their username""" + return db.query(User).filter(User.username == username).first() + + +def get_users(db: Session): + """Fetch a list of all users""" + return db.query(User).all() diff --git a/src/main.py b/src/main.py index fb0d013..b11540c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,13 +1,14 @@ -from datetime import datetime, timedelta from typing import List, Optional import jwt from fastapi import Depends, FastAPI, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from passlib.context import CryptContext from sqlalchemy.orm import Session -import crud +from crud import authentication as crud_authentication +from crud import courseprogress as crud_courseprogress +from crud import highscores as crud_highscores +from crud import users as crud_users from database import SessionLocal, engine from enums import CourseEnum, MinigameEnum from models import Base @@ -28,11 +29,16 @@ def get_db(): finally: db.close() + def get_current_user_name( token: HTTPAuthorizationCredentials = Depends(bearer_scheme), ): try: - payload = jwt.decode(token.credentials, crud.jwt_secret, algorithms=[crud.ALGORITHM]) + payload = jwt.decode( + token.credentials, + crud_authentication.jwt_secret, + algorithms=[crud_authentication.ALGORITHM], + ) username = payload.get("sub") if username is None: raise HTTPException(status_code=401, detail="Invalid JWT token") @@ -40,6 +46,7 @@ def get_current_user_name( except jwt.exceptions.DecodeError: raise HTTPException(status_code=401, detail="Invalid JWT token") + @app.get("/") async def root(): return {"message": "Hello world!"} @@ -47,45 +54,42 @@ async def root(): @app.get("/users", response_model=List[users.User]) async def read_users(db: Session = Depends(get_db)): - users = crud.get_users(db) - return users + return crud_users.get_users(db) @app.patch("/users") async def patch_current_user( user: users.UserCreate, - current_user_name = Depends(get_current_user_name), + current_user_name=Depends(get_current_user_name), db: Session = Depends(get_db), ): - crud.patch_user(db, current_user_name, user) + crud_users.patch_user(db, current_user_name, user) @app.post("/register", response_model=users.User) async def register(user: users.UserCreate, db: Session = Depends(get_db)): - return crud.register( - db, user.username, user.password - ) + return crud_authentication.register(db, user.username, user.password) @app.post("/login") async def login(user: users.UserCreate, db: Session = Depends(get_db)): - return crud.login(db, user) + return crud_authentication.login(db, user) @app.get("/highscores", response_model=List[users.UserHighScore]) -async def read_high_scores( +async def get_high_scores( db: Session = Depends(get_db), minigame: Optional[MinigameEnum] = None, nr_highest: Optional[int] = None, ): - return crud.get_high_scores(db, minigame, nr_highest) + return crud_highscores.get_high_scores(db, minigame, nr_highest) @app.post("/highscores", response_model=highscores.HighScore) async def create_high_score( high_score: highscores.HighScoreCreate, db: Session = Depends(get_db) ): - return crud.create_high_score(db=db, high_score=high_score) + return crud_highscores.create_high_score(db=db, high_score=high_score) #### TESTING!! DELETE LATER @@ -96,20 +100,10 @@ async def protected_route(current_user_name=Depends(get_current_user_name)): return {"message": f"Hello, {current_user_name}!"} -def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)): - db_user = crud.get_user_by_username(db=db, username=user.username) - if not db_user: - return False - hashed_password = db_user.hashed_password - if not hashed_password or not pwd_context.verify(user.password, hashed_password): - return False - return db_user - - @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, current_user_name=Depends(get_current_user_name), db: Session = Depends(get_db), ): - return crud.get_course_progress(db, current_user_name, course) + return crud_courseprogress.get_course_progress(db, current_user_name, course) From edd50b9ecb0a45519d7930c3cffe088ae826c592 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Fri, 31 Mar 2023 07:43:10 -0600 Subject: [PATCH 25/37] Let register return access token without extra login --- src/crud/authentication.py | 12 ++++++------ src/main.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/crud/authentication.py b/src/crud/authentication.py index def202a..0f7a9cf 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -16,13 +16,13 @@ ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month -def authenticate_user(db: Session, user: UserCreate): +def authenticate_user(db: Session, username: str, password: str): """Checks whether the provided credentials match with an existing User""" - db_user = get_user_by_username(db=db, username=user.username) + db_user = get_user_by_username(db, username) if not db_user: return False hashed_password = db_user.hashed_password - if not hashed_password or not pwd_context.verify(user.password, hashed_password): + if not hashed_password or not pwd_context.verify(password, hashed_password): return False return db_user @@ -36,11 +36,11 @@ def register(db: Session, username: str, password: str): db.add(db_user) db.commit() db.refresh(db_user) - return db_user + return login(db, username, password) -def login(db: Session, user: UserCreate): - user = authenticate_user(db, user) +def login(db: Session, username: str, password: str): + user = authenticate_user(db, username, password) if not user: raise HTTPException(status_code=401, detail="Invalid username or password") access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) diff --git a/src/main.py b/src/main.py index b11540c..fb0855d 100644 --- a/src/main.py +++ b/src/main.py @@ -66,14 +66,14 @@ async def patch_current_user( crud_users.patch_user(db, current_user_name, user) -@app.post("/register", response_model=users.User) +@app.post("/register") async def register(user: users.UserCreate, db: Session = Depends(get_db)): return crud_authentication.register(db, user.username, user.password) @app.post("/login") async def login(user: users.UserCreate, db: Session = Depends(get_db)): - return crud_authentication.login(db, user) + return crud_authentication.login(db, user.username, user.password) @app.get("/highscores", response_model=List[users.UserHighScore]) From 65d1a2a6e41a43146bf705b0e6cd7769d1bd729c Mon Sep 17 00:00:00 2001 From: lvrossem Date: Fri, 31 Mar 2023 11:43:07 -0600 Subject: [PATCH 26/37] Sort of fix StrEnum issue --- src/crud/authentication.py | 4 ++-- src/crud/courseprogress.py | 20 ++++++++++++------ src/crud/highscores.py | 43 ++++++++++++++++---------------------- src/enums.py | 13 +++++++++++- src/main.py | 24 ++++++++++++--------- src/models.py | 11 +++------- src/schemas/highscores.py | 6 +----- src/schemas/users.py | 4 +--- 8 files changed, 65 insertions(+), 60 deletions(-) diff --git a/src/crud/authentication.py b/src/crud/authentication.py index 0f7a9cf..5986a57 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -27,12 +27,12 @@ def authenticate_user(db: Session, username: str, password: str): return db_user -def register(db: Session, username: str, password: str): +def register(db: Session, username: str, password: str, avatar: str): """Register a new user""" db_user = get_user_by_username(db, username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") - db_user = User(username=username, hashed_password=pwd_context.hash(password)) + db_user = User(username = username, hashed_password = pwd_context.hash(password), avatar = avatar) db.add(db_user) db.commit() db.refresh(db_user) diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index c96c7db..6a2f17e 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -1,12 +1,12 @@ from fastapi import HTTPException from sqlalchemy.orm import Session -from enums import CourseEnum +from enums import CourseEnum, course_enum_list +from models import User, CourseProgress -def get_course_progress(db: Session, username: str, course: CourseEnum): +def get_course_progress(db: Session, user: User, course: CourseEnum): """Get the progress a user has for a certain course""" - user = get_user_by_username(db, username) if course != CourseEnum.All: course_progress = ( db.query(CourseProgress) @@ -18,10 +18,18 @@ def get_course_progress(db: Session, username: str, course: CourseEnum): if course_progress: return [ CourseProgressBase( - progress_value=course_progress.progress_value, - course=course_progress.course, + progress_value = course_progress.progress_value, + course = course_progress.course, ) ] else: - return [CourseProgressBase(progress_value=0, course=course)] + return [CourseProgressBase(progress_value = 0, course=course)] return [] + + +def initialize_user(db: Session, user: User): + for course in course_enum_list: + db.add(CourseProgress(progress_value = 0.0, course = course, owner_id = user.user_id)) + db.commit() + + diff --git a/src/crud/highscores.py b/src/crud/highscores.py index 0d21396..6471c34 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session from enums import MinigameEnum from models import HighScore, User -from schemas.highscores import HighScoreCreate +from schemas.highscores import HighScoreBase from schemas.users import UserHighScore @@ -31,45 +31,38 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): for high_score in high_scores: owner = db.query(User).filter(User.user_id == high_score.owner_id).first() user_high_scores.append( - UserHighScore(username=owner.username, score_value=high_score.score_value) + UserHighScore(username = owner.username, score_value = high_score.score_value, avatar = owner.avatar) ) return user_high_scores -def create_high_score(db: Session, high_score: HighScoreCreate): +def create_high_score(db: Session, user: User, high_score: HighScoreBase): """Create a new high score for a given minigame""" - owner = db.query(User).filter(User.user_id == high_score.owner_id).first() - if not owner: - raise HTTPException(status_code=400, detail="User does not exist") + def add_to_db(): + """Helper function that adds new score to database; prevents code duplication""" + db_high_score = HighScore( + score_value = high_score.score_value, + minigame = high_score.minigame, + owner_id = user.user_id, + ) + db.add(db_high_score) + db.commit() + db.refresh(db_high_score) + return db_high_score + old_high_score = ( db.query(HighScore) .filter( - HighScore.owner_id == high_score.owner_id, + HighScore.owner_id == user.user_id, HighScore.minigame == high_score.minigame, ) .first() ) if old_high_score: if old_high_score.score_value < high_score.score_value: - db_high_score = HighScore( - score_value=high_score.score_value, - minigame=high_score.minigame, - owner_id=high_score.owner_id, - ) db.delete(old_high_score) - db.add(db_high_score) - db.commit() - db.refresh(db_high_score) - return db_high_score + return add_to_db() else: return old_high_score else: - db_high_score = HighScore( - score_value=high_score.score_value, - minigame=high_score.minigame, - owner_id=high_score.owner_id, - ) - db.add(db_high_score) - db.commit() - db.refresh(db_high_score) - return db_high_score + return add_to_db() diff --git a/src/enums.py b/src/enums.py index 3ec4085..088d795 100644 --- a/src/enums.py +++ b/src/enums.py @@ -28,9 +28,20 @@ class MinigameEnum(StrEnum): class CourseEnum(StrEnum): Fingerspelling = "Fingerspelling" - Basics = "Basics" + #Basics = "Basics" Hobbies = "Hobbies" Animals = "Animals" Colors = "Colors" FruitsVegetables = "FruitsVegetables" All = "All" + +# This is needed because for some reason iterating over an enum doesn't work properly... +course_enum_list = [ + CourseEnum.Fingerspelling, + #CourseEnum.Basics, + CourseEnum.Hobbies, + CourseEnum.Animals, + CourseEnum.Colors, + CourseEnum.FruitsVegetables, + CourseEnum.All +] diff --git a/src/main.py b/src/main.py index fb0855d..5a9f139 100644 --- a/src/main.py +++ b/src/main.py @@ -15,10 +15,7 @@ from models import Base from schemas import courseprogress, highscores, users app = FastAPI() - - Base.metadata.create_all(bind=engine) - bearer_scheme = HTTPBearer() @@ -61,14 +58,17 @@ async def read_users(db: Session = Depends(get_db)): async def patch_current_user( user: users.UserCreate, current_user_name=Depends(get_current_user_name), - db: Session = Depends(get_db), + db: Session = Depends(get_db) ): crud_users.patch_user(db, current_user_name, user) @app.post("/register") async def register(user: users.UserCreate, db: Session = Depends(get_db)): - return crud_authentication.register(db, user.username, user.password) + access_token = crud_authentication.register(db, user.username, user.password, user.avatar) + user = crud_users.get_user_by_username(db, user.username) + crud_courseprogress.initialize_user(db, user) + return access_token @app.post("/login") @@ -78,18 +78,21 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): @app.get("/highscores", response_model=List[users.UserHighScore]) async def get_high_scores( - db: Session = Depends(get_db), minigame: Optional[MinigameEnum] = None, nr_highest: Optional[int] = None, + db: Session = Depends(get_db) ): return crud_highscores.get_high_scores(db, minigame, nr_highest) @app.post("/highscores", response_model=highscores.HighScore) async def create_high_score( - high_score: highscores.HighScoreCreate, db: Session = Depends(get_db) + high_score: highscores.HighScoreBase, + current_user_name = Depends(get_current_user_name), + db: Session = Depends(get_db) ): - return crud_highscores.create_high_score(db=db, high_score=high_score) + current_user = crud_users.get_user_by_username(db, current_user_name) + return crud_highscores.create_high_score(db, current_user, high_score) #### TESTING!! DELETE LATER @@ -104,6 +107,7 @@ async def protected_route(current_user_name=Depends(get_current_user_name)): async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, current_user_name=Depends(get_current_user_name), - db: Session = Depends(get_db), + db: Session = Depends(get_db) ): - return crud_courseprogress.get_course_progress(db, current_user_name, course) + current_user = crud_users.get_user_by_username(db, current_user_name) + return crud_courseprogress.get_course_progress(db, current_user, course) diff --git a/src/models.py b/src/models.py index 6c7f3c6..948f6b8 100644 --- a/src/models.py +++ b/src/models.py @@ -12,6 +12,7 @@ class User(Base): user_id = Column(Integer, primary_key=True, index=True) username = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) + avatar = Column(String, nullable=False) high_scores = relationship( "HighScore", back_populates="owner", cascade="all, delete", lazy="dynamic" @@ -20,19 +21,13 @@ class User(Base): "CourseProgress", back_populates="owner", cascade="all, delete", lazy="dynamic" ) - # add a new column to store the high_score IDs - high_score_ids = Column(ARRAY(Integer), default=[]) - - # add a new column to store the course_progress IDs - course_progress_ids = Column(ARRAY(Integer), default=[]) - class HighScore(Base): __tablename__ = "high_scores" high_score_id = Column(Integer, primary_key=True, index=True) score_value = Column(Float, nullable=False) - minigame = Column(StrEnumType(MinigameEnum), nullable=False) + minigame = Column(String, nullable=False) owner_id = Column(Integer, ForeignKey("users.user_id")) owner = relationship("User", back_populates="high_scores") @@ -42,6 +37,6 @@ class CourseProgress(Base): course_progress_id = Column(Integer, primary_key=True, index=True) progress_value = Column(Float, nullable=False) - course = Column(StrEnumType(CourseEnum), nullable=False) + course = Column(String, nullable=False) owner_id = Column(Integer, ForeignKey("users.user_id")) owner = relationship("User", back_populates="course_progress") diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index afec43f..3f1879a 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -6,15 +6,11 @@ from enums import MinigameEnum class HighScoreBase(BaseModel): score_value: float minigame: MinigameEnum - owner_id: int - - -class HighScoreCreate(HighScoreBase): - pass class HighScore(HighScoreBase): high_score_id: int + owner_id: int class Config: orm_mode = True diff --git a/src/schemas/users.py b/src/schemas/users.py index eed0d7d..a14c4a3 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -5,15 +5,13 @@ from pydantic import BaseModel class UserBase(BaseModel): username: str + avatar: str class User(UserBase): user_id: int hashed_password: str - high_score_ids: List[int] = [] - course_progress_ids: List[int] = [] - class Config: orm_mode = True From d2933a95bab09d7468f8596a992830f891057317 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Sat, 1 Apr 2023 10:03:18 -0600 Subject: [PATCH 27/37] Minor stuff --- src/crud/authentication.py | 28 +++++++++++++++-- src/crud/courseprogress.py | 19 +++++++----- src/crud/highscores.py | 13 +++++--- src/crud/users.py | 1 + src/database.py | 8 +++++ src/enums.py | 7 +++-- src/main.py | 61 +++++++++++++++----------------------- src/models.py | 6 ++++ src/schemas/users.py | 2 +- 9 files changed, 91 insertions(+), 54 deletions(-) diff --git a/src/crud/authentication.py b/src/crud/authentication.py index 5986a57..e70fb85 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -1,7 +1,8 @@ from datetime import datetime, timedelta import jwt -from fastapi import HTTPException +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session from crud.users import get_user_by_username, pwd_context @@ -15,6 +16,25 @@ jwt_secret = "secret_key" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month +bearer_scheme = HTTPBearer() + + +def get_current_user_name( + token: HTTPAuthorizationCredentials = Depends(bearer_scheme), +): + try: + payload = jwt.decode( + token.credentials, + jwt_secret, + algorithms=[ALGORITHM], + ) + username = payload.get("sub") + if username is None: + raise HTTPException(status_code=401, detail="Invalid JWT token") + return username + except jwt.exceptions.DecodeError: + raise HTTPException(status_code=401, detail="Invalid JWT token") + def authenticate_user(db: Session, username: str, password: str): """Checks whether the provided credentials match with an existing User""" @@ -29,10 +49,14 @@ def authenticate_user(db: Session, username: str, password: str): def register(db: Session, username: str, password: str, avatar: str): """Register a new user""" + if avatar == "": + raise HTTPException(status_code=400, detail="No avatar was provided") db_user = get_user_by_username(db, username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") - db_user = User(username = username, hashed_password = pwd_context.hash(password), avatar = avatar) + db_user = User( + username=username, hashed_password=pwd_context.hash(password), avatar=avatar + ) db.add(db_user) db.commit() db.refresh(db_user) diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index 6a2f17e..c6d0a30 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -2,7 +2,8 @@ from fastapi import HTTPException from sqlalchemy.orm import Session from enums import CourseEnum, course_enum_list -from models import User, CourseProgress +from models import CourseProgress, User +from schemas.courseprogress import CourseProgressBase def get_course_progress(db: Session, user: User, course: CourseEnum): @@ -15,21 +16,25 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): ) .first() ) + if course_progress: return [ CourseProgressBase( - progress_value = course_progress.progress_value, - course = course_progress.course, + progress_value=course_progress.progress_value, + course=course_progress.course, ) ] else: - return [CourseProgressBase(progress_value = 0, course=course)] + db.add( + CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id) + ) + db.commit() + return [CourseProgressBase(progress_value=0.0, course=course)] return [] def initialize_user(db: Session, user: User): + """Create CourseProgress records with a value of 0 for a new user""" for course in course_enum_list: - db.add(CourseProgress(progress_value = 0.0, course = course, owner_id = user.user_id)) + db.add(CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)) db.commit() - - diff --git a/src/crud/highscores.py b/src/crud/highscores.py index 6471c34..17b2cce 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -31,19 +31,24 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): for high_score in high_scores: owner = db.query(User).filter(User.user_id == high_score.owner_id).first() user_high_scores.append( - UserHighScore(username = owner.username, score_value = high_score.score_value, avatar = owner.avatar) + UserHighScore( + username=owner.username, + score_value=high_score.score_value, + avatar=owner.avatar, + ) ) return user_high_scores def create_high_score(db: Session, user: User, high_score: HighScoreBase): """Create a new high score for a given minigame""" + def add_to_db(): """Helper function that adds new score to database; prevents code duplication""" db_high_score = HighScore( - score_value = high_score.score_value, - minigame = high_score.minigame, - owner_id = user.user_id, + score_value=high_score.score_value, + minigame=high_score.minigame, + owner_id=user.user_id, ) db.add(db_high_score) db.commit() diff --git a/src/crud/users.py b/src/crud/users.py index 4817d27..50cadd7 100644 --- a/src/crud/users.py +++ b/src/crud/users.py @@ -18,6 +18,7 @@ def patch_user(db: Session, username: str, user: UserCreate): db_user.username = user.username db_user.hashed_password = pwd_context.hash(user.password) + db_user.avatar = user.avatar db.commit() diff --git a/src/database.py b/src/database.py index 533adf5..aca9992 100644 --- a/src/database.py +++ b/src/database.py @@ -9,3 +9,11 @@ engine = create_engine(SQLALCHEMY_DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/src/enums.py b/src/enums.py index 088d795..3b6ecd4 100644 --- a/src/enums.py +++ b/src/enums.py @@ -28,20 +28,21 @@ class MinigameEnum(StrEnum): class CourseEnum(StrEnum): Fingerspelling = "Fingerspelling" - #Basics = "Basics" + Basics = "Basics" Hobbies = "Hobbies" Animals = "Animals" Colors = "Colors" FruitsVegetables = "FruitsVegetables" All = "All" + # This is needed because for some reason iterating over an enum doesn't work properly... course_enum_list = [ CourseEnum.Fingerspelling, - #CourseEnum.Basics, + # CourseEnum.Basics, CourseEnum.Hobbies, CourseEnum.Animals, CourseEnum.Colors, CourseEnum.FruitsVegetables, - CourseEnum.All + CourseEnum.All, ] diff --git a/src/main.py b/src/main.py index 5a9f139..c7dea9c 100644 --- a/src/main.py +++ b/src/main.py @@ -2,46 +2,19 @@ from typing import List, Optional import jwt from fastapi import Depends, FastAPI, HTTPException -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session from crud import authentication as crud_authentication from crud import courseprogress as crud_courseprogress from crud import highscores as crud_highscores from crud import users as crud_users -from database import SessionLocal, engine +from database import SessionLocal, engine, get_db from enums import CourseEnum, MinigameEnum from models import Base from schemas import courseprogress, highscores, users app = FastAPI() Base.metadata.create_all(bind=engine) -bearer_scheme = HTTPBearer() - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - -def get_current_user_name( - token: HTTPAuthorizationCredentials = Depends(bearer_scheme), -): - try: - payload = jwt.decode( - token.credentials, - crud_authentication.jwt_secret, - algorithms=[crud_authentication.ALGORITHM], - ) - username = payload.get("sub") - if username is None: - raise HTTPException(status_code=401, detail="Invalid JWT token") - return username - except jwt.exceptions.DecodeError: - raise HTTPException(status_code=401, detail="Invalid JWT token") @app.get("/") @@ -57,15 +30,17 @@ async def read_users(db: Session = Depends(get_db)): @app.patch("/users") async def patch_current_user( user: users.UserCreate, - current_user_name=Depends(get_current_user_name), - db: Session = Depends(get_db) + current_user_name=Depends(crud_authentication.get_current_user_name), + db: Session = Depends(get_db), ): crud_users.patch_user(db, current_user_name, user) @app.post("/register") async def register(user: users.UserCreate, db: Session = Depends(get_db)): - access_token = crud_authentication.register(db, user.username, user.password, user.avatar) + access_token = crud_authentication.register( + db, user.username, user.password, user.avatar + ) user = crud_users.get_user_by_username(db, user.username) crud_courseprogress.initialize_user(db, user) return access_token @@ -80,7 +55,7 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): async def get_high_scores( minigame: Optional[MinigameEnum] = None, nr_highest: Optional[int] = None, - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): return crud_highscores.get_high_scores(db, minigame, nr_highest) @@ -88,8 +63,8 @@ async def get_high_scores( @app.post("/highscores", response_model=highscores.HighScore) async def create_high_score( high_score: highscores.HighScoreBase, - current_user_name = Depends(get_current_user_name), - db: Session = Depends(get_db) + current_user_name=Depends(crud_authentication.get_current_user_name), + db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) return crud_highscores.create_high_score(db, current_user, high_score) @@ -99,15 +74,27 @@ async def create_high_score( @app.get("/protected") -async def protected_route(current_user_name=Depends(get_current_user_name)): +async def protected_route( + current_user_name: str = Depends(crud_authentication.get_current_user_name), +): return {"message": f"Hello, {current_user_name}!"} @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, - current_user_name=Depends(get_current_user_name), - db: Session = Depends(get_db) + current_user_name: str = Depends(crud_authentication.get_current_user_name), + db: Session = Depends(get_db), +): + current_user = crud_users.get_user_by_username(db, current_user_name) + return crud_courseprogress.get_course_progress(db, current_user, course) + + +@app.patch("/courseprogress") +async def get_course_progress( + current_user_name: str = Depends(crud_authentication.get_current_user_name), + course: Optional[CourseEnum] = CourseEnum.All, + db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) return crud_courseprogress.get_course_progress(db, current_user, course) diff --git a/src/models.py b/src/models.py index 948f6b8..d2f1f8d 100644 --- a/src/models.py +++ b/src/models.py @@ -7,6 +7,8 @@ from enums import CourseEnum, MinigameEnum, StrEnumType class User(Base): + """The database model for users""" + __tablename__ = "users" user_id = Column(Integer, primary_key=True, index=True) @@ -23,6 +25,8 @@ class User(Base): class HighScore(Base): + """The database model for high scores""" + __tablename__ = "high_scores" high_score_id = Column(Integer, primary_key=True, index=True) @@ -33,6 +37,8 @@ class HighScore(Base): class CourseProgress(Base): + """The database model for course progress""" + __tablename__ = "course_progress" course_progress_id = Column(Integer, primary_key=True, index=True) diff --git a/src/schemas/users.py b/src/schemas/users.py index a14c4a3..3077af7 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -5,7 +5,7 @@ from pydantic import BaseModel class UserBase(BaseModel): username: str - avatar: str + avatar: str = "" class User(UserBase): From 5528ae8519e3d3ea2bf807b84339af7f58e56d72 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Sat, 1 Apr 2023 11:18:26 -0600 Subject: [PATCH 28/37] Finish most work on course progress endpoints --- src/crud/courseprogress.py | 45 ++++++++++++++++++++++++++++++++------ src/main.py | 25 ++++++++++----------- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index c6d0a30..e3ba2b1 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -8,7 +8,14 @@ from schemas.courseprogress import CourseProgressBase def get_course_progress(db: Session, user: User, course: CourseEnum): """Get the progress a user has for a certain course""" - if course != CourseEnum.All: + result = [] + courses_to_fetch = [course] + if course == CourseEnum.All: + all_courses_list = [course for course in CourseEnum] + courses_to_fetch = filter( + lambda course: course != CourseEnum.All, all_courses_list + ) + for course in courses_to_fetch: course_progress = ( db.query(CourseProgress) .filter( @@ -18,19 +25,19 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): ) if course_progress: - return [ + result.append( CourseProgressBase( - progress_value=course_progress.progress_value, - course=course_progress.course, + progress_value=course_progress.progress_value, course=course ) - ] + ) else: db.add( CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id) ) db.commit() - return [CourseProgressBase(progress_value=0.0, course=course)] - return [] + result.append(CourseProgressBase(progress_value=0.0, course=course)) + + return result def initialize_user(db: Session, user: User): @@ -38,3 +45,27 @@ def initialize_user(db: Session, user: User): for course in course_enum_list: db.add(CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)) db.commit() + + +def patch_course_progress(db: Session, user: User, course_progress: CourseProgressBase): + """Change the progress value for a given course""" + db_course_progress_list = [] + if course_progress.course != CourseEnum.All: + db_course_progress_list = ( + db.query(CourseProgress) + .filter( + CourseProgress.owner_id == user.user_id, + CourseProgress.course == course_progress.course, + ) + .all() + ) + else: + db_course_progress_list = ( + db.query(CourseProgress) + .filter(CourseProgress.owner_id == user.user_id) + .all() + ) + + for db_course_progress in db_course_progress_list: + db_course_progress.progress_value = course_progress.progress_value + db.commit() diff --git a/src/main.py b/src/main.py index c7dea9c..63eae43 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,5 @@ from typing import List, Optional -import jwt from fastapi import Depends, FastAPI, HTTPException from sqlalchemy.orm import Session @@ -70,16 +69,6 @@ async def create_high_score( return crud_highscores.create_high_score(db, current_user, high_score) -#### TESTING!! DELETE LATER - - -@app.get("/protected") -async def protected_route( - current_user_name: str = Depends(crud_authentication.get_current_user_name), -): - return {"message": f"Hello, {current_user_name}!"} - - @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, @@ -92,9 +81,19 @@ async def get_course_progress( @app.patch("/courseprogress") async def get_course_progress( + course_progress: courseprogress.CourseProgressBase, current_user_name: str = Depends(crud_authentication.get_current_user_name), - course: Optional[CourseEnum] = CourseEnum.All, db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) - return crud_courseprogress.get_course_progress(db, current_user, course) + return crud_courseprogress.patch_course_progress(db, current_user, course_progress) + + +#### TESTING!! DELETE LATER + + +@app.get("/protected") +async def protected_route( + current_user_name: str = Depends(crud_authentication.get_current_user_name), +): + return {"message": f"Hello, {current_user_name}!"} From 78138e83c79ef59935fc160bf8475d521ad54a54 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Thu, 6 Apr 2023 14:10:31 -0600 Subject: [PATCH 29/37] Fix test setup for backend --- format.sh | 4 ++++ requirements.txt | 3 ++- src/__init__.py | 0 src/crud/authentication.py | 6 +++--- src/crud/courseprogress.py | 6 +++--- src/crud/highscores.py | 8 ++++---- src/crud/users.py | 4 ++-- src/main.py | 21 ++++++++++++--------- src/models.py | 4 ++-- src/schemas/courseprogress.py | 2 +- src/schemas/highscores.py | 2 +- tests/__init__.py | 0 tests/config/database.py | 20 ++++++++++++++++++++ tests/config/setup.py | 27 +++++++++++++++++++++++++++ tests/test_users.py | 25 +++++++++++++++++++++++++ tests/usertests.py | 15 +++++++++++++++ 16 files changed, 121 insertions(+), 26 deletions(-) create mode 100644 src/__init__.py create mode 100644 tests/__init__.py create mode 100644 tests/config/database.py create mode 100644 tests/config/setup.py create mode 100644 tests/test_users.py create mode 100644 tests/usertests.py diff --git a/format.sh b/format.sh index 3c6cc0e..dd67977 100644 --- a/format.sh +++ b/format.sh @@ -1,3 +1,7 @@ flake8 src/* black src/* isort src/* + +flake8 tests/* +black tests/* +isort tests/* diff --git a/requirements.txt b/requirements.txt index 7dde8fa..9da48b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ interrogate python-jose[cryptography] passlib jwt -PyJWT \ No newline at end of file +PyJWT +pytest-asyncio \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/crud/authentication.py b/src/crud/authentication.py index e70fb85..631be5a 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -5,9 +5,9 @@ from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session -from crud.users import get_user_by_username, pwd_context -from models import User -from schemas.users import UserCreate +from src.crud.users import get_user_by_username, pwd_context +from src.models import User +from src.schemas.users import UserCreate DEFAULT_NR_HIGH_SCORES = 10 diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index e3ba2b1..6fb4b38 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -1,9 +1,9 @@ from fastapi import HTTPException from sqlalchemy.orm import Session -from enums import CourseEnum, course_enum_list -from models import CourseProgress, User -from schemas.courseprogress import CourseProgressBase +from src.enums import CourseEnum, course_enum_list +from src.models import CourseProgress, User +from src.schemas.courseprogress import CourseProgressBase def get_course_progress(db: Session, user: User, course: CourseEnum): diff --git a/src/crud/highscores.py b/src/crud/highscores.py index 17b2cce..45c397c 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -2,10 +2,10 @@ from fastapi import HTTPException from sqlalchemy import desc from sqlalchemy.orm import Session -from enums import MinigameEnum -from models import HighScore, User -from schemas.highscores import HighScoreBase -from schemas.users import UserHighScore +from src.enums import MinigameEnum +from src.models import HighScore, User +from src.schemas.highscores import HighScoreBase +from src.schemas.users import UserHighScore def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): diff --git a/src/crud/users.py b/src/crud/users.py index 50cadd7..1fc8daa 100644 --- a/src/crud/users.py +++ b/src/crud/users.py @@ -2,8 +2,8 @@ from fastapi import HTTPException from passlib.context import CryptContext from sqlalchemy.orm import Session -from models import User -from schemas.users import UserCreate +from src.models import User +from src.schemas.users import UserCreate pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") diff --git a/src/main.py b/src/main.py index 63eae43..75f3706 100644 --- a/src/main.py +++ b/src/main.py @@ -1,18 +1,21 @@ +import sys from typing import List, Optional from fastapi import Depends, FastAPI, HTTPException from sqlalchemy.orm import Session -from crud import authentication as crud_authentication -from crud import courseprogress as crud_courseprogress -from crud import highscores as crud_highscores -from crud import users as crud_users -from database import SessionLocal, engine, get_db -from enums import CourseEnum, MinigameEnum -from models import Base -from schemas import courseprogress, highscores, users +sys.path.append("..") + +from src.crud import authentication as crud_authentication +from src.crud import courseprogress as crud_courseprogress +from src.crud import highscores as crud_highscores +from src.crud import users as crud_users +from src.database import Base, SessionLocal, engine, get_db +from src.enums import CourseEnum, MinigameEnum +from src.schemas import courseprogress, highscores, users app = FastAPI() + Base.metadata.create_all(bind=engine) @@ -79,7 +82,7 @@ async def get_course_progress( return crud_courseprogress.get_course_progress(db, current_user, course) -@app.patch("/courseprogress") +@app.patch("/courseprogress/{course_name}") async def get_course_progress( course_progress: courseprogress.CourseProgressBase, current_user_name: str = Depends(crud_authentication.get_current_user_name), diff --git a/src/models.py b/src/models.py index d2f1f8d..516d9c6 100644 --- a/src/models.py +++ b/src/models.py @@ -2,8 +2,8 @@ from sqlalchemy import Column, Float, ForeignKey, Integer, String from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.orm import relationship -from database import Base -from enums import CourseEnum, MinigameEnum, StrEnumType +from src.database import Base +from src.enums import CourseEnum, MinigameEnum, StrEnumType class User(Base): diff --git a/src/schemas/courseprogress.py b/src/schemas/courseprogress.py index 1f071de..768890d 100644 --- a/src/schemas/courseprogress.py +++ b/src/schemas/courseprogress.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from enums import CourseEnum +from src.enums import CourseEnum class CourseProgressBase(BaseModel): diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index 3f1879a..39632ba 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from enums import MinigameEnum +from src.enums import MinigameEnum class HighScoreBase(BaseModel): diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/config/database.py b/tests/config/database.py new file mode 100644 index 0000000..6760fea --- /dev/null +++ b/tests/config/database.py @@ -0,0 +1,20 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from src.database import Base + +SQLALCHEMY_DATABASE_URL = "postgresql://admin:WeSign123!@localhost/wesigntest" + +engine = create_engine(SQLALCHEMY_DATABASE_URL) + +TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base.metadata.create_all(bind=engine) + + +def override_get_db(): + try: + db = TestSessionLocal() + yield db + finally: + db.close() diff --git a/tests/config/setup.py b/tests/config/setup.py new file mode 100644 index 0000000..fa49ebe --- /dev/null +++ b/tests/config/setup.py @@ -0,0 +1,27 @@ +import sys + +sys.path.append("..") + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from src.database import Base as ProductionBase +from tests.config.database import (SQLALCHEMY_DATABASE_URL, TestBase, + TestSessionLocal) + + +@pytest.fixture(scope="function") +def db_session(): + engine = create_engine(SQLALCHEMY_DATABASE_URL) + print(SQLALCHEMY_DATABASE_URL) + ProductionBase.metadata.create_all(bind=engine) + TestBase.metadata.create_all(bind=engine) + session = TestSessionLocal(bind=engine) + try: + yield session + finally: + session.rollback() + TestBase.metadata.drop_all(bind=engine) + ProductionBase.metadata.drop_all(bind=engine) + session.close() diff --git a/tests/test_users.py b/tests/test_users.py new file mode 100644 index 0000000..90f2fe2 --- /dev/null +++ b/tests/test_users.py @@ -0,0 +1,25 @@ +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.append("..") + +from src.main import app, get_db +from tests.config.database import override_get_db + +app.dependency_overrides[get_db] = override_get_db + +client = TestClient(app) + + +@pytest.mark.asyncio +async def test_add_user(): + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": "user27", "password": "mettn", "avatar": "lion"}, + ) + + print(response) + assert response.status_code == 200 diff --git a/tests/usertests.py b/tests/usertests.py new file mode 100644 index 0000000..1f79c86 --- /dev/null +++ b/tests/usertests.py @@ -0,0 +1,15 @@ +import pytest +from fastapi.testclient import TestClient + +from main import app + +client = TestClient(app) + + +def test_add_user(): + response = client.post( + "/users", + headers={"Content-Type": "application/json"}, + json={"username": "Lukas", "password": "mettn"}, + ) + assert response.status_code == 200 From 5e91bc7ef6438fa0521b3002498991ae52694bfa Mon Sep 17 00:00:00 2001 From: lvrossem Date: Thu, 6 Apr 2023 14:21:51 -0600 Subject: [PATCH 30/37] Minor changes to request paths --- src/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/main.py b/src/main.py index 75f3706..e3e0869 100644 --- a/src/main.py +++ b/src/main.py @@ -53,7 +53,7 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): return crud_authentication.login(db, user.username, user.password) -@app.get("/highscores", response_model=List[users.UserHighScore]) +@app.get("/highscores/{minigame}", response_model=List[users.UserHighScore]) async def get_high_scores( minigame: Optional[MinigameEnum] = None, nr_highest: Optional[int] = None, @@ -72,7 +72,9 @@ async def create_high_score( return crud_highscores.create_high_score(db, current_user, high_score) -@app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) +@app.get( + "/courseprogress/{course}", response_model=List[courseprogress.CourseProgressBase] +) async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, current_user_name: str = Depends(crud_authentication.get_current_user_name), @@ -82,7 +84,7 @@ async def get_course_progress( return crud_courseprogress.get_course_progress(db, current_user, course) -@app.patch("/courseprogress/{course_name}") +@app.patch("/courseprogress") async def get_course_progress( course_progress: courseprogress.CourseProgressBase, current_user_name: str = Depends(crud_authentication.get_current_user_name), From d3a29a4b2988bb30b049e612f54cfb722052dc75 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Thu, 6 Apr 2023 14:26:07 -0600 Subject: [PATCH 31/37] Add input validation to register --- src/crud/authentication.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/crud/authentication.py b/src/crud/authentication.py index 631be5a..3ec81ec 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -49,8 +49,12 @@ def authenticate_user(db: Session, username: str, password: str): def register(db: Session, username: str, password: str, avatar: str): """Register a new user""" - if avatar == "": + if len(avatar) == 0: raise HTTPException(status_code=400, detail="No avatar was provided") + if len(username) == 0: + raise HTTPException(status_code=400, detail="No username was provided") + if len(password) == 0: + raise HTTPException(status_code=400, detail="No password was provided") db_user = get_user_by_username(db, username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") @@ -64,6 +68,7 @@ def register(db: Session, username: str, password: str, avatar: str): def login(db: Session, username: str, password: str): + """Log in based on username and password; supply access token if succeeded""" user = authenticate_user(db, username, password) if not user: raise HTTPException(status_code=401, detail="Invalid username or password") From 101cd899c3ab74b7f721eb6b5f3ae312ce15b2b4 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Thu, 6 Apr 2023 16:08:27 -0600 Subject: [PATCH 32/37] Formatting --- src/crud/authentication.py | 1 - src/crud/courseprogress.py | 1 - src/crud/highscores.py | 2 ++ src/main.py | 6 +++--- src/models.py | 2 -- src/schemas/users.py | 2 -- tests/config/setup.py | 5 ++--- 7 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/crud/authentication.py b/src/crud/authentication.py index 3ec81ec..201e25b 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -7,7 +7,6 @@ from sqlalchemy.orm import Session from src.crud.users import get_user_by_username, pwd_context from src.models import User -from src.schemas.users import UserCreate DEFAULT_NR_HIGH_SCORES = 10 diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index 6fb4b38..4818f1e 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -1,4 +1,3 @@ -from fastapi import HTTPException from sqlalchemy.orm import Session from src.enums import CourseEnum, course_enum_list diff --git a/src/crud/highscores.py b/src/crud/highscores.py index 45c397c..5809027 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -7,6 +7,8 @@ from src.models import HighScore, User from src.schemas.highscores import HighScoreBase from src.schemas.users import UserHighScore +DEFAULT_NR_HIGH_SCORES = 10 + def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): """Get the n highest scores of a given minigame""" diff --git a/src/main.py b/src/main.py index e3e0869..29b16f4 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,7 @@ import sys from typing import List, Optional -from fastapi import Depends, FastAPI, HTTPException +from fastapi import Depends, FastAPI from sqlalchemy.orm import Session sys.path.append("..") @@ -10,7 +10,7 @@ from src.crud import authentication as crud_authentication from src.crud import courseprogress as crud_courseprogress from src.crud import highscores as crud_highscores from src.crud import users as crud_users -from src.database import Base, SessionLocal, engine, get_db +from src.database import Base, engine, get_db from src.enums import CourseEnum, MinigameEnum from src.schemas import courseprogress, highscores, users @@ -85,7 +85,7 @@ async def get_course_progress( @app.patch("/courseprogress") -async def get_course_progress( +async def patch_course_progress( course_progress: courseprogress.CourseProgressBase, current_user_name: str = Depends(crud_authentication.get_current_user_name), db: Session = Depends(get_db), diff --git a/src/models.py b/src/models.py index 516d9c6..99a3965 100644 --- a/src/models.py +++ b/src/models.py @@ -1,9 +1,7 @@ from sqlalchemy import Column, Float, ForeignKey, Integer, String -from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.orm import relationship from src.database import Base -from src.enums import CourseEnum, MinigameEnum, StrEnumType class User(Base): diff --git a/src/schemas/users.py b/src/schemas/users.py index 3077af7..d98b5e6 100644 --- a/src/schemas/users.py +++ b/src/schemas/users.py @@ -1,5 +1,3 @@ -from typing import List - from pydantic import BaseModel diff --git a/tests/config/setup.py b/tests/config/setup.py index fa49ebe..71466e6 100644 --- a/tests/config/setup.py +++ b/tests/config/setup.py @@ -1,10 +1,9 @@ import sys -sys.path.append("..") - import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker + +sys.path.append("..") from src.database import Base as ProductionBase from tests.config.database import (SQLALCHEMY_DATABASE_URL, TestBase, From 8f3c303a2bb5aeea251b1d2cd5358f16a6d3ce88 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Sun, 9 Apr 2023 13:39:52 -0600 Subject: [PATCH 33/37] Get started with user tests --- src/crud/authentication.py | 11 ++-- src/crud/users.py | 12 +++- src/main.py | 10 ++- tests/config/database.py | 10 +++ tests/config/setup.py | 26 -------- tests/test_users.py | 132 +++++++++++++++++++++++++++++++++++-- 6 files changed, 162 insertions(+), 39 deletions(-) delete mode 100644 tests/config/setup.py diff --git a/src/crud/authentication.py b/src/crud/authentication.py index 201e25b..c8f5654 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -5,7 +5,8 @@ from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session -from src.crud.users import get_user_by_username, pwd_context +from src.crud.users import (check_empty_fields, get_user_by_username, + pwd_context) from src.models import User DEFAULT_NR_HIGH_SCORES = 10 @@ -48,12 +49,8 @@ def authenticate_user(db: Session, username: str, password: str): def register(db: Session, username: str, password: str, avatar: str): """Register a new user""" - if len(avatar) == 0: - raise HTTPException(status_code=400, detail="No avatar was provided") - if len(username) == 0: - raise HTTPException(status_code=400, detail="No username was provided") - if len(password) == 0: - raise HTTPException(status_code=400, detail="No password was provided") + check_empty_fields(username, password, avatar) + db_user = get_user_by_username(db, username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") diff --git a/src/crud/users.py b/src/crud/users.py index 1fc8daa..e284bd3 100644 --- a/src/crud/users.py +++ b/src/crud/users.py @@ -8,14 +8,24 @@ from src.schemas.users import UserCreate pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +def check_empty_fields(username: str, password: str, avatar: str): + "Checks if any user fields are empty" + if len(avatar) == 0: + raise HTTPException(status_code=400, detail="No avatar was provided") + if len(username) == 0: + raise HTTPException(status_code=400, detail="No username was provided") + if len(password) == 0: + raise HTTPException(status_code=400, detail="No password was provided") + + def patch_user(db: Session, username: str, user: UserCreate): """Changes the username and/or the password of a User""" + check_empty_fields(user.username, user.password, user.avatar) db_user = get_user_by_username(db, username) potential_duplicate = get_user_by_username(db, user.username) if potential_duplicate: if potential_duplicate.user_id != db_user.user_id: raise HTTPException(status_code=400, detail="Username already registered") - db_user.username = user.username db_user.hashed_password = pwd_context.hash(user.password) db_user.avatar = user.avatar diff --git a/src/main.py b/src/main.py index 29b16f4..d518c4e 100644 --- a/src/main.py +++ b/src/main.py @@ -24,11 +24,19 @@ async def root(): return {"message": "Hello world!"} -@app.get("/users", response_model=List[users.User]) +@app.get("/allusers", response_model=List[users.User]) async def read_users(db: Session = Depends(get_db)): return crud_users.get_users(db) +@app.get("/users", response_model=users.User) +async def read_user( + current_user_name: str = Depends(crud_authentication.get_current_user_name), + db: Session = Depends(get_db), +): + return crud_users.get_user_by_username(db, current_user_name) + + @app.patch("/users") async def patch_current_user( user: users.UserCreate, diff --git a/tests/config/database.py b/tests/config/database.py index 6760fea..ef82ea7 100644 --- a/tests/config/database.py +++ b/tests/config/database.py @@ -2,6 +2,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from src.database import Base +from src.models import CourseProgress, HighScore, User SQLALCHEMY_DATABASE_URL = "postgresql://admin:WeSign123!@localhost/wesigntest" @@ -12,6 +13,15 @@ TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base.metadata.create_all(bind=engine) +def clear_db(): + db = TestSessionLocal() + + db.query(HighScore).delete() + db.query(CourseProgress).delete() + db.query(User).delete() + db.commit() + + def override_get_db(): try: db = TestSessionLocal() diff --git a/tests/config/setup.py b/tests/config/setup.py deleted file mode 100644 index 71466e6..0000000 --- a/tests/config/setup.py +++ /dev/null @@ -1,26 +0,0 @@ -import sys - -import pytest -from sqlalchemy import create_engine - -sys.path.append("..") - -from src.database import Base as ProductionBase -from tests.config.database import (SQLALCHEMY_DATABASE_URL, TestBase, - TestSessionLocal) - - -@pytest.fixture(scope="function") -def db_session(): - engine = create_engine(SQLALCHEMY_DATABASE_URL) - print(SQLALCHEMY_DATABASE_URL) - ProductionBase.metadata.create_all(bind=engine) - TestBase.metadata.create_all(bind=engine) - session = TestSessionLocal(bind=engine) - try: - yield session - finally: - session.rollback() - TestBase.metadata.drop_all(bind=engine) - ProductionBase.metadata.drop_all(bind=engine) - session.close() diff --git a/tests/test_users.py b/tests/test_users.py index 90f2fe2..5a169c4 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -6,20 +6,144 @@ from fastapi.testclient import TestClient sys.path.append("..") from src.main import app, get_db -from tests.config.database import override_get_db +from tests.config.database import clear_db, override_get_db app.dependency_overrides[get_db] = override_get_db client = TestClient(app) +username = "user1" +password = "password" +avatar = "lion" + +patched_username = "New name" +patched_password = "New password" +patched_avatar = "New avatar" + @pytest.mark.asyncio -async def test_add_user(): +async def test_get_current_user(): + """Test the GET /users endpoint to get info about the current user""" + clear_db() + response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": "user27", "password": "mettn", "avatar": "lion"}, + json={"username": username, "password": password, "avatar": avatar}, ) - print(response) assert response.status_code == 200 + + token = response.json()["access_token"] + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.get("/users", headers=headers) + assert response.status_code == 200 + response = response.json() + assert response["username"] == username + assert response["avatar"] == avatar + + +@pytest.mark.asyncio +async def test_get_current_user_without_auth(): + """Getting the current user without a token should fail""" + clear_db() + + response = client.get("/users", headers={"Content-Type": "application/json"}) + + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_patch_user(): + """Test the patching of a user's username, password and avatar""" + clear_db() + + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username, "password": password, "avatar": avatar}, + ) + assert response.status_code == 200 + + token = response.json()["access_token"] + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.patch( + "/users", + json={ + "username": patched_username, + "password": patched_password, + "avatar": patched_avatar, + }, + headers=headers, + ) + assert response.status_code == 200 + + response = client.post( + "/login", + headers={"Content-Type": "application/json"}, + json={"username": patched_username, "password": patched_password}, + ) + assert response.status_code == 200 + + token = response.json()["access_token"] + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.get("/users", headers=headers) + assert response.status_code == 200 + + # Correctness of password and username is already asserted by the login + assert response.json()["avatar"] == patched_avatar + + +@pytest.mark.asyncio +async def test_patch_user_with_empty_fields(): + """Patching a user with empty fields should fail""" + clear_db() + + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username, "password": password, "avatar": avatar}, + ) + assert response.status_code == 200 + + token = response.json()["access_token"] + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.patch( + "/users", + json={ + "username": patched_username, + "password": patched_password, + "avatar": "", + }, + headers=headers, + ) + assert response.status_code == 400 + + response = client.patch( + "/users", + json={ + "username": patched_username, + "password": "", + "avatar": patched_avatar, + }, + headers=headers, + ) + assert response.status_code == 400 + + response = client.patch( + "/users", + json={ + "username": "", + "password": patched_password, + "avatar": patched_avatar, + }, + headers=headers, + ) + assert response.status_code == 400 From e7145369b57da478ad6c8c6ad79a643b65043f41 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Mon, 10 Apr 2023 14:07:30 -0600 Subject: [PATCH 34/37] Write most backend tests --- src/crud/courseprogress.py | 22 ++++- src/enums.py | 12 --- src/main.py | 2 +- tests/test_authentication.py | 139 ++++++++++++++++++++++++++++ tests/test_courseprogress.py | 174 +++++++++++++++++++++++++++++++++++ tests/test_highscores.py | 17 ++++ 6 files changed, 348 insertions(+), 18 deletions(-) create mode 100644 tests/test_authentication.py create mode 100644 tests/test_courseprogress.py create mode 100644 tests/test_highscores.py diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index 4818f1e..66eda01 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -1,6 +1,7 @@ from sqlalchemy.orm import Session +from fastapi import HTTPException -from src.enums import CourseEnum, course_enum_list +from src.enums import CourseEnum from src.models import CourseProgress, User from src.schemas.courseprogress import CourseProgressBase @@ -41,13 +42,19 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): def initialize_user(db: Session, user: User): """Create CourseProgress records with a value of 0 for a new user""" - for course in course_enum_list: - db.add(CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)) - db.commit() + for course in CourseEnum: + if course != CourseEnum.All: + db.add( + CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id) + ) + db.commit() def patch_course_progress(db: Session, user: User, course_progress: CourseProgressBase): """Change the progress value for a given course""" + if course_progress.progress_value > 1 or course_progress.progress_value < 0: + raise HTTPException(status_code=400, detail="Invalid progress value") + db_course_progress_list = [] if course_progress.course != CourseEnum.All: db_course_progress_list = ( @@ -64,7 +71,12 @@ def patch_course_progress(db: Session, user: User, course_progress: CourseProgre .filter(CourseProgress.owner_id == user.user_id) .all() ) - + print(f"LENGTH OF LIST OF {course_progress.course}: {len(db_course_progress_list)}") for db_course_progress in db_course_progress_list: db_course_progress.progress_value = course_progress.progress_value db.commit() + + return [ + CourseProgressBase(course=db_cp.course, progress_value=db_cp.progress_value) + for db_cp in db_course_progress_list + ] diff --git a/src/enums.py b/src/enums.py index 3b6ecd4..3ec4085 100644 --- a/src/enums.py +++ b/src/enums.py @@ -34,15 +34,3 @@ class CourseEnum(StrEnum): Colors = "Colors" FruitsVegetables = "FruitsVegetables" All = "All" - - -# This is needed because for some reason iterating over an enum doesn't work properly... -course_enum_list = [ - CourseEnum.Fingerspelling, - # CourseEnum.Basics, - CourseEnum.Hobbies, - CourseEnum.Animals, - CourseEnum.Colors, - CourseEnum.FruitsVegetables, - CourseEnum.All, -] diff --git a/src/main.py b/src/main.py index d518c4e..71cf170 100644 --- a/src/main.py +++ b/src/main.py @@ -92,7 +92,7 @@ async def get_course_progress( return crud_courseprogress.get_course_progress(db, current_user, course) -@app.patch("/courseprogress") +@app.patch("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) async def patch_course_progress( course_progress: courseprogress.CourseProgressBase, current_user_name: str = Depends(crud_authentication.get_current_user_name), diff --git a/tests/test_authentication.py b/tests/test_authentication.py new file mode 100644 index 0000000..825d73e --- /dev/null +++ b/tests/test_authentication.py @@ -0,0 +1,139 @@ +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.append("..") + +from src.main import app, get_db +from tests.config.database import clear_db, override_get_db + +app.dependency_overrides[get_db] = override_get_db + +client = TestClient(app) + +username1 = "user1" +username2 = "user2" +password = "password" +avatar = "lion" + + +@pytest.mark.asyncio +async def test_register(): + """LEAVE THIS TEST AT THE TOP OF THE FILE!""" + """Test the register endpoint""" + clear_db() + + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username1, "password": password, "avatar": avatar}, + ) + + assert response.status_code == 200 + assert len(response.json()["access_token"]) > 0 + + +@pytest.mark.asyncio +async def test_register_duplicate_name_should_fail(): + """Test whether registering a user with an existing username fails""" + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username1, "password": password, "avatar": avatar}, + ) + + assert response.status_code == 400 + assert "access_token" not in response.json() + + +@pytest.mark.asyncio +async def test_register_without_username_should_fail(): + """Test whether registering a user without passing a username fails""" + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"password": password, "avatar": avatar}, + ) + + assert response.status_code == 422 + assert "access_token" not in response.json() + + +@pytest.mark.asyncio +async def test_register_without_password_should_fail(): + """Test whether registering a user without passing a password fails""" + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username2, "avatar": avatar}, + ) + + assert response.status_code == 422 + assert "access_token" not in response.json() + + +@pytest.mark.asyncio +async def test_register_without_avatar_should_fail(): + """Test whether registering a user without passing an avatar fails""" + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username2, "password": password}, + ) + + # Not ideal that this is 400 instead of 422, but had no other choice than to give this field a default value + assert response.status_code == 400 + assert "access_token" not in response.json() + + +@pytest.mark.asyncio +async def test_login(): + """Test the login endpoint""" + response = client.post( + "/login", + headers={"Content-Type": "application/json"}, + json={"username": username1, "password": password}, + ) + + assert response.status_code == 200 + assert len(response.json()["access_token"]) > 0 + + +@pytest.mark.asyncio +async def test_login_wrong_password_should_fail(): + wrong_password = password + "extra characters" + response = client.post( + "/login", + headers={"Content-Type": "application/json"}, + json={"username": username1, "password": wrong_password}, + ) + + assert response.status_code == 401 + assert "access_token" not in response.json() + + +@pytest.mark.asyncio +async def test_login_without_username_should_fail(): + """Test whether logging in without passing a username fails""" + response = client.post( + "/login", + headers={"Content-Type": "application/json"}, + json={"username": username1}, + ) + + assert response.status_code == 422 + assert "access_token" not in response.json() + + +@pytest.mark.asyncio +async def test_login_without_password_should_fail(): + """Test whether logging in without passing a password fails""" + response = client.post( + "/login", + headers={"Content-Type": "application/json"}, + json={"username": username1}, + ) + + assert response.status_code == 422 + assert "access_token" not in response.json() diff --git a/tests/test_courseprogress.py b/tests/test_courseprogress.py new file mode 100644 index 0000000..ec41d56 --- /dev/null +++ b/tests/test_courseprogress.py @@ -0,0 +1,174 @@ +import random +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.append("..") + +from src.enums import CourseEnum +from src.main import app, get_db +from tests.config.database import clear_db, override_get_db + +app.dependency_overrides[get_db] = override_get_db + +client = TestClient(app) + +username = "user1" +password = "password" +avatar = "lion" + + +async def register_user(): + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username, "password": password, "avatar": avatar}, + ) + + assert response.status_code == 200 + + return response.json()["access_token"] + + +@pytest.mark.asyncio +async def test_register_creates_progress_of_zero(): + """Test whether registering a new user initializes all progress values to 0.0""" + clear_db() + + token = await register_user() + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + for course in CourseEnum: + if course != CourseEnum.All: + response = client.get(f"/courseprogress/{course}", headers=headers) + assert response.status_code == 200 + + response = response.json()[0] + + assert response["progress_value"] == 0.0 + assert response["course"] == course + + +@pytest.mark.asyncio +async def test_get_all_returns_all(): + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + response = client.get("/courseprogress/All", headers=headers) + + assert response.status_code == 200 + response = response.json() + + for course in CourseEnum: + if course != CourseEnum.All: + assert {"progress_value": 0.0, "course": course} in response + + +@pytest.mark.asyncio +async def test_get_nonexisting_course_should_fail(): + clear_db() + token = await register_user() + + fake_course = "FakeCourse" + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + response = client.get(f"/courseprogress/{fake_course}", headers=headers) + + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_patch_course_progress(): + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + for course in CourseEnum: + if course != CourseEnum.All: + progress_value = random.uniform(0, 1) + + response = client.patch( + f"/courseprogress", + headers=headers, + json={"progress_value": progress_value, "course": course}, + ) + + assert response.status_code == 200 + assert response.json()[0]["progress_value"] == progress_value + + +@pytest.mark.asyncio +async def test_patch_all_should_patch_all_courses(): + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + progress_value = random.uniform(0, 1) + + response = client.patch( + f"/courseprogress", + headers=headers, + json={"progress_value": progress_value, "course": "All"}, + ) + + assert response.status_code == 200 + + response = client.get("/courseprogress/All", headers=headers) + + assert response.status_code == 200 + response = response.json() + + for course in CourseEnum: + if course != CourseEnum.All: + assert {"progress_value": progress_value, "course": course} in response + + +@pytest.mark.asyncio +async def test_patch_nonexisting_course_should_fail(): + clear_db() + token = await register_user() + + fake_course = "FakeCourse" + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + progress_value = random.uniform(0, 1) + + response = client.patch( + f"/courseprogress", + headers=headers, + json={"progress_value": progress_value, "course": fake_course}, + ) + + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_patch_course_with_invalid_value_should_fail(): + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + too_high_progress_value = random.uniform(0, 1) + 2 + too_low_progress_value = random.uniform(0, 1) - 2 + + response = client.patch( + f"/courseprogress", + headers=headers, + json={"progress_value": too_high_progress_value, "course": "All"}, + ) + + assert response.status_code == 400 + + response = client.patch( + f"/courseprogress", + headers=headers, + json={"progress_value": too_low_progress_value, "course": "All"}, + ) + + assert response.status_code == 400 diff --git a/tests/test_highscores.py b/tests/test_highscores.py new file mode 100644 index 0000000..5e145a4 --- /dev/null +++ b/tests/test_highscores.py @@ -0,0 +1,17 @@ +import sys + +import pytest +from fastapi.testclient import TestClient + +sys.path.append("..") + +from src.main import app, get_db +from tests.config.database import clear_db, override_get_db + +app.dependency_overrides[get_db] = override_get_db + +client = TestClient(app) + +username = "user1" +password = "password" +avatar = "lion" From 8e128ca033f7a06d8eb153a299a286783f2e38a5 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Mon, 10 Apr 2023 14:44:21 -0600 Subject: [PATCH 35/37] The great endpoint refactor --- src/crud/courseprogress.py | 16 +++++------ src/crud/highscores.py | 6 ++-- src/main.py | 14 +++++---- src/schemas/courseprogress.py | 5 +++- src/schemas/highscores.py | 2 +- tests/test_authentication.py | 53 +++++++++++++++++++++++++++-------- tests/test_courseprogress.py | 26 ++++++++++------- tests/test_highscores.py | 18 ++++++++++++ 8 files changed, 100 insertions(+), 40 deletions(-) diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index 66eda01..3000804 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -3,7 +3,7 @@ from fastapi import HTTPException from src.enums import CourseEnum from src.models import CourseProgress, User -from src.schemas.courseprogress import CourseProgressBase +from src.schemas.courseprogress import CourseProgressBase, CourseProgressParent def get_course_progress(db: Session, user: User, course: CourseEnum): @@ -26,7 +26,7 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): if course_progress: result.append( - CourseProgressBase( + CourseProgressParent( progress_value=course_progress.progress_value, course=course ) ) @@ -35,7 +35,7 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id) ) db.commit() - result.append(CourseProgressBase(progress_value=0.0, course=course)) + result.append(CourseProgressParent(progress_value=0.0, course=course)) return result @@ -50,18 +50,18 @@ def initialize_user(db: Session, user: User): db.commit() -def patch_course_progress(db: Session, user: User, course_progress: CourseProgressBase): +def patch_course_progress(db: Session, user: User, course: CourseEnum, course_progress: CourseProgressBase): """Change the progress value for a given course""" if course_progress.progress_value > 1 or course_progress.progress_value < 0: raise HTTPException(status_code=400, detail="Invalid progress value") db_course_progress_list = [] - if course_progress.course != CourseEnum.All: + if course != CourseEnum.All: db_course_progress_list = ( db.query(CourseProgress) .filter( CourseProgress.owner_id == user.user_id, - CourseProgress.course == course_progress.course, + CourseProgress.course == course, ) .all() ) @@ -71,12 +71,12 @@ def patch_course_progress(db: Session, user: User, course_progress: CourseProgre .filter(CourseProgress.owner_id == user.user_id) .all() ) - print(f"LENGTH OF LIST OF {course_progress.course}: {len(db_course_progress_list)}") + for db_course_progress in db_course_progress_list: db_course_progress.progress_value = course_progress.progress_value db.commit() return [ - CourseProgressBase(course=db_cp.course, progress_value=db_cp.progress_value) + CourseProgressParent(course=db_cp.course, progress_value=db_cp.progress_value) for db_cp in db_course_progress_list ] diff --git a/src/crud/highscores.py b/src/crud/highscores.py index 5809027..ffb2961 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -42,14 +42,14 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): return user_high_scores -def create_high_score(db: Session, user: User, high_score: HighScoreBase): +def create_high_score(db: Session, user: User, minigame: MinigameEnum, high_score: HighScoreBase): """Create a new high score for a given minigame""" def add_to_db(): """Helper function that adds new score to database; prevents code duplication""" db_high_score = HighScore( score_value=high_score.score_value, - minigame=high_score.minigame, + minigame=minigame, owner_id=user.user_id, ) db.add(db_high_score) @@ -61,7 +61,7 @@ def create_high_score(db: Session, user: User, high_score: HighScoreBase): db.query(HighScore) .filter( HighScore.owner_id == user.user_id, - HighScore.minigame == high_score.minigame, + HighScore.minigame == minigame, ) .first() ) diff --git a/src/main.py b/src/main.py index 71cf170..93cbd01 100644 --- a/src/main.py +++ b/src/main.py @@ -63,25 +63,26 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): @app.get("/highscores/{minigame}", response_model=List[users.UserHighScore]) async def get_high_scores( - minigame: Optional[MinigameEnum] = None, + minigame: MinigameEnum, nr_highest: Optional[int] = None, db: Session = Depends(get_db), ): return crud_highscores.get_high_scores(db, minigame, nr_highest) -@app.post("/highscores", response_model=highscores.HighScore) +@app.post("/highscores/{minigame}", response_model=highscores.HighScore) async def create_high_score( + minigame: MinigameEnum, high_score: highscores.HighScoreBase, current_user_name=Depends(crud_authentication.get_current_user_name), db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) - return crud_highscores.create_high_score(db, current_user, high_score) + return crud_highscores.create_high_score(db, current_user, minigame, high_score) @app.get( - "/courseprogress/{course}", response_model=List[courseprogress.CourseProgressBase] + "/courseprogress/{course}", response_model=List[courseprogress.CourseProgressParent] ) async def get_course_progress( course: Optional[CourseEnum] = CourseEnum.All, @@ -92,14 +93,15 @@ async def get_course_progress( return crud_courseprogress.get_course_progress(db, current_user, course) -@app.patch("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) +@app.patch("/courseprogress/{course}", response_model=List[courseprogress.CourseProgressParent]) async def patch_course_progress( + course: CourseEnum, course_progress: courseprogress.CourseProgressBase, current_user_name: str = Depends(crud_authentication.get_current_user_name), db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) - return crud_courseprogress.patch_course_progress(db, current_user, course_progress) + return crud_courseprogress.patch_course_progress(db, current_user, course, course_progress) #### TESTING!! DELETE LATER diff --git a/src/schemas/courseprogress.py b/src/schemas/courseprogress.py index 768890d..2d8327c 100644 --- a/src/schemas/courseprogress.py +++ b/src/schemas/courseprogress.py @@ -5,10 +5,13 @@ from src.enums import CourseEnum class CourseProgressBase(BaseModel): progress_value: float + + +class CourseProgressParent(CourseProgressBase): course: CourseEnum -class CourseProgress(CourseProgressBase): +class CourseProgress(CourseProgressParent): course_progress_id: int owner_id: int diff --git a/src/schemas/highscores.py b/src/schemas/highscores.py index 39632ba..22849b1 100644 --- a/src/schemas/highscores.py +++ b/src/schemas/highscores.py @@ -5,12 +5,12 @@ from src.enums import MinigameEnum class HighScoreBase(BaseModel): score_value: float - minigame: MinigameEnum class HighScore(HighScoreBase): high_score_id: int owner_id: int + minigame: MinigameEnum class Config: orm_mode = True diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 825d73e..d13cf05 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -12,22 +12,32 @@ app.dependency_overrides[get_db] = override_get_db client = TestClient(app) -username1 = "user1" -username2 = "user2" +username = "user1" password = "password" avatar = "lion" +async def register_user(): + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username, "password": password, "avatar": avatar}, + ) + + assert response.status_code == 200 + + return response.json()["access_token"] + + @pytest.mark.asyncio async def test_register(): - """LEAVE THIS TEST AT THE TOP OF THE FILE!""" """Test the register endpoint""" clear_db() response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": username1, "password": password, "avatar": avatar}, + json={"username": username, "password": password, "avatar": avatar}, ) assert response.status_code == 200 @@ -37,10 +47,13 @@ async def test_register(): @pytest.mark.asyncio async def test_register_duplicate_name_should_fail(): """Test whether registering a user with an existing username fails""" + clear_db() + await register_user() + response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": username1, "password": password, "avatar": avatar}, + json={"username": username, "password": password, "avatar": avatar}, ) assert response.status_code == 400 @@ -50,6 +63,8 @@ async def test_register_duplicate_name_should_fail(): @pytest.mark.asyncio async def test_register_without_username_should_fail(): """Test whether registering a user without passing a username fails""" + clear_db() + response = client.post( "/register", headers={"Content-Type": "application/json"}, @@ -63,10 +78,12 @@ async def test_register_without_username_should_fail(): @pytest.mark.asyncio async def test_register_without_password_should_fail(): """Test whether registering a user without passing a password fails""" + clear_db() + response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": username2, "avatar": avatar}, + json={"username": username, "avatar": avatar}, ) assert response.status_code == 422 @@ -76,10 +93,12 @@ async def test_register_without_password_should_fail(): @pytest.mark.asyncio async def test_register_without_avatar_should_fail(): """Test whether registering a user without passing an avatar fails""" + clear_db() + response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": username2, "password": password}, + json={"username": username, "password": password}, ) # Not ideal that this is 400 instead of 422, but had no other choice than to give this field a default value @@ -90,10 +109,13 @@ async def test_register_without_avatar_should_fail(): @pytest.mark.asyncio async def test_login(): """Test the login endpoint""" + clear_db() + await register_user() + response = client.post( "/login", headers={"Content-Type": "application/json"}, - json={"username": username1, "password": password}, + json={"username": username, "password": password}, ) assert response.status_code == 200 @@ -102,11 +124,14 @@ async def test_login(): @pytest.mark.asyncio async def test_login_wrong_password_should_fail(): + clear_db() + await register_user() + wrong_password = password + "extra characters" response = client.post( "/login", headers={"Content-Type": "application/json"}, - json={"username": username1, "password": wrong_password}, + json={"username": username, "password": wrong_password}, ) assert response.status_code == 401 @@ -116,10 +141,13 @@ async def test_login_wrong_password_should_fail(): @pytest.mark.asyncio async def test_login_without_username_should_fail(): """Test whether logging in without passing a username fails""" + clear_db() + await register_user() + response = client.post( "/login", headers={"Content-Type": "application/json"}, - json={"username": username1}, + json={"password": password}, ) assert response.status_code == 422 @@ -129,10 +157,13 @@ async def test_login_without_username_should_fail(): @pytest.mark.asyncio async def test_login_without_password_should_fail(): """Test whether logging in without passing a password fails""" + clear_db() + await register_user() + response = client.post( "/login", headers={"Content-Type": "application/json"}, - json={"username": username1}, + json={"username": username}, ) assert response.status_code == 422 diff --git a/tests/test_courseprogress.py b/tests/test_courseprogress.py index ec41d56..33cf43f 100644 --- a/tests/test_courseprogress.py +++ b/tests/test_courseprogress.py @@ -52,6 +52,7 @@ async def test_register_creates_progress_of_zero(): @pytest.mark.asyncio async def test_get_all_returns_all(): + """Test whether the 'All'-course fetches all course progress values""" clear_db() token = await register_user() @@ -68,6 +69,7 @@ async def test_get_all_returns_all(): @pytest.mark.asyncio async def test_get_nonexisting_course_should_fail(): + """Test whether fetching the progress of a nonexisting course fails""" clear_db() token = await register_user() @@ -81,6 +83,7 @@ async def test_get_nonexisting_course_should_fail(): @pytest.mark.asyncio async def test_patch_course_progress(): + """Test whether patching the progress value of a course works properly""" clear_db() token = await register_user() @@ -91,9 +94,9 @@ async def test_patch_course_progress(): progress_value = random.uniform(0, 1) response = client.patch( - f"/courseprogress", + f"/courseprogress/{course}", headers=headers, - json={"progress_value": progress_value, "course": course}, + json={"progress_value": progress_value}, ) assert response.status_code == 200 @@ -102,6 +105,7 @@ async def test_patch_course_progress(): @pytest.mark.asyncio async def test_patch_all_should_patch_all_courses(): + """Test whether patching the 'All'-course updates all progress values""" clear_db() token = await register_user() @@ -110,9 +114,9 @@ async def test_patch_all_should_patch_all_courses(): progress_value = random.uniform(0, 1) response = client.patch( - f"/courseprogress", + "/courseprogress/All", headers=headers, - json={"progress_value": progress_value, "course": "All"}, + json={"progress_value": progress_value}, ) assert response.status_code == 200 @@ -129,6 +133,7 @@ async def test_patch_all_should_patch_all_courses(): @pytest.mark.asyncio async def test_patch_nonexisting_course_should_fail(): + """Test whether patching a nonexisting course fails""" clear_db() token = await register_user() @@ -139,9 +144,9 @@ async def test_patch_nonexisting_course_should_fail(): progress_value = random.uniform(0, 1) response = client.patch( - f"/courseprogress", + f"/courseprogress/{fake_course}", headers=headers, - json={"progress_value": progress_value, "course": fake_course}, + json={"progress_value": progress_value}, ) assert response.status_code == 422 @@ -149,6 +154,7 @@ async def test_patch_nonexisting_course_should_fail(): @pytest.mark.asyncio async def test_patch_course_with_invalid_value_should_fail(): + """Test whether patching a course progress value with an invalid value fails""" clear_db() token = await register_user() @@ -158,17 +164,17 @@ async def test_patch_course_with_invalid_value_should_fail(): too_low_progress_value = random.uniform(0, 1) - 2 response = client.patch( - f"/courseprogress", + "/courseprogress/All", headers=headers, - json={"progress_value": too_high_progress_value, "course": "All"}, + json={"progress_value": too_high_progress_value}, ) assert response.status_code == 400 response = client.patch( - f"/courseprogress", + "/courseprogress/All", headers=headers, - json={"progress_value": too_low_progress_value, "course": "All"}, + json={"progress_value": too_low_progress_value}, ) assert response.status_code == 400 diff --git a/tests/test_highscores.py b/tests/test_highscores.py index 5e145a4..a5a9a48 100644 --- a/tests/test_highscores.py +++ b/tests/test_highscores.py @@ -15,3 +15,21 @@ client = TestClient(app) username = "user1" password = "password" avatar = "lion" + + +async def register_user(): + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username, "password": password, "avatar": avatar}, + ) + + assert response.status_code == 200 + + return response.json()["access_token"] + + +@pytest.mark.asyncio +async def test_post_highscore(): + """Test whether posting a new high score succeeds""" + clear_db() From 73ce1bf2e03cfb394fe99de95cb83b3354fb6fa1 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Mon, 10 Apr 2023 16:07:25 -0600 Subject: [PATCH 36/37] Finish off backend tests --- src/crud/courseprogress.py | 6 +- src/crud/highscores.py | 9 +- src/main.py | 13 +- tests/base.py | 16 +++ tests/test_authentication.py | 13 +- tests/test_courseprogress.py | 42 +++++-- tests/test_highscores.py | 230 +++++++++++++++++++++++++++++++++-- tests/test_users.py | 13 +- tests/usertests.py | 15 --- 9 files changed, 285 insertions(+), 72 deletions(-) create mode 100644 tests/base.py delete mode 100644 tests/usertests.py diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index 3000804..67a488c 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -1,5 +1,5 @@ -from sqlalchemy.orm import Session from fastapi import HTTPException +from sqlalchemy.orm import Session from src.enums import CourseEnum from src.models import CourseProgress, User @@ -50,7 +50,9 @@ def initialize_user(db: Session, user: User): db.commit() -def patch_course_progress(db: Session, user: User, course: CourseEnum, course_progress: CourseProgressBase): +def patch_course_progress( + db: Session, user: User, course: CourseEnum, course_progress: CourseProgressBase +): """Change the progress value for a given course""" if course_progress.progress_value > 1 or course_progress.progress_value < 0: raise HTTPException(status_code=400, detail="Invalid progress value") diff --git a/src/crud/highscores.py b/src/crud/highscores.py index ffb2961..a7eee7b 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -12,8 +12,9 @@ DEFAULT_NR_HIGH_SCORES = 10 def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): """Get the n highest scores of a given minigame""" - if nr_highest < 1: - raise HTTPException(status_code=400, detail="Invalid number of high scores") + if nr_highest: + if nr_highest < 1: + raise HTTPException(status_code=400, detail="Invalid number of high scores") user_high_scores = [] @@ -42,7 +43,9 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): return user_high_scores -def create_high_score(db: Session, user: User, minigame: MinigameEnum, high_score: HighScoreBase): +def create_high_score( + db: Session, user: User, minigame: MinigameEnum, high_score: HighScoreBase +): """Create a new high score for a given minigame""" def add_to_db(): diff --git a/src/main.py b/src/main.py index 93cbd01..08f6e6f 100644 --- a/src/main.py +++ b/src/main.py @@ -65,16 +65,17 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): async def get_high_scores( minigame: MinigameEnum, nr_highest: Optional[int] = None, + current_user_name: str = Depends(crud_authentication.get_current_user_name), db: Session = Depends(get_db), ): return crud_highscores.get_high_scores(db, minigame, nr_highest) -@app.post("/highscores/{minigame}", response_model=highscores.HighScore) +@app.put("/highscores/{minigame}", response_model=highscores.HighScore) async def create_high_score( minigame: MinigameEnum, high_score: highscores.HighScoreBase, - current_user_name=Depends(crud_authentication.get_current_user_name), + current_user_name: str = Depends(crud_authentication.get_current_user_name), db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) @@ -93,7 +94,9 @@ async def get_course_progress( return crud_courseprogress.get_course_progress(db, current_user, course) -@app.patch("/courseprogress/{course}", response_model=List[courseprogress.CourseProgressParent]) +@app.patch( + "/courseprogress/{course}", response_model=List[courseprogress.CourseProgressParent] +) async def patch_course_progress( course: CourseEnum, course_progress: courseprogress.CourseProgressBase, @@ -101,7 +104,9 @@ async def patch_course_progress( db: Session = Depends(get_db), ): current_user = crud_users.get_user_by_username(db, current_user_name) - return crud_courseprogress.patch_course_progress(db, current_user, course, course_progress) + return crud_courseprogress.patch_course_progress( + db, current_user, course, course_progress + ) #### TESTING!! DELETE LATER diff --git a/tests/base.py b/tests/base.py new file mode 100644 index 0000000..686bebb --- /dev/null +++ b/tests/base.py @@ -0,0 +1,16 @@ +import sys + +from fastapi.testclient import TestClient + +sys.path.append("..") + +from src.main import app, get_db +from tests.config.database import override_get_db + +app.dependency_overrides[get_db] = override_get_db + +client = TestClient(app) + +username = "user1" +password = "password" +avatar = "lion" diff --git a/tests/test_authentication.py b/tests/test_authentication.py index d13cf05..8be7cea 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,21 +1,10 @@ -import sys - import pytest from fastapi.testclient import TestClient -sys.path.append("..") - from src.main import app, get_db +from tests.base import avatar, client, password, username from tests.config.database import clear_db, override_get_db -app.dependency_overrides[get_db] = override_get_db - -client = TestClient(app) - -username = "user1" -password = "password" -avatar = "lion" - async def register_user(): response = client.post( diff --git a/tests/test_courseprogress.py b/tests/test_courseprogress.py index 33cf43f..d6ee649 100644 --- a/tests/test_courseprogress.py +++ b/tests/test_courseprogress.py @@ -1,23 +1,13 @@ import random -import sys import pytest from fastapi.testclient import TestClient -sys.path.append("..") - from src.enums import CourseEnum from src.main import app, get_db +from tests.base import avatar, client, password, username from tests.config.database import clear_db, override_get_db -app.dependency_overrides[get_db] = override_get_db - -client = TestClient(app) - -username = "user1" -password = "password" -avatar = "lion" - async def register_user(): response = client.post( @@ -67,6 +57,19 @@ async def test_get_all_returns_all(): assert {"progress_value": 0.0, "course": course} in response +@pytest.mark.asyncio +async def test_get_course_progress_value_without_auth_should_fail(): + """Test whether fetching a course progress value without authentication fails""" + clear_db() + + headers = {"Content-Type": "application/json"} + + for course in CourseEnum: + response = client.get(f"/courseprogress/{course}", headers=headers) + + assert response.status_code == 403 + + @pytest.mark.asyncio async def test_get_nonexisting_course_should_fail(): """Test whether fetching the progress of a nonexisting course fails""" @@ -178,3 +181,20 @@ async def test_patch_course_with_invalid_value_should_fail(): ) assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_patch_course_progress_value_without_auth_should_fail(): + """Test whether updating a course progress value without authentication fails""" + clear_db() + + headers = {"Content-Type": "application/json"} + + for course in CourseEnum: + response = client.patch( + f"/courseprogress/{course}", + headers=headers, + json={"progress_value": random.uniform(0, 1)}, + ) + + assert response.status_code == 403 diff --git a/tests/test_highscores.py b/tests/test_highscores.py index a5a9a48..5b56291 100644 --- a/tests/test_highscores.py +++ b/tests/test_highscores.py @@ -1,21 +1,13 @@ -import sys +import random import pytest from fastapi.testclient import TestClient -sys.path.append("..") - +from src.enums import MinigameEnum from src.main import app, get_db +from tests.base import avatar, client, password, username from tests.config.database import clear_db, override_get_db -app.dependency_overrides[get_db] = override_get_db - -client = TestClient(app) - -username = "user1" -password = "password" -avatar = "lion" - async def register_user(): response = client.post( @@ -30,6 +22,218 @@ async def register_user(): @pytest.mark.asyncio -async def test_post_highscore(): - """Test whether posting a new high score succeeds""" +async def test_put_highscore(): + """Test whether putting a new high score succeeds""" clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + for minigame in MinigameEnum: + score_value = random.random() + response = client.put( + f"/highscores/{minigame}", + headers=headers, + json={"score_value": score_value}, + ) + + assert response.status_code == 200 + + response = response.json() + + assert response["minigame"] == minigame + assert response["score_value"] == score_value + + +@pytest.mark.asyncio +async def test_put_lower_highscore_does_not_change_old_value(): + """Test whether putting a new high score lower than the current one doesn't change the old one""" + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + for minigame in MinigameEnum: + score_value = random.random() + response = client.put( + f"/highscores/{minigame}", + headers=headers, + json={"score_value": score_value}, + ) + + assert response.status_code == 200 + + response = response.json() + + assert response["minigame"] == minigame + assert response["score_value"] == score_value + + lower_score_value = score_value - 100 + response = client.put( + f"/highscores/{minigame}", + headers=headers, + json={"score_value": lower_score_value}, + ) + + assert response.status_code == 200 + + response = response.json() + + assert response["minigame"] == minigame + assert response["score_value"] == score_value + + +@pytest.mark.asyncio +async def test_put_highscore_for_nonexisting_minigame_should_fail(): + """Test whether putting a new high score for a nonexisting minigame fails""" + clear_db() + token = await register_user() + + fake_minigame = "FakeGame" + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.put( + f"/highscores/{fake_minigame}", + headers=headers, + json={"score_value": random.random()}, + ) + + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_put_highscores_without_auth_should_fail(): + """Test whether putting high scores without authentication fails""" + clear_db() + + headers = {"Content-Type": "application/json"} + + for minigame in MinigameEnum: + response = client.put( + f"/highscores/{minigame}", + headers=headers, + json={"score_value": random.random()}, + ) + + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_get_highscores_without_auth_should_fail(): + """Test whether fetching high scores without authentication fails""" + clear_db() + + headers = {"Content-Type": "application/json"} + + for minigame in MinigameEnum: + response = client.get( + f"/highscores/{minigame}", + headers=headers, + ) + + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_get_highscore_for_nonexisting_minigame_should_fail(): + """Test whether fetching a new high score for a nonexisting minigame fails""" + clear_db() + token = await register_user() + + fake_minigame = "FakeGame" + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.get( + f"/highscores/{fake_minigame}", + headers=headers, + ) + + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_get_invalid_number_of_highscores_should_fail(): + """Test whether getting a numbe rof high scores lower than 1 fails""" + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + for minigame in MinigameEnum: + response = client.get( + f"/highscores/{minigame}?nr_highest={random.randint(-100, 0)}", + headers=headers, + ) + + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_get_highscores_should_work_with_default_value(): + """Test whether fetching high scores without passing an explicit amount still succeeds""" + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + for minigame in MinigameEnum: + response = client.get( + f"/highscores/{minigame}", + headers=headers, + ) + + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_get_highscores_returns_sorted_list_with_correct_length(): + clear_db() + token = await register_user() + + headers = {"Content-Type": "application/json"} + + for minigame in MinigameEnum: + clear_db() + nr_entries = random.randint(5, 50) + + users_score_tuples = [ + (f"user{i + 1}", random.random()) for i in range(nr_entries) + ] + + for user, score in users_score_tuples: + response = client.post( + "/register", + headers=headers, + json={"username": user, "password": password, "avatar": avatar}, + ) + + assert response.status_code == 200 + + token = response.json()["access_token"] + + response = client.put( + f"/highscores/{minigame}?nr_highest={random.randint(1, nr_entries)}", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json={"score_value": score}, + ) + + assert response.status_code == 200 + + response = client.get( + f"/highscores/{minigame}", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + ) + + assert response.status_code == 200 + response = response.json() + + for i in range(1, len(response)): + assert response[i]["score_value"] <= response[i - 1]["score_value"] diff --git a/tests/test_users.py b/tests/test_users.py index 5a169c4..f8ee485 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,21 +1,10 @@ -import sys - import pytest from fastapi.testclient import TestClient -sys.path.append("..") - from src.main import app, get_db +from tests.base import avatar, client, password, username from tests.config.database import clear_db, override_get_db -app.dependency_overrides[get_db] = override_get_db - -client = TestClient(app) - -username = "user1" -password = "password" -avatar = "lion" - patched_username = "New name" patched_password = "New password" patched_avatar = "New avatar" diff --git a/tests/usertests.py b/tests/usertests.py deleted file mode 100644 index 1f79c86..0000000 --- a/tests/usertests.py +++ /dev/null @@ -1,15 +0,0 @@ -import pytest -from fastapi.testclient import TestClient - -from main import app - -client = TestClient(app) - - -def test_add_user(): - response = client.post( - "/users", - headers={"Content-Type": "application/json"}, - json={"username": "Lukas", "password": "mettn"}, - ) - assert response.status_code == 200 From f9aad400e0cc93d98e4610f75e1b0554e1d1daa2 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Wed, 12 Apr 2023 12:02:25 -0600 Subject: [PATCH 37/37] More highscore endpoint functionality & tests --- src/crud/highscores.py | 39 ++++++++++++++----- src/main.py | 18 +++------ tests/test_highscores.py | 83 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 117 insertions(+), 23 deletions(-) diff --git a/src/crud/highscores.py b/src/crud/highscores.py index a7eee7b..7a8d1a2 100644 --- a/src/crud/highscores.py +++ b/src/crud/highscores.py @@ -7,20 +7,41 @@ from src.models import HighScore, User from src.schemas.highscores import HighScoreBase from src.schemas.users import UserHighScore -DEFAULT_NR_HIGH_SCORES = 10 - -def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): +def get_high_scores( + db: Session, minigame: MinigameEnum, user: User, nr_highest: int, mine_only: bool +): """Get the n highest scores of a given minigame""" - if nr_highest: - if nr_highest < 1: - raise HTTPException(status_code=400, detail="Invalid number of high scores") + if nr_highest < 1: + raise HTTPException(status_code=400, detail="Invalid number of high scores") + + if mine_only: + if nr_highest > 1: + raise HTTPException( + status_code=400, + detail="nr_highest should be 1 when requesting high score of current user only", + ) + else: + high_score = ( + db.query(HighScore) + .filter( + HighScore.minigame == minigame, HighScore.owner_id == user.user_id + ) + .first() + ) + if high_score: + return [ + UserHighScore( + username=user.username, + score_value=high_score.score_value, + avatar=user.avatar, + ) + ] + else: + return [] user_high_scores = [] - if not nr_highest: - nr_highest = DEFAULT_NR_HIGH_SCORES - if not minigame: minigame = MinigameEnum.SpellingBee diff --git a/src/main.py b/src/main.py index 08f6e6f..320d627 100644 --- a/src/main.py +++ b/src/main.py @@ -64,11 +64,15 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)): @app.get("/highscores/{minigame}", response_model=List[users.UserHighScore]) async def get_high_scores( minigame: MinigameEnum, - nr_highest: Optional[int] = None, + nr_highest: Optional[int] = 1, + mine_only: Optional[bool] = True, current_user_name: str = Depends(crud_authentication.get_current_user_name), db: Session = Depends(get_db), ): - return crud_highscores.get_high_scores(db, minigame, nr_highest) + print(str(nr_highest)) + print(str(mine_only)) + user = crud_users.get_user_by_username(db, current_user_name) + return crud_highscores.get_high_scores(db, minigame, user, nr_highest, mine_only) @app.put("/highscores/{minigame}", response_model=highscores.HighScore) @@ -107,13 +111,3 @@ async def patch_course_progress( return crud_courseprogress.patch_course_progress( db, current_user, course, course_progress ) - - -#### TESTING!! DELETE LATER - - -@app.get("/protected") -async def protected_route( - current_user_name: str = Depends(crud_authentication.get_current_user_name), -): - return {"message": f"Hello, {current_user_name}!"} diff --git a/tests/test_highscores.py b/tests/test_highscores.py index 5b56291..1eadcd1 100644 --- a/tests/test_highscores.py +++ b/tests/test_highscores.py @@ -134,6 +134,13 @@ async def test_get_highscores_without_auth_should_fail(): assert response.status_code == 403 + response = client.get( + f"/highscores/{minigame}?mine_only=false&nr_highest={random.randint(1, 50)}", + headers=headers, + ) + + assert response.status_code == 403 + @pytest.mark.asyncio async def test_get_highscore_for_nonexisting_minigame_should_fail(): @@ -152,6 +159,13 @@ async def test_get_highscore_for_nonexisting_minigame_should_fail(): assert response.status_code == 422 + response = client.get( + f"/highscores/{fake_minigame}?mine_only=false&nr_highest={random.randint(1, 50)}", + headers=headers, + ) + + assert response.status_code == 422 + @pytest.mark.asyncio async def test_get_invalid_number_of_highscores_should_fail(): @@ -189,6 +203,7 @@ async def test_get_highscores_should_work_with_default_value(): @pytest.mark.asyncio async def test_get_highscores_returns_sorted_list_with_correct_length(): + """Test whether getting a list of high scores gets a list in descending order and of the correct length""" clear_db() token = await register_user() @@ -197,6 +212,7 @@ async def test_get_highscores_returns_sorted_list_with_correct_length(): for minigame in MinigameEnum: clear_db() nr_entries = random.randint(5, 50) + token = "" users_score_tuples = [ (f"user{i + 1}", random.random()) for i in range(nr_entries) @@ -214,7 +230,7 @@ async def test_get_highscores_returns_sorted_list_with_correct_length(): token = response.json()["access_token"] response = client.put( - f"/highscores/{minigame}?nr_highest={random.randint(1, nr_entries)}", + f"/highscores/{minigame}", headers={ "Authorization": f"Bearer {token}", "Content-Type": "application/json", @@ -225,7 +241,7 @@ async def test_get_highscores_returns_sorted_list_with_correct_length(): assert response.status_code == 200 response = client.get( - f"/highscores/{minigame}", + f"/highscores/{minigame}?mine_only=false&nr_highest={int(nr_entries)}", headers={ "Authorization": f"Bearer {token}", "Content-Type": "application/json", @@ -235,5 +251,68 @@ async def test_get_highscores_returns_sorted_list_with_correct_length(): assert response.status_code == 200 response = response.json() + assert len(response) == nr_entries + for i in range(1, len(response)): assert response[i]["score_value"] <= response[i - 1]["score_value"] + + +@pytest.mark.asyncio +async def test_get_own_existing_high_score_should_return_high_score(): + """Test whether fetching your own high score of a game succeeds""" + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + for minigame in MinigameEnum: + response = client.put( + f"/highscores/{minigame}", + headers=headers, + json={"score_value": random.random()}, + ) + + assert response.status_code == 200 + + response = client.get( + f"/highscores/{minigame}", + headers=headers, + ) + + assert response.status_code == 200 + assert len(response.json()) == 1 + + +@pytest.mark.asyncio +async def test_get_own_nonexisting_high_score_should_return_empty_list(): + """Test whether fetching the high score of a game you haven't played returns an empty list""" + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + for minigame in MinigameEnum: + response = client.get( + f"/highscores/{minigame}", + headers=headers, + ) + + assert response.status_code == 200 + assert len(response.json()) == 0 + + +@pytest.mark.asyncio +async def test_get_multiple_own_high_scores_of_same_game_should_fail(): + """Test whether asking more than one of your high scores on a single game fails""" + clear_db() + token = await register_user() + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + for minigame in MinigameEnum: + response = client.get( + f"/highscores/{minigame}?nr_highest={random.randint(2, 20)}", + headers=headers, + ) + + assert response.status_code == 400