Initial Commit

This commit is contained in:
2023-04-07 09:44:12 +00:00
parent 42d655a451
commit c49645d7bc
13 changed files with 423 additions and 128 deletions

View File

@@ -12,12 +12,14 @@ def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None)
running_loss = 0.0
model.train(True)
for i, data in enumerate(dataloader):
inputs, labels = data
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device, dtype=torch.long)
optimizer.zero_grad()
outputs = model(inputs).expand(1, -1, -1)
loss = criterion(outputs[0], labels[0])
loss.backward()
optimizer.step()
@@ -159,7 +161,7 @@ def evaluate(model, dataloader, device, print_stats=False):
logger = get_logger(__name__)
pred_correct, pred_all = 0, 0
stats = {i: [0, 0] for i in range(101)}
stats = {i: [0, 0] for i in range(251)}
for i, data in enumerate(dataloader):
inputs, labels = data