Compare commits

...

10 Commits

Author SHA1 Message Date
lvrossem
55a5f59524 All tests pass 2023-04-18 07:26:07 -06:00
lvrossem
d376e39773 Almost there 2023-04-18 07:01:12 -06:00
lvrossem
4235395579 Finish up learnable tests 2023-04-18 05:24:17 -06:00
lvrossem
3968dfd4eb More refactoring 2023-04-18 02:52:26 -06:00
lvrossem
d074074b03 More and more refactors 2023-04-17 15:38:25 -06:00
lvrossem
81e9eb154b Fix tests for users and highscores 2023-04-17 14:52:36 -06:00
lvrossem
3596394f3f Fix auth tests ... again 2023-04-17 13:56:58 -06:00
lvrossem
6a8cb2c3bd Fix infinite sql query bug 2023-04-17 13:21:46 -06:00
lvrossem
38eb9027d6 More refactors 2023-04-17 07:51:53 -06:00
lvrossem
0bf764a0f4 Refactoring: auth tests pass 2023-04-16 07:15:03 -06:00
19 changed files with 903 additions and 337 deletions

View File

@ -47,15 +47,18 @@ def authenticate_user(db: Session, username: str, password: str):
return db_user
def register(db: Session, username: str, password: str, avatar: str):
def register(db: Session, username: str, password: str, avatar_index: int):
"""Register a new user"""
check_empty_fields(username, password, avatar)
check_empty_fields(username, password, avatar_index)
db_user = get_user_by_username(db, username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
db_user = User(
username=username, hashed_password=pwd_context.hash(password), avatar=avatar
username=username,
hashed_password=pwd_context.hash(password),
avatar_index=avatar_index,
playtime=0.0,
)
db.add(db_user)
db.commit()

View File

@ -1,10 +1,25 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session
from typing import List
from src.enums import CourseEnum
from src.models import CourseProgress, User
from src.schemas.courseprogress import CourseProgressBase, CourseProgressParent
from src.schemas.courseprogress import CourseProgressBase, CourseProgressParent, SavedCourseProgress
from src.schemas.learnableprogress import SavedLearnableProgress
from src.crud.learnableprogress import get_learnables
def get_learnable_values(learnables: List[SavedLearnableProgress]):
completed_learnables = sum(
[1 if learnable.progress == 5.0 else 0 for learnable in learnables]
)
in_use_learnables = sum(
[1 if learnable.in_use else 0 for learnable in learnables]
)
total_learnables = len(learnables)
return completed_learnables, in_use_learnables, total_learnables
def get_course_progress(db: Session, user: User, course: CourseEnum):
"""Get the progress a user has for a certain course"""
@ -12,9 +27,11 @@ def get_course_progress(db: Session, user: User, course: CourseEnum):
courses_to_fetch = [course]
if course == CourseEnum.All:
all_courses_list = [course for course in CourseEnum]
courses_to_fetch = filter(
courses_to_fetch = [course for course in filter(
lambda course: course != CourseEnum.All, all_courses_list
)
)]
print([course for course in courses_to_fetch])
for course in courses_to_fetch:
course_progress = (
db.query(CourseProgress)
@ -25,17 +42,34 @@ def get_course_progress(db: Session, user: User, course: CourseEnum):
)
if course_progress:
result.append(
CourseProgressParent(
progress_value=course_progress.progress_value, course=course
)
)
print("CURRENT COURSE: " + course_progress.course)
learnables = get_learnables(db, user, course)
completed_learnables, in_use_learnables, total_learnables = get_learnable_values(learnables)
result.append(SavedCourseProgress(
course_index=course_progress.course,
progress=course_progress.progress,
completed_learnables=completed_learnables,
in_use_learnables=in_use_learnables,
total_learnables=total_learnables,
learnables=learnables,
))
else:
db.add(
CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)
)
db.add(CourseProgress(progress=0.0, course=course, owner_id=user.user_id))
db.commit()
result.append(CourseProgressParent(progress_value=0.0, course=course))
result.append(SavedCourseProgress(
course_index=course,
progress=0.0,
completed_learnables=0,
in_use_learnables=0,
total_learnables=0,
learnables=[],
))
print(f"RESULT: {result}")
return result
@ -44,9 +78,7 @@ def initialize_user(db: Session, user: User):
"""Create CourseProgress records with a value of 0 for a new user"""
for course in CourseEnum:
if course != CourseEnum.All:
db.add(
CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)
)
db.add(CourseProgress(progress=0.0, course=course, owner_id=user.user_id))
db.commit()
@ -54,7 +86,7 @@ def patch_course_progress(
db: Session, user: User, course: CourseEnum, course_progress: CourseProgressBase
):
"""Change the progress value for a given course"""
if course_progress.progress_value > 1 or course_progress.progress_value < 0:
if course_progress.progress > 1 or course_progress.progress < 0:
raise HTTPException(status_code=400, detail="Invalid progress value")
db_course_progress_list = []
@ -75,10 +107,23 @@ def patch_course_progress(
)
for db_course_progress in db_course_progress_list:
db_course_progress.progress_value = course_progress.progress_value
db_course_progress.progress = course_progress.progress
db.commit()
return [
CourseProgressParent(course=db_cp.course, progress_value=db_cp.progress_value)
for db_cp in db_course_progress_list
]
result = []
for db_cp in db_course_progress_list:
learnables = get_learnables(db, user, db_cp.course)
completed_learnables, in_use_learnables, total_learnables = get_learnable_values(learnables)
result.append(SavedCourseProgress(
course_index=db_cp.course,
progress=db_cp.progress,
completed_learnables=completed_learnables,
in_use_learnables=in_use_learnables,
total_learnables=total_learnables,
learnables=learnables,
))
return result

View File

@ -1,25 +1,48 @@
import datetime
from fastapi import HTTPException
from sqlalchemy import desc
from sqlalchemy import asc, desc, func
from sqlalchemy.orm import Session
from src.enums import MinigameEnum
from src.models import HighScore, User
from src.schemas.highscores import HighScoreBase
from src.schemas.users import UserHighScore
from src.schemas.highscores import HighScoreBase, Score
def get_high_scores(
db: Session, minigame: MinigameEnum, user: User, nr_highest: int, mine_only: bool
def get_most_recent_high_scores(db: Session, minigame: MinigameEnum, amount: int):
"""Get the n most recent high scores of a given minigame"""
if amount < 1:
raise HTTPException(status_code=400, detail="Invalid number of high scores")
high_scores = []
if not minigame:
minigame = MinigameEnum.SpellingBee
high_scores_query = (
db.query(HighScore)
.filter(HighScore.minigame == minigame)
.order_by(desc(HighScore.time))
.limit(amount)
)
for high_score in high_scores_query:
high_scores.append(
Score(score_value=high_score.score_value, time=str(high_score.time))
)
return high_scores
def get_highest_high_scores(
db: Session, minigame: MinigameEnum, user: User, amount: int, mine_only: bool
):
"""Get the n highest scores of a given minigame"""
if nr_highest < 1:
if amount < 1:
raise HTTPException(status_code=400, detail="Invalid number of high scores")
if mine_only:
if nr_highest > 1:
if amount > 1:
raise HTTPException(
status_code=400,
detail="nr_highest should be 1 when requesting high score of current user only",
detail="amount should be 1 when requesting high score of current user only",
)
else:
high_score = (
@ -31,37 +54,30 @@ def get_high_scores(
)
if high_score:
return [
UserHighScore(
username=user.username,
Score(
score_value=high_score.score_value,
avatar=user.avatar,
time=str(high_score.time),
)
]
else:
return []
user_high_scores = []
high_scores = []
if not minigame:
minigame = MinigameEnum.SpellingBee
high_scores = (
high_scores_query = (
db.query(HighScore)
.filter(HighScore.minigame == minigame)
.order_by(desc(HighScore.score_value))
.limit(nr_highest)
.all()
.limit(amount)
)
for high_score in high_scores:
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
user_high_scores.append(
UserHighScore(
username=owner.username,
score_value=high_score.score_value,
avatar=owner.avatar,
)
for high_score in high_scores_query:
high_scores.append(
Score(score_value=high_score.score_value, time=str(high_score.time))
)
return user_high_scores
return high_scores
def create_high_score(
@ -75,11 +91,14 @@ def create_high_score(
score_value=high_score.score_value,
minigame=minigame,
owner_id=user.user_id,
time=str(datetime.datetime.now()),
)
db.add(db_high_score)
db.commit()
db.refresh(db_high_score)
return db_high_score
return Score(
score_value=db_high_score.score_value, time=str(db_high_score.time)
)
old_high_score = (
db.query(HighScore)
@ -94,6 +113,8 @@ def create_high_score(
db.delete(old_high_score)
return add_to_db()
else:
return old_high_score
return Score(
score_value=old_high_score.score_value, time=str(old_high_score.time)
)
else:
return add_to_db()

View File

@ -0,0 +1,121 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session
from sqlalchemy import asc
from src.enums import CourseEnum
from src.models import CourseProgress, LearnableProgress, User
from src.schemas.learnableprogress import SavedLearnableProgress
def get_learnables(db: Session, user: User, course: CourseEnum):
"""Get all learnables of a certain course"""
db_course = (
db.query(CourseProgress)
.filter(
CourseProgress.owner_id == user.user_id, CourseProgress.course == course
)
.first()
)
db_learnable_query = (
db.query(LearnableProgress)
.filter(LearnableProgress.course_progress_id == db_course.course_progress_id)
.order_by(asc(LearnableProgress.index))
.all()
)
return [
SavedLearnableProgress(
index=dbl.index, in_use=dbl.in_use, name=dbl.name, progress=dbl.progress
)
for dbl in db_learnable_query
]
def create_learnable(
db: Session, user: User, course: CourseEnum, learnable: SavedLearnableProgress
):
"""Create a new learnable for a given course"""
if learnable.index < 0:
raise HTTPException(status_code=400, detail="Negative index not allowed")
if learnable.in_use is None:
raise HTTPException(
status_code=400, detail="Please indicate whether the learnable is in use"
)
if len(learnable.name) < 1:
raise HTTPException(
status_code=400, detail="No name was provided for the learnable"
)
potential_duplicate = (
db.query(LearnableProgress)
.filter(LearnableProgress.name == learnable.name)
.first()
)
if potential_duplicate:
raise HTTPException(
status_code=400, detail="No duplicate learnable names allowed"
)
db_course = (
db.query(CourseProgress)
.filter(
CourseProgress.owner_id == user.user_id, CourseProgress.course == course
)
.first()
)
db_learnable = LearnableProgress(
index=learnable.index,
in_use=learnable.in_use,
name=learnable.name,
progress=0.0,
course_progress_id=db_course.course_progress_id,
)
db.add(db_learnable)
db.commit()
def patch_learnable(db: Session, user: User, learnable_name: str, learnable: SavedLearnableProgress):
"""Patch an existing learnable"""
db_learnable = (
db.query(LearnableProgress)
.filter(LearnableProgress.name == learnable_name)
.first()
)
if not db_learnable:
raise HTTPException(
status_code=400, detail="Learnable with provided name not found"
)
potential_duplicate = (
db.query(LearnableProgress)
.filter(LearnableProgress.name == learnable.name, LearnableProgress.learnable_progress_id != db_learnable.learnable_progress_id)
.first()
)
if potential_duplicate:
raise HTTPException(
status_code=400, detail="No duplicate learnable names allowed"
)
if learnable.index < -1:
raise HTTPException(status_code=400, detail="Invalid learnable index")
elif learnable.index > -1:
db_learnable.index = learnable.index
if learnable.in_use is not None:
db_learnable.in_use = learnable.in_use
if len(learnable.name) > 0:
db_learnable.name = learnable.name
# TODO: chek progress semantics
db.add(db_learnable)
db.commit()

View File

@ -2,15 +2,23 @@ from fastapi import HTTPException
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from src.models import User
from src.schemas.users import UserCreate
from src.crud.highscores import (get_highest_high_scores,
get_most_recent_high_scores)
from src.crud.courseprogress import get_course_progress
from src.enums import CourseEnum, MinigameEnum
from src.models import CourseProgress, LearnableProgress, User
from src.schemas.courseprogress import SavedCourseProgress
from src.schemas.highscores import SavedMinigameProgress
from src.schemas.learnableprogress import SavedLearnableProgress
from src.schemas.users import SavedUser, UserCreate
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def check_empty_fields(username: str, password: str, avatar: str):
def check_empty_fields(username: str, password: str, avatar_index: int):
"Checks if any user fields are empty"
if len(avatar) == 0:
if avatar_index < 0:
raise HTTPException(status_code=400, detail="No avatar was provided")
if len(username) == 0:
raise HTTPException(status_code=400, detail="No username was provided")
@ -20,15 +28,28 @@ def check_empty_fields(username: str, password: str, avatar: str):
def patch_user(db: Session, username: str, user: UserCreate):
"""Changes the username and/or the password of a User"""
check_empty_fields(user.username, user.password, user.avatar)
# check_empty_fields(user.username, user.password, user.avatar_index)
db_user = get_user_by_username(db, username)
potential_duplicate = get_user_by_username(db, user.username)
if potential_duplicate:
if potential_duplicate.user_id != db_user.user_id:
raise HTTPException(status_code=400, detail="Username already registered")
db_user.username = user.username
db_user.hashed_password = pwd_context.hash(user.password)
db_user.avatar = user.avatar
if user.playtime < 0:
raise HTTPException(status_code=400, detail="Negative playtime is invalid")
if len(user.username) > 0:
db_user.username = user.username
if len(user.password) > 0:
db_user.hashed_password = pwd_context.hash(user.password)
if user.avatar_index > -1:
db_user.avatar_index = user.avatar_index
elif user.avatar_index < -1:
raise HTTPException(status_code=400, detail="Invalid avatar index")
db_user.playtime += user.playtime
db.commit()
@ -40,3 +61,29 @@ def get_user_by_username(db: Session, username: str):
def get_users(db: Session):
"""Fetch a list of all users"""
return db.query(User).all()
def get_saved_data(db: Session, username: str):
"""Fetches all saved progress for the current user from the database"""
user = get_user_by_username(db, username)
minigames = []
courses = get_course_progress(db, user, CourseEnum.All)
for minigame in MinigameEnum:
minigames.append(
SavedMinigameProgress(
minigame_index=minigame,
latest_scores=get_most_recent_high_scores(db, minigame, 10),
highest_scores=get_highest_high_scores(db, minigame, user, 10, False),
)
)
user_progress = SavedUser(
username=user.username,
avatar_index=user.avatar_index,
playtime=user.playtime,
minigames=minigames,
courses=courses,
)
return user_progress

View File

@ -2,24 +2,6 @@ from fastapi_utils.enums import StrEnum
from sqlalchemy.types import Enum, TypeDecorator
class StrEnumType(TypeDecorator):
impl = Enum
def __init__(self, enum_class, **kw):
self.enum_class = enum_class
super().__init__(enum_class, **kw)
def process_bind_param(self, value, dialect):
if value is None:
return None
return value.value
def process_result_value(self, value, dialect):
if value is None:
return None
return self.enum_class(value)
class MinigameEnum(StrEnum):
SpellingBee = "SpellingBee"
Hangman = "Hangman"

View File

@ -9,10 +9,11 @@ sys.path.append("..")
from src.crud import authentication as crud_authentication
from src.crud import courseprogress as crud_courseprogress
from src.crud import highscores as crud_highscores
from src.crud import learnableprogress as crud_learnables
from src.crud import users as crud_users
from src.database import Base, engine, get_db
from src.enums import CourseEnum, MinigameEnum
from src.schemas import courseprogress, highscores, users
from src.schemas import courseprogress, highscores, learnableprogress, users
app = FastAPI()
@ -24,17 +25,20 @@ async def root():
return {"message": "Hello world!"}
@app.get("/allusers", response_model=List[users.User])
"""
@app.get("/allusers", response_model=List[users.SavedUser])
async def read_users(db: Session = Depends(get_db)):
return crud_users.get_users(db)
"""
@app.get("/users", response_model=users.User)
"""
@app.get("/users", response_model=users.SavedUser)
async def read_user(
current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
return crud_users.get_user_by_username(db, current_user_name)
"""
@app.patch("/users")
@ -46,10 +50,18 @@ async def patch_current_user(
crud_users.patch_user(db, current_user_name, user)
@app.get("/saveddata", response_model=users.SavedUser)
async def read_saved_data(
current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
return crud_users.get_saved_data(db, current_user_name)
@app.post("/register")
async def register(user: users.UserCreate, db: Session = Depends(get_db)):
access_token = crud_authentication.register(
db, user.username, user.password, user.avatar
db, user.username, user.password, user.avatar_index
)
user = crud_users.get_user_by_username(db, user.username)
crud_courseprogress.initialize_user(db, user)
@ -61,21 +73,24 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)):
return crud_authentication.login(db, user.username, user.password)
@app.get("/highscores/{minigame}", response_model=List[users.UserHighScore])
@app.get("/highscores/{minigame}", response_model=List[highscores.Score])
async def get_high_scores(
minigame: MinigameEnum,
nr_highest: Optional[int] = 1,
amount: Optional[int] = 1,
mine_only: Optional[bool] = True,
most_recent: Optional[bool] = False,
current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
print(str(nr_highest))
print(str(mine_only))
if most_recent:
return crud_highscores.get_most_recent_high_scores(db, minigame, amount)
user = crud_users.get_user_by_username(db, current_user_name)
return crud_highscores.get_high_scores(db, minigame, user, nr_highest, mine_only)
return crud_highscores.get_highest_high_scores(
db, minigame, user, amount, mine_only
)
@app.put("/highscores/{minigame}", response_model=highscores.HighScore)
@app.put("/highscores/{minigame}", response_model=highscores.Score)
async def create_high_score(
minigame: MinigameEnum,
high_score: highscores.HighScoreBase,
@ -87,7 +102,7 @@ async def create_high_score(
@app.get(
"/courseprogress/{course}", response_model=List[courseprogress.CourseProgressParent]
"/courseprogress/{course}", response_model=List[courseprogress.SavedCourseProgress]
)
async def get_course_progress(
course: Optional[CourseEnum] = CourseEnum.All,
@ -99,7 +114,7 @@ async def get_course_progress(
@app.patch(
"/courseprogress/{course}", response_model=List[courseprogress.CourseProgressParent]
"/courseprogress/{course}", response_model=List[courseprogress.CourseProgressBase]
)
async def patch_course_progress(
course: CourseEnum,
@ -111,3 +126,38 @@ async def patch_course_progress(
return crud_courseprogress.patch_course_progress(
db, current_user, course, course_progress
)
@app.get(
"/learnables/{course}",
response_model=List[learnableprogress.SavedLearnableProgress],
)
async def create_learnable(
course: CourseEnum,
current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
current_user = crud_users.get_user_by_username(db, current_user_name)
return crud_learnables.get_learnables(db, current_user, course)
@app.post("/learnables/{course}")
async def create_learnable(
course: CourseEnum,
learnable: learnableprogress.SavedLearnableProgress,
current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
current_user = crud_users.get_user_by_username(db, current_user_name)
crud_learnables.create_learnable(db, current_user, course, learnable)
@app.patch("/learnables/{name}")
async def create_learnable(
name: str,
learnable: learnableprogress.SavedLearnableProgress,
current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db),
):
current_user = crud_users.get_user_by_username(db, current_user_name)
crud_learnables.patch_learnable(db, current_user, name, learnable)

View File

@ -1,4 +1,5 @@
from sqlalchemy import Column, Float, ForeignKey, Integer, String
from sqlalchemy import (Boolean, Column, DateTime, Float, ForeignKey, Integer,
String)
from sqlalchemy.orm import relationship
from src.database import Base
@ -12,7 +13,8 @@ class User(Base):
user_id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
avatar = Column(String, nullable=False)
avatar_index = Column(Integer, nullable=False)
playtime = Column(Float, nullable=False)
high_scores = relationship(
"HighScore", back_populates="owner", cascade="all, delete", lazy="dynamic"
@ -29,6 +31,7 @@ class HighScore(Base):
high_score_id = Column(Integer, primary_key=True, index=True)
score_value = Column(Float, nullable=False)
time = Column(DateTime, nullable=False)
minigame = Column(String, nullable=False)
owner_id = Column(Integer, ForeignKey("users.user_id"))
owner = relationship("User", back_populates="high_scores")
@ -40,7 +43,24 @@ class CourseProgress(Base):
__tablename__ = "course_progress"
course_progress_id = Column(Integer, primary_key=True, index=True)
progress_value = Column(Float, nullable=False)
progress = Column(Float, nullable=False)
course = Column(String, nullable=False)
owner_id = Column(Integer, ForeignKey("users.user_id"))
owner = relationship("User", back_populates="course_progress")
learnables = relationship("LearnableProgress", back_populates="course")
class LearnableProgress(Base):
"""The database model for learnable progress"""
__tablename__ = "learnable_progress"
learnable_progress_id = Column(Integer, primary_key=True, index=True)
index = Column(Integer, nullable=False)
in_use = Column(Boolean, nullable=False)
name = Column(String, unique=True, nullable=False)
progress = Column(Float, nullable=False)
course_progress_id = Column(
Integer, ForeignKey("course_progress.course_progress_id")
)
course = relationship("CourseProgress", back_populates="learnables")

View File

@ -1,10 +1,13 @@
from typing import List
from pydantic import BaseModel
from src.enums import CourseEnum
from src.schemas.learnableprogress import SavedLearnableProgress
class CourseProgressBase(BaseModel):
progress_value: float
progress: float
class CourseProgressParent(CourseProgressBase):
@ -17,3 +20,15 @@ class CourseProgress(CourseProgressParent):
class Config:
orm_mode = True
class SavedCourseProgress(BaseModel):
course_index: CourseEnum
progress: float
completed_learnables: int
in_use_learnables: int
total_learnables: int
learnables: List[SavedLearnableProgress]
class Config:
orm_mode = True

View File

@ -1,3 +1,5 @@
from typing import List
from pydantic import BaseModel
from src.enums import MinigameEnum
@ -7,10 +9,17 @@ class HighScoreBase(BaseModel):
score_value: float
class HighScore(HighScoreBase):
high_score_id: int
owner_id: int
minigame: MinigameEnum
class Score(HighScoreBase):
time: str
class Config:
orm_mode = True
class SavedMinigameProgress(BaseModel):
minigame_index: MinigameEnum
latest_scores: List[Score]
highest_scores: List[Score]
class Config:
orm_mode = True

View File

@ -0,0 +1,8 @@
from pydantic import BaseModel
class SavedLearnableProgress(BaseModel):
index: int = -1
in_use: bool = None
name: str = ""
progress: float = -1.0

View File

@ -1,22 +1,25 @@
from typing import List
from pydantic import BaseModel
from src.schemas.courseprogress import SavedCourseProgress
from src.schemas.highscores import SavedMinigameProgress
class UserBase(BaseModel):
username: str
avatar: str = ""
class User(UserBase):
user_id: int
hashed_password: str
class Config:
orm_mode = True
username: str = ""
avatar_index: int = -1
class UserCreate(UserBase):
password: str
password: str = ""
playtime: float = 0.0
class UserHighScore(UserBase):
score_value: float
class SavedUser(UserBase):
playtime: float
minigames: List[SavedMinigameProgress]
courses: List[SavedCourseProgress]
class Config:
orm_mode = True

View File

@ -13,4 +13,23 @@ client = TestClient(app)
username = "user1"
password = "password"
avatar = "lion"
avatar_index = 1
def get_headers(token=None):
if token:
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
else:
return {"Content-Type": "application/json"}
async def register_user():
response = client.post(
"/register",
headers=get_headers(),
json={"username": username, "password": password, "avatar_index": avatar_index},
)
assert response.status_code == 200
return response.json()["access_token"]

View File

@ -2,7 +2,7 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from src.database import Base
from src.models import CourseProgress, HighScore, User
from src.models import CourseProgress, HighScore, LearnableProgress, User
SQLALCHEMY_DATABASE_URL = "postgresql://admin:WeSign123!@localhost/wesigntest"
@ -17,6 +17,7 @@ def clear_db():
db = TestSessionLocal()
db.query(HighScore).delete()
db.query(LearnableProgress).delete()
db.query(CourseProgress).delete()
db.query(User).delete()
db.commit()

View File

@ -1,32 +1,19 @@
import pytest
from fastapi.testclient import TestClient
from src.main import app, get_db
from tests.base import avatar, client, password, username
from tests.config.database import clear_db, override_get_db
async def register_user():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
return response.json()["access_token"]
from tests.base import (avatar_index, client, get_headers, password,
register_user, username)
from tests.config.database import clear_db
@pytest.mark.asyncio
async def test_register():
async def test_register_should_succeed():
"""Test the register endpoint"""
clear_db()
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
headers=get_headers(),
json={"username": username, "password": password, "avatar_index": avatar_index},
)
assert response.status_code == 200
@ -41,8 +28,8 @@ async def test_register_duplicate_name_should_fail():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
headers=get_headers(),
json={"username": username, "password": password, "avatar_index": avatar_index},
)
assert response.status_code == 400
@ -56,11 +43,11 @@ async def test_register_without_username_should_fail():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"password": password, "avatar": avatar},
headers=get_headers(),
json={"password": password, "avatar_index": avatar_index},
)
assert response.status_code == 422
assert response.status_code == 400
assert "access_token" not in response.json()
@ -71,11 +58,11 @@ async def test_register_without_password_should_fail():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "avatar": avatar},
headers=get_headers(),
json={"username": username, "avatar_index": avatar_index},
)
assert response.status_code == 422
assert response.status_code == 400
assert "access_token" not in response.json()
@ -86,24 +73,23 @@ async def test_register_without_avatar_should_fail():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
headers=get_headers(),
json={"username": username, "password": password},
)
# Not ideal that this is 400 instead of 422, but had no other choice than to give this field a default value
assert response.status_code == 400
assert "access_token" not in response.json()
@pytest.mark.asyncio
async def test_login():
async def test_login_should_succeed():
"""Test the login endpoint"""
clear_db()
await register_user()
response = client.post(
"/login",
headers={"Content-Type": "application/json"},
headers=get_headers(),
json={"username": username, "password": password},
)
@ -119,7 +105,7 @@ async def test_login_wrong_password_should_fail():
wrong_password = password + "extra characters"
response = client.post(
"/login",
headers={"Content-Type": "application/json"},
headers=get_headers(),
json={"username": username, "password": wrong_password},
)
@ -129,31 +115,31 @@ async def test_login_wrong_password_should_fail():
@pytest.mark.asyncio
async def test_login_without_username_should_fail():
"""Test whether logging in without passing a username fails"""
"""Test whether logging in without passing a username fails, since the default is an empty string"""
clear_db()
await register_user()
response = client.post(
"/login",
headers={"Content-Type": "application/json"},
headers=get_headers(),
json={"password": password},
)
assert response.status_code == 422
assert response.status_code == 401
assert "access_token" not in response.json()
@pytest.mark.asyncio
async def test_login_without_password_should_fail():
"""Test whether logging in without passing a password fails"""
"""Test whether logging in without passing a password fails, since the default is an empty string"""
clear_db()
await register_user()
response = client.post(
"/login",
headers={"Content-Type": "application/json"},
headers=get_headers(),
json={"username": username},
)
assert response.status_code == 422
assert response.status_code == 401
assert "access_token" not in response.json()

View File

@ -1,33 +1,19 @@
import random
import pytest
from fastapi.testclient import TestClient
from src.enums import CourseEnum
from src.main import app, get_db
from tests.base import avatar, client, password, username
from tests.config.database import clear_db, override_get_db
async def register_user():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
return response.json()["access_token"]
from tests.base import client, get_headers, register_user
from tests.config.database import clear_db
@pytest.mark.asyncio
async def test_register_creates_progress_of_zero():
async def test_register_should_create_progress_of_zero():
"""Test whether registering a new user initializes all progress values to 0.0"""
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
for course in CourseEnum:
if course != CourseEnum.All:
@ -36,17 +22,17 @@ async def test_register_creates_progress_of_zero():
response = response.json()[0]
assert response["progress_value"] == 0.0
assert response["course"] == course
assert response["progress"] == 0.0
assert response["course_index"] == course
@pytest.mark.asyncio
async def test_get_all_returns_all():
async def test_get_all_sould_return_all():
"""Test whether the 'All'-course fetches all course progress values"""
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
response = client.get("/courseprogress/All", headers=headers)
assert response.status_code == 200
@ -54,15 +40,15 @@ async def test_get_all_returns_all():
for course in CourseEnum:
if course != CourseEnum.All:
assert {"progress_value": 0.0, "course": course} in response
assert {"progress": 0.0, "course_index": course, "completed_learnables": 0, "in_use_learnables": 0, "total_learnables": 0, "learnables": []} in response
@pytest.mark.asyncio
async def test_get_course_progress_value_without_auth_should_fail():
async def test_get_course_progress_without_auth_should_fail():
"""Test whether fetching a course progress value without authentication fails"""
clear_db()
headers = {"Content-Type": "application/json"}
headers = get_headers()
for course in CourseEnum:
response = client.get(f"/courseprogress/{course}", headers=headers)
@ -78,32 +64,32 @@ async def test_get_nonexisting_course_should_fail():
fake_course = "FakeCourse"
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
response = client.get(f"/courseprogress/{fake_course}", headers=headers)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_patch_course_progress():
async def test_patch_course_progress_should_succeed():
"""Test whether patching the progress value of a course works properly"""
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
for course in CourseEnum:
if course != CourseEnum.All:
progress_value = random.uniform(0, 1)
progress = random.uniform(0, 1)
response = client.patch(
f"/courseprogress/{course}",
headers=headers,
json={"progress_value": progress_value},
json={"progress": progress},
)
assert response.status_code == 200
assert response.json()[0]["progress_value"] == progress_value
assert response.json()[0]["progress"] == progress
@pytest.mark.asyncio
@ -112,14 +98,14 @@ async def test_patch_all_should_patch_all_courses():
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
progress_value = random.uniform(0, 1)
progress = random.uniform(0, 1)
response = client.patch(
"/courseprogress/All",
headers=headers,
json={"progress_value": progress_value},
json={"progress": progress},
)
assert response.status_code == 200
@ -131,7 +117,7 @@ async def test_patch_all_should_patch_all_courses():
for course in CourseEnum:
if course != CourseEnum.All:
assert {"progress_value": progress_value, "course": course} in response
assert {"progress": progress, "course_index": course, "completed_learnables": 0, "in_use_learnables": 0, "total_learnables": 0, "learnables": []} in response
@pytest.mark.asyncio
@ -142,14 +128,14 @@ async def test_patch_nonexisting_course_should_fail():
fake_course = "FakeCourse"
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
progress_value = random.uniform(0, 1)
progress = random.uniform(0, 1)
response = client.patch(
f"/courseprogress/{fake_course}",
headers=headers,
json={"progress_value": progress_value},
json={"progress": progress},
)
assert response.status_code == 422
@ -161,15 +147,15 @@ async def test_patch_course_with_invalid_value_should_fail():
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
too_high_progress_value = random.uniform(0, 1) + 2
too_low_progress_value = random.uniform(0, 1) - 2
too_high_progress = random.uniform(0, 1) + 2
too_low_progress = random.uniform(0, 1) - 2
response = client.patch(
"/courseprogress/All",
headers=headers,
json={"progress_value": too_high_progress_value},
json={"progress": too_high_progress},
)
assert response.status_code == 400
@ -177,24 +163,24 @@ async def test_patch_course_with_invalid_value_should_fail():
response = client.patch(
"/courseprogress/All",
headers=headers,
json={"progress_value": too_low_progress_value},
json={"progress": too_low_progress},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_patch_course_progress_value_without_auth_should_fail():
async def test_patch_course_progress_without_auth_should_fail():
"""Test whether updating a course progress value without authentication fails"""
clear_db()
headers = {"Content-Type": "application/json"}
headers = get_headers()
for course in CourseEnum:
response = client.patch(
f"/courseprogress/{course}",
headers=headers,
json={"progress_value": random.uniform(0, 1)},
json={"progress": random.uniform(0, 1)},
)
assert response.status_code == 403

View File

@ -1,33 +1,20 @@
import random
import pytest
from fastapi.testclient import TestClient
from src.enums import MinigameEnum
from src.main import app, get_db
from tests.base import avatar, client, password, username
from tests.config.database import clear_db, override_get_db
async def register_user():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
return response.json()["access_token"]
from tests.base import (avatar_index, client, get_headers, password,
register_user)
from tests.config.database import clear_db
@pytest.mark.asyncio
async def test_put_highscore():
async def test_put_highscore_should_succeed():
"""Test whether putting a new high score succeeds"""
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
for minigame in MinigameEnum:
score_value = random.random()
@ -41,17 +28,16 @@ async def test_put_highscore():
response = response.json()
assert response["minigame"] == minigame
assert response["score_value"] == score_value
@pytest.mark.asyncio
async def test_put_lower_highscore_does_not_change_old_value():
async def test_put_lower_highscore_should_not_change_old_value():
"""Test whether putting a new high score lower than the current one doesn't change the old one"""
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
for minigame in MinigameEnum:
score_value = random.random()
@ -65,7 +51,6 @@ async def test_put_lower_highscore_does_not_change_old_value():
response = response.json()
assert response["minigame"] == minigame
assert response["score_value"] == score_value
lower_score_value = score_value - 100
@ -79,7 +64,6 @@ async def test_put_lower_highscore_does_not_change_old_value():
response = response.json()
assert response["minigame"] == minigame
assert response["score_value"] == score_value
@ -91,7 +75,7 @@ async def test_put_highscore_for_nonexisting_minigame_should_fail():
fake_minigame = "FakeGame"
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
response = client.put(
f"/highscores/{fake_minigame}",
@ -107,7 +91,7 @@ async def test_put_highscores_without_auth_should_fail():
"""Test whether putting high scores without authentication fails"""
clear_db()
headers = {"Content-Type": "application/json"}
headers = get_headers()
for minigame in MinigameEnum:
response = client.put(
@ -124,7 +108,7 @@ async def test_get_highscores_without_auth_should_fail():
"""Test whether fetching high scores without authentication fails"""
clear_db()
headers = {"Content-Type": "application/json"}
headers = get_headers()
for minigame in MinigameEnum:
response = client.get(
@ -135,7 +119,7 @@ async def test_get_highscores_without_auth_should_fail():
assert response.status_code == 403
response = client.get(
f"/highscores/{minigame}?mine_only=false&nr_highest={random.randint(1, 50)}",
f"/highscores/{minigame}?mine_only=false&amount={random.randint(1, 50)}",
headers=headers,
)
@ -150,7 +134,7 @@ async def test_get_highscore_for_nonexisting_minigame_should_fail():
fake_minigame = "FakeGame"
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
response = client.get(
f"/highscores/{fake_minigame}",
@ -160,7 +144,7 @@ async def test_get_highscore_for_nonexisting_minigame_should_fail():
assert response.status_code == 422
response = client.get(
f"/highscores/{fake_minigame}?mine_only=false&nr_highest={random.randint(1, 50)}",
f"/highscores/{fake_minigame}?mine_only=false&amount={random.randint(1, 50)}",
headers=headers,
)
@ -173,11 +157,11 @@ async def test_get_invalid_number_of_highscores_should_fail():
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
for minigame in MinigameEnum:
response = client.get(
f"/highscores/{minigame}?nr_highest={random.randint(-100, 0)}",
f"/highscores/{minigame}?amount={random.randint(-100, 0)}",
headers=headers,
)
@ -185,12 +169,12 @@ async def test_get_invalid_number_of_highscores_should_fail():
@pytest.mark.asyncio
async def test_get_highscores_should_work_with_default_value():
async def test_get_highscores_should_succeed_with_default_value():
"""Test whether fetching high scores without passing an explicit amount still succeeds"""
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
for minigame in MinigameEnum:
response = client.get(
@ -202,16 +186,16 @@ async def test_get_highscores_should_work_with_default_value():
@pytest.mark.asyncio
async def test_get_highscores_returns_sorted_list_with_correct_length():
async def test_get_highscores_should_return_sorted_list_with_correct_length():
"""Test whether getting a list of high scores gets a list in descending order and of the correct length"""
clear_db()
token = await register_user()
headers = {"Content-Type": "application/json"}
headers = get_headers()
for minigame in MinigameEnum:
clear_db()
nr_entries = random.randint(5, 50)
nr_entries = random.randint(5, 10)
token = ""
users_score_tuples = [
@ -222,7 +206,11 @@ async def test_get_highscores_returns_sorted_list_with_correct_length():
response = client.post(
"/register",
headers=headers,
json={"username": user, "password": password, "avatar": avatar},
json={
"username": user,
"password": password,
"avatar_index": avatar_index,
},
)
assert response.status_code == 200
@ -231,21 +219,15 @@ async def test_get_highscores_returns_sorted_list_with_correct_length():
response = client.put(
f"/highscores/{minigame}",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
headers=get_headers(token),
json={"score_value": score},
)
assert response.status_code == 200
response = client.get(
f"/highscores/{minigame}?mine_only=false&nr_highest={int(nr_entries)}",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
f"/highscores/{minigame}?mine_only=false&amount={int(nr_entries)}",
headers=get_headers(token),
)
assert response.status_code == 200
@ -256,14 +238,27 @@ async def test_get_highscores_returns_sorted_list_with_correct_length():
for i in range(1, len(response)):
assert response[i]["score_value"] <= response[i - 1]["score_value"]
response = client.get(
f"/highscores/{minigame}?most_recent=true&mine_only=false&amount={int(nr_entries)}",
headers=get_headers(token),
)
assert response.status_code == 200
response = response.json()
assert len(response) == nr_entries
for i in range(1, len(response)):
assert response[i]["time"] <= response[i - 1]["time"]
@pytest.mark.asyncio
async def test_get_own_existing_high_score_should_return_high_score():
async def test_get_own_existing_high_score_should_succeed():
"""Test whether fetching your own high score of a game succeeds"""
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
for minigame in MinigameEnum:
response = client.put(
@ -289,7 +284,7 @@ async def test_get_own_nonexisting_high_score_should_return_empty_list():
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
for minigame in MinigameEnum:
response = client.get(
@ -307,11 +302,11 @@ async def test_get_multiple_own_high_scores_of_same_game_should_fail():
clear_db()
token = await register_user()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
for minigame in MinigameEnum:
response = client.get(
f"/highscores/{minigame}?nr_highest={random.randint(2, 20)}",
f"/highscores/{minigame}?amount={random.randint(2, 20)}",
headers=headers,
)

300
tests/test_learnables.py Normal file
View File

@ -0,0 +1,300 @@
import random
import pytest
from src.enums import CourseEnum
from tests.base import client, get_headers, register_user
from tests.config.database import clear_db
@pytest.mark.asyncio
async def test_create_learnables_should_succeed():
"""Test whether creating a new learnable succeeds"""
clear_db()
token = await register_user()
headers = get_headers(token)
for course in CourseEnum:
if course != CourseEnum.All:
nr_learnables = random.randint(1, 5)
for i in range(nr_learnables):
response = client.post(
f"/learnables/{course}",
json={
"index": i,
"in_use": bool(random.randint(0, 1)),
"name": f"{course} {i}",
},
headers=headers,
)
assert response.status_code == 200
response = client.get(f"/learnables/{course}", headers=headers)
assert response.status_code == 200
response = response.json()
assert len(response) == nr_learnables
@pytest.mark.asyncio
async def test_patch_learnables_should_succeed():
"""Test whether patching learnables succeeds"""
clear_db()
token = await register_user()
headers = get_headers(token)
for course in CourseEnum:
if course != CourseEnum.All:
response = client.post(
f"/learnables/{course}",
json={
"index": random.randint(0, 100),
"in_use": bool(random.randint(0, 1)),
"name": f"{course}",
},
headers=headers,
)
assert response.status_code == 200
new_index = random.randint(0, 100)
new_in_use = bool(random.randint(0, 1))
new_name = "New" + course
response = client.patch(f"/learnables/{course}", json={"index": new_index, "in_use": new_in_use, "name": new_name}, headers=headers)
assert response.status_code == 200
response = client.get(f"/learnables/{course}", headers=headers)
assert response.status_code == 200
response = response.json()[0]
assert response["index"] == new_index
assert response["in_use"] == new_in_use
assert response["name"] == new_name
@pytest.mark.asyncio
async def test_create_learnables_without_name_should_fail():
"""Test whether creating a new learnable without name fails"""
clear_db()
token = await register_user()
headers = get_headers(token)
for course in CourseEnum:
if course != CourseEnum.All:
response = client.post(
f"/learnables/{course}",
json={
"index": random.randint(0, 100),
"in_use": bool(random.randint(0, 1)),
},
headers=headers,
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_create_learnables_without_index_should_fail():
"""Test whether creating a new learnable without index fails"""
clear_db()
token = await register_user()
headers = get_headers(token)
for course in CourseEnum:
if course != CourseEnum.All:
response = client.post(
f"/learnables/{course}",
json={
"name": course,
"in_use": bool(random.randint(0, 1)),
},
headers=headers,
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_create_learnables_without_in_use_should_fail():
"""Test whether creating a new learnable without in_use fails"""
clear_db()
token = await register_user()
headers = get_headers(token)
for course in CourseEnum:
if course != CourseEnum.All:
response = client.post(
f"/learnables/{course}",
json={
"index": random.randint(0, 100),
"name": course,
},
headers=headers,
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_create_learnable_without_auth_should_fail():
"""Test whether creating learnables without authentication fails"""
clear_db()
for course in CourseEnum:
if course != CourseEnum.All:
response = client.post(
f"/learnables/{course}",
json={
"index": 0,
"in_use": bool(random.randint(0, 1)),
"name": f"{course}",
},
headers=get_headers(),
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_get_learnables_of_nonexisting_course_should_fail():
"""Test whether learnables of a nonexisting course fails"""
clear_db()
token = await register_user()
course = "FakeCourse"
response = client.get(f"/learnables/{course}", headers=get_headers(token))
assert response.status_code == 422
@pytest.mark.asyncio
async def test_post_learnable_to_nonexisting_course_should_fail():
"""Test whether creating a learnable for a nonexisting course fails fails"""
clear_db()
token = await register_user()
course = "FakeCourse"
response = client.post(
f"/learnables/{course}",
json={
"index": 0,
"in_use": bool(random.randint(0, 1)),
"name": f"{course}",
},
headers=get_headers(token),
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_get_learnables_without_auth_should_fail():
"""Test whether fetching learnables without authentication fails"""
clear_db()
for course in CourseEnum:
if course != CourseEnum.All:
response = client.get(f"/learnables/{course}", headers=get_headers())
assert response.status_code == 403
@pytest.mark.asyncio
async def test_patch_learnable_without_auth_should_fail():
"""Test whether patching learnables without authentication fails"""
clear_db()
token = await register_user()
for course in CourseEnum:
if course != CourseEnum.All:
response = client.post(
f"/learnables/{course}",
json={
"index": 0,
"in_use": bool(random.randint(0, 1)),
"name": f"{course}",
},
headers=get_headers(token),
)
assert response.status_code == 200
response = client.patch(
f"/learnables/{course}",
json={
"index": 0,
"in_use": bool(random.randint(0, 1)),
"name": f"{course}",
},
headers=get_headers(),
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_create_learnable_with_existing_name_should_fail():
"""Test whether putting high scores without authentication fails"""
clear_db()
token = await register_user()
for course in CourseEnum:
if course != CourseEnum.All:
response = client.post(
f"/learnables/{course}",
json={
"index": 0,
"in_use": bool(random.randint(0, 1)),
"name": f"{course}",
},
headers=get_headers(token),
)
assert response.status_code == 200
response = client.post(
f"/learnables/{course}",
json={
"index": 1,
"in_use": bool(random.randint(0, 1)),
"name": f"{course}",
},
headers=get_headers(token),
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_patch_nonexisting_learnable_should_fail():
"""Test whether patching nonexisting learnables fails"""
clear_db()
token = await register_user()
for course in CourseEnum:
if course != CourseEnum.All:
response = client.patch(
f"/learnables/{course}",
json={
"index": 0,
"in_use": bool(random.randint(0, 1)),
"name": f"{course}",
},
headers=get_headers(token),
)
assert response.status_code == 400

View File

@ -1,70 +1,29 @@
import pytest
from fastapi.testclient import TestClient
from src.main import app, get_db
from tests.base import avatar, client, password, username
from tests.config.database import clear_db, override_get_db
from tests.base import (client, get_headers, register_user,
username)
from tests.config.database import clear_db
patched_username = "New name"
patched_password = "New password"
patched_avatar = "New avatar"
patched_avatar_index = 2
@pytest.mark.asyncio
async def test_get_current_user():
"""Test the GET /users endpoint to get info about the current user"""
clear_db()
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
token = response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
response = client.get("/users", headers=headers)
assert response.status_code == 200
response = response.json()
assert response["username"] == username
assert response["avatar"] == avatar
@pytest.mark.asyncio
async def test_get_current_user_without_auth():
"""Getting the current user without a token should fail"""
clear_db()
response = client.get("/users", headers={"Content-Type": "application/json"})
assert response.status_code == 403
@pytest.mark.asyncio
async def test_patch_user():
async def test_patch_user_should_succeed():
"""Test the patching of a user's username, password and avatar"""
clear_db()
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
token = await register_user()
token = response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
response = client.patch(
"/users",
json={
"username": patched_username,
"password": patched_password,
"avatar": patched_avatar,
"avatar_index": patched_avatar_index,
},
headers=headers,
)
@ -72,67 +31,63 @@ async def test_patch_user():
response = client.post(
"/login",
headers={"Content-Type": "application/json"},
headers=get_headers(),
json={"username": patched_username, "password": patched_password},
)
assert response.status_code == 200
token = response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
headers = get_headers(token)
response = client.get("/users", headers=headers)
response = client.get("/saveddata", headers=headers)
assert response.status_code == 200
# Correctness of password and username is already asserted by the login
assert response.json()["avatar"] == patched_avatar
assert response.json()["avatar_index"] == patched_avatar_index
@pytest.mark.asyncio
async def test_patch_user_with_empty_fields():
"""Patching a user with empty fields should fail"""
async def test_patch_user_with_empty_fields_should_succeed():
"""Patching a user with empty fields should still succeed"""
clear_db()
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
token = await register_user()
token = response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
response = client.patch(
"/users",
json={
"username": patched_username,
"password": patched_password,
"avatar": "",
},
headers=headers,
)
assert response.status_code == 400
response = client.patch(
"/users",
json={
"username": patched_username,
"password": "",
"avatar": patched_avatar,
},
headers=headers,
)
assert response.status_code == 400
headers = get_headers(token)
response = client.patch(
"/users",
json={
"username": "",
"password": patched_password,
"avatar": patched_avatar,
"avatar_index": patched_avatar_index,
"playtime": 0.0,
},
headers=headers,
)
assert response.status_code == 400
assert response.status_code == 200
response = client.patch(
"/users",
json={
"username": username,
"password": patched_password,
"avatar_index": -1,
"playtime": 0.0,
},
headers=headers,
)
assert response.status_code == 200
response = client.patch(
"/users",
json={
"username": username,
"password": "",
"avatar_index": patched_avatar_index,
"playtime": 0.0,
},
headers=headers,
)
assert response.status_code == 200