Knowledge Distillation

Student network is trained to mimic the output of teacher network. Source: Upadhyay 2018.
Student network is trained to mimic the output of teacher network. Source: Upadhyay 2018.

Deep learning is being used in a plethora of applications ranging from Computer Vision and Digital Assistants to Healthcare and Finance. The popularity of the fields of Machine Learning and Deep Learning can be attributed to the high accuracy of the obtained results, which is largely due to the average of an ensemble of thousands of models. However, such computationally intensive models cannot be deployed on mobile devices, or FPGAs for instant use. These devices have constraints on resources like limited memory and input/output ports.

One way of mitigating this problem is to use Knowledge Distillation. We train an ensemble of models or a complex model ('teacher') on the data. We then train a lighter model ('student') with the help of the complex model. The less-intensive student model can then be deployed on FPGAs.

Discussion

  • What is a teacher-student network?

    The best Machine Learning models are those that average the predictions of an ensemble of thousands of models. While deploying on hardware devices like FPGAs, however, problems ensue. FPGAs have a limited number of I/O ports, which forces developers to drastically reduce the number of inputs and outputs at each layer of their network.

    To alleviate this problem, we use two networks - a teacher and a student. Essentially, we train a bulky ensemble of models (teacher) and use a smaller, lighter model (student) for testing, prediction and deployment. The student is trained to mimic the prediction capabilities of the teacher. How we go about doing this constitutes the crux of Knowledge Distillation.

    In other words, the ensemble is simply a function that maps input to output. Transfer the knowledge in this function to the student network is knowledge distillation.

  • What is dark knowledge and softmax temperature?

    In classification problems, neural networks output logits that are computed for each class. A softmax layer "normalizes" these logits \(z_i\) into probabilities \(q_i\). For a softer distribution, logits are 'softened' or divided by a constant value, called the temperature \(T\):

    $$q_i = \frac{exp(z_i/T)}{\sum_j exp(z_j/T)}$$

    When the temperature is 1, the probabilities obtained are said to be unsoftened. Hinton et.al. that, in general, the temperature depends on the number of units in the hidden layer of a network. For example, when the number of units in the hidden layer was 300, temperatures above 8 worked well, whereas when the number of units was 30, temperatures in the range of 2.5-4 worked best. Higher the temperature, softer the probabilities.

    Consider a classification problem with four classes, [cow, dog, cat, car]. If we have an image of a dog, unsoftened hard targets would be [0, 1, 0, 0]. This doesn't tell much about what the ensemble has learned. By softening, we may get [0.05, 0.3, 0.2, 0.005]. It's clear that predicting a cow is 10 times greater than a car. It's this 'dark' knowledge that needs to be distilled from the teacher network to the student.

  • How could I implement this Knowledge distillation?
    Distilling the knowledge from a teacher to a student. Source: Neural Network Distiller 2019.
    Distilling the knowledge from a teacher to a student. Source: Neural Network Distiller 2019.

    Buciluǎ et al. designed the first methods of model compression. Later, Hinton et.al. showed the means of distilling the knowledge from an ensemble of models into a single, lighter model.

    For example, in image classification, the student would be trained on the class probabilities, or logits, output by the teacher. The logits represent a similarity metric over the classes and help in training good classifiers. Extracting this form of 'dark knowledge' from the teacher network and passing it on to the student is called distillation.

    Kariya's Medium article provides a simple implementation of Hinton's paper. He touches upon dark knowledge and proceeds to build a simple CNN-based network on the MNIST dataset , showing how the teacher-trained student performed better than a standalone student.

    Implementing knowledge distillation can be a resource-intensive task. It requires the training of the student model on the teacher's logits, in addition to training the teacher model.

    While training the student, care should be taken to avoid the vanishing gradient problem, which can occur if the learning rate of the student is too high.

  • How about performance?

    The objective of distilling the knowledge from an ensemble of models into a single, lightweight model is to ease the processes of deployment and testing. It is of paramount importance that accuracy not be compromised in trying to achieve this objective.

    In the original paper authored by Hinton et. al., the performance of the student network after knowledge distillation improved, when compared with a standalone student network. Both networks were trained on the MNIST dataset of images. The accuracies of the various models have been tabulated.

    As is obvious from the table, the best results are obtained from the bulky ensemble of models and their student alternatives must be used only in case of constrained resources.

  • What are the challenges with Knowledge Distillation?
    A framework for visual question answering. Source: Mun et al. 2018, fig. 2.
    A framework for visual question answering. Source: Mun et al. 2018, fig. 2.

    KD is limited to classification tasks that use softmax layer. Sometimes the assumptions are too strict, such as in FitNets where student models may not suit constrained deployment environments. Other approaches to model compression may therefore be preferred over KD.

    However, KD continues to be a promising area of research. In 2017, it was adapted for multiclass object detection. In 2018, KD was applied to construct specialized student models for visual question answering. Also in 2018, Guo et al. improved the robustness of student network so that it resists perturbations.

    In some domains such as healthcare, DNNs are not preferred. Decision trees are preferred since their predictions can be more easily interpreted. KD has been used to distil DNN into decision tree and thereby provide good performance and interpretability.

Milestones

1989

Hanson and Pratt propose network pruning using biased weight decay. They call their pruned networks minimal networks. In the early 1990s, other pruning methods such as optimal brain damage and optimal brain surgeon are proposed. These are early approaches to compress a neural network model. Knowledge distillation as an alternative is invented about two decades later.

2006

Buciluǎ et al. publish a paper titled Model Compression. They present a method for “compressing” large, complex ensembles into smaller, faster models, usually without significant loss in performance. They use the ensemble to label large unlabelled datasets. They then use this labelled data to train a single model that performs as well as the ensemble. Because it's not easy to obtain large sets of unlabelled data, they develop an algorithm called MUNGE to generate pseudo data. This work is limited to shallow networks.

2014
Mimic SNN: a shallow neural network mimicking the teacher network performs as well as deep CNN. Source: Ba and Caruana 2014, fig. 1.
Mimic SNN: a shallow neural network mimicking the teacher network performs as well as deep CNN. Source: Ba and Caruana 2014, fig. 1.

Ba and Caruana propose the idea of teacher-student learning method. They show that shallow models can perform as well as deep models. A complex teacher network (either a deep network or an ensemble) is trained. Instead of using the softmax output, the logits are used to train the shallow student network. Thus, the student network benefits from what the teacher network has learned without losing information via the softmax layer. Some call this softened softmax.

2014

Hinton et al. introduce the idea of passing on the 'dark' knowledge from an ensemble of models into a lighter, deployable model. In a paper published in March 2015, they explain that they're "distilling knowledge" from the complex model. The core idea is that models should generalize well to new data rather than optimize on training data. Instead of using logits, they use distillation, in which the softmax is used with a higher temperature, also called "soft targets". They note that using logits is a special case of distillation.

2015
FitNet uses intermediate-level hints from teacher network. Source: Romero et al. 2015, fig. 1.
FitNet uses intermediate-level hints from teacher network. Source: Romero et al. 2015, fig. 1.

FitNet aims to produce a student network that's thinner than teacher network while being of similar depth. In addition to the teacher's distilled knowledge of the final softmax layer, Fitnets also make use of intermediate-level hints from the hidden layers. Yim et al. propose a variation of this in 2017 by distilling knowledge from the inner product of features of two layers.

2018
In BANs, each generation of student trains the next generation. Source: Furlanello et al. 2018, fig. 1.
In BANs, each generation of student trains the next generation. Source: Furlanello et al. 2018, fig. 1.

Furlanello et al. show that student models parameterized similar to teacher models outperform the latter. They call these Born-Again Networks (BANs) where model compression is not the goal. Students are trained to predict correct labels plus match the teacher's output distribution (knowledge distillation).

2019

Researchers at the Indian Institute of Science, Bangalore, propose Zero-Shot Knowledge Distillation (ZSKD) in which they don't use teacher's training dataset or a transfer dataset for distillation. Instead, they synthesize pseudo data from the teacher's model parameters. They call this Data Impressions (DI). This is then used as a transfer dataset to perform distillation. Another research group, with the aim of reducing training, show that just 1% of the training data can be adequate.

2019
Relational Knowledge Distillation. Source: Park et al. 2019, fig. 1.
Relational Knowledge Distillation. Source: Park et al. 2019, fig. 1.

Park et al. look at the mutual relationships among data samples and transfer this knowledge to the student network. Called Relational Knowledge Distillation (RKD), this departs from the conventional approach of looking at individual samples. Liu et al. propose something similar, calling it Instance Relationship Graph (IRG). Attention network is another approach to distil relationships.

Sep
2019

Yuan et al. note the conventional KD can be reversed; that is, the teacher can also learn from the student. Another observation is that a poorly-trained teacher can improve the student. They see KD not just as similarity across categories but also as a regularization of soft targets. With this understanding, they propose Teacher-free Knowledge Distillation (Tf-KD) in which a student model learns from itself.

Sample Code

  • import os
    os.environ['CUDA_DEVICE_ORDER']= "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICE'] =  '"'
     
    import tensorflow as tf
    import numpy as np
    from tensorflow import keras
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Conv2D,MaxPooling2D,Dropout,Flatten,Dense,Activation,Lambda
    from tensorflow.keras.optimizers import SGD
     
    (train_data,train_labels),(test_data,test_labels) = keras.datasets.cifar10.load_data()
    train_data = train_data.astype('float32')
    test_data = test_data.astype('float32')
    train_data = train_data/255
    test_data = test_data/255
    train_labels = keras.utils.to_categorical(train_labels.astype('float32'))
    test_labels = keras.utils.to_categorical(test_labels.astype('float32'))
     
    def swish(x):
       beta = 1.5 
       return beta * x * keras.backend.sigmoid(x)
     
    def new_softmax(logits, temperature=1):
       logits = logits/temperature
       return np.exp(logits)/np.sum(np.exp(logits))
     
    print(train_data.shape)
     
    #teacher
    model = Sequential()
    model.add(Conv2D(32,(3,3),activation=swish, kernel_initializer='he_uniform', padding='same', input_shape=(32,32,3)))
    model.add(Conv2D(32, (3, 3), activation=swish, kernel_initializer='he_uniform', padding='same'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.2))
    model.add(Conv2D(64, (3, 3), activation=swish, kernel_initializer='he_uniform', padding='same'))
    model.add(Conv2D(64, (3, 3), activation=swish, kernel_initializer='he_uniform', padding='same'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.2))
    model.add(Conv2D(128, (3, 3), activation=swish, kernel_initializer='he_uniform', padding='same'))
    model.add(Conv2D(128, (3, 3), activation=swish, kernel_initializer='he_uniform', padding='same'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(128, activation=swish, kernel_initializer='he_uniform'))
    model.add(Dropout(0.2))
    model.add(Dense(10, name='logits'))
    model.add(Activation('softmax'))
    model.summary()
     
    opt = SGD(lr=0.001, momentum=0.9)
     
    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
    model.fit(train_data,train_labels,epochs=50)
    (loss,accuracy) = model.evaluate(test_data,test_labels)
    print(loss, accuracy)
     
    model_sans_softmax = keras.models.Model(inputs=model.input, outputs = model.get_layer('logits').output)
    new_logits = model_sans_softmax.predict(train_data)
    unsoftened_prob = new_softmax(new_logits, 1)
    print("Unsoftened probabilities " + str(unsoftened_prob[0]))
    temperature = 4
    softened_prob = new_softmax(new_logits, temperature)
    print("Softened probabilities " + str(softened_prob[0]))
     
    #student
    model1 = Sequential()
    model1.add(Conv2D(32, (3,3), activation=swish, kernel_initializer='he_uniform', padding='same', input_shape=(32,32,3)))
    model1.add(Conv2D(32, (3, 3), activation=swish, kernel_initializer='he_uniform', padding='same'))
    model1.add(MaxPooling2D((2, 2)))
    model1.add(Dropout(0.2))
    model1.add(Conv2D(8,(3,3),activation=swish,kernel_initializer='he_uniform', padding='same', input_shape=(16,16,32)))
    model1.add(MaxPooling2D((4,4)))
    model1.add(Conv2D(4,(3,3),activation=swish,kernel_initializer='he_uniform', padding='same'))
    model1.add(Conv2D(8,(3,3),activation=swish,kernel_initializer='he_uniform', padding='same'))
    model1.add(Conv2D(64, (3, 3), activation=swish, kernel_initializer='he_uniform', padding='same'))
    model1.add(MaxPooling2D((2, 2)))
    model1.add(Dropout(0.2))
    model1.add(Conv2D(128, (3, 3), activation=swish, kernel_initializer='he_uniform', padding='same'))
    model1.add(Conv2D(128, (3, 3), activation=swish, kernel_initializer='he_uniform', padding='same'))
    model1.add(MaxPooling2D((2, 2)))
    model1.add(Dropout(0.2))
    model1.add(Flatten())
    model1.add(Dense(128, activation=swish, kernel_initializer='he_uniform'))
    model1.add(Dropout(0.2))
    model1.add(Dense(10, name='logits'))
    model1.add(Activation('softmax'))
    model1.summary()
    logits = model1.get_layer('logits').output
    logits = Lambda(lambda x:x/temperature)(logits)
    out = Activation('softmax',name='soft')(logits)
     
    new_student = keras.models.Model(inputs=model1.input,outputs=out)
    new_student.summary()
     
    new_student.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
     
    new_student.fit(train_data,softened_prob,epochs=100)
    (loss,accuracy) = new_student.evaluate(test_data,test_labels)
    print(loss, accuracy)
     

References

  1. Ba, Lei Jimmy and Rich Caruana. 2014. "Do Deep Nets Really Need to be Deep?" Advances in Neural Information Processing Systems 27 (NIPS 2014), pp. 2654-2662. Accessed 2019-10-28.
  2. Buciluǎ, C., R. Caruana, and A. Niculescu-Mizil. 2006. "Model compression." In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 535-541, August. Accessed 2019-10-28.
  3. Chen, Guobin, Wongun Choi, Xiang Yu, Tony Han, and Manmohan Chandraker. 2017. "Learning Efficient Object Detection Models with Knowledge Distillation." Advances in Neural Information Processing Systems 30 (NIPS 2017), pp. 742-751. Accessed 2019-10-28.
  4. Cheng, Yu, Duo Wang, Pan Zhou, and Tao Zhan. 2018. "Model Compression and Acceleration for Deep Neural Networks." IEEE Signal Processing Magazine, pp. 126-136, January. Accessed 2019-10-28.
  5. Furlanello, Tommaso, Zachary C. Lipton, Michael Tschannen, Laurent Itti, and Anima Anandkumar. 2018. "Born Again Neural Networks." arXiv, v2, June 29. Accessed 2019-11-02.
  6. Guo, Tianyu, Chang Xu, Shiyi He, Boxin Shi, Chao Xu, and Dacheng Tao. 2018. "Robust Student Network Learning." arXiv, v2, July 31. Accessed 2019-11-02.
  7. Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. 2014. "Dark knowledge." Accessed 2019-09-18.
  8. Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. 2015. "Distilling the Knowledge in a Neural Network." arXiv, v1, March 09. Accessed 2019-09-18.
  9. Kariya, Mahendra. 2018a. "Dark Knowledge in Neural Networks." Medium, December 31. Accessed 2019-09-18.
  10. Kariya, Mahendra. 2018b. "Dark Knowledge in Neural Networks." A Colab notebook, via Google CoLab. Accessed 2019-09-18.
  11. Lee, Seunghyun, and Byung Cheol Song. 2019. "Graph-based Knowledge Distillation by Multi-head Attention Network." arXiv, v2, July 09. Accessed 2019-10-28.
  12. Li, Tianhong, Jianguo Li, Zhuang Liu, and Changshui Zhang. 2019. "Few Sample Knowledge Distillation for Efficient Network Compression." arXiv, v2, April 04. Accessed 2019-10-28.
  13. Liu, Xuan, Xiaoguang Wang, and Stan Matwin. 2018. "Improving the Interpretability of Deep Neural Networks with Knowledge Distillation." arXiv, v1, December 28. Accessed 2019-10-28.
  14. Liu, Yufan, Jiajiong Cao, Bing Li, Chunfeng Yuan, Weiming Hu, Yangxi Li, and Yunqiang Duan. 2019. "Knowledge Distillation via Instance Relationship Graph." CVPR2019, June 16-20, pp. 7096-7104. Accessed 2019-10-28.
  15. Mun, Jonghwan, Kimin Lee, Jinwoo Shin, and Bohyung Han. 2018. "Learning to Specialize with Knowledge Distillation for Visual Question Answering." Advances in Neural Information Processing Systems 31 (NIPS 2018), pp. 8081-8091. Accessed 2019-10-28.
  16. Nayak, Gaurav Kumar, Konda Reddy Mopuri, Vaisakh Shaj, R Venkatesh Babu, and Anirban Chakraborty. 2019. "Zero-Shot Knowledge Distillation in Deep Networks." arXiv, v1, May 20. Accessed 2019-10-26.
  17. Neural Network Distiller. 2019. "Knowledge Distillation." Documentation, Neural Network Distiller. Accessed 2019-10-28.
  18. Park, Wonpyo, Dongju Kim, Yan Lu, and Minsu Cho. 2019. "Relational Knowledge Distillation." CVPR2019, June 16-20, pp. 3967-3976. Accessed 2019-10-28.
  19. Romero, Adriana, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta, and Yoshua Bengio. 2015. "FitNets: Hints for Thin Deep Nets." arXiv, v4, March 27. Accessed 2019-10-28.
  20. Upadhyay, Ujjwal. 2018. "Knowledge Distillation." Medium, April 05. Accessed 2019-10-28.
  21. Yim, Junho, Donggyu Joo, Jihoon Bae, and Junmo Kim. 2017. "A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning." CVPR2017, July 21-26, pp. 4133-4141. Accessed 2019-10-28.
  22. Yuan, Li, Francis E.H.Tay, Guilin Li, Tao Wang, and Jiashi Feng. 2019. "Revisit Knowledge Distillation: a Teacher-free Framework." arXiv, v1, September 25. Accessed 2019-10-28.

Further Reading

  1. Malli, Refik Can. 2019. "SqueezeNet implementation with Keras Framework." rcmalli/keras-squeezenet, on GitHub, February 19. Accessed 2019-09-18.
  2. Iandola, Forrest N., Song Han, Matthew W. Moskewicz, Khalid Ashraf, William J. Dally, and Kurt Keutzer. 2016. "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size." v4, arXiv, November 04. Accessed 2019-09-18.
  3. Brownlee, Jason. 2019. "How to Develop a CNN From Scratch for CIFAR-10 Photo Classification." Machine Learning Mastery, May 13. Updated 2019-07-05. Accessed 2019-09-18.
  4. Kompella, Ravindra. 2018. "Tap into the dark knowledge using neural nets — Knowledge distillation." Towards Data Science, July 30. Accessed 2019-09-18.
  5. Preda, Gabriel. 2018. "CNN with Tensorflow|Keras for Fashion MNIST." v19, Kaggle, November 14. Accessed 2019-09-18.
  6. Gluon. 2019. "Prepare the ImageNet dataset." Gluon. Accessed 2019-09-18.

Article Stats

Author-wise Stats for Article Edits

Author
No. of Edits
No. of Chats
DevCoins
6
4
1394
5
4
1352
1635
Words
13
Likes
11K
Hits

Cite As

Devopedia. 2020. "Knowledge Distillation." Version 11, July 24. Accessed 2024-06-25. https://devopedia.org/knowledge-distillation
Contributed by
2 authors


Last updated on
2020-07-24 04:47:08