Hi i am new to programming with python and deep learning. A bit frustrated. I cant add visualization part into this function. My goal is to get a line plot showing the accuracy scored at each epoch.
def train_and_test(model, optimizer, loss_fn, train_loader, test_loader, epochs=20, device="cpu"):
for epoch in range(1, epochs+1):
training_loss = 0.0
valid_loss = 0.0
model.train()
for batch in train_loader:
optimizer.zero_grad()
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
output = model(inputs)
loss = loss_fn(output, targets)
loss.backward()
optimizer.step()
training_loss += loss.data.item() * inputs.size(0)
training_loss /= len(train_loader.dataset)
model.eval()
num_correct = 0
num_examples = 0
for batch in test_loader:
inputs, targets = batch
inputs = inputs.to(device)
output = model(inputs)
targets = targets.to(device)
loss = loss_fn(output,targets)
valid_loss += loss.data.item() * inputs.size(0)
correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
num_correct += torch.sum(correct).item()
num_examples += correct.shape[0]
valid_loss /= len(test_loader.dataset)
accuracy = num_correct/num_examples
print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
valid_loss, accuracy))
What I have tried:
I have tried doing is
<pre lang="Python">
plt.plot(epochs, accuracy)
plt.show()
<pre lang="Python">