Loading Image Data

Other examples have used fairly artificial datasets that would not be used in real-world image classification. Instead, you’ll likely be dealing with full-sized images like you’d get from smart phone cameras. In this notebook, we’ll look at how to load images and use them to train neural networks.

We’ll be using a dataset of cat and dog photos available from Kaggle. Here are a couple example images:

This example uses this dataset to train a neural network that can differentiate between cats and dogs. These days it doesn’t seem like a big accomplishment, but five years ago it was a serious challenge for computer vision systems.

import torch
from torchvision import datasets, transforms

The easiest way to load image data is with datasets.ImageFolder from torchvision (documentation). In general you’ll use ImageFolder like so:

dataset = datasets.ImageFolder('path/to/data', transform=transform)

where 'path/to/data' is the file path to the data directory and transform is a list of processing steps built with the transforms module from torchvision. ImageFolder expects the files and directories to be constructed like so:



where each class has it’s own directory (cat and dog) for the images. The images are then labeled with the class taken from the directory name. So here, the image 123.png would be loaded with the class label cat. The dataset can be downloaded from here. It has also been split it into a training set and test set.


When you load in the data with ImageFolder, you’ll need to define some transforms. For example, the images are different sizes but need to all be the same size for training. They can be resized with transforms.Resize() or cropped with transforms.CenterCrop(), transforms.RandomResizedCrop(), etc. We’ll also need to convert the images to PyTorch tensors with transforms.ToTensor(). Typically, these transforms are combined into a pipeline with transforms.Compose(), which accepts a list of transforms and runs them in sequence. It looks something like this to scale, then crop, then convert to a tensor:

transform = transforms.Compose([transforms.Resize(255),

There are plenty of transforms available, I’ll cover more in a bit and you can read through the documentation.

Data Loaders

With the ImageFolder loaded, you have to pass it to a DataLoader. The DataLoader takes a dataset (such as you would get from ImageFolder) and returns batches of images and the corresponding labels. You can set various parameters like the batch size and if the data is shuffled after each epoch.

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

Here dataloader is a generator. To get data out of it, you need to loop through it or convert it to an iterator and call next().

# Looping through it, get a batch on each loop 
for images, labels in dataloader:

# Get one batch
images, labels = next(iter(dataloader))
data_dir = 'loading-image-data-into-pytorch/Cat_Dog_data'

transform = transforms.Compose([transforms.Resize(255),
dataset = datasets.ImageFolder(data_dir, 
dataloader = torch.utils.data.DataLoader(dataset, 

The following is a helper function that prints images to screen.

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

def imshow(image, ax=None, title=None, normalize=True):
    """Imshow for Tensor."""
    if ax is None:
        fig, ax = plt.subplots()
    image = image.numpy().transpose((1, 2, 0))

    if normalize:
        mean = np.array([0.5, 0.5, 0.5])
        std = np.array([0.5, 0.5, 0.5])
        image = std * image + mean
        image = np.clip(image, 0, 1)

    ax.tick_params(axis='both', length=0)

    return ax
images, labels = next(iter(dataloader))
imshow(images[0], normalize=False);


Data Augmentation

A common strategy for training neural networks is to introduce randomness in the input data itself. For example, you can randomly rotate, mirror, scale, and/or crop your images during training. This will help your network generalize as it’s seeing the same images but in different locations, with different sizes, in different orientations, etc.

To randomly rotate, scale and crop, then flip your images you would define your transforms like this:

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.Normalize([0.5, 0.5, 0.5], 
                                                            [0.5, 0.5, 0.5])])

You’ll also typically want to normalize images with transforms.Normalize. You pass in a list of means and list of standard deviations, then the color channels are normalized like so

input[channel] = (input[channel] - mean[channel]) / std[channel]

Subtracting mean centers the data around zero and dividing by std squishes the values to be between -1 and 1. Normalizing helps keep the network work weights near zero which in turn makes backpropagation more stable. Without normalization, networks will tend to fail to learn.

You can find a list of all the available transforms here. When you’re testing however, you’ll want to use images that aren’t altered (except you’ll need to normalize the same way). So, for validation/test images, you’ll typically just resize and crop.

data_dir = 'loading-image-data-into-pytorch/Cat_Dog_data'

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.Normalize([0.5, 0.5, 0.5], 
                                                            [0.5, 0.5, 0.5])])
test_transforms = transforms.Compose([transforms.Resize(255),
                                      transforms.Normalize([0.5, 0.5, 0.5], 
                                                           [0.5, 0.5, 0.5])])

train_data = datasets.ImageFolder(data_dir + '/train', 
test_data = datasets.ImageFolder(data_dir + '/test', 

trainloader = torch.utils.data.DataLoader(train_data, 
testloader = torch.utils.data.DataLoader(test_data, 
for loader_str, loader in [('Train Image', trainloader), 
                           ('Test Image', testloader)]:
    data_iter = iter(loader)

    images, labels = next(data_iter)
    ncol = 5
    fig, axes = plt.subplots(figsize=(12.5,4), 
    for ii in range(ncol):
        ax = axes[ii]
        ax.set_title('{} {}'.format(loader_str, ii+1))



Classifying these images likly won’t work with a fully-connected network, no matter how deep. These images have three color channels and at a higher resolution (previous examples used 28x28 images which are tiny).

Pre-trained networks can be used to build a model that can actually solve this problem.