Layer

class paddle.nn. Layer ( name_scope=None, dtype='float32' ) [源代码]

基于 OOD 实现的动态图 Layer,包含该 Layer 的参数、前序运行的结构等信息。

参数

  • name_scope (str,可选) - 为 Layer 内部参数命名而采用的名称前缀。如果前缀为“my_layer”,在一个类名为 MyLayer 的 Layer 中,参数名为“mylayer_0.w_n”,其中 w 是参数的名称,n 为自动生成的具有唯一性的后缀。如果为 None,前缀名将为小写的类名。默认值为 None。

  • dtype (str 可选) - Layer 中参数数据类型。如果设置为 str,则可以是“bool”,“float16”,“float32”,“float64”,“int8”,“int16”,“int32”,“int64”,“uint8”或“uint16”。默认值为 "float32"。

返回

代码示例

>>> import paddle
>>> paddle.seed(100)

>>> class MyLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self._linear = paddle.nn.Linear(1, 1)
...         self._dropout = paddle.nn.Dropout(p=0.5)
...
...     def forward(self, input):
...         temp = self._linear(input)
...         temp = self._dropout(temp)
...         return temp
...
>>> x = paddle.randn([10, 1], 'float32')
>>> mylayer = MyLayer()
>>> mylayer.eval()  # set mylayer._dropout to eval mode
>>> out = mylayer(x)
>>> mylayer.train()  # set mylayer._dropout to train mode
>>> out = mylayer(x)
>>> print(out)
Tensor(shape=[10, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[-3.44879317],
 [ 0.        ],
 [ 0.        ],
 [-0.73825276],
 [ 0.        ],
 [ 0.        ],
 [ 0.64444798],
 [-3.22185946],
 [ 0.        ],
 [-0.68077987]])

方法

train()

将此层及其所有子层设置为训练模式。这只会影响某些模块,如 Dropout 和 BatchNorm。

代码示例

>>> import paddle
>>> paddle.seed(100)

>>> class MyLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self._linear = paddle.nn.Linear(1, 1)
...         self._dropout = paddle.nn.Dropout(p=0.5)
...
...     def forward(self, input):
...         temp = self._linear(input)
...         temp = self._dropout(temp)
...         return temp
...
>>> x = paddle.randn([10, 1], 'float32')
>>> mylayer = MyLayer()
>>> mylayer.eval()  # set mylayer._dropout to eval mode
>>> out = mylayer(x)
>>> mylayer.train()  # set mylayer._dropout to train mode
>>> out = mylayer(x)
>>> print(out)
Tensor(shape=[10, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[-3.44879317],
 [ 0.        ],
 [ 0.        ],
 [-0.73825276],
 [ 0.        ],
 [ 0.        ],
 [ 0.64444798],
 [-3.22185946],
 [ 0.        ],
 [-0.68077987]])

eval()

将此层及其所有子层设置为预测模式。这只会影响某些模块,如 Dropout 和 BatchNorm。

返回

代码示例

>>> import paddle
>>> paddle.seed(100)
>>> class MyLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self._linear = paddle.nn.Linear(1, 1)
...         self._dropout = paddle.nn.Dropout(p=0.5)
...
...     def forward(self, input):
...         temp = self._linear(input)
...         temp = self._dropout(temp)
...         return temp
...
>>> x = paddle.randn([10, 1], 'float32')
>>> mylayer = MyLayer()
>>> mylayer.eval()  # set mylayer._dropout to eval mode
>>> out = mylayer(x)
>>> print(out)
Tensor(shape=[10, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[-1.72439659],
 [ 0.31532824],
 [ 0.01192369],
 [-0.36912638],
 [-1.63426113],
 [-0.93169814],
 [ 0.32222399],
 [-1.61092973],
 [ 0.77209264],
 [-0.34038994]])

apply(fn)

将一个函数 fn 递归地应用到网络的每一个子层(即在函数的 .sublayers() 中返回的子层)以及模块自身。该方法通常用来初始化一个模型中的参数。

参数

  • fn (function) - 应用到每一个子层的函数

返回 Layer (返回网络层), self (返回自身)

代码示例

>>> import paddle
>>> import paddle.nn as nn
>>> paddle.seed(2023)

>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

>>> def init_weights(layer):
...     if type(layer) == nn.Linear:
...         print('before init weight:', layer.weight.numpy())
...         new_weight = paddle.full(shape=layer.weight.shape, dtype=layer.weight.dtype, fill_value=0.9)
...         layer.weight.set_value(new_weight)
...         print('after init weight:', layer.weight.numpy())
...
>>> net.apply(init_weights)

>>> print(net.state_dict())
before init weight: [[ 0.89611185  0.04935038]
                     [-0.5888344   0.99266374]]
after init weight: [[0.9 0.9]
                    [0.9 0.9]]
before init weight: [[-0.18615901 -0.22924072]
                     [ 1.1517721   0.59859073]]
after init weight: [[0.9 0.9]
                    [0.9 0.9]]
OrderedDict([('0.weight', Parameter containing:
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.89999998, 0.89999998],
 [0.89999998, 0.89999998]])), ('0.bias', Parameter containing:
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=False,
[0., 0.])), ('1.weight', Parameter containing:
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.89999998, 0.89999998],
 [0.89999998, 0.89999998]])), ('1.bias', Parameter containing:
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=False,
[0., 0.]))])

full_name()

Layer 的全名。组成方式为:name_scope + “/” + MyLayer.__class__.__name__ 。

返回 str, Layer 的全名

代码示例

>>> import paddle

>>> class LinearNet(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__(name_scope = "demo_linear_net")
...         self._linear = paddle.nn.Linear(1, 1)
...
...     def forward(self, x):
...         return self._linear(x)
...
>>> linear_net = LinearNet()
>>> print(linear_net.full_name())
demo_linear_net_0

register_forward_pre_hook(hook)

为 Layer 注册一个 forward pre-hook 函数,该 hook 函数将会在 forward 函数调用之前被调用。

hook 函数具有以下形式:它的 inputLayerinput,并且可以返回一个元组或者单个修改值;如果返回单个修改值,则将值包装到一个元组中。用户可以使用该函数来查看或修改 Layer forward 函数的输入。

hook(Layer, input) -> None or modified input

参数

  • hook (function) - 被注册为 forward pre-hook 的函数

返回 HookRemoveHelper,可通过调用 hook_remove_helper.remove() 来删除注册的 hook 函数。

代码示例

>>> import paddle
>>> import numpy as np

>>> # the forward_pre_hook change the input of the layer: input = input * 2
>>> def forward_pre_hook(layer, input):
...     # user can use layer and input for information statistis tasks
...
...     # change the input
...     input_return = (input[0] * 2)
...     return input_return
...
>>> linear = paddle.nn.Linear(13, 5)

>>> # register the hook
>>> forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook)

>>> value0 = np.arange(26).reshape(2, 13).astype("float32")
>>> in0 = paddle.to_tensor(value0)
>>> out0 = linear(in0)

>>> # remove the hook
>>> forward_pre_hook_handle.remove()

>>> value1 = value0 * 2
>>> in1 = paddle.to_tensor(value1)
>>> out1 = linear(in1)

>>> # hook change the linear's input to input * 2, so out0 is equal to out1.
>>> assert (out0.numpy() == out1.numpy()).any()

register_forward_post_hook(hook)

为 Layer 注册一个 forward post-hook 函数,该 hook 函数将会在 forward 函数调用之后被调用。

hook 函数具有以下形式,它的 inputoutputLayerinputoutput。用户可以用该函数来查看和修改 Layer forward 函数的输出。

hook(Layer, input, output) -> None or modified output

参数

  • hook (function) - 被注册为 forward post-hook 的函数

返回 HookRemoveHelper,可通过调用 hook_remove_helper.remove() 来删除注册的 hook 函数。

代码示例

>>> import paddle
>>> import numpy as np

>>> # the forward_post_hook change the output of the layer: output = output * 2
>>> def forward_post_hook(layer, input, output):
...     # user can use layer, input and output for information statistis tasks
...
...     # change the output
...     return output * 2
...
>>> linear = paddle.nn.Linear(13, 5)

>>> # register the hook
>>> forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)

>>> value1 = np.arange(26).reshape(2, 13).astype("float32")
>>> in1 = paddle.to_tensor(value1)

>>> out0 = linear(in1)

>>> # remove the hook
>>> forward_post_hook_handle.remove()

>>> out1 = linear(in1)

>>> # hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
>>> assert (out0.numpy() == (out1.numpy()) * 2).any()

create_parameter(shape, attr=None, dtype="float32", is_bias=False, default_initializer=None)

为 Layer 创建参数。

参数

  • shape (list) - 参数的形状。列表中的数据类型必须为 int。

  • attr (ParamAttr,可选) - 指定权重参数属性的对象,表示使用默认的权重参数属性。具体用法请参见 ParamAttr。默认值为 None。

  • dtype (str|core.VarDesc.VarType,可选) - Layer 中参数数据类型。如果设置为 str,则可以是“bool”,“float16”,“float32”,“float64”,“int8”,“int16”,“int32”,“int64”,“uint8”或“uint16”。默认值为“float32”。

  • is_bias (bool,可选) - 是否是偏置参数。默认值:False。

  • default_initializer (Initializer,可选) - 默认的参数初始化方法。如果设置为 None,则设置非 bias 参数的初始化方式为 paddle.nn.initializer.Xavier,设置 bias 参数的初始化方式为 paddle.nn.initializer.Constant。默认值:None。

返回 Tensor,创建的参数变量

代码示例

>>> import paddle
>>> paddle.seed(2023)

>>> class MyLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self._linear = paddle.nn.Linear(1, 1)
...         w_tmp = self.create_parameter([1,1])
...         self.add_parameter("w_tmp", w_tmp)
...
...     def forward(self, input):
...         return self._linear(input)
...
>>> mylayer = MyLayer()
>>> for name, param in mylayer.named_parameters():
...     print(name, param)      # will print w_tmp,_linear.weight,_linear.bias
w_tmp Parameter containing:
Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.06979191]])
_linear.weight Parameter containing:
Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[1.26729357]])
_linear.bias Parameter containing:
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False,
[0.])

create_variable(name=None, persistable=None, dtype=None)

为 Layer 创建变量。

参数

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

  • persistable (bool,可选) - 是否为持久性变量,后续会被移出。默认值:None。

  • dtype (str,可选) - Layer 中参数数据类型。如果设置为 str,则可以是“bool”,“float16”,“float32”,“float64”,“int8”,“int16”,“int32”,“int64”,“uint8”或“uint16”。默认值为 "float32" 。

返回 Tensor,返回创建的 Tensor

代码示例

>>> import paddle

>>> class MyLinear(paddle.nn.Layer):
...     def __init__(self,
...                 in_features,
...                 out_features):
...         super().__init__()
...         self.linear = paddle.nn.Linear( 10, 10)
...
...         self.back_var = self.create_variable(name = "linear_tmp_0", dtype=self._dtype)
...
...     def forward(self, input):
...         out = self.linear(input)
...         paddle.assign( out, self.back_var)
...
...         return out

create_tensor(name=None, persistable=None, dtype=None)

为 Layer 创建变量。

参数

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

  • persistable (bool,可选) - 是否为持久性变量,后续会被移出。默认值:None。

  • dtype (str,可选) - Layer 中参数数据类型。如果设置为 str,则可以是“bool”,“float16”,“float32”,“float64”,“int8”,“int16”,“int32”,“int64”,“uint8”或“uint16”。默认值为 "float32" 。

返回 Tensor,返回创建的 Tensor

代码示例

>>> import paddle

>>> class MyLinear(paddle.nn.Layer):
...     def __init__(self,
...                  in_features,
...                  out_features):
...         super().__init__()
...         self.linear = paddle.nn.Linear(10, 10)
...
...         self.back_var = self.create_tensor(name = "linear_tmp_0", dtype=self._dtype)
...
...     def forward(self, input):
...         out = self.linear(input)
...         paddle.assign(out, self.back_var)
...
...         return out

parameters(include_sublayers=True)

返回一个由当前层及其子层的所有参数组成的列表。

参数

  • include_sublayers (bool,可选) - 是否返回子层的参数。如果为 True,返回的列表中包含子层的参数。默认值:True。

返回 list,一个由当前层及其子层的所有参数组成的列表,列表中的元素类型为 Parameter(Tensor)。

代码示例

>>> import paddle
>>> paddle.seed(100)

>>> linear = paddle.nn.Linear(1, 1)
>>> print(linear.parameters())
[Parameter containing:
Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.18551230]]), Parameter containing:
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False,
[0.])]

children()

返回所有子层的迭代器。

返回 iterator,子层的迭代器。

代码示例

>>> import paddle

>>> linear1 = paddle.nn.Linear(10, 3)
>>> linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
>>> model = paddle.nn.Sequential(linear1, linear2)

>>> layer_list = list(model.children())

>>> print(layer_list)
[Linear(in_features=10, out_features=3, dtype=float32), Linear(in_features=3, out_features=10, dtype=float32)]

named_children()

返回所有子层的迭代器,生成子层名称和子层的元组。

返回 iterator,产出子层名称和子层的元组的迭代器。

代码示例

>>> import paddle

>>> linear1 = paddle.nn.Linear(10, 3)
>>> linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
>>> model = paddle.nn.Sequential(linear1, linear2)
>>> for prefix, layer in model.named_children():
...     print(prefix, layer)
0 Linear(in_features=10, out_features=3, dtype=float32)
1 Linear(in_features=3, out_features=10, dtype=float32)

sublayers(include_self=False)

返回一个由所有子层组成的列表。

参数

  • include_self (bool,可选) - 是否包含本层。如果为 True,则包括本层。默认值:False

返回

list,一个由所有子层组成的列表,列表中的元素类型为 Layer。

代码示例

>>> import paddle

>>> class MyLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self._linear = paddle.nn.Linear(1, 1)
...         self._dropout = paddle.nn.Dropout(p=0.5)
...
...     def forward(self, input):
...         temp = self._linear(input)
...         temp = self._dropout(temp)
...         return temp
...
>>> mylayer = MyLayer()
>>> print(mylayer.sublayers())
[Linear(in_features=1, out_features=1, dtype=float32), Dropout(p=0.5, axis=None, mode=upscale_in_train)]

clear_gradients()

清除该层所有参数的梯度。

返回

代码示例

>>> import paddle
>>> import numpy as np

>>> value = np.arange(26).reshape(2, 13).astype("float32")
>>> a = paddle.to_tensor(value)
>>> linear = paddle.nn.Linear(13, 5)
>>> adam = paddle.optimizer.Adam(learning_rate=0.01,
...                              parameters=linear.parameters())
>>> out = linear(a)
>>> out.backward()
>>> adam.step()
>>> linear.clear_gradients()

named_parameters(prefix='', include_sublayers=True)

返回层中所有参数的迭代器,生成名称和参数的元组。

参数

  • prefix (str,可选) - 在所有参数名称前加的前缀。默认值:''。

  • include_sublayers (bool,可选) - 是否返回子层的参数。如果为 True,返回的列表中包含子层的参数。默认值:True。

返回 iterator,产出名称和参数的元组的迭代器。

代码示例

>>> import paddle
>>> paddle.seed(100)

>>> fc1 = paddle.nn.Linear(10, 3)
>>> fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
>>> model = paddle.nn.Sequential(fc1, fc2)
>>> for name, param in model.named_parameters():
...     print(name, param)
0.weight Parameter containing:
Tensor(shape=[10, 3], dtype=float32, place=Place(cpu), stop_gradient=False,
[[ 0.07276392, -0.39791510, -0.66356444],
 [ 0.02143478, -0.18519843, -0.32485050],
 [-0.42249614,  0.08450919, -0.66838276],
 [ 0.38208580, -0.24303678,  0.55127048],
 [ 0.47745085,  0.62117910, -0.08336520],
 [-0.28653207,  0.47237599, -0.05868882],
 [-0.14385653,  0.29945642,  0.12832761],
 [-0.21237159,  0.38539791, -0.62760031],
 [ 0.02637231,  0.20621127,  0.43255770],
 [-0.19984481, -0.26259184, -0.29696006]])
0.bias Parameter containing:
Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=False,
[0., 0., 0.])
1.weight Parameter containing:
Tensor(shape=[3, 10], dtype=float32, place=Place(cpu), stop_gradient=False,
[[ 0.01985580, -0.40268910,  0.41172385, -0.47249708, -0.09002256,
 -0.00533628, -0.52048630,  0.62360322,  0.20848787, -0.02033746],
 [ 0.58281910,  0.12841827,  0.12907702,  0.02325618, -0.07746267,
 0.31950659, -0.37924835, -0.59209681, -0.11732036, -0.58378261],
 [-0.62100595,  0.22293305,  0.28229684, -0.03687060, -0.59323978,
 0.08411229,  0.53275704,  0.40431368,  0.03171402, -0.17922515]])

named_sublayers(prefix='', include_self=False, layers_set=None)

返回层中所有子层上的迭代器,生成名称和子层的元组。重复的子层只产生一次。

参数

  • prefix (str,可选) - 在所有参数名称前加的前缀。默认值:''。

  • include_self (bool,可选) - 是否包含该层自身。默认值:False。

  • layers_set (set,可选):记录重复子层的集合。默认值:None。

返回 iterator,产出名称和子层的元组的迭代器。

代码示例

>>> import paddle

>>> fc1 = paddle.nn.Linear(10, 3)
>>> fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
>>> model = paddle.nn.Sequential(fc1, fc2)
>>> for prefix, layer in model.named_sublayers():
...     print(prefix, layer)
0 Linear(in_features=10, out_features=3, dtype=float32)
1 Linear(in_features=3, out_features=10, dtype=float32)

register_buffer(name, tensor, persistable=True)

将一个 Tensor 注册为 buffer。

buffer 是一个不可训练的变量,不会被优化器更新,但在评估或预测阶段可能是必要的状态变量。比如 BatchNorm 中的均值和方差。

注册的 buffer 默认是可持久性的,会被保存到 state_dict 中。如果指定 persistable 参数为 False,则会注册一个非持久性的 buffer,即不会同步和保存到 state_dict 中。

参数

  • name (str) - 注册 buffer 的名字。可以通过此名字来访问已注册的 buffer。

  • tensor (Tensor) - 将被注册为 buffer 的变量。

  • persistable (bool,可选) - 注册的 buffer 是否需要可持久性地保存到 state_dict 中。

返回 None

代码示例

>>> import numpy as np
>>> import paddle

>>> linear = paddle.nn.Linear(10, 3)
>>> value = np.array([0]).astype("float32")
>>> buffer = paddle.to_tensor(value)
>>> linear.register_buffer("buf_name", buffer, persistable=True)

>>> # get the buffer by attribute.
>>> print(linear.buf_name)
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.])

buffers(include_sublayers=True)

返回一个由当前层及其子层的所有 buffers 组成的列表。

参数

  • include_sublayers (bool,可选) - 是否返回子层的 buffers。如果为 True,返回的列表中包含子层的 buffers。默认值:True。

返回 list,一个由当前层及其子层的所有 buffers 组成的列表,列表中的元素类型为 Tensor。

代码示例

>>> import numpy as np
>>> import paddle

>>> linear = paddle.nn.Linear(10, 3)
>>> value = np.array([0]).astype("float32")
>>> buffer = paddle.to_tensor(value)
>>> linear.register_buffer("buf_name", buffer, persistable=True)

>>> print(linear.buffers())
[Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.])]

named_buffers(prefix='', include_sublayers=True)

返回层中所有 buffers 的迭代器,生成名称和 buffer 的元组。

参数

  • prefix (str,可选) - 在所有 buffer 名称前加的前缀。默认值:''。

  • include_sublayers (bool,可选) - 是否返回子层的 buffers。如果为 True,返回的列表中包含子层的 buffers。默认值:True。

返回 iterator,产出名称和 buffer 的元组的迭代器。

代码示例

>>> import numpy as np
>>> import paddle

>>> fc1 = paddle.nn.Linear(10, 3)
>>> buffer1 = paddle.to_tensor(np.array([0]).astype("float32"))
>>> # register a tensor as buffer by specific `persistable`
>>> fc1.register_buffer("buf_name_1", buffer1, persistable=True)

>>> fc2 = paddle.nn.Linear(3, 10)
>>> buffer2 = paddle.to_tensor(np.array([1]).astype("float32"))
>>> # register a buffer by assigning an attribute with Tensor.
>>> # The `persistable` can only be False by this way.
>>> fc2.buf_name_2 = buffer2

>>> model = paddle.nn.Sequential(fc1, fc2)

>>> # get all named buffers
>>> for name, buffer in model.named_buffers():
...     print(name, buffer)
0.buf_name_1 Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.])
1.buf_name_2 Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[1.])

forward(*inputs, **kwargs)

定义每次调用时执行的计算。应该被所有子类覆盖。

参数

  • *inputs (tuple) - 解包后的 tuple 参数。

  • **kwargs (dict) - 解包后的 dict 参数。

返回

add_sublayer(name, sublayer)

添加子层实例。可以通过 self.name 访问该 sublayer。

参数

  • name (str) - 子层名。

  • sublayer (Layer) - Layer 实例。

返回 Layer,添加的子层

代码示例

>>> import paddle

>>> class MySequential(paddle.nn.Layer):
...     def __init__(self, *layers):
...         super().__init__()
...         if len(layers) > 0 and isinstance(layers[0], tuple):
...             for name, layer in layers:
...                 self.add_sublayer(name, layer)
...         else:
...             for idx, layer in enumerate(layers):
...                 self.add_sublayer(str(idx), layer)
...
...     def forward(self, input):
...         for layer in self._sub_layers.values():
...             input = layer(input)
...         return input
...
>>> fc1 = paddle.nn.Linear(10, 3)
>>> fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
>>> model = MySequential(fc1, fc2)
>>> for prefix, layer in model.named_sublayers():
...     print(prefix, layer)
0 Linear(in_features=10, out_features=3, dtype=float32)
1 Linear(in_features=3, out_features=10, dtype=float32)

add_parameter(name, parameter)

添加参数实例。可以通过 self.name 访问该 parameter。

参数

  • name (str) - 参数名。

  • parameter (Parameter) - Parameter 实例。

返回 Parameter,传入的参数实例

代码示例

>>> import paddle
>>> paddle.seed(100)

>>> class MyLayer(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self._linear = paddle.nn.Linear(1, 1)
...         w_tmp = self.create_parameter([1,1])
...         self.add_parameter("w_tmp", w_tmp)
...
...     def forward(self, input):
...         return self._linear(input)
...
>>> mylayer = MyLayer()
>>> for name, param in mylayer.named_parameters():
...     print(name, param)
w_tmp Parameter containing:
Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[-1.01448846]])
_linear.weight Parameter containing:
Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.18551230]])
_linear.bias Parameter containing:
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False,
[0.])

state_dict(destination=None, include_sublayers=True, use_hook=True)

获取当前层及其子层的所有参数和可持久性 buffers。并将所有参数和 buffers 存放在 dict 结构中。

参数

  • destination (dict,可选) - 如果提供 destination,则所有参数和可持久性 buffers 都将存放在 destination 中。默认值:None。

  • include_sublayers (bool,可选) - 如果设置为 True,则包括子层的参数和 buffers。默认值:True。

  • use_hook (bool,可选) - 如果设置为 True,将_state_dict_hooks 中注册的函数应用于 destination。默认值:True。

返回 dict,包含所有参数和可持久行 buffers 的 dict

代码示例

>>> import paddle

>>> emb = paddle.nn.Embedding(10, 10)

>>> state_dict = emb.state_dict()
>>> paddle.save( state_dict, "paddle_dy.pdparams")

set_state_dict(state_dict, use_structured_name=True)

根据传入的 state_dict 设置参数和可持久性 buffers。所有参数和 buffers 将由 state_dict 中的 Tensor 设置。

参数

  • state_dict (dict) - 包含所有参数和可持久性 buffers 的 dict。

  • use_structured_name (bool,可选) - 如果设置为 True,将使用 Layer 的结构性变量名作为 dict 的 key,否则将使用 Parameter 或者 Buffer 的变量名作为 key。默认值:True。

返回
  • missing_keys (list) - 没有匹配到的参数名列表

  • unexpected_keys (list) - state_dict 传入的无效的参数名列表

代码示例

>>> import paddle

>>> emb = paddle.nn.Embedding(10, 10)

>>> state_dict = emb.state_dict()
>>> paddle.save(state_dict, "paddle_dy.pdparams")
>>> para_state_dict = paddle.load("paddle_dy.pdparams")
>>> emb.set_state_dict(para_state_dict)

to(device=None, dtype=None, blocking=None)

根据给定的 device、dtype 和 blocking 转换 Layer 中的 parameters 和 buffers。

参数

  • device (str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None,可选) - 希望存储 Layer 的设备位置。如果为 None,设备位置和原始的 Tensor 的设备位置一致。如果设备位置是 string 类型,取值可为 cpu, gpu:x and xpu:x,这里的 x 是 GPUs 或者 XPUs 的编号。默认值:None。

  • dtype (str|numpy.dtype|paddle.dtype|None,可选) - 数据的类型。如果为 None,数据类型和原始的 Tensor 一致。默认值:None。

  • blocking (bool|None,可选)- 如果为 False 并且当前 Tensor 处于固定内存上,将会发生主机到设备端的异步拷贝。否则,会发生同步拷贝。如果为 None,blocking 会被设置为 True。默认为 False。

代码示例

>>> import paddle
>>> paddle.seed(2023)

>>> linear=paddle.nn.Linear(2, 2)
>>> linear.weight
>>> print(linear.weight)
Parameter containing:
Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
[[ 0.89611185,  0.04935038],
 [-0.58883440,  0.99266374]])

>>> linear.to(dtype='float64')
>>> linear.weight
>>> print(linear.weight)
Parameter containing:
Tensor(shape=[2, 2], dtype=float64, place=Place(gpu:0), stop_gradient=False,
[[ 0.89611185,  0.04935038],
 [-0.58883440,  0.99266374]])

>>> linear.to(device='cpu')
>>> linear.weight
>>> print(linear.weight)
Parameter containing:
Tensor(shape=[2, 2], dtype=float64, place=Place(cpu), stop_gradient=False,
[[ 0.89611185,  0.04935038],
 [-0.58883440,  0.99266374]])

>>> linear.to(device=paddle.CUDAPinnedPlace(), blocking=False)
>>> linear.weight
>>> print(linear.weight)
Tensor(shape=[2, 2], dtype=float64, place=Place(gpu_pinned), stop_gradient=False,
[[ 0.89611185,  0.04935038],
 [-0.58883440,  0.99266374]])

astype(dtype=None)

将 Layer 的所有 parametersbuffers 的数据类型转换为 dtype,并返回这个 Layer。

参数

  • dtype (str | paddle.dtype | numpy.dtype) - 转换后的 dtype,str 类型支持"bool", "bfloat16", "float16", "float32", "float64", "int8", "int16", "int32", "int64", "uint8", "complex64", "complex128"。

返回:类型转换后的 Layer

返回类型:Layer

代码示例

>>> import paddle
>>> import paddle.nn as nn
>>> weight_attr = paddle.ParamAttr(name="weight",initializer=paddle.nn.initializer.Constant(value=1.5))
>>> bias_attr = paddle.ParamAttr(name="bias",initializer=paddle.nn.initializer.Constant(value=2.5))

>>> linear = paddle.nn.Linear(2, 2, weight_attr=weight_attr, bias_attr=bias_attr).to(device="cpu",dtype="float32")
>>> print(linear)
Linear(in_features=2, out_features=2, dtype=float32)
>>> print(linear.parameters())
[Parameter containing:
Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=False,
    [[1.50000000, 1.50000000],
        [1.50000000, 1.50000000]]), Parameter containing:
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=False,
    [2.50000000, 2.50000000])]

>>> linear=linear.astype("int8")
>>> print(linear)
Linear(in_features=2, out_features=2, dtype=paddle.int8)
>>> print(linear.parameters())
[Parameter containing:
Tensor(shape=[2, 2], dtype=int8, place=Place(cpu), stop_gradient=False,
    [[1, 1],
        [1, 1]]), Parameter containing:
Tensor(shape=[2], dtype=int8, place=Place(cpu), stop_gradient=False,
    [2, 2])]

float(excluded_layers=None)

将所有浮点型的参数和通过 register_buffers() 注册的 Buffer 变量转换为 float 数据类型。

参数

  • excluded_layers (list|tuple|nn.Layer|None,可选) - 不需要转换数据类型的层。如果 excluded_layers 为 None,则转换所有浮点参数和缓冲区,默认值:None。

代码示例

>>> import paddle

>>> class Model(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.linear = paddle.nn.Linear(1, 1)
...         self.dropout = paddle.nn.Dropout(p=0.5)
...
...     def forward(self, input):
...         out = self.linear(input)
...         out = self.dropout(out)
...         return out
...
>>> model = Model()
>>> model.float()
Model(
    (linear): Linear(in_features=1, out_features=1, dtype=paddle.float32)
    (dropout): Dropout(p=0.5, axis=None, mode=upscale_in_train)
)

float16(excluded_layers=None)

将所有浮点型的参数和通过 register_buffers() 注册的 Buffer 变量转换为 float16 数据类型。

注解

nn.BatchNorm 不支持 float16 类型的权重,默认不对其权重进行类型转换。

参数

  • excluded_layers (list|tuple|nn.Layer|None,可选) - 不需要转换数据类型的层。如果 excluded_layers 为 None,则转换除 nn.BatchNorm 之外的所有浮点参数和缓冲区,默认值:None。

代码示例

>>> import paddle

>>> class Model(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.linear = paddle.nn.Linear(1, 1)
...         self.dropout = paddle.nn.Dropout(p=0.5)
...
...     def forward(self, input):
...         out = self.linear(input)
...         out = self.dropout(out)
...         return out
...
>>> model = Model()
>>> model.float16()
Model(
    (linear): Linear(in_features=1, out_features=1, dtype=float32)
    (dropout): Dropout(p=0.5, axis=None, mode=upscale_in_train)
)

bfloat16(excluded_layers=None)

将所有浮点型的参数和通过 register_buffers() 注册的 Buffer 变量转换为 bfloat16 数据类型。

注解

nn.BatchNorm 不支持 bfloat16 类型的权重,默认不对其权重进行类型转换。

参数

  • excluded_layers (list|tuple|nn.Layer|None,可选) - 不需要转换数据类型的层。如果 excluded_layers 为 None,则转换除 nn.BatchNorm 之外的所有浮点参数和缓冲区,默认值:None。

代码示例

>>> import paddle

>>> class Model(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.linear = paddle.nn.Linear(1, 1)
...         self.dropout = paddle.nn.Dropout(p=0.5)
...
...     def forward(self, input):
...         out = self.linear(input)
...         out = self.dropout(out)
...         return out
...
>>> model = Model()
>>> model.bfloat16()
>>> #UserWarning: Paddle compiled by the user does not support bfloat16, so keep original data type.
Model(
    (linear): Linear(in_features=1, out_features=1, dtype=float32)
    (dropout): Dropout(p=0.5, axis=None, mode=upscale_in_train)
)