Transfer Learning Heuristics
Transfer learning involves adapting a pre-trained neural network to a new, different data set.
Udacity provides a set of basic heuristics to help determine the best way to approach transfer learning. The right approach depends on:
- The size of the new dataset and
- How similar the new dataset is to the original training data.
The size and similarity factors are somewhat subjective:
- A large dataset might contain 1,000,000 images, whereas a small one might have 2000 images. Overfitting becomes a concern when using transfer learing and a small dataset.
- Images of dogs and wolves could be considered similar. Flower images would be very different from a dog images.
The four cases are outlined in the image below.
The following generic pre-trained convolution neural network will be used to explain how to adjust the network for each of the four cases.
This generic network contains three convolutional layers and three fully-connected layers. In this hypothetical example the
- first layer detects edges, the
- second layer detects shapes, and the
- third convolutional layer detects higher level features.
Small Dataset, Similar Data
Since the new dataset is small, to avoid overfitting, the weights in the retained layers of the original network will be held constant rather than re-training.
Since the datasets are similar, images from each dataset will have similar higher level features. So, most or all of the pre-trained neural network layers already contain relevant information about the new data and should be kept.
Approach:
- Slice off the end of the neural network
- Add a new fully-connected layer with a layer that matches the number of classes in the new dataset
- Randomize the weights of the new fully-connected layer
- Freeze all the weights from retained layers of the pre-trained network
- Train the network to update the weights of the new fully connected layer
Small Dataset, Different Data
Since the new dataset is small, to avoid overfitting, the weights in the retained layers of the original network will be held constant rather than re-training.
Since the original training set and the new dataset are not similar, they do not share higher level features. The new network should only use the layers containing lower level features.
- Slice off most of the pre-training layers near the beginning of the network
- Add to the remaining pre-trained layers a new fully-connected layer that matches the number of classes in the new dataset
- Randomize the weights of the new fully connected layer
- Freeze all the weights from retained layers of the pre-trained network
- Train the network to update the weights of the new fully connected layer
Large Dataset, Similar Data
Since the new dataset is large, overfitting is less of a concern, so retraining all the weights is okay.
Since the original training set and the new dataset share higher level features, the entire neural network can be used.
- Remove the last fully-connected layer and add a new fully-connected layer that matches the number of classes in the new dataset
- Randomize the weights of the new fully connected layer
- Initialize the rest of the weights using the pre-trained weights
- Re-train the entire neural network
Large Dataset, Different Data
Since the new dataset is large, overfitting is less of a concern, so retraining all the weights should be okay.
Even though the new dataset is different from the training data, initializing the weights from the pre-trained network might make training faster. If using the pre-trained network as a starting point for th weights does not produce a successful model, another option would be to randomly initialize the convolutional neural network weights and train the network from scratch.
- Remove the last fully-connected layer and add a new fully-connected layer that matches the number of classes in the new dataset
- Randomize the weights of the new fully connected layer
Either:
- Randomize the weights of the new fully connected layer
- Retrain after either:
- Randomizing all of the weights in the network or
- Initializing the weights using the weights from the pre-trained network