Added train, val, test selctor
This commit is contained in:
43
backend/alembic/versions/3afa46612906_add_dataset_column_to_signvideo.py
Executable file
43
backend/alembic/versions/3afa46612906_add_dataset_column_to_signvideo.py
Executable file
@@ -0,0 +1,43 @@
|
||||
"""add dataset column to signvideo
|
||||
|
||||
Revision ID: 3afa46612906
|
||||
Revises: 2d2d4523082b
|
||||
Create Date: 2023-03-22 20:54:32.035246
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '3afa46612906'
|
||||
down_revision = '2d2d4523082b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('sign', schema=None) as batch_op:
|
||||
batch_op.alter_column('category_id',
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('signvideo', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('dataset', sqlmodel.sql.sqltypes.AutoString(), nullable=False, server_default='train'))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('signvideo', schema=None) as batch_op:
|
||||
batch_op.drop_column('dataset')
|
||||
|
||||
with op.batch_alter_table('sign', schema=None) as batch_op:
|
||||
batch_op.alter_column('category_id',
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -9,6 +9,7 @@ class SignVideo(SQLModelExtended, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
|
||||
approved: bool = False
|
||||
dataset: str = "train" # train, test, val
|
||||
|
||||
# foreign key to sign
|
||||
sign_id: int = Field(default=None, foreign_key="sign.id")
|
||||
@@ -22,4 +23,5 @@ class SignVideo(SQLModelExtended, table=True):
|
||||
|
||||
class SignVideoOut(BaseModel):
|
||||
id: int
|
||||
approved: bool
|
||||
approved: bool
|
||||
dataset: str
|
||||
@@ -157,6 +157,7 @@ async def sign_video(
|
||||
|
||||
class SignVideoUpdate(BaseModel):
|
||||
approved: bool
|
||||
dataset: str
|
||||
|
||||
@router.patch("/{video_id}/", status_code=status.HTTP_200_OK)
|
||||
async def sign_video(
|
||||
@@ -178,7 +179,12 @@ async def sign_video(
|
||||
if not sign_video:
|
||||
raise BaseException("Sign video not found")
|
||||
|
||||
# check if dataset is train, val or test
|
||||
if update.dataset not in ["train", "val", "test"]:
|
||||
raise BaseException("Dataset must be train, val or test")
|
||||
|
||||
sign_video.approved = update.approved
|
||||
sign_video.dataset = update.dataset
|
||||
await sign_video.save(session)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user