Developing a Deep Learning Application
This project demonstrates development of a Python command-line application that uses a deep neural network to predict the contents of images. The application enables the user to train the network on a set of arbitrary, labeled images and to use that trained network to predict the labels of new images.
During development, the network managed 78.4% accuracy on a dataset consisting of images of 102 different types of flowers.
This is page 2 of 3 for the project, where I
Convert the Network Training Function to a Command-Line Application
import torch
from torch import nn, optim
from torchvision import transforms, datasets, models
import time
import argparse
import os
import pandas as pd
parser = argparse.ArgumentParser(description='Train a neural network.')
parser.add_argument('data_directory', # Required: data_directory
help='the directory of images to train on.')
parser.add_argument('--save_directory', # Optional: save_directory
help='the directory where checkpoints will be saved.')
parser.add_argument('--arch', # Optional: --arch
help='the base network architecture to train from.')
parser.add_argument('--learning_rate', # Optional: --learning_rate
help='the learning rate to use while training.',
type=float)
parser.add_argument('--hidden_units', # Optional: --hidden_units
help='the number of hidden units while training.',
type=int)
parser.add_argument('--epochs', # Optional: --epochs
help='the number of epochs to train for.',
type=int)
parser.add_argument('--gpu', # Optional: --gpu
help='use the gpu instead of CPU for training.',
action="store_true")
args = parser.parse_args()
data_dir = args.data_directory
if args.save_directory:
save_dir = args.save_directory
else:
save_dir = os.getcwd() + '/checkpoints/'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
arch = args.arch if args.arch in ['vgg11','vgg13','vgg16','vgg19'] else 'vgg11'
learning_rate = args.learning_rate if args.learning_rate else 0.003
hidden_units = args.hidden_units if args.hidden_units else 512
epochs = args.epochs if args.epochs else 20
gpu = args.gpu
print('\n data_dir: ' + data_dir)
print(' save_dir: ' + save_dir)
print(' arch: ' + arch)
print('learning_rate: ' + str(learning_rate))
print(' hidden_units: ' + str(hidden_units))
print(' epochs: ' + str(epochs))
print(' gpu: ' + str(gpu) + '\n')
######################### Helper Functions ####################################
def save_classifier_state(e):
filepath = (save_dir + '/classifier_state_epoch_{}.pth'
.format(str(e)
.zfill(2)))
checkpoint = {'base_arch' : arch,
'hidden_units' : hidden_units,
'output_units' : output_units,
'state_dict' : model.classifier.state_dict(),
'idx_to_class_num' : idx_to_class_num}
torch.save(checkpoint,
filepath)
def train_network(epochs, step_increment, print_table):
if print_table == True:
print('Epoch\tStep\tTraining Loss\tValidation Loss\t\tAccuracy\tElapsed Time')
step = 1
# Capture Start Time
start_time = time.time()
epoch_index, step_index = [], []
train_losses, valid_losses, accuracies, elapsed = [], [], [], []
for e in range(epochs):
train_loss, train_count, valid_loss, valid_count = 0, 0, 0, 0
for inputs, labels in trainloader:
# Model in training mode, dropout is on
model.train()
# Move input and label tensors to the default device
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
output = model.forward(inputs)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_count += len(inputs)
if step == len(trainloader):
# Turn off gradients for validation, will speed up inference
with torch.no_grad():
# Model in inference mode, dropout is off
model.eval()
accuracy = 0
equals_list = []
for inputs, labels in validloader:
inputs, labels = inputs.to(device), labels.to(device)
logps = model.forward(inputs)
batch_loss = criterion(logps, labels)
valid_loss += batch_loss.item()
valid_count += len(inputs)
# Calculate accuracy
ps = torch.exp(logps)
top_p, top_class = ps.topk(1, dim=1)
equals = top_class == labels.view(*top_class.shape)
equals_list.extend(equals.cpu().numpy().ravel())
# Calculate Elapsed Time
elapsed_time = time.time() - start_time
if print_table == True:
print("{:2}/{}\t{:4}\t{:13.3f}\t{:15.3f}\t{:16.3f}\t{:12.1f}"
.format(e+1, epochs,
step,
train_loss / train_count,
valid_loss / valid_count,
sum(equals_list) / len(equals_list),
elapsed_time))
train_loss, train_count, valid_loss, valid_count = 0, 0, 0, 0
step += 1
# Reset Step Counter for New Epoch
step = 1
save_classifier_state(e+1)
##################### Script Execution Continues ##############################
train_transforms = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
test_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
train_data = datasets.ImageFolder(data_dir + '/train',
transform=train_transforms)
valid_data = datasets.ImageFolder(data_dir + '/valid',
transform=test_transforms)
class_num_to_idx = train_data.class_to_idx
idx_to_class_num = {v:k for k, v in class_num_to_idx.items()}
output_units = len(class_num_to_idx.items())
batch_size = 48
trainloader = torch.utils.data.DataLoader(train_data,
batch_size=batch_size,
shuffle=True)
validloader = torch.utils.data.DataLoader(valid_data,
batch_size=batch_size)
if arch == 'vgg13':
model = models.vgg13(pretrained=True)
elif arch == 'vgg16':
model = models.vgg16(pretrained=True)
elif arch == 'vgg19':
model = models.vgg19(pretrained=True)
else:
model = models.vgg11(pretrained=True)
for param in model.parameters():
param.requires_grad = False
model.classifier = nn.Sequential(nn.Linear(25088, hidden_units),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_units, output_units),
nn.LogSoftmax(dim=1))
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.classifier.parameters(),
lr=learning_rate)
device = 'cuda' if gpu else 'cpu'
model.to(device)
train_network(epochs, 20, True)