switch_case

Note: This API is only avaliable in [Static Graph] mode

paddle.fluid.layers.switch_case(branch_index, branch_fns, default=None, name=None)[source]

This operator is like a C++ switch/case statement.

Parameters
  • branch_index (Variable) – A Tensor with shape [1] to specify which branch to execute. The data type is int32, int64 or uint8.

  • branch_fns (dict|list|tuple) – If it’s a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it’s a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.

  • default (callable, optional) – Callable that returns a structure of Tensors.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

Tensors returned by the callable specified by branch_index in branch_fns, or Tensors returned by default if default is not None and no index matches in branch_fns, or Tensors returned by the callable with the max index in branch_fns if default is None and no index matches in branch_fns.

Return type

Variable|list(Variable)

Raises
  • TypeError – If the type of branch_index is not Variable.

  • TypeError – If the data type of branch_index is not int32, int64 or uint8.

  • TypeError – If the type of branch_fns is not dict, list or tuple.

  • TypeError – If the elements of branch_fns is not 2-tuple.

  • TypeError – If the first element of 2-tuple in branch_fns is not integer.

  • ValueError – If the first element of 2-tuple in branch_fns is not unique.

  • TypeError – If the second element of 2-tuple in branch_fns is not callable.

  • TypeError – If default is not None but it is not callable.

Examples

import paddle.fluid as fluid
import paddle.fluid.layers as layers

def fn_1():
    return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)

def fn_2():
    return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)

def fn_3():
    return layers.fill_constant(shape=[3], dtype='int32', value=3)

main_program = fluid.default_startup_program()
startup_program = fluid.default_main_program()
with fluid.program_guard(main_program, startup_program):
    index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
    index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)

    out_1 = layers.switch_case(
        branch_index=index_1,
        branch_fns={1: fn_1, 2: fn_2},
        default=fn_3)

    out_2 = layers.switch_case(
        branch_index=index_2,
        branch_fns=[(1, fn_1), (2, fn_2)],
        default=fn_3)

    # Argument default is None and no index matches. fn_3 will be called because of the max index 7.
    out_3 = layers.switch_case(
        branch_index=index_2,
        branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])

    exe = fluid.Executor(fluid.CPUPlace())
    res_1, res_2, res_3 = exe.run(main_program,
                                  fetch_list=[out_1, out_2, out_3])
    print(res_1)  # [[1. 1.]]
    print(res_2)  # [[2 2] [2 2]]
    print(res_3)  # [3 3 3]