Initial Commit
This commit is contained in:
@@ -12,12 +12,14 @@ def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None)
|
||||
running_loss = 0.0
|
||||
model.train(True)
|
||||
for i, data in enumerate(dataloader):
|
||||
|
||||
inputs, labels = data
|
||||
inputs = inputs.squeeze(0).to(device)
|
||||
labels = labels.to(device, dtype=torch.long)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs).expand(1, -1, -1)
|
||||
|
||||
loss = criterion(outputs[0], labels[0])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -159,7 +161,7 @@ def evaluate(model, dataloader, device, print_stats=False):
|
||||
logger = get_logger(__name__)
|
||||
|
||||
pred_correct, pred_all = 0, 0
|
||||
stats = {i: [0, 0] for i in range(101)}
|
||||
stats = {i: [0, 0] for i in range(251)}
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
inputs, labels = data
|
||||
|
||||
Reference in New Issue
Block a user