Click here to Skip to main content
15,887,214 members
Please Sign up or sign in to vote.
1.00/5 (1 vote)
See more:
I don't understand why my code produces a graph that looks complicated and difficult to understand. How do I adjust my code so that it plots something similar to the picture but instead of just '3's and '5's, I want to plot all the numbers in the confusion matrix.

My overall objective is classifying digits from 0 to 9 so I want to display all numbers 0-9 in a confusion matrix to show whether the decision tree algorithm was successful.

Picture of what I want to achieve:https://i.stack.imgur.com/hgY2R.png[^]

What I have tried:

# Define all possible classes (numbers 0-9)
classes = np.arange(10)

# Initialize a figure to plot the confusion matrix
size = 5
pad = 0.2
fig6,ax6 = plt.subplots(figsize=(size, size), layout='constrained')

# Loop over all pairs of true and predicted labels
for true_label in classes:
    for pred_label in classes:
        # Extract examples where true label is true_label and predicted label is pred_label
        examples = X_test[(y_test == str(true_label)) & (y_pred_dt == str(pred_label))]
        
        # Plot the examples
        for idx, image_data in enumerate(examples[:size*size]):
            x = idx % size + pred_label * (size + pad)
            y = idx // size + true_label * (size + pad)
            ax6.imshow(image_data.reshape(28, 28), cmap="binary",
                       extent=(x, x + 1, y, y + 1))

# Set ticks and labels
ax6.set_xticks([size / 2 + i * (size + pad) for i in range(len(classes))], labels=classes)
ax6.set_yticks([size / 2 + i * (size + pad) for i in range(len(classes))], labels=classes)

# Show grid lines
ax6.plot([size + pad / 2, size + pad / 2], [0, len(classes) * (size + pad)], "k:")
ax6.plot([0, len(classes) * (size + pad)], [size + pad / 2, size + pad / 2], "k:")

# Set axis limits
ax6.axis([0, len(classes) * (size + pad), 0, len(classes) * (size + pad)])

# Set axis labels
ax6.set_xlabel("Predicted label")
ax6.set_ylabel("True label")

# Show the plot
plt.show()


Picture of output: https://i.stack.imgur.com/dKzgJ.png[^]
Posted

This content, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)



CodeProject, 20 Bay Street, 11th Floor Toronto, Ontario, Canada M5J 2N8 +1 (416) 849-8900