switch_case

paddle.static.nn. switch_case ( branch_index, branch_fns, default=None, name=None ) [source]
Api_attr

Static Graph

This operator is like a C++ switch/case statement.

Parameters
  • branch_index (Tensor) – A Tensor whose numel should be 1 (shape [] or shape [1]) to specify which branch to execute. The data type is int32, int64 or uint8.

  • branch_fns (dict|list|tuple) – If it’s a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it’s a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.

  • default (callable, optional) – Callable that returns a structure of Tensors.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

Tensors returned by the callable specified by branch_index in branch_fns, or Tensors returned by default if default is not None and no index matches in branch_fns, or Tensors returned by the callable with the max index in branch_fns if default is None and no index matches in branch_fns.

Return type

Tensor|list(Tensor)

Raises
  • TypeError – If the type of branch_index is not Tensor.

  • TypeError – If the data type of branch_index is not int32, int64 or uint8.

  • TypeError – If the type of branch_fns is not dict, list or tuple.

  • TypeError – If the elements of branch_fns is not 2-tuple.

  • TypeError – If the first element of 2-tuple in branch_fns is not integer.

  • ValueError – If the first element of 2-tuple in branch_fns is not unique.

  • TypeError – If the second element of 2-tuple in branch_fns is not callable.

  • TypeError – If default is not None but it is not callable.

Examples

>>> import paddle
>>> paddle.enable_static()

>>> def fn_1():
...    return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)

>>> def fn_2():
...    return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)

>>> def fn_3():
...    return paddle.full(shape=[3], dtype='int32', fill_value=3)

>>> startup_program = paddle.static.default_startup_program()
>>> main_program = paddle.static.default_main_program()
>>> with paddle.static.program_guard(main_program, startup_program):
...    index_1 = paddle.full(shape=[1], dtype='int32', fill_value=1)
...    index_2 = paddle.full(shape=[1], dtype='int32', fill_value=2)
...
...    out_1 = paddle.static.nn.switch_case(
...        branch_index=index_1,
...        branch_fns={1: fn_1, 2: fn_2},
...        default=fn_3)
...
...    out_2 = paddle.static.nn.switch_case(
...        branch_index=index_2,
...        branch_fns=[(1, fn_1), (2, fn_2)],
...        default=fn_3)
...
...    # Argument default is None and no index matches. fn_3 will be called because of the max index 7.
...    out_3 = paddle.static.nn.switch_case(
...        branch_index=index_2,
...        branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
...
...    exe = paddle.static.Executor(paddle.CPUPlace())
...    res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
...    # Variable: fill_constant_1.tmp_0
...    #   - message: The content of input layer:
...    #   - lod: {}
...    #   - place: Place(cpu)
...    #   - shape: [2, 3]
...    #   - layout: NCHW
...    #   - dtype: int64
...    #   - data: [3 3 3 3 3 3]

>>> print(res_1)
[[1. 1.]]

>>> print(res_2)
[[2 2]
 [2 2]]

>>> print(res_3)
[3 3 3]