cond¶
如果 pred 是 True ,该API返回 true_fn() ,否则返回 false_fn() 。 用户如果不想在 callable 中做任何事,可以把 true_fn 或 false_fn 设为 None ,此时本API会把该 callable 视为简单返回 None 。
true_fn 和 false_fn 需要返回同样嵌套结构(nest structure)的Tensor,如果不想返回任何值也可都返回 None 。 PaddlePaddle里Tensor的嵌套结构是指一个Tensor,或者Tensor的元组(tuple),或者Tensor的列表(list)。
- 注意:
-
true_fn和false_fn返回的元组必须形状相同,但是里面的Tensor形状可以不同。本接口在动态图和静态图模式下都可以运行,在动态图情况下就只会按
pred条件运行其中一支分支。静态图模式下,因为各个分支都要参与组网,因此不论运行哪个分支,在
true_fn和false_fn内外创建的Tensor和Op都会组网,即PaddlePaddle并不是惰性语法(lazy semantics)。例如import paddle a = paddle.zeros((1, 1)) b = paddle.zeros((1, 1)) c = a * b out = paddle.static.nn.cond(a < b, lambda: a + c, lambda: b * b)
不管
a < b是否成立,c = a * b都会被组网且运行,a + c和b * b都会参与组网,只是组网后运行时只会运行条件对应的分支。
参数¶
pred (Tensor) - 一个形状为[1]的布尔型(boolean)的Tensor,该布尔值决定要返回
true_fn还是false_fn的运行结果。true_fn (callable) - 一个当
pred是True时被调用的callable,默认值:None。false_fn (callable) - 一个当
pred是False时被调用的callable,默认值:None。name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值:
None。
返回¶
Tensor|list(Tensor)|tuple(Tensor),如果 pred 是 True ,该API返回 true_fn() ,否则返回 false_fn() 。
代码示例¶
import paddle
#
# pseudocode:
# if 0.1 < 0.23:
# return 1, True
# else:
# return 3, 2
#
def true_func():
return paddle.full(shape=[1, 2], dtype='int32',
fill_value=1), paddle.full(shape=[2, 3],
dtype='bool',
fill_value=True)
def false_func():
return paddle.full(shape=[3, 4], dtype='float32',
fill_value=3), paddle.full(shape=[4, 5],
dtype='int64',
fill_value=2)
x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
pred = paddle.less_than(x=x, y=y, name=None)
ret = paddle.static.nn.cond(pred, true_func, false_func)
# ret 是包含两个tensors的元组
# ret[0] = [[1 1]]
# ret[1] = [[ True True True]
# [ True True True]]