Handwritten Digit Recognition with AI
- Name
- Francisco Sandi
- Published on
Introduction
During my master’s program, I recently learned about various Deep Learning techniques, with convolutional neural networks (CNNs) standing out due to their effectiveness in image recognition tasks. As a way to solidify my understanding, I decided to play around with the MNIST dataset—a classic in the AI community, often dubbed the "Hello World" of deep learning.
This experiment opened the door to new challenges, from detecting and cropping characters to adapting the data format to match my training dataset. In this article, I’ll walk you through the process, the hurdles I faced, and the lessons I learned along the way.
ℹ️ Background
MNIST
The MNIST dataset is a large collection of handwritten digits, commonly used for training various image processing systems. It contains 60,000 training images and 10,000 testing images, each a 28x28 grayscale pixel representation of digits from 0 to 9. Due to its simplicity and well-labeled data, MNIST has become the go-to dataset for anyone beginning their journey in deep learning.
MNIST dataset
Neural Networks
Before diving into convolutional neural networks, it’s important to understand the foundation upon which they are built: neural networks (NN). At their core, neural networks are computational models inspired by the human brain's neural structure. They consist of layers of interconnected nodes, or "neurons," where each connection has an associated weight.
The network is trained by adjusting the weights of the connections based on the error in its predictions. This is done through a process called backpropagation, combined with an optimization algorithm like gradient descent. As the network iterates over the data, it gradually learns to make more accurate predictions by minimizing the error.
Simple Neural Network
Convolutional Neural Networks (CNNs)
While traditional neural networks are powerful, they have limitations when it comes to handling image data, especially for tasks like digit recognition. CNNs are a specialized type of neural network designed to process data with a grid-like topology, such as images. The key difference between CNNs and traditional NNs lies in the use of convolutional layers.
The combination of these layers allows CNNs to automatically and adaptively learn spatial hierarchies of features, making them exceptionally well-suited for image recognition tasks. In the case of the MNIST dataset, the CNN can learn to recognize the shapes and structures of digits across different handwriting styles, leading to highly accurate digit classification.
Convolutional Neural Network
🧪 The Experiment
I used PyTorch, a popular Deep Learning framework, to implement and train the Convolutional Neural Network (CNN) on the MNIST dataset. Here’s how I structured the experiment:
1. Loading the Dataset
I loaded the MNIST dataset and organized it into batches using PyTorch's DataLoader. This allowed me to efficiently feed the data into the model during training and testing. The training dataset includes 60000 datapoints, and the testing dataset 10000 datapoints. Each loader was set up to handle 100 images per batch, with shuffling enabled for the training data to ensure a diverse mix of samples in each batch.
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
# Training data
train_data = datasets.MNIST(
root = 'data',
train = True,
transform = ToTensor(),
download = True,
)
# Testing data
test_data = datasets.MNIST(
root = 'data',
train = False,
transform = ToTensor(),
download = True,
)
loaders = {
'train' : torch.utils.data.DataLoader(train_data,
batch_size=100,
shuffle=True,
num_workers=1),
'test' : torch.utils.data.DataLoader(test_data,
batch_size=100,
shuffle=True,
num_workers=1),
}
Just to validate that the data was loaded correctly, I randomly previewed 10 elements from the training data with their corresponding labels:
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(20, 8))
cols, rows = 10, 1
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(train_data), size=(1,)).item()
img, label = train_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(label)
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
2. Defining the CNN Model
The CNN architecture I used was relatively simple but effective for digit recognition. The model consisted of the following layers:
- The 1st convolutional layer had 16 filters, each with a 5x5 kernel, followed by a ReLU activation function and a 2x2 max-pooling layer to reduce the spatial dimensions.
- The 2nd convolutional layer had 32 filters, again with a 5x5 kernel, followed by ReLU and max-pooling.
- Finally, a fully connected layer took the output of the convolutions and reduced it to 10 output classes, corresponding to the digits 0-9.
This structure allowed the model to learn increasingly abstract features of the images as they passed through the layers, from simple edges to more complex shapes.
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.out = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
output = self.out(x)
return output, x
cnn = CNN()
3. Training the Model
With the model defined, I moved on to training it using the training dataset. I set up the training process as follows:
- Loss Function: I used the CrossEntropyLoss, a standard choice for classification tasks, which measures the difference between the predicted output and the actual labels.
- Optimizer: I chose the Adam optimizer, which is widely used due to its ability to adapt the learning rate for each parameter individually, leading to faster convergence.
- Training Loop: The training loop ran for 10 epochs. In each epoch, the model processed the entire training dataset in batches, updating the weights based on the loss calculated for each batch. After every 100 steps, the model printed the current loss, providing insights into the training progress.
from torch.autograd import Variable
from torch import optim
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr = 0.01)
num_epochs = 10
def train(num_epochs, cnn, loaders):
cnn.train()
total_step = len(loaders['train'])
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(loaders['train']):
batch_x = Variable(images)
batch_y = Variable(labels)
output = cnn(batch_x)[0]
loss = loss_func(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print loss every 100 steps
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
torch.save(cnn.state_dict(), 'cnn.th')
train(num_epochs, cnn, loaders)
4. Model Evaluation
After training, I evaluated the model's performance on the entire test dataset to determine its accuracy. The model achieved an accuracy close to 100%, which is consistent with expectations for a CNN trained on MNIST. This high accuracy confirmed that the model could effectively recognize handwritten digits within the controlled environment of the MNIST dataset.
def test():
cnn.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in loaders['test']:
test_output, last_layer = cnn(images)
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = (pred_y == labels).sum().item() / float(labels.size(0))
print('Test Accuracy of the model on the 10000 test images: %.2f' % accuracy)
test()
Test Accuracy of the model on the 10000 test images: 1.00
This experiment provided a strong foundation in applying CNNs for digit recognition and set the stage for tackling more complex real-world applications.
🖼️ Use with a real picture
After successfully training the CNN on the MNIST dataset, I wanted to see how well it could handle handwritten digits from a photo. This required several steps to preprocess the data to make it compatible with the model.
1. Loading Image
I began by loading my test image containing handwritten digits using OpenCV. This image was loaded in grayscale to simplify the preprocessing steps.
import cv2
import matplotlib.pyplot as plt
# Load the image
image_path = 'handwritten-numbers.jpg'
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
# Display the image
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.show()
Handwritten digits
2. Preprocessing Digits
The next step involved extracting individual digits from the image resembling the MNIST dataset. This was very challenging since I don't have much experience with image processing, I'm sure there are betters ways to do this, bit this is the approach I intuitively took:
Binary Format: I converted the image to a binary format using thresholding. This step created a high-contrast binary image where the digits stood out clearly against the background.
Finding Contours: I used contour detection to identify distinct regions in the binary image, each of which potentially contained a digit. These contours were then sorted from left to right to maintain the correct order of digits since my image contains multiple digits.
Remove Noise: To ensure the extracted digits were of a reasonable size, I filtered out contours with small bounding boxes based on the diagonal size of the rectangles.
Padding and Resizing: Each digit was then padded to ensure it had a consistent size and was centered within a square. After padding, the digit images were resized to 28x28 pixels, the same size as the MNIST images, to match the input format expected by the model.
Transformations: Finally, I applied a series of transformations to prepare the digits for model input. This included converting the images to tensors, normalizing them, and maximizing contrast to enhance readability.
Preview Digits: To verify that the preprocessing was done correctly, I previewed the extracted and processed digits. This step was crucial to ensure that the digits were properly preprocessed before feeding them into the model.
import torchvision.transforms as transforms
from PIL import Image
import cv2
import math
import torch
# Threshold the image to create a binary image
_, binary_image = cv2.threshold(image, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Find contours in the binary image
contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort contours from left to right based on the x coordinate of the bounding box
contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[0])
# Process each contour to get rectangle information
rectangles = [cv2.boundingRect(contour) for contour in contours]
# Filter out rectangles that are too small
maxDiagonal = max([math.sqrt(w**2 + h**2) for (x, y, w, h) in rectangles])
rectangles = [(x, y, w, h) for (x, y, w, h) in rectangles if math.sqrt(w**2 + h**2) > maxDiagonal/2]
# Define the transformation to preprocess the image
transform = transforms.Compose([
transforms.Grayscale(), # Convert image to grayscale
transforms.Resize((28, 28)), # Resize image to 28x28
transforms.ToTensor(), # Convert image to tensor
transforms.Normalize((0.5,), (0.5,)), # Normalize the image
transforms.Lambda(lambda x: torch.where(x > 0.5, torch.ones_like(x), torch.zeros_like(x))) # Maximize contrast
])
digits = []
for (x, y, w, h) in rectangles:
digit = binary_image[y:y+h, x:x+w]
# Create a new black square image with padding
max_dim = max(w, h) + 10
padded_digit = np.zeros((max_dim, max_dim), dtype=np.uint8)
# Compute the position to place the original digit in the center
x_offset = (max_dim - w) // 2
y_offset = (max_dim - h) // 2
# Place the original digit image in the center of the new image
padded_digit[y_offset:y_offset+h, x_offset:x_offset+w] = digit
digit = padded_digit
digit_resized = cv2.resize(digit, (28, 28))
# Filter out noise, and apply transformations
min_white_pixels = 50
if cv2.countNonZero(digit_resized) > min_white_pixels:
digit_resized = Image.fromarray(digit_resized)
digit_tensor = transform(digit_resized).unsqueeze(0)
digits.append((digit_tensor, digit_resized))
# Preview digits obtained from the image
plt.figure(figsize=(10, 2))
for idx, (digit_tensor, digit_image) in enumerate(digits):
plt.subplot(1, len(digits), idx + 1)
plt.imshow(digit_image, cmap='gray')
plt.axis('off')
plt.show()
Extracted digits
3. Making Predictions
With the digits preprocessed, I loaded the trained CNN model and used it to make predictions on the extracted digits. For each digit, the model produced a prediction, which was then displayed alongside the digit image. This provided a clear view of the model’s performance on real-world data.
# Load the model
model = CNN()
model.load_state_dict(torch.load('cnn.th'))
model.eval()
# Make predictions on the digits
predictions = []
plt.figure(figsize=(10, 2))
for idx, (digit_tensor, digit_image) in enumerate(digits):
output, _ = model(digit_tensor)
_, predicted = torch.max(output, 1)
predictions.append(predicted.item())
# Plot the digit
plt.subplot(1, len(digits), idx + 1)
plt.imshow(digit_image, cmap='gray')
plt.title(f'Pred: {predicted.item()}')
plt.axis('off')
plt.show()
Results of digit recognition
The predictions were visually confirmed by plotting the digits with their predicted labels. This step allowed me to assess the model’s accuracy and see how well it generalized beyond the MNIST dataset, which surprised me.
Adapting the CNN to handle handwritten digits from a picture involved several preprocessing steps to convert the image into a format compatible with the trained model. This process demonstrated some of the practical challenges of working with real-world data and highlighted the importance of robust image preprocessing techniques.
However, when I tested the model with images containing less clear digits or those affected by shadows and varying lighting conditions, the results were less accurate. This indicates that further preprocessing and potentially more advanced techniques may be necessary to improve performance in real-world applications, where variability in digit quality and image conditions can significantly impact recognition accuracy.
👀 Try It Yourself
To give you a hands-on experience with the trained model, I’ve created a web utility in the site where you can draw your own numbers and see how well the model recognizes them. This interactive tool lets you play around with the trained CNN and see its predictions in real-time.
Feel free to experiment with different numbers and see how accurately the model can identify them. You can access the tool 👉 here.
AI Digit Recognition project
I hope you enjoy exploring the model's capabilities and find it helpful for understanding how convolutional neural networks can be used for image recognition!
📚 Resources
- But what is a neural network?: This video inspired me to do this experiment
- Deep Learning Basics: Introductory class to Deep Learning by Lex Friedman
- MNIST dataset: The MNIST database of handwritten digits in Kaggle
- Convolutional Neural Networks: Datacamp introduction to CNNs
If you reached this far, thanks for reading, I hope this was valuable for you in some way. I'm sure readers have very different levels of experience with deep learning, please leave your thoughts in the comments so we all can share what we know and keep learning together.
Comments (0)
Any thoughts to share? 🤔
Previous Article
My First TriathlonThanks for reading! Feel free to keep exploringmore articles