case

paddle.static.nn. case ( pred_fn_pairs, default=None, name=None ) [source]
Api_attr

Static Graph

This operator works like an if-elif-elif-else chain.

Parameters
  • pred_fn_pairs (list|tuple) – A list or tuple of (pred, fn) pairs. pred is a boolean Tensor whose numel should be 1 (shape [] or shape [1]), fn is a callable. All callables return the same structure of Tensors.

  • default (callable, optional) – Callable that returns a structure of Tensors.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

Tensors returned by the callable from the first pair whose pred is True, or Tensors returned by default if no pred in pred_fn_pairs is True and default is not None, or Tensors returned by the last callable in pred_fn_pairs if no pred in pred_fn_pairs is True and default is None.

Return type

Tensor|list(Tensor)

Raises
  • TypeError – If the type of pred_fn_pairs is not list or tuple.

  • TypeError – If the type of elements in pred_fn_pairs is not tuple.

  • TypeError – If the size of tuples in pred_fn_pairs is not 2.

  • TypeError – If the first element of 2-tuple in pred_fn_pairs is not a Tensor.

  • TypeError – If the second element of 2-tuple in pred_fn_pairs is not callable.

  • TypeError – If default is not None but it is not callable.

Examples

>>> import paddle
>>> paddle.enable_static()

>>> def fn_1():
...     return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)

>>> def fn_2():
...     return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)

>>> def fn_3():
...     return paddle.full(shape=[3], dtype='int32', fill_value=3)

>>> main_program = paddle.static.default_startup_program()
>>> startup_program = paddle.static.default_main_program()

>>> with paddle.static.program_guard(main_program, startup_program):
...     x = paddle.full(shape=[1], dtype='float32', fill_value=0.3)
...     y = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
...     z = paddle.full(shape=[1], dtype='float32', fill_value=0.2)

...     pred_1 = paddle.less_than(z, x)  # true: 0.2 < 0.3
...     pred_2 = paddle.less_than(x, y)  # false: 0.3 < 0.1
...     pred_3 = paddle.equal(x, y)      # false: 0.3 == 0.1

...     # Call fn_1 because pred_1 is True
...     out_1 = paddle.static.nn.case(
...         pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)

...     # Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
...     # because fn_3 is the last callable in pred_fn_pairs.
...     out_2 = paddle.static.nn.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])

...     exe = paddle.static.Executor(paddle.CPUPlace())
...     res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
...     print(res_1, res_2)
[[1. 1.]] [3 3 3]