Currently, especially in NLP, very large scale models are being trained. A large portion of those can’t even fit on an average person’s hardware. We can train a small network that can run on the limited computational resource of our mobile device. But small models can’t extract many complex features that can be handy in generating predictions unless you devise some elegant algorithm to do so. Plus, due to the Law of diminishing returns, a great increase in the size of model barely maps to a small increase in the accuracy.
There are currently two ways to solve this problem:
- Knowledge Distillation.
- Model Compression.
In this blog, I focus on talking about Knowledge Distillation. Using distillation, one could reduce the size of models like BERT by 87% and still retain 96% of its performance. Recent work even suggests that students can actually exceed teacher performance.
Shortcoming of normal neural networks
Take an example of MNIST dataset. Let’s pick a sample picture of number 3.
In training data, the number 3 translates to a corresponding one-hot-vector:
[0 0 0 1 0 0 0 0 0 0]
. This vector simply tells that the number in that image in 3 but fails to explicitly mention anything about the shape of number 3. Like the shape of 3 is similar to 8. Hence, neural network is never explicitly being asked to learn the generalized understanding of the training data.
Generalization of Information
The goal of a neural network is to predict the output for samples that the network had never seen during training by generalizing the knowledge within the training data. Taking the example of a discriminative neural network whose objective is to identify the number in a picture. Now the neural network returns distribution of probabilities across all classes 0, 1, 2, ..., 9
and this tells us a lot about the capability of the network to generalize over the concepts within the training data.
For a decently trained neural network on MNIST,
-
even though the probability for number 3 is significantly greater than the probability for the number 8 and number 0
-
Probability of 8 and 0 are comparable
-
still the probabilities of 8 and 0 are comparatively higher than other numbers.
So, the neural network is able to identify that the shape of the number in that image is 3 but the neural network also suggests that the shape of 3 is quite similar to the shape of numbers 8 and 0.
In the above example, we usually train a large and complex network or an ensemble model which can extract important features from the image data and can, therefore, produce better predictions.
However, these models are mostly very cumbersome (aka cumbersome model/network, which means deep and complex) Its deepness gives the ability to extract complex features and its complexity gives it the power to remain accurate. But the model is heavy enough that one need a large amount of memory and a powerful GPU to perform large and complex calculations. So that’s why we need to transfer the knowledge learned by this model to a much smaller model which can easily be used in mobile.
Knowledge Distillation
A few Definitions
-
soft targets: network’s probability/weight distribution across all classes.
-
hard targets: one-hot vector representation within the original training data.
-
Transfer-Set: pass the data through the cumbersome model and use its output (probability distribution) as the respective truth values. It can consist of the dataset used to train the original model, new dataset or both.
General idea of knowledge distillation
Knowledge distillation is a simple way to improve the performance of deep learning models on mobile devices. In this process, we train a large and complex network or an ensemble model which can extract important features from the given data and can, therefore, produce better predictions. Then we train a small network with the help of the cumbersome model. This small network will be able to produce comparable results, and in some cases, it can even be made capable of replicating the results of the cumbersome network.
You can distill the large and complex network in another much smaller network, and the smaller network does a reasonable job of approximating the original function learned by a deep network.
Teacher and Student
The distilled model (student), is trained to mimic the output of the larger network (teacher), instead of training it on the raw data directly.
The point is that the teacher is outputting class probabilities — soft labels rather than hard labels. A number classifier (classify 0,3,8
) might say “0
: 0.1, 3
: 0.75, 8
: 0.15” instead of “0
: 0, 3
: 1, 8
: 0”. Why bother? Because these “soft labels” are more informative than the original ones — telling the student that 3
does very slightly resemble 0
or 8
.
Student models can often come very close to teacher-level performance. Recent work even suggests that students can actually exceed teacher performance.
Temperature & Entropy
Temperature and Entropy are what we learned in physics and we know that Entropy increases with Temperature.
When soft-targets have high entropy, they give much more information per-training sample than hard-targets. For example, the soft targets “0
: 0.1, 3
: 0.75, 8
: 0.15”, contains information such as 0
and 8
are somehow similar. However, hard targets
“0
: 0, 3
: 1, 8
: 0” does not contain such relation between 0
and 8
.
However, the soft-targets would be less useful if the probability distribution of the output has low entropy (e.x. “0
: 0.01, 3
: 0.98, 8
: 0.01”). If this is the case, we need to raise its entropy and make it more informative.
Specifically, we use a parameter Temperature (T) to adjust the level of entropy and the formula is
$$q_{i}=\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)} \tag 1$$
Note:
-
$z_i$ is the logit of a class
-
$z_j$ are logits of all classes
-
$T$ is the temperature
-
$q_i$ is the resulting probability
-
For high temperatures $(T \to \infty)$, all actions have nearly the same probability.
-
For temperature $(T = 1)$, probabilities remain the same.
-
For low temperatures $(T \to 0)$, the probability of the class with the highest logit tends to be $1$.
-
In distillation, we raise the temperature of the final softmax until the cumbersome model produces a suitably soft set of targets. We then use the same high temperature when training the small model to match these soft targets.
Here is an example of adjusting Temperature $T$:
0 |
3 |
8 |
|
---|---|---|---|
truth | 0 | 1 | 0 |
logit | 0.1 | 0.7 | 0.2 |
Temp= 0.5 | 0.183 | 0.594 | 0.223 |
Temp= 1.0 | 0.254 | 0.464 | 0.282 |
Temp= 2.0 | 0.294 | 0.397 | 0.309 |
Temp= 5.0 | 0.318 | 0.358 | 0.324 |
Higher temperature results in a softer probability distribution over classes.
Suitable soft targets leads to:
-
smaller loss, hence smaller correction gradient (backpropagation).
-
less variation between the gradients of different training examples.
As a result:
-
a greater learning rate can be used to train the model.
-
a smaller dataset can be used to train the model.
Training the Distil Model
The simplest form of distillation is training a model on the soft targets from a cumbersome model with high temperature. But it works better to fit both the hard targets and the soft targets from the cumbersome model.
One of the most efficient methods of doing this is by using 2 objective functions (as shown in the figure above):
-
cross-entropy with soft targets using a high-temperature cumbersome model.
-
cross-entropy with hard targets using the same cumbersome model but with the temperature set to 1.
The following is a more detailed figure for training the Distil Model:
We calculate the total loss by
$$L=\lambda L^{s o f t}+(1-\lambda) L^{h a r d} \tag 2$$
where the weight of the first term should be usually larger and the total entropy loss
$$L^{s o f t}=-\sum_{c=1}^{C} y_{c}^{s o f t} \log \frac{e^{\frac{z_c}{T}}}{\sum_{c=1}^{C} e^{\frac{z_c}{T}}} \tag 3$$
where $C$ are all classes and $y_{c}^{s o f t}$ is the output soft target of class $c$ from the cumbersome model with high temperature setting.
During inference, the temperature of the distilled model is set to 1 to do prediction normally.
Reference:
- Knowledge distillation 1: https://medium.com/neuralmachine/knowledge-distillation-dc241d7c2322
- Knowledge distillation 2: https://towardsdatascience.com/distillation-of-knowledge-in-neural-networks-cc02f79698b6
- Knowledge distillation 3: https://blog.csdn.net/xbinworld/article/details/83063726?biz_id=102&utm_term=%E7%9F%A5%E8%AF%86%E8%92%B8%E9%A6%8F&utm_medium=distribute.pc_search_result.none-task-blog-2~blog~sobaiduweb~default-0-83063726&spm=1018.2118.3001.4187
- Knowledge distillation 4: https://www.zhihu.com/question/50519680
- Hinton, Dark knowledge: https://www.ttic.edu/dl/dark14.pdf
- Teacher and Student: https://www.quora.com/What-is-a-teacher-student-model-in-a-Convolutional-neural-network