Added train, val, test selctor

This commit is contained in:
2023-03-22 21:21:44 +00:00
parent c15178da22
commit 51bd92d65a
8 changed files with 92 additions and 11 deletions

View File

@@ -0,0 +1,43 @@
"""add dataset column to signvideo
Revision ID: 3afa46612906
Revises: 2d2d4523082b
Create Date: 2023-03-22 20:54:32.035246
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = '3afa46612906'
down_revision = '2d2d4523082b'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('sign', schema=None) as batch_op:
batch_op.alter_column('category_id',
existing_type=sa.INTEGER(),
nullable=False)
with op.batch_alter_table('signvideo', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset', sqlmodel.sql.sqltypes.AutoString(), nullable=False, server_default='train'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('signvideo', schema=None) as batch_op:
batch_op.drop_column('dataset')
with op.batch_alter_table('sign', schema=None) as batch_op:
batch_op.alter_column('category_id',
existing_type=sa.INTEGER(),
nullable=True)
# ### end Alembic commands ###

View File

@@ -9,6 +9,7 @@ class SignVideo(SQLModelExtended, table=True):
id: int = Field(primary_key=True)
approved: bool = False
dataset: str = "train" # train, test, val
# foreign key to sign
sign_id: int = Field(default=None, foreign_key="sign.id")
@@ -22,4 +23,5 @@ class SignVideo(SQLModelExtended, table=True):
class SignVideoOut(BaseModel):
id: int
approved: bool
approved: bool
dataset: str

View File

@@ -157,6 +157,7 @@ async def sign_video(
class SignVideoUpdate(BaseModel):
approved: bool
dataset: str
@router.patch("/{video_id}/", status_code=status.HTTP_200_OK)
async def sign_video(
@@ -178,7 +179,12 @@ async def sign_video(
if not sign_video:
raise BaseException("Sign video not found")
# check if dataset is train, val or test
if update.dataset not in ["train", "val", "test"]:
raise BaseException("Dataset must be train, val or test")
sign_video.approved = update.approved
sign_video.dataset = update.dataset
await sign_video.save(session)