Transfer learning
In general, transfer learning refers to the notion of using knowledge gained in one domain to solve a related problem in another domain. In deep learning, however, it specifically refers to the process of reusing a neural network trained for a specific task for a similar task in a different domain. The new task uses the feature detectors learned from a previous task, and so we do not have to train the model to learn them.
Deep-learning models tend to have a huge number of parameters, due to the nature of connectivity patterns among units of different layers. To train such a large model, a considerable amount of data is required; otherwise, the model may suffer from overfitting. For many problems requiring a deep learning solution, a large amount of data will not be available. For instance, in image processing for object recognition, deep-learning models provide state-of-the-art solutions. In such cases, transfer learning can be used to create features, based on the feature detectors learned from an existing trained deep-learning model. Then, those features can be used to build a simple model with the available data in order to solve the new problem at hand. So the only parameters that the new model needs to learn are the ones related to building the simple model, thus reducing the chances of overfitting. The pretrained models are generally trained on a huge corpus of data, and thus, they have reliable parameters as the feature detectors.
When we process images in CNNs, the initial layers learn to detect very generic features, such as curls, edges, color composition, and so on. As the network grows deeper, the convolutional layers in the deeper layers learn to detect more complex features that are relevant to the specific kind of dataset. We can use a pretrained network and choose to not train the first few layers, as they learn very generic features. Instead, we can concentrate on only training the parameters of the last few layers, since these would learn complex features that are specific to the problem at hand. This would ensure that we have fewer parameters to train for, and that we use the data judiciously, only training for the required complex parameters and not for the generic features.
Transfer learning is widely used in image processing through CNNs, where the filters act as feature detectors. The most common pretrained CNNs that are used for transfer learning are AlexNet, VGG16, VGG19, Inception V3, and ResNet, among others. The following diagram illustrates a pretrained VGG16 network that is used for transfer learning:
The input images represented by x are fed to the Pretrained VGG 16 network, and the 4096 dimensional output feature vector, x', is extracted from the last fully connected layer. The extracted features, x', along with the corresponding class label, y, are used to train a simple classification network, reducing the data required to solve the problem.
We will solve an image classification problem in the healthcare domain by using transfer learning in Chapter 2, Transfer Learning.