case

api_attr

declarative programming (static graph)

paddle.fluid.layers.case(pred_fn_pairs, default=None, name=None)[source]

This operator works like an if-elif-elif-else chain.

Parameters
  • pred_fn_pairs (list|tuple) – A list or tuple of (pred, fn) pairs. pred is a boolean Tensor with shape [1], fn is a callable. All callables return the same structure of Tensors.

  • default (callable, optional) – Callable that returns a structure of Tensors.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

Tensors returned by the callable from the first pair whose pred is True, or Tensors returned by default if no pred in pred_fn_pairs is True and default is not None, or Tensors returned by the last callable in pred_fn_pairs if no pred in pred_fn_pairs is True and default is None.

Return type

Variable|list(Variable)

Raises
  • TypeError – If the type of pred_fn_pairs is not list or tuple.

  • TypeError – If the type of elements in pred_fn_pairs is not tuple.

  • TypeError – If the size of tuples in pred_fn_pairs is not 2.

  • TypeError – If the first element of 2-tuple in pred_fn_pairs is not Variable.

  • TypeError – If the second element of 2-tuple in pred_fn_pairs is not callable.

  • TypeError – If default is not None but it is not callable.

Examples

import paddle.fluid as fluid
import paddle.fluid.layers as layers

def fn_1():
    return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)

def fn_2():
    return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)

def fn_3():
    return layers.fill_constant(shape=[3], dtype='int32', value=3)

main_program = fluid.default_startup_program()
startup_program = fluid.default_main_program()
with fluid.program_guard(main_program, startup_program):
    x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
    y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
    z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)

    pred_1 = layers.less_than(z, x)  # true: 0.2 < 0.3
    pred_2 = layers.less_than(x, y)  # false: 0.3 < 0.1
    pred_3 = layers.equal(x, y)      # false: 0.3 == 0.1

    # Call fn_1 because pred_1 is True
    out_1 = layers.case(
        pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)

    # Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
    # because fn_3 is the last callable in pred_fn_pairs.
    out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])

    exe = fluid.Executor(fluid.CPUPlace())
    res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
    print(res_1)  # [[1. 1.]]
    print(res_2)  # [3 3 3]