Click here to Skip to main content
15,895,709 members
Please Sign up or sign in to vote.
0.00/5 (No votes)
See more:
Hi i am new to programming with python and deep learning. A bit frustrated. I cant add visualization part into this function. My goal is to get a line plot showing the accuracy scored at each epoch.

Python
# Develope a train and test mechanism 
def train_and_test(model, optimizer, loss_fn, train_loader, test_loader, epochs=20, device="cpu"):
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            # Reset the gradient
            optimizer.zero_grad()
            # Allocate inputs and targets to batch for training
            inputs, targets = batch
            # Selecting the processor
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            # Calculate loss
            loss = loss_fn(output, targets)
            # Back pass
            loss.backward()
            optimizer.step()
            # Calculation of loss
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        # Evaluating the model on test
        model.eval()
        num_correct = 0 
        num_examples = 0
        # loop within the batch
        for batch in test_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            # Evaluating the performance of the model
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(test_loader.dataset)
        # Accuracy
        accuracy = num_correct/num_examples
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, accuracy))


What I have tried:

I have tried doing is
<pre lang="Python">
plt.plot(epochs, accuracy)
plt.show()
<pre lang="Python">
Posted

This content, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)



CodeProject, 20 Bay Street, 11th Floor Toronto, Ontario, Canada M5J 2N8 +1 (416) 849-8900