cond¶
- paddle.static.nn. cond ( pred, true_fn=None, false_fn=None, name=None ) [source]
-
This API returns
true_fn()if the predicatepredis true elsefalse_fn(). Users could also settrue_fnorfalse_fntoNoneif do nothing and this API will treat the callable simply returnsNonein this case.true_fnandfalse_fnshould return same nest structure of tensors or both returnNoneif user doens’t like to return anything. A nest structure of tensors in PaddlePaddle is tensor(s), or tuple of tensors, or list of tensors.Note
1. The tuples or lists returned by
true_fnandfalse_fnmust have the same shape because of dataflow model of PaddlePaddle while the tensors in the tuples or the lists can have different shapes.2. This API could be used under both static mode or dygraph mode. If it is in dygraph mode, the API only runs one branch based on condition.
3. If it is in static mode, any tensors or operations created outside or inside of
true_fnandfalse_fnwill be in net building regardless of which branch is selected at runtime. This has frequently surprised users who expected a lazy semantics. For example:import paddle a = paddle.zeros((1, 1)) b = paddle.zeros((1, 1)) c = a * b out = paddle.static.nn.cond(a < b, lambda: a + c, lambda: b * b)
No matter whether
a < b,c = a * bwill be in net building and run.a + candb * bwill be in net building, but only one branch will be executed during runtime.- Parameters
-
pred (Tensor) – A boolean tensor whose numel should be 1. The boolean value determines whether to return the result of
true_fnorfalse_fn.true_fn (callable, optional) – A callable to be performed if
predis true. The default value isNone.false_fn (callable, optional) – A callable to be performed if
predis false. The default value isNone.name (str, optional) – The default value is
None. Normally users don’t have to set this parameter. For more information, please refer to Name .
- Returns
-
returns
true_fn()if the predicatepredis true elsefalse_fn(). - Return type
-
Tensor|list(Tensor)|tuple(Tensor)
- Raises
-
TypeError – if
true_fnorfalse_fnis not callable.ValueError – if
true_fnandfalse_fndon’t return the same nest structure of tensors.
Examples
import paddle # # pseudocode: # if 0.1 < 0.23: # return 1, True # else: # return 3, 2 # def true_func(): return paddle.full(shape=[1, 2], dtype='int32', fill_value=1), paddle.full(shape=[2, 3], dtype='bool', fill_value=True) def false_func(): return paddle.full(shape=[3, 4], dtype='float32', fill_value=3), paddle.full(shape=[4, 5], dtype='int64', fill_value=2) x = paddle.full(shape=[1], dtype='float32', fill_value=0.1) y = paddle.full(shape=[1], dtype='float32', fill_value=0.23) pred = paddle.less_than(x=x, y=y, name=None) ret = paddle.static.nn.cond(pred, true_func, false_func) # ret is a tuple containing 2 tensors # ret[0] = [[1 1]] # ret[1] = [[ True True True] # [ True True True]]
