\u200E

背景介绍

AI应用场景演变
随着国家强有力的推进5G移动网络的发展,更多的移动应用和智能终端将快速占领消费者的生活,人工智能在提升用户体验和商业服务方面都是必不可少的一环。近年来,整个互联网的战场逐渐从移动生态向人工智能转移,“夯实移动基础决战AI时刻”已成为诸多公司的战略方向。Feed流、小视频、直播等类型应用将消费者的娱乐时间锁定在智能手机,与此同时,语音智能音响、可穿戴设备、车载智能设备和智能家居的兴起更是表明物联网时代已经到来。从PC互联网到移动互联网,再到智能物联网,百度一直在使用AI能力让搜索这件事情更简单,从PC搜索框到手百feed,再到智能音响,人们获取信息的方式更简单更智能,而在这背后,计算、感知、决策等环节,已经从数据中心逐渐转移到了边缘设备。

消费者需求
深度学习模型在各种消费者智能终端的高效部署已成为重要的需求。边缘设备这种特殊的部署环境就给我们的AI模型提出了新的挑战。受能耗和设备体积的限制,智能终端计算性能和存储能力相对较弱,突出的诉求主要体现在以下三点:

  • 首先是速度,比如人脸闸机、人脸解锁手机等,对响应速度比较敏感,需要做到实时响应。
  • 其次是存储,比如电网周边环境监测这个场景,图像目标检测模型部署在监控设备上,可用的内存只有200M。
  • 最后是能耗,离线翻译这种移动设备内置AI模型的能耗直接决定了它的续航能力。

以上诉求都需要我们根据终端环境对现有模型进行小型化处理,在不损失精度的情况下,让模型的体积更小、速度更快,能耗更低。

学术需求
当前关于深度学习的研究都是基于一些开源的标准数据集和任务,比如用于分类模型的ImageNet数据集、用于检测任务的COCO数据集、用于分割任务的CityScape数据集,而各个学者和机构发布的最先进的模型都是基于这类开源任务。 但是实际生产环境中的数据和任务的复杂度,都与这些开源的任务有较大的差异。直接将针对开源任务设计优化的模型用到实际需求中,难免会有不相关/弱相关/冗余的信息,这就需要我们针对特定任务和场景对已有的开源模型进行裁剪/压缩/优化。

直接设计小模型?
产出小模型的常见方式,无非是:设计更高效的网络结构、将模型的参数量变少、将模型的计算量减少,同时提高模型的精度。可能有人会提出疑问,为什么不直接设计一个小的CNN?要知道,实际业务子垂类众多,任务复杂度不同,在这种情况下,人工设计小模型难度非常大,需要非常强的领域知识。而模型压缩可以在经典小模型的基础上,稍作处理就可以快速拔高模型的各项性能,达到“多快好省”的目的。

上图是分类模型使用了蒸馏和量化的效果图,横轴是推理耗时,纵轴是模型准确率。可以看出,在人工设计的经典小模型基础上,经过蒸馏和量化可以进一步提升模型的精度和推理速度。

PaddleSlim简介

PaddleSlim是国内第一个拥有完全自主知识产权、功能最完备、完全开源的深度学习模型压缩工具,集深度学习模型压缩中常用的量化、剪裁、蒸馏、模型结构搜索、模型硬件搜索等方法于一体。其兼具灵活和效率的开发机制、工业级应用的压缩模型产出、超大规模并行模型结构搜索能力、产出的压缩模型与tensorRT及PaddleLite推理引擎无缝匹配、压缩算法对PaddlePaddle模型仓库全能力支持、完善的中英文文档、更多的模型压缩结果及使用示例等特性。PaddleSlim致力于让深度学习技术产出的模型更容易在工业落地。

  • 对于业务用户,PaddleSlim提供完整的模型压缩解决方案,可用于图像分类、检测、分割等各种类型的视觉场景。同时也在持续探索NLP领域模型的压缩方案。另外,PaddleSlim提供且在不断完善各种压缩策略在经典开源任务的benchmark, 以便业务用户参考。
  • 对于模型压缩算法研究者或开发者,PaddleSlim提供各种压缩策略的底层辅助接口,方便用户复现、调研和使用最新论文方法。 PaddleSlim会从底层能力、技术咨询合作和业务场景等角度支持开发者进行模型压缩策略相关的创新工作。

PaddleSlim的特点

  • 接口简单:以配置文件方式集中管理可配参数,方便实验管理。同时,在普通模型训练脚本上,添加极少代码即可完成模型压缩。
  • 效果好:对于冗余信息较少的 MobileNetV1 和 MobileNetV2 模型,卷积核剪切工具和自动网络结构搜索工具依然可缩减模型大小,并保持尽量少的精度损失。蒸馏压缩策略可明显提升原始模型的精度。量化训练与蒸馏的组合使用,可同时做到缩减模型大小和提升模型精度。网络结构搜索工具相比于传统RL方法提速几十倍。
  • 功能更强更灵活:剪切压缩过程自动化。剪切压缩策略支持更多网络结构。蒸馏支持多种方式,用户可自定义组合loss。

核心技术简介

PaddleSlim支持所有主流的模型压缩方法,包括剪裁、量化、蒸馏和NAS;同时支持多种压缩策略组合使用;并且配置操作也非常简单。
简单概括来说:剪裁,是通过剪掉卷积参数来让大模型的参数量变少。量化,是将Float32格式(4个字节)的模型转为int格式(2个字节),来减少计算量和模型体积。蒸馏,是将大模型的知识迁移到小模型,从而提高小模型的精度。NAS,是以模型大小和推理速度为约束的模型结构搜索,通过搜索来自动设计更高效的网络结构。下图是PaddleSlim的功能全景图。

剪裁
深度学习模型中有很多冗余的卷积参数,把这些参数去掉可以大大减少参数量,从而加快推理速度。我们来看下这张图,这张图很直观的解释了剪裁是什么。A是一个feature map, 右边的W1是卷积层的参数,有很多个卷积核,每个卷积核都对应一个输出通道,剪裁会把一些不重要的卷积核去掉,相应的卷积之后的输出B的通道数会减少。B后面的卷积参数W的每个卷积核的通道数也会减少,卷积核的个数不变,最后输出的结果C是和没有裁剪的时候是一样的。

确定剪掉哪些卷积核,可以通过以下方式:

  • 对于单个卷积,基于L0_norm、L1_norm或者FPGM裁剪:
  • 对于多个卷积,基于敏感度分析剪裁

对于一个卷积内的卷积核,可以通过以下规则对卷积核进行重要性排序,然后减掉不重要的卷积核。比如,输入只有一个通道,三个卷积核,如下图所示。

  • 使用L0_norm统计量打分:检查非零值的个数,比如绿色卷积核非零值是2个。得分越低,代表卷积核重要性越低,也意味此卷积核可以被减掉。
  • 使用L1_norm统计量打分:就是卷积核绝对值相加,又称曼哈顿距离。比如绿色卷积核得分是0+0.5+0.5+0=1。得分越低,代表卷积核重要性越低。
  • FPGM:把每个卷积核展开成一个向量,将这个向量当成某个空间里的点,然后两两计算卷积核之间的欧式距离。卷积核的得分就是当前卷积核到其他所有卷积核的距离之和,比如橙色卷积核得分0.2+0.5831=0.7831。得分越低,代表卷积核重要性越低。

卷积内的卷积核做好重要性排序,下一步是如何确定该卷积需要裁剪多少比例,这就用到了敏感度分析。如下图所示,横轴是一个卷积的裁剪比例,纵轴是裁剪这个比例后,模型在测试集上的精度。假设模型未裁剪前的精度是90。卷积1,裁剪比例依次为25%,50%,75%,计算对应裁剪比例后的模型在测试集上的精度,然后依次用红色点表示精度,并将红色点用折线连接起来绘制成敏感度曲线。然后恢复卷积1,对卷积K依次同比例裁剪,对应的精度分别用绿色点表示。比如上面这个折线,是卷积K的精度走向,在卷积K被剪掉25%,50%,75%后,模型精度下降的不多,说明模型对剪掉卷积K的卷积核不敏感,因此卷积K的裁剪比例可以大一些。在使用敏感度确定卷积层的重要性之后,即可使用L0_norm、L1_norm或者FPGM评估卷积核的重要性,并对卷积核进行裁剪。

量化训练
量化是将浮点数映射到低比特int,目的是减少模型体积和计算量,加快推理速度。量化训练包括以下几种方式:在线量化、离线量化、 Embedding 量化和只量化weight。
如果没有训练数据,可以选Embedding量化或者只量化weight,Embedding量化只针对NLP任务。只量化weight是指量化卷积或者全连接参数,它的好处是能够减小模型体积,增大读数据时的带宽,对IO密集型计算加速明显。
如果有训练数据,可以使用离线量化或者量化训练,两者区别在于是否需要re-train。好处是可以提高模型精度,减小模型体积。离线量化可以减少量化训练的开销,通常情况下大部分模型在不进行re-train的情况下也能达到较高的精度。

蒸馏
蒸馏是将复杂网络(老师模型)的知识迁移到小网络(学生模型)中。蒸馏的目的是为了提高小网络的精度。蒸馏的核心知识点在于如何将老师模型的知识迁移到学生网络,通常做法是使用老师模型的输出信息去监督学生模型的训练。设计算法的关键是设计一个合适的蒸馏loss:在老师与学生之间,组建一个loss,监督学生学习。

NAS
NAS(network architecture search)是一种自动设计神经网络的技术,根据搜索空间,使用一定的搜索算法来自动设计出高性能的网络结构。目的是自动设计高效的网络结构。Nas的搜索过程如下图所示,使用搜素策略在搜索空间中搜出来模型,对这个模型进行性能评估,用性能评估的结果来反馈给搜索策略。

模型压缩效果

分类模型
数据: ImageNet2012;模型: MobileNetV1

压缩策略 精度收益(baseline: 70.91%) 模型大小(baseline: 17.0M)
知识蒸馏(ResNet50) +1.06% -
知识蒸馏(ResNet50) + int8量化训练 +1.10% -71.76%
剪裁(FLOPs-50%) + int8量化训练 -1.71% -86.47%

图像检测模型

数据:Pascal VOC;模型:MobileNet-V1-YOLOv3

压缩方法 mAP(baseline: 76.2%) 模型大小(baseline: 94MB)
知识蒸馏(ResNet34-YOLOv3) +2.8% -
剪裁 FLOPs -52.88% +1.4% -67.76%
知识蒸馏(ResNet34-YOLOv3)+剪裁(FLOPs-69.57%) +2.6% -67.00%

数据:COCO;模型:MobileNet-V1-YOLOv3

压缩方法 mAP(baseline: 29.3%) 模型大小
知识蒸馏(ResNet34-YOLOv3) +2.1%
知识蒸馏(ResNet34-YOLOv3)+剪裁(FLOPs-67.56%) -0.3% -66.90%

搜索模型
数据:ImageNet2012;模型:MobileNetV2

硬件环境 推理耗时 Top1准确率(baseline:71.90%)
RK3288 -23% +0.07%
Android cellphone -20% +0.16%
iPhone 6s -17% +0.32%

环境准备

  1. 安装飞桨,推荐使用飞桨最新版本。飞桨安装请参考:飞桨安装教程
  2. 安装PaddleSlim。

    Paddle Slim对软硬件环境的要求与飞桨一致。

# 安装PaddleSlim
!pip install --upgrade paddleslim -i https://pypi.tuna.tsinghua.edu.cn/simple  
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already up-to-date: paddleslim in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (1.0.1)
Requirement already satisfied, skipping upgrade: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleslim) (4.36.1)

下面用简单的示例介绍如何使用PaddleSlim提供的接口进行模型的剪裁和蒸馏操作。
其它更多模型压缩教程,请参考:https://paddlepaddle.github.io/PaddleSlim/index.html

PaddleSlim实战-剪裁

本教程以图像分类模型MobileNetV1为例,说明如何快速使用PaddleSlim的敏感度分析接口对卷积层的通道进行敏感度分析,然后根据敏感度分析进行剪裁。其它包含卷积层的模型也可以参考该教程中的步骤进行操作。

该示例内容部分参考论文:Pruning Filters for Efficient ConvNets

该示例包含以下步骤:

  1. 导入依赖:PaddleSlim依赖飞桨,需要正确安装飞桨并导入飞桨和Paddle Slim。
  2. 构建模型:定义模型的网络结构。
  3. 定义输入数据:定义一个data reader,用于读取MNIST数据集。
  4. 定义模型评估方法:预定义一个用于评估模型在测试数据上精度的方法,以便在后续章节复用。
  5. 训练模型:敏感度分析操作必须基于训练好的模型。
  6. 获取待分析卷积参数名称:在调用敏感度分析接口时,需要指定待分析的参数名。这里介绍如何获得模型的参数名。
  7. 分析敏感度:进行敏感度分析,包括如果并行加速敏感度计算。
  8. 剪裁模型:根据敏感度信息对模型进行合适的剪裁。

以上步骤依次依赖前一个步骤。

以下章节依次介绍每个步骤的内容。

1. 导入依赖

PaddleSlim依赖飞桨,请确认已正确安装飞桨,然后按以下方式导入飞桨和PaddleSlim。

import paddle
import paddle.fluid as fluid
import paddleslim as slim

2. 构建网络

该章节构造一个用于对MNIST数据进行分类的分类模型,选用MobileNetV1,并将输入大小设置为[1, 28, 28],输出类别数为10。 为了方便展示示例,我们在paddleslim.models下预定义了用于构建分类模型的方法,执行以下代码构建分类模型。

use_gpu = False
place =  fluid.CPUPlace() # 指定运行网络训练的设备
if paddle.fluid.is_compiled_with_cuda():
    use_gpu = True
    place =  fluid.CUDAPlace() # 指定运行网络训练的设备

# 构建用于训练和测试的Program,同时返回网络的输入和输出变量
exe, train_program, val_program, inputs, outputs = slim.models.image_classification("MobileNet", [1, 28, 28], 10, use_gpu=use_gpu) 

3. 定义输入数据

为了快速执行该示例,我们选取简单的MNIST数据,Paddle框架的paddle.dataset.mnist包定义了MNIST数据的下载和读取。 代码如下:

import paddle.dataset.mnist as reader
train_reader = paddle.batch(
        reader.train(), batch_size=128, drop_last=True) # 用于读取训练数据的生成器
test_reader = paddle.batch(
        reader.test(), batch_size=128, drop_last=True) # 用于读取测试数据的生成器
data_feeder = fluid.DataFeeder(inputs, place) # 定义生成器产出的数据与网络输入的对应关系
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz 
Begin to download

Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz 
Begin to download
........
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz 
Begin to download

Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz 
Begin to download
..
Download finished

4. 定义模型评估方法

在计算敏感度时,需要裁剪单个卷积层后的模型在测试数据上的效果,我们定义以下方法实现该功能:

import numpy as np
def test(program): 
    acc_top1_ns = []
    acc_top5_ns = []
    for data in test_reader():
        acc_top1_n, acc_top5_n, _ = exe.run( # 对一个batch数据执行测试网络
            program,
            feed=data_feeder.feed(data),
            fetch_list=outputs)
        acc_top1_ns.append(np.mean(acc_top1_n)) # 收集当前batch数据的精度指标
        acc_top5_ns.append(np.mean(acc_top5_n))
    print("Final eva - acc_top1: {}; acc_top5: {}".format(
        np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)))) # 平均多个batch测试句的精度指标并打印出来
    return np.mean(np.array(acc_top1_ns)) # 返回整个测试数据集上的top1准确率

5. 训练模型

只有训练好的模型才能做敏感度分析,因为该示例任务相对简单,这里用训练一个epoch产出的模型做敏感度分析。对于其它训练比较耗时的模型,您可以加载训练好的模型权重。

以下为模型训练代码:

for data in train_reader():
    acc1, acc5, loss = exe.run(train_program, feed=data_feeder.feed(data), fetch_list=outputs) # 训练一个batch
print(np.mean(acc1), np.mean(acc5), np.mean(loss))
1.0 1.0 0.014517799

用上节定义的模型评估方法,评估当前模型在测试集上的精度:

test(val_program)
Final eva - acc_top1: 0.9665464758872986; acc_top5: 0.9988982081413269
0.9665465

6. 获取待分析卷积参数

params = []
for param in train_program.global_block().all_parameters(): # 遍历所有参数
    if "_sep_weights" in param.name: # 用参数名过滤需要分析的参数
        params.append(param.name)
print(params)
params = params[:5] # 为方便展示,我们仅取前5个参数来分析
['conv2_1_sep_weights', 'conv2_2_sep_weights', 'conv3_1_sep_weights', 'conv3_2_sep_weights', 'conv4_1_sep_weights', 'conv4_2_sep_weights', 'conv5_1_sep_weights', 'conv5_2_sep_weights', 'conv5_3_sep_weights', 'conv5_4_sep_weights', 'conv5_5_sep_weights', 'conv5_6_sep_weights', 'conv6_sep_weights']

7. 分析敏感度

7-1 简单计算敏感度

调用sensitivity接口对训练好的模型进行敏感度分析。

在计算过程中,敏感度信息会不断追加保存到选项sensitivities_file指定的文件中,该文件中已有的敏感度信息不会被重复计算。

先用以下命令删除当前路径下可能已有的sensitivities_0.data文件。

!rm -rf sensitivities_0.data

除了指定待分析的卷积层参数,我们还可以指定敏感度分析的粒度和范围,即单个卷积层参数分别被剪裁掉的比例。

如果待分析的模型比较敏感,剪掉单个卷积层的40%的通道,模型在测试集上的精度损失就达90%,那么pruned_ratios最大设置到0.4即可,比如: [0.1, 0.2, 0.3, 0.4]

为了得到更精确的敏感度信息,我可以适当调小pruned_ratios的粒度,比如:[0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]

pruned_ratios的粒度越小,计算敏感度的速度越慢。

sens_0 = slim.prune.sensitivity( # 调用敏感度计算接口
        val_program, # 使用测试网络进行敏感度评估
        place, # 指定之前定义的设备信息
        params, # 指定要分析的参数
        test, # 指定用于评估模型的方法
        sensitivities_file="sensitivities_0.data", # 保存敏感度信息的文件
        pruned_ratios=[0.1, 0.2])  # 依次剪裁掉的比例
print(sens_0)
2020-05-11 11:22:23,898-INFO: sensitive - param: conv2_1_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9665464758872986; acc_top5: 0.9988982081413269
2020-05-11 11:23:27,210-INFO: pruned param: conv2_1_sep_weights; 0.1; loss=0.003316054353490472
2020-05-11 11:23:27,211-INFO: sensitive - param: conv2_1_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9633413553237915; acc_top5: 0.9986979365348816
2020-05-11 11:24:30,642-INFO: pruned param: conv2_1_sep_weights; 0.2; loss=0.003419717540964484
2020-05-11 11:24:30,644-INFO: sensitive - param: conv2_2_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9632411599159241; acc_top5: 0.9987980723381042
2020-05-11 11:25:34,129-INFO: pruned param: conv2_2_sep_weights; 0.1; loss=0.0004145299026276916
2020-05-11 11:25:34,131-INFO: sensitive - param: conv2_2_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9661458134651184; acc_top5: 0.9987980723381042
2020-05-11 11:26:37,465-INFO: pruned param: conv2_2_sep_weights; 0.2; loss=0.006321242079138756
2020-05-11 11:26:37,467-INFO: sensitive - param: conv3_1_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9604367017745972; acc_top5: 0.9983974099159241
2020-05-11 11:27:41,218-INFO: pruned param: conv3_1_sep_weights; 0.1; loss=-0.000725396501366049
2020-05-11 11:27:41,221-INFO: sensitive - param: conv3_1_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9672476053237915; acc_top5: 0.9989984035491943
2020-05-11 11:28:45,097-INFO: pruned param: conv3_1_sep_weights; 0.2; loss=0.007253903429955244
2020-05-11 11:28:45,099-INFO: sensitive - param: conv3_2_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9595352411270142; acc_top5: 0.9985977411270142
2020-05-11 11:29:48,997-INFO: pruned param: conv3_2_sep_weights; 0.1; loss=0.00259065767750144
2020-05-11 11:29:48,999-INFO: sensitive - param: conv3_2_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9640424847602844; acc_top5: 0.9988982081413269
2020-05-11 11:30:52,993-INFO: pruned param: conv3_2_sep_weights; 0.2; loss=0.016580332070589066
2020-05-11 11:30:52,994-INFO: sensitive - param: conv4_1_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9505208134651184; acc_top5: 0.9989984035491943
2020-05-11 11:31:56,862-INFO: pruned param: conv4_1_sep_weights; 0.1; loss=0.0009326614672318101
2020-05-11 11:31:56,863-INFO: sensitive - param: conv4_1_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9656450152397156; acc_top5: 0.9991987347602844
2020-05-11 11:33:00,752-INFO: pruned param: conv4_1_sep_weights; 0.2; loss=0.011813485994935036
Final eva - acc_top1: 0.9551281929016113; acc_top5: 0.9984976053237915
{'conv2_1_sep_weights': {0.1: 0.0033160544, 0.2: 0.0034197175}, 'conv2_2_sep_weights': {0.1: 0.0004145299, 0.2: 0.006321242}, 'conv3_1_sep_weights': {0.1: -0.0007253965, 0.2: 0.0072539034}, 'conv3_2_sep_weights': {0.1: 0.0025906577, 0.2: 0.016580332}, 'conv4_1_sep_weights': {0.1: 0.00093266147, 0.2: 0.011813486}}

7-2 扩展敏感度信息

前边计算敏感度用的是pruned_ratios=[0.1, 0.2], 我们可以在此基础上将其扩展到[0.1, 0.2, 0.3]

sens_0 = slim.prune.sensitivity(# 调用敏感度计算接口
        val_program, # 使用测试网络进行敏感度评估
        place, # 指定之前定义的设备信息
        params, # 指定要分析的参数
        test, # 指定用于评估模型的方法
        sensitivities_file="sensitivities_0.data", # 保存敏感度信息的文件
        pruned_ratios=[0.3]) # 依次剪裁掉的比例
print(sens_0)
2020-05-11 12:04:00,400-INFO: sensitive - param: conv2_1_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9665464758872986; acc_top5: 0.9988982081413269
2020-05-11 12:05:03,899-INFO: pruned param: conv2_1_sep_weights; 0.3; loss=0.007046638522297144
2020-05-11 12:05:03,901-INFO: sensitive - param: conv2_2_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9597355723381042; acc_top5: 0.9984976053237915
2020-05-11 12:06:07,507-INFO: pruned param: conv2_2_sep_weights; 0.3; loss=0.009015562944114208
2020-05-11 12:06:07,508-INFO: sensitive - param: conv3_1_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9578325152397156; acc_top5: 0.9983974099159241
2020-05-11 12:07:10,976-INFO: pruned param: conv3_1_sep_weights; 0.3; loss=0.03471500054001808
2020-05-11 12:07:10,978-INFO: sensitive - param: conv3_2_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9329928159713745; acc_top5: 0.9976963400840759
2020-05-11 12:08:14,431-INFO: pruned param: conv3_2_sep_weights; 0.3; loss=0.05564764887094498
2020-05-11 12:08:14,432-INFO: sensitive - param: conv4_1_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9127604365348816; acc_top5: 0.9971955418586731
2020-05-11 12:09:17,918-INFO: pruned param: conv4_1_sep_weights; 0.3; loss=0.03233160451054573
Final eva - acc_top1: 0.9352964758872986; acc_top5: 0.9974960088729858
{'conv2_1_sep_weights': {0.1: 0.0033160544, 0.2: 0.0034197175, 0.3: 0.0070466385}, 'conv2_2_sep_weights': {0.1: 0.0004145299, 0.2: 0.006321242, 0.3: 0.009015563}, 'conv3_1_sep_weights': {0.1: -0.0007253965, 0.2: 0.0072539034, 0.3: 0.034715}, 'conv3_2_sep_weights': {0.1: 0.0025906577, 0.2: 0.016580332, 0.3: 0.05564765}, 'conv4_1_sep_weights': {0.1: 0.00093266147, 0.2: 0.011813486, 0.3: 0.032331605}}

7-3 多进程加速计算敏感度信息
敏感度分析所用时间取决于待分析的卷积层数量和模型评估的速度,我们可以通过多进程的方式加速敏感度计算。 在不同的进程设置不同pruned_ratios, 然后将结果合并。 7-3-1 多进程计算敏感度
在以上章节,我们计算了pruned_ratios=[0.1, 0.2, 0.3]的敏感度,并将其保存到了文件sensitivities_0.data中。 在另一个进程中,我们可以设置pruned_ratios=[0.4],并将结果保存在文件sensitivities_1.data中。代码如下:

sens_1 = slim.prune.sensitivity( # 调用敏感度计算接口
        val_program, # 使用测试网络进行敏感度评估
        place, # 指定之前定义的设备信息
        params, # 指定要分析的参数
        test, # 指定用于评估模型的方法
        sensitivities_file="sensitivities_1.data", # 保存敏感度信息的文件
        pruned_ratios=[0.4]) # 依次剪裁掉的比例
print(sens_1)
{'conv2_1_sep_weights': {0.4: 0.015814325}, 'conv3_1_sep_weights': {0.4: 0.06366808}, 'conv4_1_sep_weights': {0.4: 0.056377064}, 'conv3_2_sep_weights': {0.4: 0.027110271}, 'conv2_2_sep_weights': {0.4: 0.03727665}}

7-3-2 加载多个进程产出的敏感度文件

s_0 = slim.prune.load_sensitivities("sensitivities_0.data") # 从文件中加载敏感度信息
s_1 = slim.prune.load_sensitivities("sensitivities_1.data")
print(s_0)
print(s_1)
{'conv2_1_sep_weights': {0.1: 0.0033160544, 0.2: 0.0034197175, 0.3: 0.0070466385}, 'conv2_2_sep_weights': {0.1: 0.0004145299, 0.2: 0.006321242, 0.3: 0.009015563}, 'conv3_1_sep_weights': {0.1: -0.0007253965, 0.2: 0.0072539034, 0.3: 0.034715}, 'conv3_2_sep_weights': {0.1: 0.0025906577, 0.2: 0.016580332, 0.3: 0.05564765}, 'conv4_1_sep_weights': {0.1: 0.00093266147, 0.2: 0.011813486, 0.3: 0.032331605}}
{'conv2_1_sep_weights': {0.4: 0.015814325}, 'conv3_1_sep_weights': {0.4: 0.06366808}, 'conv4_1_sep_weights': {0.4: 0.056377064}, 'conv3_2_sep_weights': {0.4: 0.027110271}, 'conv2_2_sep_weights': {0.4: 0.03727665}}

7-3-3 合并敏感度信息

s = slim.prune.merge_sensitive([s_0, s_1]) # 合并敏感度信息
print(s)
{'conv2_1_sep_weights': {0.1: 0.0033160544, 0.2: 0.0034197175, 0.3: 0.0070466385, 0.4: 0.015814325}, 'conv2_2_sep_weights': {0.1: 0.0004145299, 0.2: 0.006321242, 0.3: 0.009015563, 0.4: 0.03727665}, 'conv3_1_sep_weights': {0.1: -0.0007253965, 0.2: 0.0072539034, 0.3: 0.034715, 0.4: 0.06366808}, 'conv3_2_sep_weights': {0.1: 0.0025906577, 0.2: 0.016580332, 0.3: 0.05564765, 0.4: 0.027110271}, 'conv4_1_sep_weights': {0.1: 0.00093266147, 0.2: 0.011813486, 0.3: 0.032331605, 0.4: 0.056377064}}

8. 剪裁模型

根据以上章节产出的敏感度信息,对模型进行剪裁。 8-1 计算剪裁率
首先,调用PaddleSlim提供的get_ratios_by_loss方法根据敏感度计算剪裁率,通过调整参数loss大小获得合适的一组剪裁率:

loss = 0.01
ratios = slim.prune.get_ratios_by_loss(s_0, loss) # 根据敏感度计算一组剪裁率
print(ratios)
{'conv2_1_sep_weights': 0.3, 'conv2_2_sep_weights': 0.3, 'conv3_1_sep_weights': 0.20999995231656052, 'conv3_2_sep_weights': 0.15296293795649352, 'conv4_1_sep_weights': 0.1833331959422747}

8-2 剪裁训练网络

pruner = slim.prune.Pruner() # 构造一个Pruner实例
print("FLOPs before pruning: {}".format(slim.analysis.flops(train_program))) # 剪裁前的FLOPs
pruned_program, _, _ = pruner.prune(
        train_program, # 待剪裁的训练网络
        fluid.global_scope(), # 保存模型参数数值的scope
        params=ratios.keys(), # 待剪裁的参数名称
        ratios=ratios.values(), # 待剪裁比例
        place=place) # 参数所在设备
print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_program))) # 剪裁后的FLOPs
FLOPs before pruning: 10896832.0
FLOPs after pruning: 9588167.0

8-3 剪裁测试网络

注意:对测试网络进行剪裁时,需要将only_graph设置为True,具体原因请参考Pruner API文档

pruner = slim.prune.Pruner()
print("FLOPs before pruning: {}".format(slim.analysis.flops(val_program)))
pruned_val_program, _, _ = pruner.prune(
        val_program, # 对测试网络进行剪裁
        fluid.global_scope(),
        params=ratios.keys(),
        ratios=ratios.values(),
        place=place,
        only_graph=True) # 只对网络结构进行剪裁,不修改scope中的参数数值
print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_val_program)))
FLOPs before pruning: 10896832.0
FLOPs after pruning: 9588167.0

测试一下剪裁后的模型在测试集上的精度:

test(pruned_val_program)
Final eva - acc_top1: 0.7620192170143127; acc_top5: 0.9674479365348816
0.7620192

8-4 训练剪裁后的模型

对剪裁后的模型在训练集上训练一个epoch

for data in train_reader():
    acc1, acc5, loss = exe.run(pruned_program, feed=data_feeder.feed(data), fetch_list=outputs)
print(np.mean(acc1), np.mean(acc5), np.mean(loss))
1.0 1.0 0.004613471

测试训练后模型的精度:

test(pruned_val_program)
Final eva - acc_top1: 0.9762620329856873; acc_top5: 0.9992988705635071
0.97626203

PaddleSlim实战-蒸馏

注意:在执行本章节代码前请重启环境。

该教程以图像分类模型MobileNetV1为例,说明如何快速使用PaddleSlim的知识蒸馏接口。 该示例包含以下步骤:

  1. 导入依赖:PaddleSlim依赖飞桨,需要正确安装飞桨并导入飞桨和Paddle Slim。
  2. 定义student_program和teacher_program:
  3. 选择特征图
  4. 合并program(merge)并添加蒸馏loss
  5. 模型训练

以下章节依次介绍每个步骤的内容。

1. 导入依赖

PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以下方式导入Paddle和PaddleSlim:

import paddle
import paddle.fluid as fluid
import paddleslim as slim

2. 定义student_program和teacher_program

本教程在MNIST数据集上进行知识蒸馏的训练和验证,输入图片尺寸为[1, 28, 28],输出类别数为10。 选择ResNet50作为teacher对MobileNet结构的student进行蒸馏训练。

from  paddleslim.models.mobilenet import MobileNet
model = MobileNet()
student_program = fluid.Program() # 构造一个新的program用于学生网络的训练
student_startup = fluid.Program() # 构造用于学生网络初始化的program
with fluid.program_guard(student_program, student_startup):
    image = fluid.data(
        name='image', shape=[None] + [1, 28, 28], dtype='float32') # 声明输入
    label = fluid.data(name='label', shape=[None, 1], dtype='int64') # 声明输入
    out = model.net(input=image, class_dim=10) # 定义网络结构
    cost = fluid.layers.cross_entropy(input=out, label=label) # 添加cost
    avg_cost = fluid.layers.mean(x=cost)
    acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) # 计算精度
    acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
from  paddleslim.models.resnet import ResNet50
teacher_model = ResNet50()
teacher_program = fluid.Program()
teacher_startup = fluid.Program()
with fluid.program_guard(teacher_program, teacher_startup):
    with fluid.unique_name.guard():
        image = fluid.data(
            name='image', shape=[None] + [1, 28, 28], dtype='float32')
        predict = teacher_model.net(image, class_dim=10)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(teacher_startup)
[]

3. 选择特征图

我们可以用student_的list_vars方法来观察其中全部的Variables,从中选出一个或多个变量(Variable)来拟合teacher相应的变量。

# get all student variables
student_vars = []
for v in student_program.list_vars():
    student_vars.append((v.name, v.shape))
#uncomment the following lines to observe student's variables for distillation
print("="*50+"student_model_vars"+"="*50)
print(student_vars)

# get all teacher variables
teacher_vars = []
for v in teacher_program.list_vars():
    teacher_vars.append((v.name, v.shape))
#uncomment the following lines to observe teacher's variables for distillation
#print("="*50+"teacher_model_vars"+"="*50)
#print(teacher_vars)
==================================================student_model_vars==================================================
[('image', (-1, 1, 28, 28)), ('label', (-1, 1)), ('conv1_weights', (32, 1, 3, 3)), ('conv2d_0.tmp_0', (-1, 32, 14, 14)), ('conv1_bn_scale', (32,)), ('conv1_bn_offset', (32,)), ('conv1_bn_mean', (32,)), ('conv1_bn_variance', (32,)), ('batch_norm_0.tmp_0', (32,)), ('batch_norm_0.tmp_1', (32,)), ('batch_norm_0.tmp_2', (-1, 32, 14, 14)), ('batch_norm_0.tmp_3', (-1, 32, 14, 14)), ('conv2_1_dw_weights', (32, 1, 3, 3)), ('depthwise_conv2d_0.tmp_0', (-1, 32, 14, 14)), ('conv2_1_dw_bn_scale', (32,)), ('conv2_1_dw_bn_offset', (32,)), ('conv2_1_dw_bn_mean', (32,)), ('conv2_1_dw_bn_variance', (32,)), ('batch_norm_1.tmp_0', (32,)), ('batch_norm_1.tmp_1', (32,)), ('batch_norm_1.tmp_2', (-1, 32, 14, 14)), ('batch_norm_1.tmp_3', (-1, 32, 14, 14)), ('conv2_1_sep_weights', (64, 32, 1, 1)), ('conv2d_1.tmp_0', (-1, 64, 14, 14)), ('conv2_1_sep_bn_scale', (64,)), ('conv2_1_sep_bn_offset', (64,)), ('conv2_1_sep_bn_mean', (64,)), ('conv2_1_sep_bn_variance', (64,)), ('batch_norm_2.tmp_0', (64,)), ('batch_norm_2.tmp_1', (64,)), ('batch_norm_2.tmp_2', (-1, 64, 14, 14)), ('batch_norm_2.tmp_3', (-1, 64, 14, 14)), ('conv2_2_dw_weights', (64, 1, 3, 3)), ('depthwise_conv2d_1.tmp_0', (-1, 64, 7, 7)), ('conv2_2_dw_bn_scale', (64,)), ('conv2_2_dw_bn_offset', (64,)), ('conv2_2_dw_bn_mean', (64,)), ('conv2_2_dw_bn_variance', (64,)), ('batch_norm_3.tmp_0', (64,)), ('batch_norm_3.tmp_1', (64,)), ('batch_norm_3.tmp_2', (-1, 64, 7, 7)), ('batch_norm_3.tmp_3', (-1, 64, 7, 7)), ('conv2_2_sep_weights', (128, 64, 1, 1)), ('conv2d_2.tmp_0', (-1, 128, 7, 7)), ('conv2_2_sep_bn_scale', (128,)), ('conv2_2_sep_bn_offset', (128,)), ('conv2_2_sep_bn_mean', (128,)), ('conv2_2_sep_bn_variance', (128,)), ('batch_norm_4.tmp_0', (128,)), ('batch_norm_4.tmp_1', (128,)), ('batch_norm_4.tmp_2', (-1, 128, 7, 7)), ('batch_norm_4.tmp_3', (-1, 128, 7, 7)), ('conv3_1_dw_weights', (128, 1, 3, 3)), ('depthwise_conv2d_2.tmp_0', (-1, 128, 7, 7)), ('conv3_1_dw_bn_scale', (128,)), ('conv3_1_dw_bn_offset', (128,)), ('conv3_1_dw_bn_mean', (128,)), ('conv3_1_dw_bn_variance', (128,)), ('batch_norm_5.tmp_0', (128,)), ('batch_norm_5.tmp_1', (128,)), ('batch_norm_5.tmp_2', (-1, 128, 7, 7)), ('batch_norm_5.tmp_3', (-1, 128, 7, 7)), ('conv3_1_sep_weights', (128, 128, 1, 1)), ('conv2d_3.tmp_0', (-1, 128, 7, 7)), ('conv3_1_sep_bn_scale', (128,)), ('conv3_1_sep_bn_offset', (128,)), ('conv3_1_sep_bn_mean', (128,)), ('conv3_1_sep_bn_variance', (128,)), ('batch_norm_6.tmp_0', (128,)), ('batch_norm_6.tmp_1', (128,)), ('batch_norm_6.tmp_2', (-1, 128, 7, 7)), ('batch_norm_6.tmp_3', (-1, 128, 7, 7)), ('conv3_2_dw_weights', (128, 1, 3, 3)), ('depthwise_conv2d_3.tmp_0', (-1, 128, 4, 4)), ('conv3_2_dw_bn_scale', (128,)), ('conv3_2_dw_bn_offset', (128,)), ('conv3_2_dw_bn_mean', (128,)), ('conv3_2_dw_bn_variance', (128,)), ('batch_norm_7.tmp_0', (128,)), ('batch_norm_7.tmp_1', (128,)), ('batch_norm_7.tmp_2', (-1, 128, 4, 4)), ('batch_norm_7.tmp_3', (-1, 128, 4, 4)), ('conv3_2_sep_weights', (256, 128, 1, 1)), ('conv2d_4.tmp_0', (-1, 256, 4, 4)), ('conv3_2_sep_bn_scale', (256,)), ('conv3_2_sep_bn_offset', (256,)), ('conv3_2_sep_bn_mean', (256,)), ('conv3_2_sep_bn_variance', (256,)), ('batch_norm_8.tmp_0', (256,)), ('batch_norm_8.tmp_1', (256,)), ('batch_norm_8.tmp_2', (-1, 256, 4, 4)), ('batch_norm_8.tmp_3', (-1, 256, 4, 4)), ('conv4_1_dw_weights', (256, 1, 3, 3)), ('depthwise_conv2d_4.tmp_0', (-1, 256, 4, 4)), ('conv4_1_dw_bn_scale', (256,)), ('conv4_1_dw_bn_offset', (256,)), ('conv4_1_dw_bn_mean', (256,)), ('conv4_1_dw_bn_variance', (256,)), ('batch_norm_9.tmp_0', (256,)), ('batch_norm_9.tmp_1', (256,)), ('batch_norm_9.tmp_2', (-1, 256, 4, 4)), ('batch_norm_9.tmp_3', (-1, 256, 4, 4)), ('conv4_1_sep_weights', (256, 256, 1, 1)), ('conv2d_5.tmp_0', (-1, 256, 4, 4)), ('conv4_1_sep_bn_scale', (256,)), ('conv4_1_sep_bn_offset', (256,)), ('conv4_1_sep_bn_mean', (256,)), ('conv4_1_sep_bn_variance', (256,)), ('batch_norm_10.tmp_0', (256,)), ('batch_norm_10.tmp_1', (256,)), ('batch_norm_10.tmp_2', (-1, 256, 4, 4)), ('batch_norm_10.tmp_3', (-1, 256, 4, 4)), ('conv4_2_dw_weights', (256, 1, 3, 3)), ('depthwise_conv2d_5.tmp_0', (-1, 256, 2, 2)), ('conv4_2_dw_bn_scale', (256,)), ('conv4_2_dw_bn_offset', (256,)), ('conv4_2_dw_bn_mean', (256,)), ('conv4_2_dw_bn_variance', (256,)), ('batch_norm_11.tmp_0', (256,)), ('batch_norm_11.tmp_1', (256,)), ('batch_norm_11.tmp_2', (-1, 256, 2, 2)), ('batch_norm_11.tmp_3', (-1, 256, 2, 2)), ('conv4_2_sep_weights', (512, 256, 1, 1)), ('conv2d_6.tmp_0', (-1, 512, 2, 2)), ('conv4_2_sep_bn_scale', (512,)), ('conv4_2_sep_bn_offset', (512,)), ('conv4_2_sep_bn_mean', (512,)), ('conv4_2_sep_bn_variance', (512,)), ('batch_norm_12.tmp_0', (512,)), ('batch_norm_12.tmp_1', (512,)), ('batch_norm_12.tmp_2', (-1, 512, 2, 2)), ('batch_norm_12.tmp_3', (-1, 512, 2, 2)), ('conv5_1_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_6.tmp_0', (-1, 512, 2, 2)), ('conv5_1_dw_bn_scale', (512,)), ('conv5_1_dw_bn_offset', (512,)), ('conv5_1_dw_bn_mean', (512,)), ('conv5_1_dw_bn_variance', (512,)), ('batch_norm_13.tmp_0', (512,)), ('batch_norm_13.tmp_1', (512,)), ('batch_norm_13.tmp_2', (-1, 512, 2, 2)), ('batch_norm_13.tmp_3', (-1, 512, 2, 2)), ('conv5_1_sep_weights', (512, 512, 1, 1)), ('conv2d_7.tmp_0', (-1, 512, 2, 2)), ('conv5_1_sep_bn_scale', (512,)), ('conv5_1_sep_bn_offset', (512,)), ('conv5_1_sep_bn_mean', (512,)), ('conv5_1_sep_bn_variance', (512,)), ('batch_norm_14.tmp_0', (512,)), ('batch_norm_14.tmp_1', (512,)), ('batch_norm_14.tmp_2', (-1, 512, 2, 2)), ('batch_norm_14.tmp_3', (-1, 512, 2, 2)), ('conv5_2_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_7.tmp_0', (-1, 512, 2, 2)), ('conv5_2_dw_bn_scale', (512,)), ('conv5_2_dw_bn_offset', (512,)), ('conv5_2_dw_bn_mean', (512,)), ('conv5_2_dw_bn_variance', (512,)), ('batch_norm_15.tmp_0', (512,)), ('batch_norm_15.tmp_1', (512,)), ('batch_norm_15.tmp_2', (-1, 512, 2, 2)), ('batch_norm_15.tmp_3', (-1, 512, 2, 2)), ('conv5_2_sep_weights', (512, 512, 1, 1)), ('conv2d_8.tmp_0', (-1, 512, 2, 2)), ('conv5_2_sep_bn_scale', (512,)), ('conv5_2_sep_bn_offset', (512,)), ('conv5_2_sep_bn_mean', (512,)), ('conv5_2_sep_bn_variance', (512,)), ('batch_norm_16.tmp_0', (512,)), ('batch_norm_16.tmp_1', (512,)), ('batch_norm_16.tmp_2', (-1, 512, 2, 2)), ('batch_norm_16.tmp_3', (-1, 512, 2, 2)), ('conv5_3_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_8.tmp_0', (-1, 512, 2, 2)), ('conv5_3_dw_bn_scale', (512,)), ('conv5_3_dw_bn_offset', (512,)), ('conv5_3_dw_bn_mean', (512,)), ('conv5_3_dw_bn_variance', (512,)), ('batch_norm_17.tmp_0', (512,)), ('batch_norm_17.tmp_1', (512,)), ('batch_norm_17.tmp_2', (-1, 512, 2, 2)), ('batch_norm_17.tmp_3', (-1, 512, 2, 2)), ('conv5_3_sep_weights', (512, 512, 1, 1)), ('conv2d_9.tmp_0', (-1, 512, 2, 2)), ('conv5_3_sep_bn_scale', (512,)), ('conv5_3_sep_bn_offset', (512,)), ('conv5_3_sep_bn_mean', (512,)), ('conv5_3_sep_bn_variance', (512,)), ('batch_norm_18.tmp_0', (512,)), ('batch_norm_18.tmp_1', (512,)), ('batch_norm_18.tmp_2', (-1, 512, 2, 2)), ('batch_norm_18.tmp_3', (-1, 512, 2, 2)), ('conv5_4_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_9.tmp_0', (-1, 512, 2, 2)), ('conv5_4_dw_bn_scale', (512,)), ('conv5_4_dw_bn_offset', (512,)), ('conv5_4_dw_bn_mean', (512,)), ('conv5_4_dw_bn_variance', (512,)), ('batch_norm_19.tmp_0', (512,)), ('batch_norm_19.tmp_1', (512,)), ('batch_norm_19.tmp_2', (-1, 512, 2, 2)), ('batch_norm_19.tmp_3', (-1, 512, 2, 2)), ('conv5_4_sep_weights', (512, 512, 1, 1)), ('conv2d_10.tmp_0', (-1, 512, 2, 2)), ('conv5_4_sep_bn_scale', (512,)), ('conv5_4_sep_bn_offset', (512,)), ('conv5_4_sep_bn_mean', (512,)), ('conv5_4_sep_bn_variance', (512,)), ('batch_norm_20.tmp_0', (512,)), ('batch_norm_20.tmp_1', (512,)), ('batch_norm_20.tmp_2', (-1, 512, 2, 2)), ('batch_norm_20.tmp_3', (-1, 512, 2, 2)), ('conv5_5_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_10.tmp_0', (-1, 512, 2, 2)), ('conv5_5_dw_bn_scale', (512,)), ('conv5_5_dw_bn_offset', (512,)), ('conv5_5_dw_bn_mean', (512,)), ('conv5_5_dw_bn_variance', (512,)), ('batch_norm_21.tmp_0', (512,)), ('batch_norm_21.tmp_1', (512,)), ('batch_norm_21.tmp_2', (-1, 512, 2, 2)), ('batch_norm_21.tmp_3', (-1, 512, 2, 2)), ('conv5_5_sep_weights', (512, 512, 1, 1)), ('conv2d_11.tmp_0', (-1, 512, 2, 2)), ('conv5_5_sep_bn_scale', (512,)), ('conv5_5_sep_bn_offset', (512,)), ('conv5_5_sep_bn_mean', (512,)), ('conv5_5_sep_bn_variance', (512,)), ('batch_norm_22.tmp_0', (512,)), ('batch_norm_22.tmp_1', (512,)), ('batch_norm_22.tmp_2', (-1, 512, 2, 2)), ('batch_norm_22.tmp_3', (-1, 512, 2, 2)), ('conv5_6_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_11.tmp_0', (-1, 512, 1, 1)), ('conv5_6_dw_bn_scale', (512,)), ('conv5_6_dw_bn_offset', (512,)), ('conv5_6_dw_bn_mean', (512,)), ('conv5_6_dw_bn_variance', (512,)), ('batch_norm_23.tmp_0', (512,)), ('batch_norm_23.tmp_1', (512,)), ('batch_norm_23.tmp_2', (-1, 512, 1, 1)), ('batch_norm_23.tmp_3', (-1, 512, 1, 1)), ('conv5_6_sep_weights', (1024, 512, 1, 1)), ('conv2d_12.tmp_0', (-1, 1024, 1, 1)), ('conv5_6_sep_bn_scale', (1024,)), ('conv5_6_sep_bn_offset', (1024,)), ('conv5_6_sep_bn_mean', (1024,)), ('conv5_6_sep_bn_variance', (1024,)), ('batch_norm_24.tmp_0', (1024,)), ('batch_norm_24.tmp_1', (1024,)), ('batch_norm_24.tmp_2', (-1, 1024, 1, 1)), ('batch_norm_24.tmp_3', (-1, 1024, 1, 1)), ('conv6_dw_weights', (1024, 1, 3, 3)), ('depthwise_conv2d_12.tmp_0', (-1, 1024, 1, 1)), ('conv6_dw_bn_scale', (1024,)), ('conv6_dw_bn_offset', (1024,)), ('conv6_dw_bn_mean', (1024,)), ('conv6_dw_bn_variance', (1024,)), ('batch_norm_25.tmp_0', (1024,)), ('batch_norm_25.tmp_1', (1024,)), ('batch_norm_25.tmp_2', (-1, 1024, 1, 1)), ('batch_norm_25.tmp_3', (-1, 1024, 1, 1)), ('conv6_sep_weights', (1024, 1024, 1, 1)), ('conv2d_13.tmp_0', (-1, 1024, 1, 1)), ('conv6_sep_bn_scale', (1024,)), ('conv6_sep_bn_offset', (1024,)), ('conv6_sep_bn_mean', (1024,)), ('conv6_sep_bn_variance', (1024,)), ('batch_norm_26.tmp_0', (1024,)), ('batch_norm_26.tmp_1', (1024,)), ('batch_norm_26.tmp_2', (-1, 1024, 1, 1)), ('batch_norm_26.tmp_3', (-1, 1024, 1, 1)), ('pool2d_0.tmp_0', (-1, 1024, 1, 1)), ('fc7_weights', (1024, 10)), ('fc_0.tmp_0', (-1, 10)), ('fc7_offset', (10,)), ('fc_0.tmp_1', (-1, 10)), ('fc_0.tmp_2', (-1, 10)), ('cross_entropy2_0.tmp_0', (-1, 1)), ('cross_entropy2_0.tmp_1', (-1, 10, 0)), ('cross_entropy2_0.tmp_2', (-1, 1)), ('mean_0.tmp_0', (1,)), ('top_k_0.tmp_0', (-1, 1)), ('top_k_0.tmp_1', (-1, 1)), ('accuracy_0.tmp_0', (1,)), ('accuracy_0.tmp_1', (1,)), ('accuracy_0.tmp_2', (1,)), ('top_k_1.tmp_0', (-1, 5)), ('top_k_1.tmp_1', (-1, 5)), ('accuracy_1.tmp_0', (1,)), ('accuracy_1.tmp_1', (1,)), ('accuracy_1.tmp_2', (1,))]

经过筛选我们可以看到,teacher_program中的’bn5c_branch2b.output.1.tmp_3’和student_program的’depthwise_conv2d_11.tmp_0’尺寸一致,可以组成蒸馏损失函数。

4. 合并program (merge)并添加蒸馏loss

merge操作将student_program和teacher_program中的所有Variables和Op都将被添加到同一个Program中,同时为了避免两个program中有同名变量会引起命名冲突,merge也会为teacher_program中的Variables添加一个同一的命名前缀name_prefix,其默认值是’teacher_’

为了确保teacher网络和student网络输入的数据是一样的,merge操作也会对两个program的输入数据层进行合并操作,所以需要指定一个数据层名称的映射关系data_name_map,key是teacher的输入数据名称,value是student的。

data_name_map = {'image': 'image'}
main = slim.dist.merge(teacher_program, student_program, data_name_map, fluid.CPUPlace()) # 合并教师网络到学生网络
with fluid.program_guard(student_program, student_startup):
    # 添加蒸馏loss
    l2_loss = slim.dist.l2_loss('teacher_bn5c_branch2b.output.1.tmp_3', 'depthwise_conv2d_11.tmp_0', student_program)
    # 同时优化蒸馏loss和学生网络原来的分类loss
    loss = l2_loss + avg_cost
    opt = fluid.optimizer.Momentum(0.01, 0.9)
    opt.minimize(loss) # 添加反向计算和优化相关的操作
exe.run(student_startup) # 初始化学生网络
[]

5. 模型训练

为了快速执行该示例,我们选取简单的MNIST数据,Paddle框架的paddle.dataset.mnist包定义了MNIST数据的下载和读取。 代码如下:

train_reader = paddle.batch(
    paddle.dataset.mnist.train(), batch_size=128, drop_last=True) # 定义读取MNIST训练数据的生成器
train_feeder = fluid.DataFeeder(['image', 'label'], fluid.CPUPlace(), student_program) # 定义读取MNIST测试数据的生成器


for data in train_reader(): # 训练一个epoch
    acc1, acc5, loss_np = exe.run(student_program, feed=train_feeder.feed(data), fetch_list=[acc_top1.name, acc_top5.name, loss.name])
    print("Acc1: {:.6f}, Acc5: {:.6f}, Loss: {:.6f}".format(acc1.mean(), acc5.mean(), loss_np.mean()))
Acc1: 0.062500, Acc5: 0.382812, Loss: 221.120117
Acc1: 0.101562, Acc5: 0.507812, Loss: 222.530258
Acc1: 0.070312, Acc5: 0.500000, Loss: 222.556686
Acc1: 0.117188, Acc5: 0.437500, Loss: 220.210587
Acc1: 0.101562, Acc5: 0.562500, Loss: 217.044830
Acc1: 0.117188, Acc5: 0.484375, Loss: 215.024094
Acc1: 0.109375, Acc5: 0.500000, Loss: 218.645142
Acc1: 0.125000, Acc5: 0.578125, Loss: 214.771973
Acc1: 0.203125, Acc5: 0.554688, Loss: 210.950424
Acc1: 0.171875, Acc5: 0.554688, Loss: 220.628342
Acc1: 0.140625, Acc5: 0.546875, Loss: 217.624451
Acc1: 0.093750, Acc5: 0.562500, Loss: 211.037888