Abstract
A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions. Unfortunately, making predictions using a whole ensemble of models is cumbersome and may be too computationally expensive to allow deployment to a large number of users, especially if the individual models are large neural nets. It is possible to compress the knowledge in an ensemble into a single model which is much easier to deploy and we develop this approach further using a different compression technique. We achieve some surprising results on MNIST and we show that we can significantly improve the acoustic model of a heavily used commercial system by distilling the knowledge in an ensemble of models into a single model. We also introduce a new type of ensemble composed of one or more full models and many specialist models which learn to distinguish fine-grained classes that the full models confuse. Unlike a mixture of experts, these specialist models can be trained rapidly and in parallel.
Key takeaways
- Knowledge from a cumbersome model, such as an ensemble of models, can be transferred to a smaller model suitable for deployment using a process called “distillation”.
- Distillation involves using the class probabilities generated by the cumbersome model as soft targets for training the smaller model.
- Soft targets provide more information per training case than hard targets, and also reduce the variance in the gradient between training cases, which allows the smaller model to be trained on less data and at a higher learning rate.
- The temperature parameter in the softmax function can be adjusted to control the softness of the target probabilities, and matching the logits of the cumbersome model is a special case of distillation.
- Specialist models, each trained on a subset of confusable classes, can improve performance when combined with a generalist model.
- Training specialist models can be parallelized, which can significantly speed up training time.
- Soft targets can act as a regularizer, allowing models to generalize well even with a small amount of training data.
Experiment
Preliminary concepts
- When the soft targets have high entropy, they provide much more information per training case than hard targets and much less variance in the gradient between training cases.
- The softmax function convert a logit into a probability :
where is a temperature that is normally set to 1. A higher value of produces a softer probability distribution over classes: this means that the probabilities will be more evenly distributed across all classes.
- Magnitude of gradients produced by the soft targets scale as .
Objectives
- Demonstrate that knowledge can be effectively transferred from a large, cumbersome model to a smaller, more efficient model using distillation, and to show that this distilled model can achieve comparable performance to the larger model.
- Show that specialist models trained on confusable subsets of classes can improve performance compared to a single generalist model, and to validate that these specialist models can be trained quickly and independently.
- Demonstrate that soft targets act as regularizers to avoid overfitting and improve generalization, even when training data is limited.
Setup
-
Models:
- Cumbersome Model:
- In the MNIST experiments, a large network with two hidden layers of 1200 rectified linear units is used.
- In the speech recognition experiments, an architecture with 8 hidden layers each containing 2560 rectified linear units and a final softmax layer is used.
- In the JFT experiments, a deep convolutional neural network was used.
- Distilled Model:
- In the MNIST experiments, a smaller net with two hidden layers of 800 or 30 units is used.
- The speech recognition experiments use a single model with same size as the individual models in the ensemble.
- Specialist Model: Models trained on a confusable subset of classes. They are initialized with the weights of a generalist model. The specialists have a smaller softmax layer that combines all non-specialist classes into a single “dustbin” class.
- Cumbersome Model:
-
Datasets:
- MNIST: A dataset of handwritten digits.
- Speech Recognition: A dataset of about 2000 hours of spoken English data yielding about 700M training examples.
- JFT: An internal Google dataset with 100 million labeled images and 15,000 labels.
-
Hyperparameters:
- The temperature parameter, T, in the softmax is varied during distillation.
- A weighted average of two objective functions (cross-entropy with soft targets and cross-entropy with hard targets) is used when the correct labels are known.
-
Metrics:
- MNIST: Test errors.
- Speech Recognition: Frame classification accuracy and Word Error Rate (WER).
- JFT: Classification accuracy (top 1).
-
Distillation method :
- Train on a transfer set and use a soft target distribution for each case in the transfer set that is produced by using the cumbersome model with a high temperature in its softmax.
- Use the same high temperature when training the distilled model.
- When the correct labels are known for all or some of the transfer set, use a weighted average of cross-entropy with the soft targets (distillation) and cross-entropy on correct labels (vanilla training, T=1).
-
Specialization method:
- Start from the trained baseline full network: each specialist model is initialized with the weights of the generalist model.
- Cluster the covariance matrix of the predictions of the generalist model, so that a set of classes that are often predicted together will be used as targets for one of the specialist models, .
- Train the specialist with half its examples coming from its special subset and half sampled at random from the remainder of the training set.
Results
Ablation study
While distilling on MNIST, omit all examples of the digit 3 from the transfer set :
→ from the perspective of the distilled model, 3 is a mythical digit that it has never seen.
→ the distilled model only makes 206 test errors of which 133 are on the 1010 threes in the test set.
Most of the errors are caused by the fact that the learned bias for the 3 class is much too low.
Increase the bias by 3.5 :
→ the distilled model makes 109 errors of which 14 are on 3s.
So with the right bias, the distilled model gets 98.6% of the test 3s correct despite never having seen a 3 during training.