Get started with user tests
This commit is contained in:
@@ -5,7 +5,8 @@ from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.crud.users import get_user_by_username, pwd_context
|
||||
from src.crud.users import (check_empty_fields, get_user_by_username,
|
||||
pwd_context)
|
||||
from src.models import User
|
||||
|
||||
DEFAULT_NR_HIGH_SCORES = 10
|
||||
@@ -48,12 +49,8 @@ def authenticate_user(db: Session, username: str, password: str):
|
||||
|
||||
def register(db: Session, username: str, password: str, avatar: str):
|
||||
"""Register a new user"""
|
||||
if len(avatar) == 0:
|
||||
raise HTTPException(status_code=400, detail="No avatar was provided")
|
||||
if len(username) == 0:
|
||||
raise HTTPException(status_code=400, detail="No username was provided")
|
||||
if len(password) == 0:
|
||||
raise HTTPException(status_code=400, detail="No password was provided")
|
||||
check_empty_fields(username, password, avatar)
|
||||
|
||||
db_user = get_user_by_username(db, username)
|
||||
if db_user:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
|
||||
@@ -8,14 +8,24 @@ from src.schemas.users import UserCreate
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def check_empty_fields(username: str, password: str, avatar: str):
|
||||
"Checks if any user fields are empty"
|
||||
if len(avatar) == 0:
|
||||
raise HTTPException(status_code=400, detail="No avatar was provided")
|
||||
if len(username) == 0:
|
||||
raise HTTPException(status_code=400, detail="No username was provided")
|
||||
if len(password) == 0:
|
||||
raise HTTPException(status_code=400, detail="No password was provided")
|
||||
|
||||
|
||||
def patch_user(db: Session, username: str, user: UserCreate):
|
||||
"""Changes the username and/or the password of a User"""
|
||||
check_empty_fields(user.username, user.password, user.avatar)
|
||||
db_user = get_user_by_username(db, username)
|
||||
potential_duplicate = get_user_by_username(db, user.username)
|
||||
if potential_duplicate:
|
||||
if potential_duplicate.user_id != db_user.user_id:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
|
||||
db_user.username = user.username
|
||||
db_user.hashed_password = pwd_context.hash(user.password)
|
||||
db_user.avatar = user.avatar
|
||||
|
||||
10
src/main.py
10
src/main.py
@@ -24,11 +24,19 @@ async def root():
|
||||
return {"message": "Hello world!"}
|
||||
|
||||
|
||||
@app.get("/users", response_model=List[users.User])
|
||||
@app.get("/allusers", response_model=List[users.User])
|
||||
async def read_users(db: Session = Depends(get_db)):
|
||||
return crud_users.get_users(db)
|
||||
|
||||
|
||||
@app.get("/users", response_model=users.User)
|
||||
async def read_user(
|
||||
current_user_name: str = Depends(crud_authentication.get_current_user_name),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return crud_users.get_user_by_username(db, current_user_name)
|
||||
|
||||
|
||||
@app.patch("/users")
|
||||
async def patch_current_user(
|
||||
user: users.UserCreate,
|
||||
|
||||
@@ -2,6 +2,7 @@ from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from src.database import Base
|
||||
from src.models import CourseProgress, HighScore, User
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "postgresql://admin:WeSign123!@localhost/wesigntest"
|
||||
|
||||
@@ -12,6 +13,15 @@ TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def clear_db():
|
||||
db = TestSessionLocal()
|
||||
|
||||
db.query(HighScore).delete()
|
||||
db.query(CourseProgress).delete()
|
||||
db.query(User).delete()
|
||||
db.commit()
|
||||
|
||||
|
||||
def override_get_db():
|
||||
try:
|
||||
db = TestSessionLocal()
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
from src.database import Base as ProductionBase
|
||||
from tests.config.database import (SQLALCHEMY_DATABASE_URL, TestBase,
|
||||
TestSessionLocal)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session():
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
||||
print(SQLALCHEMY_DATABASE_URL)
|
||||
ProductionBase.metadata.create_all(bind=engine)
|
||||
TestBase.metadata.create_all(bind=engine)
|
||||
session = TestSessionLocal(bind=engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.rollback()
|
||||
TestBase.metadata.drop_all(bind=engine)
|
||||
ProductionBase.metadata.drop_all(bind=engine)
|
||||
session.close()
|
||||
@@ -6,20 +6,144 @@ from fastapi.testclient import TestClient
|
||||
sys.path.append("..")
|
||||
|
||||
from src.main import app, get_db
|
||||
from tests.config.database import override_get_db
|
||||
from tests.config.database import clear_db, override_get_db
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
username = "user1"
|
||||
password = "password"
|
||||
avatar = "lion"
|
||||
|
||||
patched_username = "New name"
|
||||
patched_password = "New password"
|
||||
patched_avatar = "New avatar"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user():
|
||||
async def test_get_current_user():
|
||||
"""Test the GET /users endpoint to get info about the current user"""
|
||||
clear_db()
|
||||
|
||||
response = client.post(
|
||||
"/register",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"username": "user27", "password": "mettn", "avatar": "lion"},
|
||||
json={"username": username, "password": password, "avatar": avatar},
|
||||
)
|
||||
|
||||
print(response)
|
||||
assert response.status_code == 200
|
||||
|
||||
token = response.json()["access_token"]
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
response = client.get("/users", headers=headers)
|
||||
assert response.status_code == 200
|
||||
response = response.json()
|
||||
assert response["username"] == username
|
||||
assert response["avatar"] == avatar
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_without_auth():
|
||||
"""Getting the current user without a token should fail"""
|
||||
clear_db()
|
||||
|
||||
response = client.get("/users", headers={"Content-Type": "application/json"})
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_user():
|
||||
"""Test the patching of a user's username, password and avatar"""
|
||||
clear_db()
|
||||
|
||||
response = client.post(
|
||||
"/register",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"username": username, "password": password, "avatar": avatar},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
token = response.json()["access_token"]
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
response = client.patch(
|
||||
"/users",
|
||||
json={
|
||||
"username": patched_username,
|
||||
"password": patched_password,
|
||||
"avatar": patched_avatar,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.post(
|
||||
"/login",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"username": patched_username, "password": patched_password},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
token = response.json()["access_token"]
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
response = client.get("/users", headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Correctness of password and username is already asserted by the login
|
||||
assert response.json()["avatar"] == patched_avatar
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_user_with_empty_fields():
|
||||
"""Patching a user with empty fields should fail"""
|
||||
clear_db()
|
||||
|
||||
response = client.post(
|
||||
"/register",
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"username": username, "password": password, "avatar": avatar},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
token = response.json()["access_token"]
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
response = client.patch(
|
||||
"/users",
|
||||
json={
|
||||
"username": patched_username,
|
||||
"password": patched_password,
|
||||
"avatar": "",
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
response = client.patch(
|
||||
"/users",
|
||||
json={
|
||||
"username": patched_username,
|
||||
"password": "",
|
||||
"avatar": patched_avatar,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
response = client.patch(
|
||||
"/users",
|
||||
json={
|
||||
"username": "",
|
||||
"password": patched_password,
|
||||
"avatar": patched_avatar,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
Reference in New Issue
Block a user