torch.nn.Module.bfloat16()
paddle.nn.Layer.to(dtype=paddle.bfloat16)
Paddle 相比 PyTorch 支持更多其他参数,具体如下:
# PyTorch 写法: module = torch.nn.Module() module.bfloat16() # Paddle 写法: module = paddle.nn.Layer() module.to(dtype=paddle.bfloat16)