AI应用场景演变
随着国家强有力的推进5G移动网络的发展,更多的移动应用和智能终端将快速占领消费者的生活,人工智能在提升用户体验和商业服务方面都是必不可少的一环。近年来,整个互联网的战场逐渐从移动生态向人工智能转移,“夯实移动基础决战AI时刻”已成为诸多公司的战略方向。Feed流、小视频、直播等类型应用将消费者的娱乐时间锁定在智能手机,与此同时,语音智能音响、可穿戴设备、车载智能设备和智能家居的兴起更是表明物联网时代已经到来。从PC互联网到移动互联网,再到智能物联网,百度一直在使用AI能力让搜索这件事情更简单,从PC搜索框到手百feed,再到智能音响,人们获取信息的方式更简单更智能,而在这背后,计算、感知、决策等环节,已经从数据中心逐渐转移到了边缘设备。
消费者需求
深度学习模型在各种消费者智能终端的高效部署已成为重要的需求。边缘设备这种特殊的部署环境就给我们的AI模型提出了新的挑战。受能耗和设备体积的限制,智能终端计算性能和存储能力相对较弱,突出的诉求主要体现在以下三点:
以上诉求都需要我们根据终端环境对现有模型进行小型化处理,在不损失精度的情况下,让模型的体积更小、速度更快,能耗更低。
学术需求
当前关于深度学习的研究都是基于一些开源的标准数据集和任务,比如用于分类模型的ImageNet数据集、用于检测任务的COCO数据集、用于分割任务的CityScape数据集,而各个学者和机构发布的最先进的模型都是基于这类开源任务。
但是实际生产环境中的数据和任务的复杂度,都与这些开源的任务有较大的差异。直接将针对开源任务设计优化的模型用到实际需求中,难免会有不相关/弱相关/冗余的信息,这就需要我们针对特定任务和场景对已有的开源模型进行裁剪/压缩/优化。
直接设计小模型?
产出小模型的常见方式,无非是:设计更高效的网络结构、将模型的参数量变少、将模型的计算量减少,同时提高模型的精度。可能有人会提出疑问,为什么不直接设计一个小的CNN?要知道,实际业务子垂类众多,任务复杂度不同,在这种情况下,人工设计小模型难度非常大,需要非常强的领域知识。而模型压缩可以在经典小模型的基础上,稍作处理就可以快速拔高模型的各项性能,达到“多快好省”的目的。
上图是分类模型使用了蒸馏和量化的效果图,横轴是推理耗时,纵轴是模型准确率。可以看出,在人工设计的经典小模型基础上,经过蒸馏和量化可以进一步提升模型的精度和推理速度。
PaddleSlim是国内第一个拥有完全自主知识产权、功能最完备、完全开源的深度学习模型压缩工具,集深度学习模型压缩中常用的量化、剪裁、蒸馏、模型结构搜索、模型硬件搜索等方法于一体。其兼具灵活和效率的开发机制、工业级应用的压缩模型产出、超大规模并行模型结构搜索能力、产出的压缩模型与tensorRT及PaddleLite推理引擎无缝匹配、压缩算法对PaddlePaddle模型仓库全能力支持、完善的中英文文档、更多的模型压缩结果及使用示例等特性。PaddleSlim致力于让深度学习技术产出的模型更容易在工业落地。
PaddleSlim的特点
PaddleSlim支持所有主流的模型压缩方法,包括剪裁、量化、蒸馏和NAS;同时支持多种压缩策略组合使用;并且配置操作也非常简单。
简单概括来说:剪裁,是通过剪掉卷积参数来让大模型的参数量变少。量化,是将Float32格式(4个字节)的模型转为int格式(2个字节),来减少计算量和模型体积。蒸馏,是将大模型的知识迁移到小模型,从而提高小模型的精度。NAS,是以模型大小和推理速度为约束的模型结构搜索,通过搜索来自动设计更高效的网络结构。下图是PaddleSlim的功能全景图。
剪裁
深度学习模型中有很多冗余的卷积参数,把这些参数去掉可以大大减少参数量,从而加快推理速度。我们来看下这张图,这张图很直观的解释了剪裁是什么。A是一个feature map, 右边的W1是卷积层的参数,有很多个卷积核,每个卷积核都对应一个输出通道,剪裁会把一些不重要的卷积核去掉,相应的卷积之后的输出B的通道数会减少。B后面的卷积参数W的每个卷积核的通道数也会减少,卷积核的个数不变,最后输出的结果C是和没有裁剪的时候是一样的。
确定剪掉哪些卷积核,可以通过以下方式:
对于一个卷积内的卷积核,可以通过以下规则对卷积核进行重要性排序,然后减掉不重要的卷积核。比如,输入只有一个通道,三个卷积核,如下图所示。
卷积内的卷积核做好重要性排序,下一步是如何确定该卷积需要裁剪多少比例,这就用到了敏感度分析。如下图所示,横轴是一个卷积的裁剪比例,纵轴是裁剪这个比例后,模型在测试集上的精度。假设模型未裁剪前的精度是90。卷积1,裁剪比例依次为25%,50%,75%,计算对应裁剪比例后的模型在测试集上的精度,然后依次用红色点表示精度,并将红色点用折线连接起来绘制成敏感度曲线。然后恢复卷积1,对卷积K依次同比例裁剪,对应的精度分别用绿色点表示。比如上面这个折线,是卷积K的精度走向,在卷积K被剪掉25%,50%,75%后,模型精度下降的不多,说明模型对剪掉卷积K的卷积核不敏感,因此卷积K的裁剪比例可以大一些。在使用敏感度确定卷积层的重要性之后,即可使用L0_norm、L1_norm或者FPGM评估卷积核的重要性,并对卷积核进行裁剪。
量化训练
量化是将浮点数映射到低比特int,目的是减少模型体积和计算量,加快推理速度。量化训练包括以下几种方式:在线量化、离线量化、 Embedding 量化和只量化weight。
如果没有训练数据,可以选Embedding量化或者只量化weight,Embedding量化只针对NLP任务。只量化weight是指量化卷积或者全连接参数,它的好处是能够减小模型体积,增大读数据时的带宽,对IO密集型计算加速明显。
如果有训练数据,可以使用离线量化或者量化训练,两者区别在于是否需要re-train。好处是可以提高模型精度,减小模型体积。离线量化可以减少量化训练的开销,通常情况下大部分模型在不进行re-train的情况下也能达到较高的精度。
蒸馏
蒸馏是将复杂网络(老师模型)的知识迁移到小网络(学生模型)中。蒸馏的目的是为了提高小网络的精度。蒸馏的核心知识点在于如何将老师模型的知识迁移到学生网络,通常做法是使用老师模型的输出信息去监督学生模型的训练。设计算法的关键是设计一个合适的蒸馏loss:在老师与学生之间,组建一个loss,监督学生学习。
NAS
NAS(network architecture search)是一种自动设计神经网络的技术,根据搜索空间,使用一定的搜索算法来自动设计出高性能的网络结构。目的是自动设计高效的网络结构。Nas的搜索过程如下图所示,使用搜素策略在搜索空间中搜出来模型,对这个模型进行性能评估,用性能评估的结果来反馈给搜索策略。
分类模型
数据: ImageNet2012;模型: MobileNetV1
压缩策略 | 精度收益(baseline: 70.91%) | 模型大小(baseline: 17.0M) |
---|---|---|
知识蒸馏(ResNet50) | +1.06% | - |
知识蒸馏(ResNet50) | + int8量化训练 +1.10% | -71.76% |
剪裁(FLOPs-50%) | + int8量化训练 -1.71% | -86.47% |
图像检测模型
数据:Pascal VOC;模型:MobileNet-V1-YOLOv3
压缩方法 | mAP(baseline: 76.2%) | 模型大小(baseline: 94MB) |
---|---|---|
知识蒸馏(ResNet34-YOLOv3) | +2.8% | - |
剪裁 FLOPs -52.88% | +1.4% | -67.76% |
知识蒸馏(ResNet34-YOLOv3)+剪裁(FLOPs-69.57%) | +2.6% | -67.00% |
数据:COCO;模型:MobileNet-V1-YOLOv3
压缩方法 | mAP(baseline: 29.3%) | 模型大小 |
---|---|---|
知识蒸馏(ResNet34-YOLOv3) | +2.1% | |
知识蒸馏(ResNet34-YOLOv3)+剪裁(FLOPs-67.56%) | -0.3% | -66.90% |
搜索模型
数据:ImageNet2012;模型:MobileNetV2
硬件环境 | 推理耗时 | Top1准确率(baseline:71.90%) |
---|---|---|
RK3288 | -23% | +0.07% |
Android cellphone | -20% | +0.16% |
iPhone 6s | -17% | +0.32% |
# 安装PaddleSlim
!pip install --upgrade paddleslim -i https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Requirement already up-to-date: paddleslim in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (1.0.1) Requirement already satisfied, skipping upgrade: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleslim) (4.36.1)
下面用简单的示例介绍如何使用PaddleSlim提供的接口进行模型的剪裁和蒸馏操作。
其它更多模型压缩教程,请参考:https://paddlepaddle.github.io/PaddleSlim/index.html
本教程以图像分类模型MobileNetV1为例,说明如何快速使用PaddleSlim的敏感度分析接口对卷积层的通道进行敏感度分析,然后根据敏感度分析进行剪裁。其它包含卷积层的模型也可以参考该教程中的步骤进行操作。
该示例内容部分参考论文:Pruning Filters for Efficient ConvNets
该示例包含以下步骤:
以上步骤依次依赖前一个步骤。
以下章节依次介绍每个步骤的内容。
PaddleSlim依赖飞桨,请确认已正确安装飞桨,然后按以下方式导入飞桨和PaddleSlim。
import paddle
import paddle.fluid as fluid
import paddleslim as slim
use_gpu = False
place = fluid.CPUPlace() # 指定运行网络训练的设备
if paddle.fluid.is_compiled_with_cuda():
use_gpu = True
place = fluid.CUDAPlace() # 指定运行网络训练的设备
# 构建用于训练和测试的Program,同时返回网络的输入和输出变量
exe, train_program, val_program, inputs, outputs = slim.models.image_classification("MobileNet", [1, 28, 28], 10, use_gpu=use_gpu)
import paddle.dataset.mnist as reader
train_reader = paddle.batch(
reader.train(), batch_size=128, drop_last=True) # 用于读取训练数据的生成器
test_reader = paddle.batch(
reader.test(), batch_size=128, drop_last=True) # 用于读取测试数据的生成器
data_feeder = fluid.DataFeeder(inputs, place) # 定义生成器产出的数据与网络输入的对应关系
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz Begin to download Download finished Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz Begin to download ........ Download finished Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz Begin to download Download finished Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz Begin to download .. Download finished
import numpy as np
def test(program):
acc_top1_ns = []
acc_top5_ns = []
for data in test_reader():
acc_top1_n, acc_top5_n, _ = exe.run( # 对一个batch数据执行测试网络
program,
feed=data_feeder.feed(data),
fetch_list=outputs)
acc_top1_ns.append(np.mean(acc_top1_n)) # 收集当前batch数据的精度指标
acc_top5_ns.append(np.mean(acc_top5_n))
print("Final eva - acc_top1: {}; acc_top5: {}".format(
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)))) # 平均多个batch测试句的精度指标并打印出来
return np.mean(np.array(acc_top1_ns)) # 返回整个测试数据集上的top1准确率
for data in train_reader():
acc1, acc5, loss = exe.run(train_program, feed=data_feeder.feed(data), fetch_list=outputs) # 训练一个batch
print(np.mean(acc1), np.mean(acc5), np.mean(loss))
1.0 1.0 0.014517799
用上节定义的模型评估方法,评估当前模型在测试集上的精度:
test(val_program)
Final eva - acc_top1: 0.9665464758872986; acc_top5: 0.9988982081413269
0.9665465
params = []
for param in train_program.global_block().all_parameters(): # 遍历所有参数
if "_sep_weights" in param.name: # 用参数名过滤需要分析的参数
params.append(param.name)
print(params)
params = params[:5] # 为方便展示,我们仅取前5个参数来分析
['conv2_1_sep_weights', 'conv2_2_sep_weights', 'conv3_1_sep_weights', 'conv3_2_sep_weights', 'conv4_1_sep_weights', 'conv4_2_sep_weights', 'conv5_1_sep_weights', 'conv5_2_sep_weights', 'conv5_3_sep_weights', 'conv5_4_sep_weights', 'conv5_5_sep_weights', 'conv5_6_sep_weights', 'conv6_sep_weights']
7-1 简单计算敏感度
调用sensitivity接口对训练好的模型进行敏感度分析。
在计算过程中,敏感度信息会不断追加保存到选项sensitivities_file
指定的文件中,该文件中已有的敏感度信息不会被重复计算。
先用以下命令删除当前路径下可能已有的sensitivities_0.data
文件。
!rm -rf sensitivities_0.data
除了指定待分析的卷积层参数,我们还可以指定敏感度分析的粒度和范围,即单个卷积层参数分别被剪裁掉的比例。
如果待分析的模型比较敏感,剪掉单个卷积层的40%的通道,模型在测试集上的精度损失就达90%,那么pruned_ratios
最大设置到0.4即可,比如:
[0.1, 0.2, 0.3, 0.4]
为了得到更精确的敏感度信息,我可以适当调小pruned_ratios
的粒度,比如:[0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]
pruned_ratios
的粒度越小,计算敏感度的速度越慢。
sens_0 = slim.prune.sensitivity( # 调用敏感度计算接口
val_program, # 使用测试网络进行敏感度评估
place, # 指定之前定义的设备信息
params, # 指定要分析的参数
test, # 指定用于评估模型的方法
sensitivities_file="sensitivities_0.data", # 保存敏感度信息的文件
pruned_ratios=[0.1, 0.2]) # 依次剪裁掉的比例
print(sens_0)
2020-05-11 11:22:23,898-INFO: sensitive - param: conv2_1_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9665464758872986; acc_top5: 0.9988982081413269
2020-05-11 11:23:27,210-INFO: pruned param: conv2_1_sep_weights; 0.1; loss=0.003316054353490472 2020-05-11 11:23:27,211-INFO: sensitive - param: conv2_1_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9633413553237915; acc_top5: 0.9986979365348816
2020-05-11 11:24:30,642-INFO: pruned param: conv2_1_sep_weights; 0.2; loss=0.003419717540964484 2020-05-11 11:24:30,644-INFO: sensitive - param: conv2_2_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9632411599159241; acc_top5: 0.9987980723381042
2020-05-11 11:25:34,129-INFO: pruned param: conv2_2_sep_weights; 0.1; loss=0.0004145299026276916 2020-05-11 11:25:34,131-INFO: sensitive - param: conv2_2_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9661458134651184; acc_top5: 0.9987980723381042
2020-05-11 11:26:37,465-INFO: pruned param: conv2_2_sep_weights; 0.2; loss=0.006321242079138756 2020-05-11 11:26:37,467-INFO: sensitive - param: conv3_1_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9604367017745972; acc_top5: 0.9983974099159241
2020-05-11 11:27:41,218-INFO: pruned param: conv3_1_sep_weights; 0.1; loss=-0.000725396501366049 2020-05-11 11:27:41,221-INFO: sensitive - param: conv3_1_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9672476053237915; acc_top5: 0.9989984035491943
2020-05-11 11:28:45,097-INFO: pruned param: conv3_1_sep_weights; 0.2; loss=0.007253903429955244 2020-05-11 11:28:45,099-INFO: sensitive - param: conv3_2_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9595352411270142; acc_top5: 0.9985977411270142
2020-05-11 11:29:48,997-INFO: pruned param: conv3_2_sep_weights; 0.1; loss=0.00259065767750144 2020-05-11 11:29:48,999-INFO: sensitive - param: conv3_2_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9640424847602844; acc_top5: 0.9988982081413269
2020-05-11 11:30:52,993-INFO: pruned param: conv3_2_sep_weights; 0.2; loss=0.016580332070589066 2020-05-11 11:30:52,994-INFO: sensitive - param: conv4_1_sep_weights; ratios: 0.1
Final eva - acc_top1: 0.9505208134651184; acc_top5: 0.9989984035491943
2020-05-11 11:31:56,862-INFO: pruned param: conv4_1_sep_weights; 0.1; loss=0.0009326614672318101 2020-05-11 11:31:56,863-INFO: sensitive - param: conv4_1_sep_weights; ratios: 0.2
Final eva - acc_top1: 0.9656450152397156; acc_top5: 0.9991987347602844
2020-05-11 11:33:00,752-INFO: pruned param: conv4_1_sep_weights; 0.2; loss=0.011813485994935036
Final eva - acc_top1: 0.9551281929016113; acc_top5: 0.9984976053237915 {'conv2_1_sep_weights': {0.1: 0.0033160544, 0.2: 0.0034197175}, 'conv2_2_sep_weights': {0.1: 0.0004145299, 0.2: 0.006321242}, 'conv3_1_sep_weights': {0.1: -0.0007253965, 0.2: 0.0072539034}, 'conv3_2_sep_weights': {0.1: 0.0025906577, 0.2: 0.016580332}, 'conv4_1_sep_weights': {0.1: 0.00093266147, 0.2: 0.011813486}}
7-2 扩展敏感度信息
前边计算敏感度用的是pruned_ratios=[0.1, 0.2]
, 我们可以在此基础上将其扩展到[0.1, 0.2, 0.3]
sens_0 = slim.prune.sensitivity(# 调用敏感度计算接口
val_program, # 使用测试网络进行敏感度评估
place, # 指定之前定义的设备信息
params, # 指定要分析的参数
test, # 指定用于评估模型的方法
sensitivities_file="sensitivities_0.data", # 保存敏感度信息的文件
pruned_ratios=[0.3]) # 依次剪裁掉的比例
print(sens_0)
2020-05-11 12:04:00,400-INFO: sensitive - param: conv2_1_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9665464758872986; acc_top5: 0.9988982081413269
2020-05-11 12:05:03,899-INFO: pruned param: conv2_1_sep_weights; 0.3; loss=0.007046638522297144 2020-05-11 12:05:03,901-INFO: sensitive - param: conv2_2_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9597355723381042; acc_top5: 0.9984976053237915
2020-05-11 12:06:07,507-INFO: pruned param: conv2_2_sep_weights; 0.3; loss=0.009015562944114208 2020-05-11 12:06:07,508-INFO: sensitive - param: conv3_1_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9578325152397156; acc_top5: 0.9983974099159241
2020-05-11 12:07:10,976-INFO: pruned param: conv3_1_sep_weights; 0.3; loss=0.03471500054001808 2020-05-11 12:07:10,978-INFO: sensitive - param: conv3_2_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9329928159713745; acc_top5: 0.9976963400840759
2020-05-11 12:08:14,431-INFO: pruned param: conv3_2_sep_weights; 0.3; loss=0.05564764887094498 2020-05-11 12:08:14,432-INFO: sensitive - param: conv4_1_sep_weights; ratios: 0.3
Final eva - acc_top1: 0.9127604365348816; acc_top5: 0.9971955418586731
2020-05-11 12:09:17,918-INFO: pruned param: conv4_1_sep_weights; 0.3; loss=0.03233160451054573
Final eva - acc_top1: 0.9352964758872986; acc_top5: 0.9974960088729858 {'conv2_1_sep_weights': {0.1: 0.0033160544, 0.2: 0.0034197175, 0.3: 0.0070466385}, 'conv2_2_sep_weights': {0.1: 0.0004145299, 0.2: 0.006321242, 0.3: 0.009015563}, 'conv3_1_sep_weights': {0.1: -0.0007253965, 0.2: 0.0072539034, 0.3: 0.034715}, 'conv3_2_sep_weights': {0.1: 0.0025906577, 0.2: 0.016580332, 0.3: 0.05564765}, 'conv4_1_sep_weights': {0.1: 0.00093266147, 0.2: 0.011813486, 0.3: 0.032331605}}
7-3 多进程加速计算敏感度信息
敏感度分析所用时间取决于待分析的卷积层数量和模型评估的速度,我们可以通过多进程的方式加速敏感度计算。
在不同的进程设置不同pruned_ratios
, 然后将结果合并。
7-3-1 多进程计算敏感度
在以上章节,我们计算了pruned_ratios=[0.1, 0.2, 0.3]
的敏感度,并将其保存到了文件sensitivities_0.data
中。
在另一个进程中,我们可以设置pruned_ratios=[0.4]
,并将结果保存在文件sensitivities_1.data
中。代码如下:
sens_1 = slim.prune.sensitivity( # 调用敏感度计算接口
val_program, # 使用测试网络进行敏感度评估
place, # 指定之前定义的设备信息
params, # 指定要分析的参数
test, # 指定用于评估模型的方法
sensitivities_file="sensitivities_1.data", # 保存敏感度信息的文件
pruned_ratios=[0.4]) # 依次剪裁掉的比例
print(sens_1)
{'conv2_1_sep_weights': {0.4: 0.015814325}, 'conv3_1_sep_weights': {0.4: 0.06366808}, 'conv4_1_sep_weights': {0.4: 0.056377064}, 'conv3_2_sep_weights': {0.4: 0.027110271}, 'conv2_2_sep_weights': {0.4: 0.03727665}}
7-3-2 加载多个进程产出的敏感度文件
s_0 = slim.prune.load_sensitivities("sensitivities_0.data") # 从文件中加载敏感度信息
s_1 = slim.prune.load_sensitivities("sensitivities_1.data")
print(s_0)
print(s_1)
{'conv2_1_sep_weights': {0.1: 0.0033160544, 0.2: 0.0034197175, 0.3: 0.0070466385}, 'conv2_2_sep_weights': {0.1: 0.0004145299, 0.2: 0.006321242, 0.3: 0.009015563}, 'conv3_1_sep_weights': {0.1: -0.0007253965, 0.2: 0.0072539034, 0.3: 0.034715}, 'conv3_2_sep_weights': {0.1: 0.0025906577, 0.2: 0.016580332, 0.3: 0.05564765}, 'conv4_1_sep_weights': {0.1: 0.00093266147, 0.2: 0.011813486, 0.3: 0.032331605}} {'conv2_1_sep_weights': {0.4: 0.015814325}, 'conv3_1_sep_weights': {0.4: 0.06366808}, 'conv4_1_sep_weights': {0.4: 0.056377064}, 'conv3_2_sep_weights': {0.4: 0.027110271}, 'conv2_2_sep_weights': {0.4: 0.03727665}}
7-3-3 合并敏感度信息
s = slim.prune.merge_sensitive([s_0, s_1]) # 合并敏感度信息
print(s)
{'conv2_1_sep_weights': {0.1: 0.0033160544, 0.2: 0.0034197175, 0.3: 0.0070466385, 0.4: 0.015814325}, 'conv2_2_sep_weights': {0.1: 0.0004145299, 0.2: 0.006321242, 0.3: 0.009015563, 0.4: 0.03727665}, 'conv3_1_sep_weights': {0.1: -0.0007253965, 0.2: 0.0072539034, 0.3: 0.034715, 0.4: 0.06366808}, 'conv3_2_sep_weights': {0.1: 0.0025906577, 0.2: 0.016580332, 0.3: 0.05564765, 0.4: 0.027110271}, 'conv4_1_sep_weights': {0.1: 0.00093266147, 0.2: 0.011813486, 0.3: 0.032331605, 0.4: 0.056377064}}
根据以上章节产出的敏感度信息,对模型进行剪裁。
8-1 计算剪裁率
首先,调用PaddleSlim提供的get_ratios_by_loss方法根据敏感度计算剪裁率,通过调整参数loss
大小获得合适的一组剪裁率:
loss = 0.01
ratios = slim.prune.get_ratios_by_loss(s_0, loss) # 根据敏感度计算一组剪裁率
print(ratios)
{'conv2_1_sep_weights': 0.3, 'conv2_2_sep_weights': 0.3, 'conv3_1_sep_weights': 0.20999995231656052, 'conv3_2_sep_weights': 0.15296293795649352, 'conv4_1_sep_weights': 0.1833331959422747}
8-2 剪裁训练网络
pruner = slim.prune.Pruner() # 构造一个Pruner实例
print("FLOPs before pruning: {}".format(slim.analysis.flops(train_program))) # 剪裁前的FLOPs
pruned_program, _, _ = pruner.prune(
train_program, # 待剪裁的训练网络
fluid.global_scope(), # 保存模型参数数值的scope
params=ratios.keys(), # 待剪裁的参数名称
ratios=ratios.values(), # 待剪裁比例
place=place) # 参数所在设备
print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_program))) # 剪裁后的FLOPs
FLOPs before pruning: 10896832.0 FLOPs after pruning: 9588167.0
8-3 剪裁测试网络
注意:对测试网络进行剪裁时,需要将
only_graph
设置为True,具体原因请参考Pruner API文档。
pruner = slim.prune.Pruner()
print("FLOPs before pruning: {}".format(slim.analysis.flops(val_program)))
pruned_val_program, _, _ = pruner.prune(
val_program, # 对测试网络进行剪裁
fluid.global_scope(),
params=ratios.keys(),
ratios=ratios.values(),
place=place,
only_graph=True) # 只对网络结构进行剪裁,不修改scope中的参数数值
print("FLOPs after pruning: {}".format(slim.analysis.flops(pruned_val_program)))
FLOPs before pruning: 10896832.0 FLOPs after pruning: 9588167.0
测试一下剪裁后的模型在测试集上的精度:
test(pruned_val_program)
Final eva - acc_top1: 0.7620192170143127; acc_top5: 0.9674479365348816
0.7620192
8-4 训练剪裁后的模型
对剪裁后的模型在训练集上训练一个epoch
。
for data in train_reader():
acc1, acc5, loss = exe.run(pruned_program, feed=data_feeder.feed(data), fetch_list=outputs)
print(np.mean(acc1), np.mean(acc5), np.mean(loss))
1.0 1.0 0.004613471
测试训练后模型的精度:
test(pruned_val_program)
Final eva - acc_top1: 0.9762620329856873; acc_top5: 0.9992988705635071
0.97626203
注意:在执行本章节代码前请重启环境。
该教程以图像分类模型MobileNetV1为例,说明如何快速使用PaddleSlim的知识蒸馏接口。 该示例包含以下步骤:
以下章节依次介绍每个步骤的内容。
PaddleSlim依赖Paddle1.7版本,请确认已正确安装Paddle,然后按以下方式导入Paddle和PaddleSlim:
import paddle
import paddle.fluid as fluid
import paddleslim as slim
from paddleslim.models.mobilenet import MobileNet
model = MobileNet()
student_program = fluid.Program() # 构造一个新的program用于学生网络的训练
student_startup = fluid.Program() # 构造用于学生网络初始化的program
with fluid.program_guard(student_program, student_startup):
image = fluid.data(
name='image', shape=[None] + [1, 28, 28], dtype='float32') # 声明输入
label = fluid.data(name='label', shape=[None, 1], dtype='int64') # 声明输入
out = model.net(input=image, class_dim=10) # 定义网络结构
cost = fluid.layers.cross_entropy(input=out, label=label) # 添加cost
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) # 计算精度
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
from paddleslim.models.resnet import ResNet50
teacher_model = ResNet50()
teacher_program = fluid.Program()
teacher_startup = fluid.Program()
with fluid.program_guard(teacher_program, teacher_startup):
with fluid.unique_name.guard():
image = fluid.data(
name='image', shape=[None] + [1, 28, 28], dtype='float32')
predict = teacher_model.net(image, class_dim=10)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(teacher_startup)
[]
# get all student variables
student_vars = []
for v in student_program.list_vars():
student_vars.append((v.name, v.shape))
#uncomment the following lines to observe student's variables for distillation
print("="*50+"student_model_vars"+"="*50)
print(student_vars)
# get all teacher variables
teacher_vars = []
for v in teacher_program.list_vars():
teacher_vars.append((v.name, v.shape))
#uncomment the following lines to observe teacher's variables for distillation
#print("="*50+"teacher_model_vars"+"="*50)
#print(teacher_vars)
==================================================student_model_vars================================================== [('image', (-1, 1, 28, 28)), ('label', (-1, 1)), ('conv1_weights', (32, 1, 3, 3)), ('conv2d_0.tmp_0', (-1, 32, 14, 14)), ('conv1_bn_scale', (32,)), ('conv1_bn_offset', (32,)), ('conv1_bn_mean', (32,)), ('conv1_bn_variance', (32,)), ('batch_norm_0.tmp_0', (32,)), ('batch_norm_0.tmp_1', (32,)), ('batch_norm_0.tmp_2', (-1, 32, 14, 14)), ('batch_norm_0.tmp_3', (-1, 32, 14, 14)), ('conv2_1_dw_weights', (32, 1, 3, 3)), ('depthwise_conv2d_0.tmp_0', (-1, 32, 14, 14)), ('conv2_1_dw_bn_scale', (32,)), ('conv2_1_dw_bn_offset', (32,)), ('conv2_1_dw_bn_mean', (32,)), ('conv2_1_dw_bn_variance', (32,)), ('batch_norm_1.tmp_0', (32,)), ('batch_norm_1.tmp_1', (32,)), ('batch_norm_1.tmp_2', (-1, 32, 14, 14)), ('batch_norm_1.tmp_3', (-1, 32, 14, 14)), ('conv2_1_sep_weights', (64, 32, 1, 1)), ('conv2d_1.tmp_0', (-1, 64, 14, 14)), ('conv2_1_sep_bn_scale', (64,)), ('conv2_1_sep_bn_offset', (64,)), ('conv2_1_sep_bn_mean', (64,)), ('conv2_1_sep_bn_variance', (64,)), ('batch_norm_2.tmp_0', (64,)), ('batch_norm_2.tmp_1', (64,)), ('batch_norm_2.tmp_2', (-1, 64, 14, 14)), ('batch_norm_2.tmp_3', (-1, 64, 14, 14)), ('conv2_2_dw_weights', (64, 1, 3, 3)), ('depthwise_conv2d_1.tmp_0', (-1, 64, 7, 7)), ('conv2_2_dw_bn_scale', (64,)), ('conv2_2_dw_bn_offset', (64,)), ('conv2_2_dw_bn_mean', (64,)), ('conv2_2_dw_bn_variance', (64,)), ('batch_norm_3.tmp_0', (64,)), ('batch_norm_3.tmp_1', (64,)), ('batch_norm_3.tmp_2', (-1, 64, 7, 7)), ('batch_norm_3.tmp_3', (-1, 64, 7, 7)), ('conv2_2_sep_weights', (128, 64, 1, 1)), ('conv2d_2.tmp_0', (-1, 128, 7, 7)), ('conv2_2_sep_bn_scale', (128,)), ('conv2_2_sep_bn_offset', (128,)), ('conv2_2_sep_bn_mean', (128,)), ('conv2_2_sep_bn_variance', (128,)), ('batch_norm_4.tmp_0', (128,)), ('batch_norm_4.tmp_1', (128,)), ('batch_norm_4.tmp_2', (-1, 128, 7, 7)), ('batch_norm_4.tmp_3', (-1, 128, 7, 7)), ('conv3_1_dw_weights', (128, 1, 3, 3)), ('depthwise_conv2d_2.tmp_0', (-1, 128, 7, 7)), ('conv3_1_dw_bn_scale', (128,)), ('conv3_1_dw_bn_offset', (128,)), ('conv3_1_dw_bn_mean', (128,)), ('conv3_1_dw_bn_variance', (128,)), ('batch_norm_5.tmp_0', (128,)), ('batch_norm_5.tmp_1', (128,)), ('batch_norm_5.tmp_2', (-1, 128, 7, 7)), ('batch_norm_5.tmp_3', (-1, 128, 7, 7)), ('conv3_1_sep_weights', (128, 128, 1, 1)), ('conv2d_3.tmp_0', (-1, 128, 7, 7)), ('conv3_1_sep_bn_scale', (128,)), ('conv3_1_sep_bn_offset', (128,)), ('conv3_1_sep_bn_mean', (128,)), ('conv3_1_sep_bn_variance', (128,)), ('batch_norm_6.tmp_0', (128,)), ('batch_norm_6.tmp_1', (128,)), ('batch_norm_6.tmp_2', (-1, 128, 7, 7)), ('batch_norm_6.tmp_3', (-1, 128, 7, 7)), ('conv3_2_dw_weights', (128, 1, 3, 3)), ('depthwise_conv2d_3.tmp_0', (-1, 128, 4, 4)), ('conv3_2_dw_bn_scale', (128,)), ('conv3_2_dw_bn_offset', (128,)), ('conv3_2_dw_bn_mean', (128,)), ('conv3_2_dw_bn_variance', (128,)), ('batch_norm_7.tmp_0', (128,)), ('batch_norm_7.tmp_1', (128,)), ('batch_norm_7.tmp_2', (-1, 128, 4, 4)), ('batch_norm_7.tmp_3', (-1, 128, 4, 4)), ('conv3_2_sep_weights', (256, 128, 1, 1)), ('conv2d_4.tmp_0', (-1, 256, 4, 4)), ('conv3_2_sep_bn_scale', (256,)), ('conv3_2_sep_bn_offset', (256,)), ('conv3_2_sep_bn_mean', (256,)), ('conv3_2_sep_bn_variance', (256,)), ('batch_norm_8.tmp_0', (256,)), ('batch_norm_8.tmp_1', (256,)), ('batch_norm_8.tmp_2', (-1, 256, 4, 4)), ('batch_norm_8.tmp_3', (-1, 256, 4, 4)), ('conv4_1_dw_weights', (256, 1, 3, 3)), ('depthwise_conv2d_4.tmp_0', (-1, 256, 4, 4)), ('conv4_1_dw_bn_scale', (256,)), ('conv4_1_dw_bn_offset', (256,)), ('conv4_1_dw_bn_mean', (256,)), ('conv4_1_dw_bn_variance', (256,)), ('batch_norm_9.tmp_0', (256,)), ('batch_norm_9.tmp_1', (256,)), ('batch_norm_9.tmp_2', (-1, 256, 4, 4)), ('batch_norm_9.tmp_3', (-1, 256, 4, 4)), ('conv4_1_sep_weights', (256, 256, 1, 1)), ('conv2d_5.tmp_0', (-1, 256, 4, 4)), ('conv4_1_sep_bn_scale', (256,)), ('conv4_1_sep_bn_offset', (256,)), ('conv4_1_sep_bn_mean', (256,)), ('conv4_1_sep_bn_variance', (256,)), ('batch_norm_10.tmp_0', (256,)), ('batch_norm_10.tmp_1', (256,)), ('batch_norm_10.tmp_2', (-1, 256, 4, 4)), ('batch_norm_10.tmp_3', (-1, 256, 4, 4)), ('conv4_2_dw_weights', (256, 1, 3, 3)), ('depthwise_conv2d_5.tmp_0', (-1, 256, 2, 2)), ('conv4_2_dw_bn_scale', (256,)), ('conv4_2_dw_bn_offset', (256,)), ('conv4_2_dw_bn_mean', (256,)), ('conv4_2_dw_bn_variance', (256,)), ('batch_norm_11.tmp_0', (256,)), ('batch_norm_11.tmp_1', (256,)), ('batch_norm_11.tmp_2', (-1, 256, 2, 2)), ('batch_norm_11.tmp_3', (-1, 256, 2, 2)), ('conv4_2_sep_weights', (512, 256, 1, 1)), ('conv2d_6.tmp_0', (-1, 512, 2, 2)), ('conv4_2_sep_bn_scale', (512,)), ('conv4_2_sep_bn_offset', (512,)), ('conv4_2_sep_bn_mean', (512,)), ('conv4_2_sep_bn_variance', (512,)), ('batch_norm_12.tmp_0', (512,)), ('batch_norm_12.tmp_1', (512,)), ('batch_norm_12.tmp_2', (-1, 512, 2, 2)), ('batch_norm_12.tmp_3', (-1, 512, 2, 2)), ('conv5_1_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_6.tmp_0', (-1, 512, 2, 2)), ('conv5_1_dw_bn_scale', (512,)), ('conv5_1_dw_bn_offset', (512,)), ('conv5_1_dw_bn_mean', (512,)), ('conv5_1_dw_bn_variance', (512,)), ('batch_norm_13.tmp_0', (512,)), ('batch_norm_13.tmp_1', (512,)), ('batch_norm_13.tmp_2', (-1, 512, 2, 2)), ('batch_norm_13.tmp_3', (-1, 512, 2, 2)), ('conv5_1_sep_weights', (512, 512, 1, 1)), ('conv2d_7.tmp_0', (-1, 512, 2, 2)), ('conv5_1_sep_bn_scale', (512,)), ('conv5_1_sep_bn_offset', (512,)), ('conv5_1_sep_bn_mean', (512,)), ('conv5_1_sep_bn_variance', (512,)), ('batch_norm_14.tmp_0', (512,)), ('batch_norm_14.tmp_1', (512,)), ('batch_norm_14.tmp_2', (-1, 512, 2, 2)), ('batch_norm_14.tmp_3', (-1, 512, 2, 2)), ('conv5_2_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_7.tmp_0', (-1, 512, 2, 2)), ('conv5_2_dw_bn_scale', (512,)), ('conv5_2_dw_bn_offset', (512,)), ('conv5_2_dw_bn_mean', (512,)), ('conv5_2_dw_bn_variance', (512,)), ('batch_norm_15.tmp_0', (512,)), ('batch_norm_15.tmp_1', (512,)), ('batch_norm_15.tmp_2', (-1, 512, 2, 2)), ('batch_norm_15.tmp_3', (-1, 512, 2, 2)), ('conv5_2_sep_weights', (512, 512, 1, 1)), ('conv2d_8.tmp_0', (-1, 512, 2, 2)), ('conv5_2_sep_bn_scale', (512,)), ('conv5_2_sep_bn_offset', (512,)), ('conv5_2_sep_bn_mean', (512,)), ('conv5_2_sep_bn_variance', (512,)), ('batch_norm_16.tmp_0', (512,)), ('batch_norm_16.tmp_1', (512,)), ('batch_norm_16.tmp_2', (-1, 512, 2, 2)), ('batch_norm_16.tmp_3', (-1, 512, 2, 2)), ('conv5_3_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_8.tmp_0', (-1, 512, 2, 2)), ('conv5_3_dw_bn_scale', (512,)), ('conv5_3_dw_bn_offset', (512,)), ('conv5_3_dw_bn_mean', (512,)), ('conv5_3_dw_bn_variance', (512,)), ('batch_norm_17.tmp_0', (512,)), ('batch_norm_17.tmp_1', (512,)), ('batch_norm_17.tmp_2', (-1, 512, 2, 2)), ('batch_norm_17.tmp_3', (-1, 512, 2, 2)), ('conv5_3_sep_weights', (512, 512, 1, 1)), ('conv2d_9.tmp_0', (-1, 512, 2, 2)), ('conv5_3_sep_bn_scale', (512,)), ('conv5_3_sep_bn_offset', (512,)), ('conv5_3_sep_bn_mean', (512,)), ('conv5_3_sep_bn_variance', (512,)), ('batch_norm_18.tmp_0', (512,)), ('batch_norm_18.tmp_1', (512,)), ('batch_norm_18.tmp_2', (-1, 512, 2, 2)), ('batch_norm_18.tmp_3', (-1, 512, 2, 2)), ('conv5_4_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_9.tmp_0', (-1, 512, 2, 2)), ('conv5_4_dw_bn_scale', (512,)), ('conv5_4_dw_bn_offset', (512,)), ('conv5_4_dw_bn_mean', (512,)), ('conv5_4_dw_bn_variance', (512,)), ('batch_norm_19.tmp_0', (512,)), ('batch_norm_19.tmp_1', (512,)), ('batch_norm_19.tmp_2', (-1, 512, 2, 2)), ('batch_norm_19.tmp_3', (-1, 512, 2, 2)), ('conv5_4_sep_weights', (512, 512, 1, 1)), ('conv2d_10.tmp_0', (-1, 512, 2, 2)), ('conv5_4_sep_bn_scale', (512,)), ('conv5_4_sep_bn_offset', (512,)), ('conv5_4_sep_bn_mean', (512,)), ('conv5_4_sep_bn_variance', (512,)), ('batch_norm_20.tmp_0', (512,)), ('batch_norm_20.tmp_1', (512,)), ('batch_norm_20.tmp_2', (-1, 512, 2, 2)), ('batch_norm_20.tmp_3', (-1, 512, 2, 2)), ('conv5_5_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_10.tmp_0', (-1, 512, 2, 2)), ('conv5_5_dw_bn_scale', (512,)), ('conv5_5_dw_bn_offset', (512,)), ('conv5_5_dw_bn_mean', (512,)), ('conv5_5_dw_bn_variance', (512,)), ('batch_norm_21.tmp_0', (512,)), ('batch_norm_21.tmp_1', (512,)), ('batch_norm_21.tmp_2', (-1, 512, 2, 2)), ('batch_norm_21.tmp_3', (-1, 512, 2, 2)), ('conv5_5_sep_weights', (512, 512, 1, 1)), ('conv2d_11.tmp_0', (-1, 512, 2, 2)), ('conv5_5_sep_bn_scale', (512,)), ('conv5_5_sep_bn_offset', (512,)), ('conv5_5_sep_bn_mean', (512,)), ('conv5_5_sep_bn_variance', (512,)), ('batch_norm_22.tmp_0', (512,)), ('batch_norm_22.tmp_1', (512,)), ('batch_norm_22.tmp_2', (-1, 512, 2, 2)), ('batch_norm_22.tmp_3', (-1, 512, 2, 2)), ('conv5_6_dw_weights', (512, 1, 3, 3)), ('depthwise_conv2d_11.tmp_0', (-1, 512, 1, 1)), ('conv5_6_dw_bn_scale', (512,)), ('conv5_6_dw_bn_offset', (512,)), ('conv5_6_dw_bn_mean', (512,)), ('conv5_6_dw_bn_variance', (512,)), ('batch_norm_23.tmp_0', (512,)), ('batch_norm_23.tmp_1', (512,)), ('batch_norm_23.tmp_2', (-1, 512, 1, 1)), ('batch_norm_23.tmp_3', (-1, 512, 1, 1)), ('conv5_6_sep_weights', (1024, 512, 1, 1)), ('conv2d_12.tmp_0', (-1, 1024, 1, 1)), ('conv5_6_sep_bn_scale', (1024,)), ('conv5_6_sep_bn_offset', (1024,)), ('conv5_6_sep_bn_mean', (1024,)), ('conv5_6_sep_bn_variance', (1024,)), ('batch_norm_24.tmp_0', (1024,)), ('batch_norm_24.tmp_1', (1024,)), ('batch_norm_24.tmp_2', (-1, 1024, 1, 1)), ('batch_norm_24.tmp_3', (-1, 1024, 1, 1)), ('conv6_dw_weights', (1024, 1, 3, 3)), ('depthwise_conv2d_12.tmp_0', (-1, 1024, 1, 1)), ('conv6_dw_bn_scale', (1024,)), ('conv6_dw_bn_offset', (1024,)), ('conv6_dw_bn_mean', (1024,)), ('conv6_dw_bn_variance', (1024,)), ('batch_norm_25.tmp_0', (1024,)), ('batch_norm_25.tmp_1', (1024,)), ('batch_norm_25.tmp_2', (-1, 1024, 1, 1)), ('batch_norm_25.tmp_3', (-1, 1024, 1, 1)), ('conv6_sep_weights', (1024, 1024, 1, 1)), ('conv2d_13.tmp_0', (-1, 1024, 1, 1)), ('conv6_sep_bn_scale', (1024,)), ('conv6_sep_bn_offset', (1024,)), ('conv6_sep_bn_mean', (1024,)), ('conv6_sep_bn_variance', (1024,)), ('batch_norm_26.tmp_0', (1024,)), ('batch_norm_26.tmp_1', (1024,)), ('batch_norm_26.tmp_2', (-1, 1024, 1, 1)), ('batch_norm_26.tmp_3', (-1, 1024, 1, 1)), ('pool2d_0.tmp_0', (-1, 1024, 1, 1)), ('fc7_weights', (1024, 10)), ('fc_0.tmp_0', (-1, 10)), ('fc7_offset', (10,)), ('fc_0.tmp_1', (-1, 10)), ('fc_0.tmp_2', (-1, 10)), ('cross_entropy2_0.tmp_0', (-1, 1)), ('cross_entropy2_0.tmp_1', (-1, 10, 0)), ('cross_entropy2_0.tmp_2', (-1, 1)), ('mean_0.tmp_0', (1,)), ('top_k_0.tmp_0', (-1, 1)), ('top_k_0.tmp_1', (-1, 1)), ('accuracy_0.tmp_0', (1,)), ('accuracy_0.tmp_1', (1,)), ('accuracy_0.tmp_2', (1,)), ('top_k_1.tmp_0', (-1, 5)), ('top_k_1.tmp_1', (-1, 5)), ('accuracy_1.tmp_0', (1,)), ('accuracy_1.tmp_1', (1,)), ('accuracy_1.tmp_2', (1,))]
经过筛选我们可以看到,teacher_program中的’bn5c_branch2b.output.1.tmp_3’和student_program的’depthwise_conv2d_11.tmp_0’尺寸一致,可以组成蒸馏损失函数。
merge操作将student_program和teacher_program中的所有Variables和Op都将被添加到同一个Program中,同时为了避免两个program中有同名变量会引起命名冲突,merge也会为teacher_program中的Variables添加一个同一的命名前缀name_prefix,其默认值是’teacher_’
为了确保teacher网络和student网络输入的数据是一样的,merge操作也会对两个program的输入数据层进行合并操作,所以需要指定一个数据层名称的映射关系data_name_map,key是teacher的输入数据名称,value是student的。
data_name_map = {'image': 'image'}
main = slim.dist.merge(teacher_program, student_program, data_name_map, fluid.CPUPlace()) # 合并教师网络到学生网络
with fluid.program_guard(student_program, student_startup):
# 添加蒸馏loss
l2_loss = slim.dist.l2_loss('teacher_bn5c_branch2b.output.1.tmp_3', 'depthwise_conv2d_11.tmp_0', student_program)
# 同时优化蒸馏loss和学生网络原来的分类loss
loss = l2_loss + avg_cost
opt = fluid.optimizer.Momentum(0.01, 0.9)
opt.minimize(loss) # 添加反向计算和优化相关的操作
exe.run(student_startup) # 初始化学生网络
[]
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True) # 定义读取MNIST训练数据的生成器
train_feeder = fluid.DataFeeder(['image', 'label'], fluid.CPUPlace(), student_program) # 定义读取MNIST测试数据的生成器
for data in train_reader(): # 训练一个epoch
acc1, acc5, loss_np = exe.run(student_program, feed=train_feeder.feed(data), fetch_list=[acc_top1.name, acc_top5.name, loss.name])
print("Acc1: {:.6f}, Acc5: {:.6f}, Loss: {:.6f}".format(acc1.mean(), acc5.mean(), loss_np.mean()))
Acc1: 0.062500, Acc5: 0.382812, Loss: 221.120117 Acc1: 0.101562, Acc5: 0.507812, Loss: 222.530258 Acc1: 0.070312, Acc5: 0.500000, Loss: 222.556686 Acc1: 0.117188, Acc5: 0.437500, Loss: 220.210587 Acc1: 0.101562, Acc5: 0.562500, Loss: 217.044830 Acc1: 0.117188, Acc5: 0.484375, Loss: 215.024094 Acc1: 0.109375, Acc5: 0.500000, Loss: 218.645142 Acc1: 0.125000, Acc5: 0.578125, Loss: 214.771973 Acc1: 0.203125, Acc5: 0.554688, Loss: 210.950424 Acc1: 0.171875, Acc5: 0.554688, Loss: 220.628342 Acc1: 0.140625, Acc5: 0.546875, Loss: 217.624451 Acc1: 0.093750, Acc5: 0.562500, Loss: 211.037888