学习笔记:Tensor Parallelism(TP)

引言

经历了一些对未来选择的思考之后,最近在了解 mlsys 相关的内容,本文即为对 TP 的理解和总结,目前网上已经有大量的博文详细介绍了 TP 的实现细节,本文主要是为了自己未来查阅方便而写的文章,欢迎大家指正。

TP简介

Tensor Parallelism 是在 DP, MP 之后提出的一个方法,由 Magatrion-LM 首创。其出发点在于 DP, MP 仍然需要单卡在计算时凑齐一个完整的 layer 的参数和各种激活值、梯度、优化器状态,当一个 layer 过大的时候,单卡就放不下了。
而 Tensor Parallelism 将模型的计算拆成分布式的了,使得一层能够分布于不同卡上进行计算。

Transformer-like model

一个经典的 Transformer 模型的架构大致如下图:

transformerarch

可以看到,一个 layer 主要由 Attention 和 MLP 层组成, TP 的关键优化点也就是在这两层上,下面将具体说明。

MLP

我们先从 MLP 层开始,简而言之,一个 MLP 层的数学描述大致这样:

其中:

一般来说,

我们先考虑不进行 TP ,仅仅进行单卡计算:

单卡forward

参数量:

计算量:

激活量:在 backward 里考虑。

单卡backward

首先对 dropout 反向:

这一步的 FLOPS 差一个数量级,可忽略不计,另外使用 表示

然后对 进行反向:

其中

这一步的 FLOPS 为

GeLU 的 FLOPS 几乎也可以忽略不计。

然后对 进行反向,几乎与 相同。

因此整个过程的 FLOPS 为 ,为前向传播的两倍。

然后我们从激活值占用角度分析,在没有梯度检查点的情况下,我们有:

1
2
3
4
X (BS, d_model) use for compute L_W1
H = XW_1 (BS, d_ff) use for compute the gelu
A = GeLU(H) (BS, d_ff) use for compute L_W2
Dropout mask(BS, d_model) use for Dropout

TP forward

sd

我们进行这样的切分方式:

1
W_1 -> (W_11, W_12, W_13, ... W_1n) # W_1i: (d_model, d_ff / n) 

这样我们在输入 X 的时候全部注入,然后得到:

1
H -> (XW_11, XW_12, XW_13, ... XW_1n) # XW_1i: (B, S, d_ff / n)

值得注意的是,我们选择按列切分 使得我们得到的结果是可以独立通过 gelu 的,省去了这一步通信的麻烦。

之后考虑

我们选择将 进行这样的切分:

1
2
3
4
5
6
7
W_2 -> [
W_21,
W_22,
W_23,
...
W_2n
]

之后,显然我们现在可以每张卡计算XW_11 @ W_21,而且他的形状就是最后矩阵的形状,
因此,我们算出来然后最后采用 all reduce 就可以得到最后结果啦。

ok ,我们现在对这整个过程进行分析:

  • 参数量:
    显然,我们现在把所有参数分散到了多卡上,而且分散均匀,

  • 计算量:

但是这里还要考虑一个问题,就是最后 reduce-all 操作还要对所有激活值进行累加,但是这部分数量级过小,可忽略。
  • 激活量:在 backward 里考虑。

TP backward

在每张卡上的前向是:

由于 AllReduce 之后每张卡上的 Z 完全相同,所以上游传回的梯度也完全一样,不需要额外通信。

此后,每一步的计算基本上与单张卡相同,但是要除以

因此,每张卡的反向 FLOPS 为:

然后,之后需要注意的是我们在反向传播的最后仍然需要一步 all-reduce ,因为我们此前计算的都是独立的梯度。

激活值的占用:我们有:

1
2
3
4
X (BS, d_model) use for compute L_W1
H = XW_1 (BS, d_ff / N) use for compute the gelu
A = GeLU(H) (BS, d_ff / N) use for compute L_W2
Dropout mask(BS, d_model) use for Dropout

Attention

单卡forward

输入数据:

1
2
3
4
5
6
7
8
9
10
X (B, S, d_model) 
X_h (B, S, h, d_head) W_Q (h, d_head, d_Q) W_K (h, d_head, d_K) W_V(h, d_head, d_V)
# 注意到在单卡情况下我们这一步计算 Q, K, V 通常不做维度划分,但可以这么理解,方便后续对 TP 的理解
Q (B, S, h, d_Q) K (B, S, h, d_K) V (B, S, h, d_V)
Q @ K.transpose -> S (B, h, S, S)
S @ V -> (B, h, S, d_V)
reshape -> (B, S, h *d_V)

# 接着引入一个 W_O: (h * d_V, d_model)
O (B, S, d_model)

ok ,对整个过程清晰之后我们便可以分析其各个指标:

参数量:

通常来说这几个都相等,所以总参数为

计算量:

操作 形状 FLOPS

因此,总的 FLOPS 为:

需保存的激活值:

Tensor shape num

单卡backward

首先我们做 的反向:

加起来是

然后我们回到

这一步的 FLOPS 为

然后经过 Softmax 反向, FLOPS 可以忽略,然后计算 Q, K , FLOPS 为

之后对 做反向,权重梯度和输入梯度均为 ,共计为 12 。

因此,总反向 FLOPS 为:

为前向的两倍,所以我们在这里也可以认为反向传播的 FLOPS 为前向的两倍。

TP

显然这时我们就可以完全将 head 分到多张卡上,所有的几乎均乘上一个 即可。

但此时仍然需要注意的是,我们得到 O 之后仍然需要 all-reduce ,这与 mlp 是一样的。

先写到这里.