Kernel 函数声明
飞桨通过头文件发布函数式 Kernel 声明,框架内外一致。
编写自定义 Kernel 需基于具体的 Kernel 函数声明,头文件位于飞桨安装路径的include/paddle/phi/kernels/下。
Kernel 函数声明的格式如下:
template <typename T, typename Context>
void KernelNameKernel(const Context& dev_ctx,
InputTensor(s),
Attribute(s),
OutTensor(s));
约定:
模板参数:固定写法,第一个模板参数为数据类型
T,第二个模板参数为设备上下文Context。函数返回:固定为
void。函数命名:Kernel 名称+Kernel 后缀,驼峰式命名,如
SoftmaxKernel。函数参数:依次为设备上下文参数,输入 Tensor 参数(InputTensor),属性参数(Attribute)和输出 Tensor 参数(OutTensor)。其中:
设备上下文参数:固定为
const Context&类型;自定义 Kernel 对应
CustomContext类型,请参照custom_context.h
InputTensor:数量>=0,支持的类型包括:
const DenseTensor&请参照dense_tensor.hconst SelectedRows&请参照selected_rows.hconst SparseCooTensor&请参照sparse_coo_tensor.hconst SparseCsrTensor&请参照sparse_csr_tensor.hconst std::vector<DenseTensor*>&const std::vector<SparseCooTensor*>&const std::vector<SparseCsrTensor*>&
Attribute:数量>=0,支持的类型包括:
boolfloatdoubleintint64_tphi::dtype::float16请参照float16.hconst Scalar&请参照scalar.hDataType请参照data_type.hDataLayout请参照layout.hPlace请参照place.hconst std::vector<int64_t>&const ScalarArray&请参照int_array.hconst std::vector<int>&const std::string&const std::vector<bool>&const std::vector<float>&const std::vector<double>&const std::vector<std::string>&
OutTensor:数量>0,支持的类型包括:
DenseTensor*SelectedRows*SparseCooTensor*SparseCsrTensor*std::vector<DenseTensor*>std::vector<SparseCooTensor*>std::vector<SparseCsrTensor*>
示例,如softmax的 Kernel 函数位于softmax_kernel.h中,具体如下:
// Softmax 内核函数
// 模板参数: T - 数据类型
// Context - 设备上下文
// 参数: dev_ctx - Context 对象
// x - DenseTensor 对象
// axis - int 类型
// dtype - DataType 类型
// out - DenseTensor 指针
// 返回: None
template <typename T, typename Context>
void SoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DataType dtype,
DenseTensor* out);
注意:
Kernel 函数声明是自定义 Kernel 能够被注册和框架调用的基础,由框架发布,需要严格遵守
Kernel 函数声明与头文件可能不完全对应,可以按照函数命名约定等查找所需 Kernel 函数声明