From 8f3c303a2bb5aeea251b1d2cd5358f16a6d3ce88 Mon Sep 17 00:00:00 2001 From: lvrossem Date: Sun, 9 Apr 2023 13:39:52 -0600 Subject: [PATCH] Get started with user tests --- src/crud/authentication.py | 11 ++-- src/crud/users.py | 12 +++- src/main.py | 10 ++- tests/config/database.py | 10 +++ tests/config/setup.py | 26 -------- tests/test_users.py | 132 +++++++++++++++++++++++++++++++++++-- 6 files changed, 162 insertions(+), 39 deletions(-) delete mode 100644 tests/config/setup.py diff --git a/src/crud/authentication.py b/src/crud/authentication.py index 201e25b..c8f5654 100644 --- a/src/crud/authentication.py +++ b/src/crud/authentication.py @@ -5,7 +5,8 @@ from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session -from src.crud.users import get_user_by_username, pwd_context +from src.crud.users import (check_empty_fields, get_user_by_username, + pwd_context) from src.models import User DEFAULT_NR_HIGH_SCORES = 10 @@ -48,12 +49,8 @@ def authenticate_user(db: Session, username: str, password: str): def register(db: Session, username: str, password: str, avatar: str): """Register a new user""" - if len(avatar) == 0: - raise HTTPException(status_code=400, detail="No avatar was provided") - if len(username) == 0: - raise HTTPException(status_code=400, detail="No username was provided") - if len(password) == 0: - raise HTTPException(status_code=400, detail="No password was provided") + check_empty_fields(username, password, avatar) + db_user = get_user_by_username(db, username) if db_user: raise HTTPException(status_code=400, detail="Username already registered") diff --git a/src/crud/users.py b/src/crud/users.py index 1fc8daa..e284bd3 100644 --- a/src/crud/users.py +++ b/src/crud/users.py @@ -8,14 +8,24 @@ from src.schemas.users import UserCreate pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +def check_empty_fields(username: str, password: str, avatar: str): + "Checks if any user fields are empty" + if len(avatar) == 0: + raise HTTPException(status_code=400, detail="No avatar was provided") + if len(username) == 0: + raise HTTPException(status_code=400, detail="No username was provided") + if len(password) == 0: + raise HTTPException(status_code=400, detail="No password was provided") + + def patch_user(db: Session, username: str, user: UserCreate): """Changes the username and/or the password of a User""" + check_empty_fields(user.username, user.password, user.avatar) db_user = get_user_by_username(db, username) potential_duplicate = get_user_by_username(db, user.username) if potential_duplicate: if potential_duplicate.user_id != db_user.user_id: raise HTTPException(status_code=400, detail="Username already registered") - db_user.username = user.username db_user.hashed_password = pwd_context.hash(user.password) db_user.avatar = user.avatar diff --git a/src/main.py b/src/main.py index 29b16f4..d518c4e 100644 --- a/src/main.py +++ b/src/main.py @@ -24,11 +24,19 @@ async def root(): return {"message": "Hello world!"} -@app.get("/users", response_model=List[users.User]) +@app.get("/allusers", response_model=List[users.User]) async def read_users(db: Session = Depends(get_db)): return crud_users.get_users(db) +@app.get("/users", response_model=users.User) +async def read_user( + current_user_name: str = Depends(crud_authentication.get_current_user_name), + db: Session = Depends(get_db), +): + return crud_users.get_user_by_username(db, current_user_name) + + @app.patch("/users") async def patch_current_user( user: users.UserCreate, diff --git a/tests/config/database.py b/tests/config/database.py index 6760fea..ef82ea7 100644 --- a/tests/config/database.py +++ b/tests/config/database.py @@ -2,6 +2,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from src.database import Base +from src.models import CourseProgress, HighScore, User SQLALCHEMY_DATABASE_URL = "postgresql://admin:WeSign123!@localhost/wesigntest" @@ -12,6 +13,15 @@ TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base.metadata.create_all(bind=engine) +def clear_db(): + db = TestSessionLocal() + + db.query(HighScore).delete() + db.query(CourseProgress).delete() + db.query(User).delete() + db.commit() + + def override_get_db(): try: db = TestSessionLocal() diff --git a/tests/config/setup.py b/tests/config/setup.py deleted file mode 100644 index 71466e6..0000000 --- a/tests/config/setup.py +++ /dev/null @@ -1,26 +0,0 @@ -import sys - -import pytest -from sqlalchemy import create_engine - -sys.path.append("..") - -from src.database import Base as ProductionBase -from tests.config.database import (SQLALCHEMY_DATABASE_URL, TestBase, - TestSessionLocal) - - -@pytest.fixture(scope="function") -def db_session(): - engine = create_engine(SQLALCHEMY_DATABASE_URL) - print(SQLALCHEMY_DATABASE_URL) - ProductionBase.metadata.create_all(bind=engine) - TestBase.metadata.create_all(bind=engine) - session = TestSessionLocal(bind=engine) - try: - yield session - finally: - session.rollback() - TestBase.metadata.drop_all(bind=engine) - ProductionBase.metadata.drop_all(bind=engine) - session.close() diff --git a/tests/test_users.py b/tests/test_users.py index 90f2fe2..5a169c4 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -6,20 +6,144 @@ from fastapi.testclient import TestClient sys.path.append("..") from src.main import app, get_db -from tests.config.database import override_get_db +from tests.config.database import clear_db, override_get_db app.dependency_overrides[get_db] = override_get_db client = TestClient(app) +username = "user1" +password = "password" +avatar = "lion" + +patched_username = "New name" +patched_password = "New password" +patched_avatar = "New avatar" + @pytest.mark.asyncio -async def test_add_user(): +async def test_get_current_user(): + """Test the GET /users endpoint to get info about the current user""" + clear_db() + response = client.post( "/register", headers={"Content-Type": "application/json"}, - json={"username": "user27", "password": "mettn", "avatar": "lion"}, + json={"username": username, "password": password, "avatar": avatar}, ) - print(response) assert response.status_code == 200 + + token = response.json()["access_token"] + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.get("/users", headers=headers) + assert response.status_code == 200 + response = response.json() + assert response["username"] == username + assert response["avatar"] == avatar + + +@pytest.mark.asyncio +async def test_get_current_user_without_auth(): + """Getting the current user without a token should fail""" + clear_db() + + response = client.get("/users", headers={"Content-Type": "application/json"}) + + assert response.status_code == 403 + + +@pytest.mark.asyncio +async def test_patch_user(): + """Test the patching of a user's username, password and avatar""" + clear_db() + + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username, "password": password, "avatar": avatar}, + ) + assert response.status_code == 200 + + token = response.json()["access_token"] + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.patch( + "/users", + json={ + "username": patched_username, + "password": patched_password, + "avatar": patched_avatar, + }, + headers=headers, + ) + assert response.status_code == 200 + + response = client.post( + "/login", + headers={"Content-Type": "application/json"}, + json={"username": patched_username, "password": patched_password}, + ) + assert response.status_code == 200 + + token = response.json()["access_token"] + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.get("/users", headers=headers) + assert response.status_code == 200 + + # Correctness of password and username is already asserted by the login + assert response.json()["avatar"] == patched_avatar + + +@pytest.mark.asyncio +async def test_patch_user_with_empty_fields(): + """Patching a user with empty fields should fail""" + clear_db() + + response = client.post( + "/register", + headers={"Content-Type": "application/json"}, + json={"username": username, "password": password, "avatar": avatar}, + ) + assert response.status_code == 200 + + token = response.json()["access_token"] + + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + response = client.patch( + "/users", + json={ + "username": patched_username, + "password": patched_password, + "avatar": "", + }, + headers=headers, + ) + assert response.status_code == 400 + + response = client.patch( + "/users", + json={ + "username": patched_username, + "password": "", + "avatar": patched_avatar, + }, + headers=headers, + ) + assert response.status_code == 400 + + response = client.patch( + "/users", + json={ + "username": "", + "password": patched_password, + "avatar": patched_avatar, + }, + headers=headers, + ) + assert response.status_code == 400