Click here to Skip to main content
15,616,163 members
Please Sign up or sign in to vote.
0.00/5 (No votes)
See more:
Hi i am new to programming with python and deep learning. A bit frustrated. I cant add visualization part into this function. My goal is to get a line plot showing the accuracy scored at each epoch.

# Develope a train and test mechanism 
def train_and_test(model, optimizer, loss_fn, train_loader, test_loader, epochs=20, device="cpu"):
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        for batch in train_loader:
            # Reset the gradient
            # Allocate inputs and targets to batch for training
            inputs, targets = batch
            # Selecting the processor
            inputs =
            targets =
            output = model(inputs)
            # Calculate loss
            loss = loss_fn(output, targets)
            # Back pass
            # Calculation of loss
            training_loss += * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        # Evaluating the model on test
        num_correct = 0 
        num_examples = 0
        # loop within the batch
        for batch in test_loader:
            inputs, targets = batch
            inputs =
            output = model(inputs)
            targets =
            loss = loss_fn(output,targets) 
            valid_loss += * inputs.size(0)
            # Evaluating the performance of the model
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(test_loader.dataset)
        # Accuracy
        accuracy = num_correct/num_examples
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, accuracy))

What I have tried:

I have tried doing is
<pre lang="Python">
plt.plot(epochs, accuracy)
<pre lang="Python">

This content, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)

CodeProject, 20 Bay Street, 11th Floor Toronto, Ontario, Canada M5J 2N8 +1 (416) 849-8900