class paddle.fluid.layers.control_flow. Switch ( name=None ) [source]

Static Graph

This class is used to implement Switch branch control function. Switch branch contains several case branches and one default branch. Switch control flow checks whether the case branch conditions are satisfied in turn, and only executes the statement after the first case branch that satisfies the conditions. If there is no case branch that satisfies the condition, only the statement following the default branch is executed.


A new OP api_fluid_layers_case is highly recommended instead of Switch if the shape of parameter cond is [1]. OP api_fluid_layers_case is easier to use and is called with less code but does the same thing as Switch .

Member Functions:

case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.

default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.

Case and default functions can only be used inside the scope of Switch, as shown below:

import paddle
import paddle.fluid as fluid
with fluid.layers.Switch() as switch:
    with switch.case(cond1):
        i = paddle.full(shape=[1], dtype='int64', fill_value=1)
    with switch.case(cond2):
        i = paddle.full(shape=[1], dtype='int64', fill_value=2)
    with switch.default():
        i = paddle.full(shape=[1], dtype='int64', fill_value=0)

name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name .


import paddle
import paddle.fluid as fluid

lr = paddle.static.create_global_var(
zero_var = paddle.full(
    shape=[1], dtype='float32', fill_value=0.0)
one_var = paddle.full(
    shape=[1], dtype='float32', fill_value=1.0)
two_var = paddle.full(
    shape=[1], dtype='float32', fill_value=2.0)

global_step = fluid.layers.autoincreased_step_counter(counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)

with fluid.layers.control_flow.Switch() as switch:
    with switch.case(global_step == zero_var):
        paddle.assign(input=one_var, output=lr)
    with switch.default():
        paddle.assign(input=two_var, output=lr)

exe = fluid.Executor(fluid.CPUPlace())

res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[lr])
print(res) # [array([1.], dtype=float32)]