Switch

class paddle.fluid.layers.Switch(name=None)[source]

This class is used to implement Switch branch control function. Switch branch contains several case branches and one default branch. Switch control flow checks whether the case branch conditions are satisfied in turn, and only executes the statement after the first case branch that satisfies the conditions. If there is no case branch that satisfies the condition, only the statement following the default branch is executed.

Member Functions:

case(cond): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.

default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.

Case and default functions can only be used inside the scope of Switch, as shown below:

'''
with fluid.layers.Switch() as switch:
    with switch.case(cond1):
        i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=1)
    with switch.case(cond2):
        i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=2)
    with switch.default():
        i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
'''
Parameters

name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name .

Examples

import paddle.fluid as fluid

lr = fluid.layers.create_global_var(
    shape=[1],
    value=0.0,
    dtype='float32',
    persistable=True,
    name="learning_rate")
zero_var = fluid.layers.fill_constant(
    shape=[1], dtype='float32', value=0.0)
one_var = fluid.layers.fill_constant(
    shape=[1], dtype='float32', value=1.0)
two_var = fluid.layers.fill_constant(
    shape=[1], dtype='float32', value=2.0)

global_step = fluid.layers.autoincreased_step_counter(counter_name='@LR_DECAY_COUNTER@', begin=0, step=1)

with fluid.layers.control_flow.Switch() as switch:
    with switch.case(global_step == zero_var):
        fluid.layers.assign(input=one_var, output=lr)
    with switch.default():
        fluid.layers.assign(input=two_var, output=lr)

exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())

res = exe.run(fluid.default_main_program(), feed={}, fetch_list=[lr])
print(res) # [array([1.], dtype=float32)]