paddle.amp¶
paddle.amp 目录下包含飞桨框架支持的动态图自动混合精度(AMP)相关的 API。具体如下:
paddle.amp 目录下包含 debugging 目录, debugging 目录中存放用于算子模型精度问题定位的 API。具体如下:
AMP 相关 API¶
API 名称  |  
           API 功能  |  
          
|---|---|
| 
             |  
           创建 AMP 上下文环境  |  
          
| 
             |  
           根据选定混合精度训练模式,改写神经网络参数数据类型  |  
          
| 
             |  
           控制 loss 的缩放比例  |  
          
开启 AMP 后默认转化为 float16 计算的相关 OP¶
OP 名称  |  
           OP 功能  |  
          
|---|---|
conv2d  |  
           卷积计算  |  
          
matmul  |  
           矩阵乘法  |  
          
matmul_v2  |  
           矩阵乘法  |  
          
mul  |  
           矩阵乘法  |  
          
开启 AMP 后默认使用 float32 计算的相关 OP¶
OP 名称  |  
           OP 功能  |  
          
|---|---|
exp  |  
           指数运算  |  
          
square  |  
           平方运算  |  
          
log  |  
           对数运算  |  
          
mean  |  
           取平均值  |  
          
sum  |  
           求和运算  |  
          
cos_sim  |  
           余弦相似度  |  
          
softmax  |  
           softmax 操作  |  
          
softmax_with_cross_entropy  |  
           softmax 交叉熵损失函数  |  
          
sigmoid_cross_entropy_with_logits  |  
           按元素的概率误差  |  
          
cross_entropy  |  
           交叉熵  |  
          
cross_entropy2  |  
           交叉熵  |  
          
AMP 场景下判断设备是否支持特定数据类型¶
API 名称  |  
           API 功能  |  
          
|---|---|
| 
             |  
           判断设备是否支持 bfloat16  |  
          
| 
             |  
           判断设备是否支持 float16  |  
          
Debug 相关辅助类¶
类名称  |  
           辅助类功能  |  
          
|---|---|
| 
             |  
           精度调试模式  |  
          
| 
             |  
           精度调试配置类  |  
          
算子调用统计相关 API¶
API 名称  |  
           API 功能  |  
          
|---|---|
| 
             |  
           收集不同数据类型的算子调用次数  |  
          
| 
             |  
           启用以收集不同数据类型的算子调用次数  |  
          
| 
             |  
           禁用收集不同数据类型的算子调用次数  |  
          
模块级别精度定位 API¶
API 名称  |  
           API 功能  |  
          
|---|---|
| 
             |  
           开启模块级别的精度检查  |  
          
| 
             |  
           关闭模块级别的精度检查  |  
          
| 
             |  
           精度比对接口  |  
          
数值检查相关 API¶
API 名称  |  
           API 功能  |  
          
|---|---|
| 
             |  
           Layer 输入、输出数据的数值检查  |  
          
| 
             |  
           调试 Tensor 数值,检查其异常值(NaN、Inf) 和零元素  |