Current Issue Cover
面向目标类别分类的无数据知识蒸馏方法

谢奕涛1, 苏鹭梅1, 杨帆2, 陈宇涵1(1.厦门理工学院;2.厦门大学)

摘 要
目的 目前,研究者们大多采用无数据蒸馏方法解决训练数据缺乏的问题。然而,现有的无数据蒸馏方法在实际应用场景中面临着模型收敛困难和学生模型紧凑性不足的问题,为了满足针对部分类别的模型训练需求,灵活选择教师网络目标类别知识,本文提出了一种新的无数据知识蒸馏方法:面向目标类别的掩码蒸馏(Masked Distillation for Target Classes, MDTC)。方法 MDTC在生成器学习原始数据的批归一化参数分布的基础上,通过掩码阻断生成网络在梯度更新过程中非目标类别的梯度回传,训练一个仅生成目标类别样本的生成器,从而实现对教师模型中特定知识的准确提取;此外,MDTC将教师模型引入到生成网络中间层的特征学习过程,优化生成器的初始参数设置和参数更新策略,加速模型收敛。结果 在4个标准图像分类数据集上,设计13个子分类任务,评估MDTC在不同难度的子分类任务上的性能表现。实验结果表明,MDTC能准确高效地提取教师模型中的特定知识,不仅总体准确率优于主流的无数据蒸馏模型,而且训练耗时少;其中,40%以上学生模型的准确率甚至能超过教师模型,最高提升了3.6%。结论 总体性能超越现有无数据蒸馏模型,尤其是在简单样本分类任务的知识学习效率非常高,在提取知识类别占比较低的情况下,模型性能最优。
关键词
Data-free knowledge distillation for target class classification

xieyitao, sulumei1, yangfan2, chenyuhan(1.Xiamen University of Technology;2.Xiamen University)

Abstract
Objective Knowledge distillation is a simple and effective method for compressing neural networks and has become a hot topic in model compression research. It features a "teacher-student" architecture where a large network guides the training of a smaller network to improve its performance in application scenarios, indirectly achieving network compression. In traditional methods, the student model"s training relies on the teacher"s training data, and the quality of the student model depends on the quality of the training data. When faced with data scarcity or lack of data, these methods fail to produce satisfactory results. Data-free knowledge distillation successfully addresses the issue of limited training data by introducing synthetic data. Such methods mainly synthesize training data by refining teacher network knowledge. For example, they use the intermediate representations of the teacher network for image inversion synthesis or employ the teacher network as a fixed discriminator to supervise the generator to synthetic images for training the student network. Compared with traditional methods, the training of data-free knowledge distillation does not rely on the original training data of the teacher network, which greatly expands the application scope of knowledge distillation. However, due to the need for additional synthetic training data, the training process may have a certain efficiency discount compared with traditional methods. Furthermore, in practical applications, we often only focus on a few target classes, but existing data-free knowledge distillation method is difficult to selectively learn the knowledge of the target class, especially when the number of teacher model classes is large, the model convergence is difficult, and the student model is difficult to achieve sufficient compactness. Therefore, this paper proposes a novel data-free knowledge distillation method: Masked Distillation for Target Classes (MDTC). It allows the student model to selectively learn the knowledge of target classes, maintaining good performance even when there are many classes in the teacher network. Compared to traditional methods, MDTC reduces the training difficulty and improves training efficiency of data-free knowledge distillation. Method The MDTC method utilizes a generator to learn the batch normalized parameter distribution of raw data, and by creating a mask to block the gradient backpropagation of non-target classes in the gradient update process of the generator, trains a generator that can generate target class samples. This method successfully extracts target knowledge from the teacher model while generating synthetic data that is similar to the original data. In addition, MDTC also introduces the teacher model into the feature learning process of the middle layer of the generator, supervises the training of the generator, optimizes the initial parameter settings and parameter update strategies of the generator, so as to accelerate the convergence of the model. The MDTC algorithm is divided into two stages: the first is the data synthesis stage, which fixes the student network and only updates the generated network. During the process of generating network updates, MDTC extracts three synthetic samples from the shallow, medium and deep layers of the generator respectively, inputs them into the teacher network for prediction, and updates the parameters of the generation network according to the feedback of the teacher network. When updating shallow and middle layer parameters, the other layer parameters of the generated network are fixed and updated separately for that layer. Finally, when updating the output layer of the generative network, the parameters of the entire generative network are updated to achieve the goal of gradually guiding the generator to learn and synthesize the target image. The second stage is the learning stage, in which the generation network is fixed, and the synthetic samples are input into the teacher network and the student network for prediction respectively. The target knowledge in the teacher is extracted by the mask, and the KL divergence is used to calculate the predicted output of the student network to update the student network. Result Four standard image classification datasets MNIST, SVHN, CIFAR10, and CIFAR100 are divided into 13 sub-classification tasks by Pearson similarity calculation, including 8 difficult tasks and 5 easy tasks. The performance of MDTC on sub-classification tasks with different difficulty is evaluated by classification accuracy. We also compare our method with five mainstream data-free knowledge distillation methods and vanilla KD method. The experimental results show that our method outperforms the other mainstream data-free distillation models on 11 sub-tasks. Moreover, in MNIST1, MNIST2, SVHN1, SVHN3, CIFAR102, and CIFAR104 (6 of the 13 sub-classification tasks), our method even surpasses the teacher model trained on the original data, achieving accuracy rates of 99.61%, 99.46%, 95.85%, 95.80%, 94.57%, and 95.00%, with a remarkable 3.6% improvement over the teacher network"s 91.40% accuracy in CIFAR104. Conclusion In this study, we propose a novel data-free knowledge distillation method: Masked Distillation for Target Classes (MDTC). The experimental results show that MDTC outperforms existing data-free distillation models overall, especially in efficiently learning knowledge for easy sample classification tasks and when knowledge classes have a low proportion. The method shows excellent performance when extracting knowledge from a limited set of categories.
Keywords

订阅号|日报