\u200E
训练精度媲美 AlphaFold2、速度翻倍,飞桨螺旋桨HelixFold训练和推理代码全面开源
发布日期:2022-07-21T08:12:21.000+0000 浏览量:1604次

2021年7月15日,DeepMind公司在Nature杂志上发表了题为"Highly accurate protein structure prediction with AlphaFold"的文章,系统介绍了一种端到端的从蛋白质序列预测蛋白质三维结构的神经网络算法—AlphaFold2。该算法预测的蛋白质结构能达到原子水平的准确度,被Science评选为2021年十大科学突破之首。虽然DeepMind公司开源了AlphaFold2推理代码,但是其训练代码一直未开源。从DeepMind公司发表的AlphaFold2论文看,完整从头训练AlphaFold2需要使用128张TPUv3训练11天,对计算资源的消耗是巨大的。科研机构和普通公司想要基于AlphaFold2探索解决蛋白领域的更多问题,例如蛋白质设计,新靶点发现等,也更加困难。因此,如何搭建一套性能更优、更加节省算力资源、支持适配国产硬件的蛋白结构预测模型,就成为亟待解决的问题。

在飞桨强大的高性能并行计算能力支持下,飞桨螺旋桨PaddleHelix 生物计算团队发布了蛋白结构预测模型HelixFold,围绕着显存峰值、训练速度、分布式策略进行了全面性能优化。通过与原版AlphaFold2模型和哥伦比亚大学Mohammed AlQuraishi 教授团队基于PyTorch复现的OpenFold模型的性能对比测试显示,HelixFold模型的训练性能相比AlphaFold2提升106.97%,相比 OpenFold 提升104.86%。

HelixFold 与AlphaFold2、OpenFold 端到端训练速度对比

HelixFold 之所以能够得到如此大的性能提升,源于如下几项技术创新:

分支并行与混合并行策略

AlphaFold2在使用 TPUv3训练模型时,每张卡上的 batch size只设置为 1,限制了数据样本维度扩卡加速训练的可能性。HelixFold创新性的提出分支并行(Branch Parallelism, BP)策略,将不同的网络模型分支放在不同的卡上并行计算,从而在 initial training 阶段大幅提高了模型并行效率和训练速度。并且,分支并行与已有的动态轴并行 (Dynamic Axial Parallelism, DAP) 和数据并行(Data Parallelism, DP) 结合使用,通过 BP-DAP-DP 三维混合并行,进一步加快了模型的整体训练速度。


算子融合优化技术和张量融合低频次访存技术

针对 AlphaFold2 中 Gated Self-Attention 小算子组合 CPU 调度开销大、模型参数小、参数个数多的问题,HelixFold 将 Gated Self-Attention 整个模块融合用一个算子实现,将CPU 调度开销优化到极致。同时,将数千个小张量融合成一个连续的大张量,模型参数的梯度、优化器状态都相应更新,大幅减少了访存次数、CPU 调度开销和显存碎片,从而提升了训练速度。


多维度显存优化方案

采用 Recompute、BFloat16、显存复用、Subbatch(Chunking)等技术,将显存峰值降低到 40G 以内,同时支持 MSA 长度为512、ExtraMSA 长度为 5120、残基序列长度为 384 的最大模型配置的微调训练,从而解决了模型结构深,中间结果计算量大,ExtraMSAStack 输入过长等导致无法训练的问题。


在性能大幅度提升的同时,HelixFold 从头端到端完整训练可以达到 AlphaFold2论文媲美的精度。在包含87个蛋白的CASP14数据集和包含371个蛋白的CAMEO数据集上,HelixFold模型 TM-score 指标分别达到0.8771和0.8885,与原版 AlphaFold2准确率相当甚至更优。

HelixFold 与AlphaFold2 精度对比
HelixFold是运用飞桨的高性能计算技术,显著提升模型性能的典型案例。不仅如此,飞桨与曙光 AC智算平台深度合作,将 HelixFold在曙光AC智算平台全面部署上线,通过曙光智算中心对外提供服务。同时,飞桨螺旋桨也正在全力支持“先导杯” AI for science赛道的比赛,希望能对参赛选手们有所启发。激发大家在AI for science领域的更多探索。也欢迎大家在曙光AC智算平台调用HelixFold模型。
HelixFold 端到端训练和推理代码现已全面向社区开源。

GitHub地址:https://github.com/PaddlePaddle/PaddleHelix/tree/dev/apps/protein_folding/helixfold

更多性能优化细节和数据分析参考技术报告:

HelixFold: An Efficient Implementation of AlphaFold2 using PaddlePaddle

https://arxiv.org/abs/2207.05477


拓展阅读:



关注【飞桨PaddlePaddle】公众号

获取更多技术内容~