TextBrewer: NLP中的知识蒸馏工具包

释放双眼,带上耳机,听听看~!
了解TextBrewer,一个融合了多种知识蒸馏技术的NLP工具包,提供便捷快速的知识蒸馏框架,用于压缩神经网络模型、提升推理速度和减少内存占用。

TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用

TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包,
融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架,用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。

1.简介

TextBrewer 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。

主要特点:

  • 模型无关:适用于多种模型结构(主要面向Transfomer结构)
  • 方便灵活:可自由组合多种蒸馏方法;可方便增加自定义损失等模块
  • 非侵入式:无需对教师与学生模型本身结构进行修改
  • 支持典型的NLP任务:文本分类、阅读理解、序列标注等

TextBrewer目前支持的知识蒸馏技术有:

  • 软标签与硬标签混合训练
  • 动态损失权重调整与蒸馏温度调整
  • 多种蒸馏损失函数: hidden states MSE, attention-based loss, neuron selectivity transfer, …
  • 任意构建中间层特征匹配方案
  • 多教师知识蒸馏

TextBrewer的主要功能与模块分为3块:

  1. Distillers:进行蒸馏的核心部件,不同的distiller提供不同的蒸馏模式。目前包含GeneralDistiller, MultiTeacherDistiller, MultiTaskDistiller等
  2. Configurations and Presets:训练与蒸馏方法的配置,并提供预定义的蒸馏策略以及多种知识蒸馏损失函数
  3. Utilities:模型参数分析显示等辅助工具

用户需要准备:

  1. 已训练好的教师模型, 待蒸馏的学生模型
  2. 训练数据与必要的实验配置, 即可开始蒸馏

在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见https://juejin.cn/post/examples/notebook_examples/sst2.ipynb (英文): SST-2文本分类任务上的BERT模型训练与蒸馏。

  • https://juejin.cn/post/examples/notebook_examples/msra_ner.ipynb (中文): MSRA NER中文命名实体识别任务上的BERT模型训练与蒸馏。
  • https://juejin.cn/post/examples/notebook_examples/sqaudv1.1.ipynb (英文): SQuAD 1.1英文阅读理解任务上的BERT模型训练与蒸馏。
  • https://juejin.cn/post/examples/random_token_example: 一个可运行的简单示例,在文本分类任务上以随机文本为输入,演示TextBrewer用法。

  • https://juejin.cn/post/examples/cmrc2018_example (中文): CMRC 2018上的中文阅读理解任务蒸馏,并使用DRCD数据集做数据增强。

  • https://juejin.cn/post/examples/mnli_example (英文): MNLI任务上的英文句对分类任务蒸馏,并展示如何使用多教师蒸馏。

  • https://juejin.cn/post/examples/conll2003_example (英文): CoNLL-2003英文实体识别任务上的序列标注任务蒸馏。

  • https://juejin.cn/post/examples/msra_ner_example (中文): MSRA NER(中文命名实体识别)任务上,使用分布式数据并行训练的Chinese-ELECTRA-base模型蒸馏。

  • 2.4.1蒸馏效果

    我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。

    我们测试了不同的学生模型,为了与已有公开结果相比较,除了BiGRU都是和BERT一样的多层Transformer结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层。

    • 英文模型
    Model #Layers Hidden size Feed-forward size #Params Relative size
    BERT-base-cased (教师) 12 768 3072 108M 100%
    T6 (学生) 6 768 3072 65M 60%
    T3 (学生) 3 768 3072 44M 41%
    T3-small (学生) 3 384 1536 17M 16%
    T4-Tiny (学生) 4 312 1200 14M 13%
    T12-nano (学生) 12 256 1024 17M 16%
    BiGRU (学生) 768 31M 29%
    • 中文模型
    Model #Layers Hidden size Feed-forward size #Params Relative size
    RoBERTa-wwm-ext (教师) 12 768 3072 102M 100%
    Electra-base (教师) 12 768 3072 102M 100%
    T3 (学生) 3 768 3072 38M 37%
    T3-small (学生) 3 384 1536 14M 14%
    T4-Tiny (学生) 4 312 1200 11M 11%
    Electra-small (学生) 12 256 1024 12M 12%

    2.4.2 蒸馏配置

    distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)
    #其他参数为默认值
    

    不同的模型用的matches我们采用了以下配置:

    Model matches
    BiGRU None
    T6 L6_hidden_mse + L6_hidden_smmd
    T3 L3_hidden_mse + L3_hidden_smmd
    T3-small L3n_hidden_mse + L3_hidden_smmd
    T4-Tiny L4t_hidden_mse + L4_hidden_smmd
    T12-nano small_hidden_mse + small_hidden_smmd
    Electra-small small_hidden_mse + small_hidden_smmd

    各种matches的定义在https://juejin.cn/post/examples/matches/matches.py中。均使用GeneralDistiller进行蒸馏。

    2.4.3训练配置

    蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。

    2.4.4英文实验结果

    在英文实验中,我们使用了如下三个典型数据集。

    Dataset Task type Metrics #Train #Dev Note
    MNLI 文本分类 m/mm Acc 393K 20K 句对三分类任务
    SQuAD 1.1 阅读理解 EM/F1 88K 11K 篇章片段抽取型阅读理解
    CoNLL-2003 序列标注 F1 23K 6K 命名实体识别任务

    我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。

    Public results:

    Model (public) MNLI SQuAD CoNLL-2003
    DistilBERT (T6) 81.6 / 81.1 78.1 / 86.2
    BERT6-PKD (T6) 81.5 / 81.0 77.1 / 85.3
    BERT-of-Theseus (T6) 82.4/ 82.1
    BERT3-PKD (T3) 76.7 / 76.3
    TinyBERT (T4-tiny) 82.8 / 82.9 72.7 / 82.1

    Our results:

    Model (ours) MNLI SQuAD CoNLL-2003
    BERT-base-cased (教师) 83.7 / 84.0 81.5 / 88.6 91.1
    BiGRU 85.3
    T6 83.5 / 84.0 80.8 / 88.1 90.7
    T3 81.8 / 82.7 76.4 / 84.9 87.5
    T3-small 81.3 / 81.7 72.3 / 81.4 78.6
    T4-tiny 82.0 / 82.6 75.2 / 84.0 89.1
    T12-nano 83.2 / 83.9 79.0 / 86.6 89.6

    说明:

    1. 公开模型的名称后括号内是其等价的模型结构
    2. 蒸馏到T4-tiny的实验中,SQuAD任务上使用了NewsQA作为增强数据;CoNLL-2003上使用了HotpotQA的篇章作为增强数据
    3. 蒸馏到T12-nano的实验中,CoNLL-2003上使用了HotpotQA的篇章作为增强数据

    2.4.5中文实验结果

    在中文实验中,我们使用了如下典型数据集。

    Dataset Task type Metrics #Train #Dev Note
    XNLI 文本分类 Acc 393K 2.5K MNLI的中文翻译版本,3分类任务
    LCQMC 文本分类 Acc 239K 8.8K 句对二分类任务,判断两个句子的语义是否相同
    CMRC 2018 阅读理解 EM/F1 10K 3.4K 篇章片段抽取型阅读理解
    DRCD 阅读理解 EM/F1 27K 3.5K 繁体中文篇章片段抽取型阅读理解
    MSRA NER 序列标注 F1 45K 3.4K (测试集) 中文命名实体识别

    实验结果如下表所示。

    Model XNLI LCQMC CMRC 2018 DRCD
    RoBERTa-wwm-ext (教师) 79.9 89.4 68.8 / 86.4 86.5 / 92.5
    T3 78.4 89.0 66.4 / 84.2 78.2 / 86.4
    T3-small 76.0 88.1 58.0 / 79.3 75.8 / 84.8
    T4-tiny 76.2 88.4 61.8 / 81.8 77.3 / 86.1
    Model XNLI LCQMC CMRC 2018 DRCD MSRA NER
    Electra-base (教师) 77.8 89.8 65.6 / 84.7 86.9 / 92.3 95.14
    Electra-small 77.7 89.3 66.5 / 84.9 85.5 / 91.3 93.48

    说明:

    1. 以RoBERTa-wwm-ext为教师模型蒸馏CMRC 2018和DRCD时,不采用学习率衰减
    2. CMRC 2018和DRCD两个任务上蒸馏时他们互作为增强数据
    3. Electra-base的教师模型训练设置参考自Chinese-ELECTRA
    4. Electra-small学生模型采用预训练权重初始化

    3.核心概念

    3.1Configurations

    • TrainingConfigDistillationConfig:训练和蒸馏相关的配置。

    3.2Distillers

    Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:

    • BasicDistiller: 提供单模型单任务蒸馏方式。可用作测试或简单实验。
    • GeneralDistiller (常用): 提供单模型单任务蒸馏方式,并且支持中间层特征匹配,一般情况下推荐使用
    • MultiTeacherDistiller: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。暂不支持中间层特征匹配
    • MultiTaskDistiller:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型。
    • BasicTrainer:用于单个模型的有监督训练,而非蒸馏。可用于训练教师模型

    3.3用户定义函数

    蒸馏实验中,有两个组件需要由用户提供,分别是callbackadaptor :

    3.3.1Callback

    回调函数。在每个checkpoint,保存模型后会被distiller调用,并传入当前模型。可以借由回调函数在每个checkpoint评测模型效果。

    3.3.2Adaptor

    将模型的输入和输出转换为指定的格式,向distiller解释模型的输入和输出,以便distiller根据不同的策略进行不同的计算。在每个训练步,batch和模型的输出model_outputs会作为参数传递给adaptoradaptor负责重新组织这些数据,返回一个字典。

    更多细节可参见完整文档中的说明。

    4.FAQ

    Q: 学生模型该如何初始化?

    A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入Feed Different batches to Student and Teacher, Feed Cached Values 章节。

    Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?

    A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。

    更多优质内容分享请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

    本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
    AI教程

    Hello world for Manim: 数学可视化和动画的Python库

    2023-11-26 19:04:14

    AI教程

    Python从零到壹 图像锐化Sobel、Laplacian算子实现边缘检测

    2023-11-26 19:18:14

    个人中心
    购物车
    优惠劵
    今日签到
    有新私信 私信列表
    搜索