where
根据 condition
来选择 x
或 y
中的对应元素来组成新的 Tensor。具体地,
\[\begin{split}out_i = \begin{cases} x_i, & \text{if} \ condition_i \ \text{is} \ True \\ y_i, & \text{if} \ condition_i \ \text{is} \ False \\ \end{cases}.\end{split}\]
备注
numpy.where(condition)
功能与 paddle.nonzero(condition, as_tuple=True)
相同,可以参考 nonzero。
备注
别名支持: 参数名 input
可替代 x
,参数名 other
可替代 y
,如 paddle.where(condition, input=x, other=y)
等价于 paddle.where(condition, x=x, y=y)
。
参数
condition (Tensor) - 选择
x
或y
元素的条件。在为 True(非零值)时,选择x
,否则选择y
。x (Tensor|scalar,可选) - 条件为 True 时选择的 Tensor 或 scalar,数据类型为 bfloat16、 float16、float32、float64、int32 或 int64。
x
和y
必须都给出或者都不给出。别名:input
。y (Tensor|scalar,可选) - 条件为 False 时选择的 Tensor 或 scalar,数据类型为 bfloat16、float16、float32、float64、int32 或 int64。
x
和y
必须都给出或者都不给出。别名:other
。name (str,可选) - 具体用法请参见 api_guide_Name,一般无需设置,默认值为 None。
关键字参数
out (Tensor,可选) - 输出 Tensor,若不为
None
,计算结果将保存在该 Tensor 中,默认值为None
。
返回
Tensor,形状与 condition
相同,数据类型与 x
和 y
相同。
代码示例
>>> import paddle
>>> x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2])
>>> y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0])
>>> out = paddle.where(x>1, x, y)
>>> print(out)
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
[1. , 1. , 3.20000005, 1.20000005])
>>> out = paddle.where(x>1)
>>> print(out)
(Tensor(shape=[2], dtype=int64, place=Place(cpu), stop_gradient=True,
[2, 3]),)