[ 仅 API 调用方式不一致 ]torch.autograd.grad_mode.set_grad_enabled
torch.autograd.grad_mode.set_grad_enabled
torch.autograd.grad_mode.set_grad_enabled(mode)
转写示例
# PyTorch 写法
with torch.autograd.grad_mode.set_grad_enabled(is_train):
y = x * 2
# Paddle 写法
with paddle.set_grad_enabled(is_train):
y = x * 2
