- paddle.fluid.contrib.mixed_precision.bf16.amp_utils. rewrite_program_bf16 ( main_prog, amp_lists=None )
Traverse all ops in current block and insert cast op according to which set current op belongs to.
When an op belongs to the fp32 list, add it to fp32 set
When an op belongs to the bf16 list, add it to bf16 set
When an op belongs to the gray list. If one of its inputs is the output of fp32 set op or fp32 list op, add it to fp32 set. If all of its previous ops are not fp32 op and one of its inputs is the output of bf16 set op or bf16 list op, add it to bf16 set.
When an op isn’t in the lists, add it to fp32 op set.
Add necessary cast ops to make sure that fp32 set op will be computed in fp32 mode, while bf16 set op will be computed in bf16 mode.
main_prog (Program) – The main program for training.