点云处理:实现PointNet点云分类
作者:Zhihao Cao
 日期:2022.5
 摘要:本示例在于演示如何基于 PaddlePaddle 2.3.0 实现PointNet在ShapeNet数据集上进行点云分类处理。
一、环境设置
本教程基于PaddlePaddle 2.3.0 编写,如果你的环境不是本版本,请先参考官网安装。
import os
import numpy as np
import random
import h5py
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
print(paddle.__version__)
         2.3.0
         二、数据集
2.1 数据介绍
ShapeNet数据集是一个注释丰富且规模较大的 3D 形状数据集,由斯坦福大学、普林斯顿大学和芝加哥丰田技术学院于 2015 年联合发布。
 ShapeNet数据集官方链接:https://vision.princeton.edu/projects/2014/3DShapeNets/
 AIStudio链接:sharpnet数据集(经过整理)
 ShapeNet数据集的储存格式是h5文件,该文件中key值分别为:
- 1、data:这一份数据中所有点的xyz坐标, 
- 2、label:这一份数据所属类别,如airplane等, 
- 3、pid:这一份数据中所有点所属的类型,如这一份数据属airplane类,则它包含的所有点的类型有机翼、机身等类型。 
2.2 解压数据集
!unzip data/data70460/shapenet_part_seg_hdf5_data.zip
!mv hdf5_data dataset
          2.3 数据列表
ShapeNet数据集所有的数据文件。
train_list = [
    "ply_data_train0.h5",
    "ply_data_train1.h5",
    "ply_data_train2.h5",
    "ply_data_train3.h5",
    "ply_data_train4.h5",
    "ply_data_train5.h5",
]
test_list = ["ply_data_test0.h5", "ply_data_test1.h5"]
val_list = ["ply_data_val0.h5"]
          2.4 搭建数据生成器
说明:将ShapeNet数据集全部读入。
def make_data(mode="train", path="./dataset/", num_point=2048):
    datas = []
    labels = []
    if mode == "train":
        for file_list in train_list:
            f = h5py.File(os.path.join(path, file_list), "r")
            datas.extend(f["data"][:, :num_point, :])
            labels.extend(f["label"])
            f.close()
    elif mode == "test":
        for file_list in test_list:
            f = h5py.File(os.path.join(path, file_list), "r")
            datas.extend(f["data"][:, :num_point, :])
            labels.extend(f["label"])
            f.close()
    else:
        for file_list in val_list:
            f = h5py.File(os.path.join(path, file_list), "r")
            datas.extend(f["data"][:, :num_point, :])
            labels.extend(f["label"])
            f.close()
    return datas, labels
          说明:通过继承paddle.io.Dataset来完成数据集的构造。
class PointDataset(paddle.io.Dataset):
    def __init__(self, datas, labels):
        super().__init__()
        self.datas = datas
        self.labels = labels
    def __getitem__(self, index):
        data = paddle.to_tensor(self.datas[index].T.astype("float32"))
        label = paddle.to_tensor(self.labels[index].astype("int64"))
        return data, label
    def __len__(self):
        return len(self.datas)
          说明:使用飞桨框架提供的API:paddle.io.DataLoader完成数据的加载,使得按照Batchsize生成Mini-batch的数据。
# 数据导入
datas, labels = make_data(mode="train", num_point=2048)
train_dataset = PointDataset(datas, labels)
datas, labels = make_data(mode="val", num_point=2048)
val_dataset = PointDataset(datas, labels)
datas, labels = make_data(mode="test", num_point=2048)
test_dataset = PointDataset(datas, labels)
# 实例化数据读取器
train_loader = paddle.io.DataLoader(
    train_dataset, batch_size=128, shuffle=True, drop_last=False
)
val_loader = paddle.io.DataLoader(
    val_dataset, batch_size=32, shuffle=False, drop_last=False
)
test_loader = paddle.io.DataLoader(
    test_dataset, batch_size=128, shuffle=False, drop_last=False
)
          三、定义网络
PointNet是斯坦福大学研究人员提出的一个点云处理网络,在这篇论文中,它提出了空间变换网络(T-Net)解决点云的旋转问题(注:因为考虑到某一物体的点云旋转后还是该物体,所以需要有一个网络结构去学习并解决这个旋转问题),并且提出了采取MaxPooling的方法极大程度上地提取点云全局特征。
3.1 定义网络结构
class PointNet(nn.Layer):
    def __init__(self, name_scope="PointNet_", num_classes=16, num_point=2048):
        super().__init__()
        self.input_transform_net = nn.Sequential(
            nn.Conv1D(3, 64, 1),
            nn.BatchNorm(64),
            nn.ReLU(),
            nn.Conv1D(64, 128, 1),
            nn.BatchNorm(128),
            nn.ReLU(),
            nn.Conv1D(128, 1024, 1),
            nn.BatchNorm(1024),
            nn.ReLU(),
            nn.MaxPool1D(num_point),
        )
        self.input_fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(
                256,
                9,
                weight_attr=paddle.ParamAttr(
                    initializer=paddle.nn.initializer.Assign(
                        paddle.zeros((256, 9))
                    )
                ),
                bias_attr=paddle.ParamAttr(
                    initializer=paddle.nn.initializer.Assign(
                        paddle.reshape(paddle.eye(3), [-1])
                    )
                ),
            ),
        )
        self.mlp_1 = nn.Sequential(
            nn.Conv1D(3, 64, 1),
            nn.BatchNorm(64),
            nn.ReLU(),
            nn.Conv1D(64, 64, 1),
            nn.BatchNorm(64),
            nn.ReLU(),
        )
        self.feature_transform_net = nn.Sequential(
            nn.Conv1D(64, 64, 1),
            nn.BatchNorm(64),
            nn.ReLU(),
            nn.Conv1D(64, 128, 1),
            nn.BatchNorm(128),
            nn.ReLU(),
            nn.Conv1D(128, 1024, 1),
            nn.BatchNorm(1024),
            nn.ReLU(),
            nn.MaxPool1D(num_point),
        )
        self.feature_fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 64 * 64),
        )
        self.mlp_2 = nn.Sequential(
            nn.Conv1D(64, 64, 1),
            nn.BatchNorm(64),
            nn.ReLU(),
            nn.Conv1D(64, 128, 1),
            nn.BatchNorm(128),
            nn.ReLU(),
            nn.Conv1D(128, 1024, 1),
            nn.BatchNorm(1024),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(p=0.7),
            nn.Linear(256, num_classes),
            nn.LogSoftmax(axis=-1),
        )
    def forward(self, inputs):
        batchsize = inputs.shape[0]
        t_net = self.input_transform_net(inputs)
        t_net = paddle.squeeze(t_net, axis=-1)
        t_net = self.input_fc(t_net)
        t_net = paddle.reshape(t_net, [batchsize, 3, 3])
        x = paddle.transpose(inputs, (0, 2, 1))
        x = paddle.matmul(x, t_net)
        x = paddle.transpose(x, (0, 2, 1))
        x = self.mlp_1(x)
        t_net = self.feature_transform_net(x)
        t_net = paddle.squeeze(t_net, axis=-1)
        t_net = self.feature_fc(t_net)
        t_net = paddle.reshape(t_net, [batchsize, 64, 64])
        x = paddle.squeeze(x, axis=-1)
        x = paddle.transpose(x, (0, 2, 1))
        x = paddle.matmul(x, t_net)
        x = paddle.transpose(x, (0, 2, 1))
        x = self.mlp_2(x)
        x = paddle.max(x, axis=-1)
        x = paddle.squeeze(x, axis=-1)
        x = self.fc(x)
        return x
          3.2 网络结构可视化
说明:使用飞桨API:paddle.summary完成模型结构可视化
pointnet = PointNet()
paddle.summary(pointnet, (64, 3, 2048))
          W0509 16:16:31.949033   135 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0509 16:16:31.957976   135 device_context.cc:465] device: 0, cuDNN Version: 7.6.
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv1D-1       [[64, 3, 2048]]       [64, 64, 2048]          256      
  BatchNorm-1     [[64, 64, 2048]]      [64, 64, 2048]          256      
    ReLU-1        [[64, 64, 2048]]      [64, 64, 2048]           0       
   Conv1D-2       [[64, 64, 2048]]     [64, 128, 2048]         8,320     
  BatchNorm-2    [[64, 128, 2048]]     [64, 128, 2048]          512      
    ReLU-2       [[64, 128, 2048]]     [64, 128, 2048]           0       
   Conv1D-3      [[64, 128, 2048]]     [64, 1024, 2048]       132,096    
  BatchNorm-3    [[64, 1024, 2048]]    [64, 1024, 2048]        4,096     
    ReLU-3       [[64, 1024, 2048]]    [64, 1024, 2048]          0       
  MaxPool1D-1    [[64, 1024, 2048]]     [64, 1024, 1]            0       
   Linear-1         [[64, 1024]]          [64, 512]           524,800    
    ReLU-4          [[64, 512]]           [64, 512]              0       
   Linear-2         [[64, 512]]           [64, 256]           131,328    
    ReLU-5          [[64, 256]]           [64, 256]              0       
   Linear-3         [[64, 256]]            [64, 9]             2,313     
   Conv1D-4       [[64, 3, 2048]]       [64, 64, 2048]          256      
  BatchNorm-4     [[64, 64, 2048]]      [64, 64, 2048]          256      
    ReLU-6        [[64, 64, 2048]]      [64, 64, 2048]           0       
   Conv1D-5       [[64, 64, 2048]]      [64, 64, 2048]         4,160     
  BatchNorm-5     [[64, 64, 2048]]      [64, 64, 2048]          256      
    ReLU-7        [[64, 64, 2048]]      [64, 64, 2048]           0       
   Conv1D-6       [[64, 64, 2048]]      [64, 64, 2048]         4,160     
  BatchNorm-6     [[64, 64, 2048]]      [64, 64, 2048]          256      
    ReLU-8        [[64, 64, 2048]]      [64, 64, 2048]           0       
   Conv1D-7       [[64, 64, 2048]]     [64, 128, 2048]         8,320     
  BatchNorm-7    [[64, 128, 2048]]     [64, 128, 2048]          512      
    ReLU-9       [[64, 128, 2048]]     [64, 128, 2048]           0       
   Conv1D-8      [[64, 128, 2048]]     [64, 1024, 2048]       132,096    
  BatchNorm-8    [[64, 1024, 2048]]    [64, 1024, 2048]        4,096     
    ReLU-10      [[64, 1024, 2048]]    [64, 1024, 2048]          0       
  MaxPool1D-2    [[64, 1024, 2048]]     [64, 1024, 1]            0       
   Linear-4         [[64, 1024]]          [64, 512]           524,800    
    ReLU-11         [[64, 512]]           [64, 512]              0       
   Linear-5         [[64, 512]]           [64, 256]           131,328    
    ReLU-12         [[64, 256]]           [64, 256]              0       
   Linear-6         [[64, 256]]           [64, 4096]         1,052,672   
   Conv1D-9       [[64, 64, 2048]]      [64, 64, 2048]         4,160     
  BatchNorm-9     [[64, 64, 2048]]      [64, 64, 2048]          256      
    ReLU-13       [[64, 64, 2048]]      [64, 64, 2048]           0       
   Conv1D-10      [[64, 64, 2048]]     [64, 128, 2048]         8,320     
 BatchNorm-10    [[64, 128, 2048]]     [64, 128, 2048]          512      
    ReLU-14      [[64, 128, 2048]]     [64, 128, 2048]           0       
   Conv1D-11     [[64, 128, 2048]]     [64, 1024, 2048]       132,096    
 BatchNorm-11    [[64, 1024, 2048]]    [64, 1024, 2048]        4,096     
    ReLU-15      [[64, 1024, 2048]]    [64, 1024, 2048]          0       
   Linear-7         [[64, 1024]]          [64, 512]           524,800    
    ReLU-16         [[64, 512]]           [64, 512]              0       
   Linear-8         [[64, 512]]           [64, 256]           131,328    
    ReLU-17         [[64, 256]]           [64, 256]              0       
   Dropout-1        [[64, 256]]           [64, 256]              0       
   Linear-9         [[64, 256]]            [64, 16]            4,112     
 LogSoftmax-1        [[64, 16]]            [64, 16]              0       
===========================================================================
Total params: 3,476,825
Trainable params: 3,461,721
Non-trainable params: 15,104
---------------------------------------------------------------------------
Input size (MB): 1.50
Forward/backward pass size (MB): 11333.40
Params size (MB): 13.26
Estimated Total Size (MB): 11348.16
---------------------------------------------------------------------------
{'total_params': 3476825, 'trainable_params': 3461721}
          四、训练
说明:模型训练的时候,将会使用paddle.optimizer.Adam优化器来进行优化。使用F.nll_loss来计算损失值。
def train():
    model = PointNet(num_classes=16, num_point=2048)
    model.train()
    optim = paddle.optimizer.Adam(
        parameters=model.parameters(), weight_decay=0.001
    )
    epoch_num = 10
    for epoch in range(epoch_num):
        # train
        print(
            "===================================train==========================================="
        )
        for batch_id, data in enumerate(train_loader()):
            inputs, labels = data
            predicts = model(inputs)
            loss = F.nll_loss(predicts, labels)
            acc = paddle.metric.accuracy(predicts, labels)
            if batch_id % 20 == 0:
                print(
                    "train: epoch: {}, batch_id: {}, loss is: {}, accuracy is: {}".format(
                        epoch, batch_id, loss.numpy(), acc.numpy()
                    )
                )
            loss.backward()
            optim.step()
            optim.clear_grad()
        if epoch % 2 == 0:
            paddle.save(model.state_dict(), "./model/PointNet.pdparams")
            paddle.save(optim.state_dict(), "./model/PointNet.pdopt")
        # validation
        print(
            "===================================val==========================================="
        )
        model.eval()
        accuracies = []
        losses = []
        for batch_id, data in enumerate(val_loader()):
            inputs, labels = data
            predicts = model(inputs)
            loss = F.nll_loss(predicts, labels)
            acc = paddle.metric.accuracy(predicts, labels)
            losses.append(loss.numpy())
            accuracies.append(acc.numpy())
        avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
        print(
            "validation: loss is: {}, accuracy is: {}".format(avg_loss, avg_acc)
        )
        model.train()
if __name__ == "__main__":
    train()
         ===================================train===========================================
train: epoch: 0, batch_id: 0, loss is: [8.135595], accuracy is: [0.046875]
train: epoch: 0, batch_id: 20, loss is: [0.96110815], accuracy is: [0.7265625]
train: epoch: 0, batch_id: 40, loss is: [0.77762437], accuracy is: [0.8046875]
train: epoch: 0, batch_id: 60, loss is: [0.575164], accuracy is: [0.84375]
train: epoch: 0, batch_id: 80, loss is: [0.60243726], accuracy is: [0.8359375]
===================================val===========================================
validation: loss is: 0.5027859807014465, accuracy is: 0.848895251750946
===================================train===========================================
train: epoch: 1, batch_id: 0, loss is: [0.5886416], accuracy is: [0.8359375]
train: epoch: 1, batch_id: 20, loss is: [0.59509534], accuracy is: [0.8515625]
train: epoch: 1, batch_id: 40, loss is: [0.43501458], accuracy is: [0.875]
train: epoch: 1, batch_id: 60, loss is: [0.5497817], accuracy is: [0.8515625]
train: epoch: 1, batch_id: 80, loss is: [0.2889481], accuracy is: [0.8984375]
===================================val===========================================
validation: loss is: 0.2470872551202774, accuracy is: 0.9263771176338196
===================================train===========================================
train: epoch: 2, batch_id: 0, loss is: [0.43095332], accuracy is: [0.8984375]
train: epoch: 2, batch_id: 20, loss is: [0.42620662], accuracy is: [0.8984375]
train: epoch: 2, batch_id: 40, loss is: [0.31073096], accuracy is: [0.8984375]
train: epoch: 2, batch_id: 60, loss is: [0.21410619], accuracy is: [0.9375]
train: epoch: 2, batch_id: 80, loss is: [0.23696409], accuracy is: [0.9296875]
===================================val===========================================
validation: loss is: 0.24663102626800537, accuracy is: 0.9278147220611572
===================================train===========================================
train: epoch: 3, batch_id: 0, loss is: [0.1000444], accuracy is: [0.96875]
train: epoch: 3, batch_id: 20, loss is: [0.2845613], accuracy is: [0.9296875]
train: epoch: 3, batch_id: 40, loss is: [0.46592], accuracy is: [0.859375]
train: epoch: 3, batch_id: 60, loss is: [0.3819336], accuracy is: [0.9140625]
train: epoch: 3, batch_id: 80, loss is: [0.08518291], accuracy is: [0.9765625]
===================================val===========================================
validation: loss is: 0.17066480219364166, accuracy is: 0.9491525292396545
===================================train===========================================
train: epoch: 4, batch_id: 0, loss is: [0.11713062], accuracy is: [0.9609375]
train: epoch: 4, batch_id: 20, loss is: [0.1716559], accuracy is: [0.953125]
train: epoch: 4, batch_id: 40, loss is: [0.15082854], accuracy is: [0.96875]
train: epoch: 4, batch_id: 60, loss is: [0.2787561], accuracy is: [0.96875]
train: epoch: 4, batch_id: 80, loss is: [0.11986132], accuracy is: [0.9609375]
===================================val===========================================
validation: loss is: 0.1389710158109665, accuracy is: 0.9608050584793091
===================================train===========================================
train: epoch: 5, batch_id: 0, loss is: [0.17427993], accuracy is: [0.9453125]
train: epoch: 5, batch_id: 20, loss is: [0.25355965], accuracy is: [0.9609375]
train: epoch: 5, batch_id: 40, loss is: [0.18881711], accuracy is: [0.9609375]
train: epoch: 5, batch_id: 60, loss is: [0.14433464], accuracy is: [0.953125]
train: epoch: 5, batch_id: 80, loss is: [0.13028377], accuracy is: [0.96875]
===================================val===========================================
validation: loss is: 0.09753856807947159, accuracy is: 0.9671609997749329
===================================train===========================================
train: epoch: 6, batch_id: 0, loss is: [0.12662013], accuracy is: [0.9765625]
train: epoch: 6, batch_id: 20, loss is: [0.1309431], accuracy is: [0.9609375]
train: epoch: 6, batch_id: 40, loss is: [0.29988244], accuracy is: [0.9453125]
train: epoch: 6, batch_id: 60, loss is: [0.114668], accuracy is: [0.9609375]
train: epoch: 6, batch_id: 80, loss is: [0.48784435], accuracy is: [0.9296875]
===================================val===========================================
validation: loss is: 0.16411711275577545, accuracy is: 0.9576271176338196
===================================train===========================================
train: epoch: 7, batch_id: 0, loss is: [0.12558301], accuracy is: [0.9609375]
train: epoch: 7, batch_id: 20, loss is: [0.1776012], accuracy is: [0.953125]
train: epoch: 7, batch_id: 40, loss is: [0.12831621], accuracy is: [0.9609375]
train: epoch: 7, batch_id: 60, loss is: [0.15245995], accuracy is: [0.953125]
train: epoch: 7, batch_id: 80, loss is: [0.08825297], accuracy is: [0.9609375]
===================================val===========================================
validation: loss is: 0.06742173433303833, accuracy is: 0.9809321761131287
===================================train===========================================
train: epoch: 8, batch_id: 0, loss is: [0.07868354], accuracy is: [0.96875]
train: epoch: 8, batch_id: 20, loss is: [0.1875119], accuracy is: [0.96875]
train: epoch: 8, batch_id: 40, loss is: [0.04444], accuracy is: [0.9921875]
train: epoch: 8, batch_id: 60, loss is: [0.08977574], accuracy is: [0.9765625]
train: epoch: 8, batch_id: 80, loss is: [0.13062863], accuracy is: [0.9765625]
===================================val===========================================
validation: loss is: 0.13399624824523926, accuracy is: 0.9661017060279846
===================================train===========================================
train: epoch: 9, batch_id: 0, loss is: [0.14676869], accuracy is: [0.953125]
train: epoch: 9, batch_id: 20, loss is: [0.16409941], accuracy is: [0.9609375]
train: epoch: 9, batch_id: 40, loss is: [0.08795467], accuracy is: [0.96875]
train: epoch: 9, batch_id: 60, loss is: [0.05970801], accuracy is: [0.984375]
train: epoch: 9, batch_id: 80, loss is: [0.2631768], accuracy is: [0.9296875]
===================================val===========================================
validation: loss is: 0.11335306614637375, accuracy is: 0.9682203531265259
         五、评估与测试
说明:通过model.load_dict的方式加载训练好的模型对测试集上的数据进行评估与测试。
def evaluation():
    model = PointNet()
    model_state_dict = paddle.load("./model/PointNet.pdparams")
    model.load_dict(model_state_dict)
    model.eval()
    accuracies = []
    losses = []
    for batch_id, data in enumerate(test_loader()):
        inputs, labels = data
        predicts = model(inputs)
        loss = F.nll_loss(predicts, labels)
        acc = paddle.metric.accuracy(predicts, labels)
        losses.append(loss.numpy())
        accuracies.append(acc.numpy())
    avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
    print("validation: loss is: {}, accuracy is: {}".format(avg_loss, avg_acc))
if __name__ == "__main__":
    evaluation()