hsplit
hsplit 全称 Horizontal Split ,即水平分割,将输入 Tensor 沿着水平轴分割成多个子 Tensor,存在以下两种情况:
- 当 x 的维度等于 1 时,等价于将 tensor_split API 的参数 axis 固定为 0;   
- 当 x 的维度大于 1 时,等价于将 tensor_split API 的参数 axis 固定为 1。   
参数
x (Tensor) - 输入变量,数据类型为 bool、bfloat16、float16、float32、float64、uint8、int8、int32、int64 的多维 Tensor,其维度必须大于 0。
num_or_indices (int|list|tuple) - 如果
num_or_indices是一个整数n,则x拆分为n部分。如果num_or_indices是整数索引的列表或元组,则在每个索引处分割x。
name (str,可选) - 具体用法请参见 api_guide_Name,一般无需设置,默认值为 None。
返回
list[Tensor],分割后的 Tensor 列表。
代码示例
>>> import paddle
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.hsplit(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1 = paddle.hsplit(x, num_or_indices=2)
>>> print(out0.shape)
[7, 4]
>>> print(out1.shape)
[7, 4]
>>> out0, out1, out2 = paddle.hsplit(x, num_or_indices=[1, 4])
>>> print(out0.shape)
[7, 1]
>>> print(out1.shape)
[7, 3]
>>> print(out2.shape)
[7, 4]