百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

LLM分布式训练第四课-模型并行之张量并行

toyiye 2024-06-21 12:39 9 浏览 0 评论

前文提要:

上一篇:LLM分布式训练第三课-模型并行之流水线并行 (qq.com)

上上一篇:LLM分布式训练第二课(数据并行) (qq.com)

上上上一篇:LLM分布式训练第二课(数据并行) (qq.com)


张量并行不是张亮并行,不是麻辣烫

如果上一节介绍的流水线并行是把模型基于层给进行了划分,来让多张GPU的显存可以承载规模较大的模型,那么这一节介绍的张量并行就正好用另外一个角度来解决单张GPU显存不足的问题。

张量并行其实也有两个细分的子流派,行并行和列并行。

我们用 GEMM 来拆解模型如何并行,以Y =XA 举例,对于模型来说,X 是输入,A是权重,Y是输出。

行并行(Row Parallelism):

行并行简单说就是把权重A给按照行来分割为2部分,为了输入X要去匹配A被按行切分的状态来进行计算,所以把X也给切成2部分,因为要矩阵乘,所以X得竖着切,如下图所示。

而Y=XA就被拆解成:

行并行

X1和A1就被放在一块GPU上进行计算,而X2和A2就放另一块GPU上进行计算,通过这个方式完成了模型的并行。

行并行计算-1

行并行计算-2

如上两图所示,两块GPU分别算出来了各自的矩阵Y1和Y2,然后用矩阵加法将两个矩阵的值进行相加,得到和原来计算结果等值的矩阵Y。

列并行(Column Parallelism):

列并行计算-1


列并行计算-2

列并行的计算方式和行并行有很大的区别,期中最重要的就是三点:

1- 列并行的A是按着矩阵的列去进行拆分的

2- 列并行的输入X是不需要拆分的,因为矩阵乘,行乘以列,A进行列切分,列维度没变,维度是相等的。

3- 最后的Y1和Y2不是相加的关系,是contact的关系,将两个矩阵合为一个矩阵Y

目前看起来似乎行并行和列并行没有什么太大区别,得到的值也是一样的,而且列并行需要把X复制两份分别和A1和A2进行矩阵乘会消耗更多的显存。

但是如果考虑了激活函数呢?

比如要连续过两个激活函数层,例如2层以上的Transformer,每一层都会有一个MLP,就要过一遍Relu或者Gelu函数,我们以Gelu为例:

上面的式子在列并行的情况下:

因为列并行并没有进行任何的输入拆分,所以只要把A激活函数层和B激活函数层划分好,就可以独立计算,在计算出GeLu(GeLu(XAi)Bi)后(i=和X计算的被拆解的子矩阵号,如1,2,3…),最后进行contact就可以,换个说法只要在得到最终结果之前通信一次就行。

如果是行并行呢?

由于GeLu是非线性的函数,所以:

也就是说,在整个计算流程中,每经过一个全连接层,都必须要通过通信来聚合成最终的结果,然后才能进入到下一个层来进行计算,过大的通信量会极大的降低模型的训练速度,增加延迟。

2D/2.5D和3D并行:

其实对于2D并行有两种解释的说法,比如TP+PP,或者TP+DDP都算2D并行的范畴,因为是从两个维度来支持更好的分布式,降低单卡显存和计算压力。

另外一种关于2D并行的解释是专门针对ColossalAI来讲的,一般称为2D张量并行。

一般会把基于Megatron的Tensor方式称为1D并行,1D并行的一个弊端是,对于刚才的函数Y=XA,在计算的过程中,并没有对激活Activation进行划分,导致激活这部分会消耗大量的显存,也就是每块GPU虽然参数被分开了,但是激活还是每块都有,还有一个重要的点是,如果采用1D并行,那么所有的GPU都要和其他的GPU进行通信,all-reduce或者其他的源语通信,通信成本越高,整体训练的性能越差。

基于以上的原因,Colossal-AI引入的2D张量并行的概念。

还是一个简单的函数Y=XA,如果我们拥有P个GPU,P必须满足q的平方,比如拥有4个GPU,那么q就是2,q*q=4,在这个前置条件下,我们把输入X和权重A都拆成q*q的子矩阵,即2个拥有用4个子矩阵的矩阵。

这个计算一共包含q个步骤,如上式而言实际上是2个步,我们首先让X矩阵的第一列和A矩阵的第一行在所有的4个GPU中进行广播即:


然后让上式的每2个子矩阵在相应的GPU上进行矩阵乘,在单位时间里,这个计算是并行的,并且4个GPU的任何一个都没有保存其他GPU的Activation的必要。

同样的在第2步,我们可以得到:


最终我们把Y=XA分解为:

虽然两个大矩阵中间要进行串行操作,但是在大矩阵内部的4个子矩阵都是进行并行的操作。

加入有1万张GPU卡,如果是1D并行的话,其中任意一张卡都要和其他9999个机器通信,而2D并行划分了子单元,每个机器理论上只需要和96个机器进行通信,极大的节省了通信的代偿和开销。

2.5D并行其实就是在2D并行的基础上加了一个维度,如图所示

2.5D并行

还是以Y=XA这个函数为例,这次P个GPU被分解成d*q*q,为了计算流程看起来清楚,假设d=q也为2,所以这个tensor为[2,2,2]。

现在把输入的X划分为d*q*q,来满足P个GPU均匀分布,得到下面式子:

这个式子其实可以被表达为下面两个子矩阵的contact,我们把大矩阵拆解成下面两个子矩阵。

然后权重A被分解成q*q个单位:

对于X的每一层,我们都使用2D算法和A做矩阵乘,就得到以下两个式子:

将这两个式子进行垂直contact就能得到最终的结果。

2.5D主要能进一步优化2D的通信代偿,但是实际生产中使用的不多,仅作为算法让大家理解它的原理。

3D并行在2.5D的基础上,把A矩阵也拆解成d*q*q,或者可以理解为X和X矩阵都被拆解成q*q*q,如下式:

每升一个维度,通信代偿都会得到进一步的下降,看明白原理就可以,在这里就不赘述了。

关于3D并行业界比较通用的解释是立体的并行手段,如图所示:

3D并行

3D并行,目前业界共识的叫法主要是针对用多种并行训练方式来进行训练,如上图所示,在GPU4/12/20/28这个维度,通过流水线将模型切割成不同的stage,然后在每个stage内部,又通过模型并行来进行横向的划分,然后在GPU0和GPU4这之间,因为他们的模型参数都是相同的,所以又可以采用数据并行来增大训练的dataset吞吐量,提升训练速度,这是一个典型的3D并行训练的案例。

Pytorch TORCH.DISTRIBUTED.TENSOR.PARALLEL

不管是 Megatron 还是Colossal 中的张量并行,都是基于 Transformer模型来实现的张量并行,不具备通用性。Pytorch 作为一个通用框架,提出了自己的并行方式,Dtensor,即TORCH.DISTRIBUTED.TENSOR.PARALLEL,可以更简单的在 SPMD(单程序多设备)中进行分布式计算。

Dtensor在Pytorch2.0中被引入。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码