经历了一些对未来选择的思考之后,最近在了解mlsys相关的内容,本文即为对TP的理解和总结,目前网上已经有大量的博文详细介绍了TP的实现细节,本文主要是为了自己未来查阅方便而写的文章,欢迎大家指正。
TP简介#
Tensor Parallelism是在DP, MP之后提出的一个方法,由Magatrion-LM首创。其出发点在于DP, MP仍然需要单卡在计算时凑齐一个完整的layer的参数和各种激活值、梯度、优化器状态,当一个layer过大的时候,单卡就放不下了。
而Tensor Parallelism将模型的计算拆成分布式的了,使得一层能够分布于不同卡上进行计算。
一个经典的Transformer 模型的架构大致如下图:
可以看到,一个layer主要由Attention和MLP层组成,TP的关键优化点也就是在这两层上,下面将具体说明。
MLP#
我们先从MLP层开始,简而言之,一个MLP层的数学描述大致这样:
O u t = D r o p o u t ( G e L U ( X W 1 ) W 2 ) \mathrm{Out} = \mathrm{Dropout}(\mathrm{GeLU}(X W_1) W_2) Out = Dropout ( GeLU ( X W 1 ) W 2 )
其中:
X : ( B , S , d model ) W 1 : ( d model , d ff ) W 2 : ( d ff , d model ) \begin{aligned}
X &: (B, S, d_{\text{model}}) \\
W_1 &: (d_{\text{model}}, d_{\text{ff}}) \\
W_2 &: (d_{\text{ff}}, d_{\text{model}})
\end{aligned} X W 1 W 2 : ( B , S , d model ) : ( d model , d ff ) : ( d ff , d model )
一般来说,d ff = 4 × d model d_{\text{ff}} = 4 \times d_{\text{model}} d ff = 4 × d model
我们先考虑不进行TP,仅仅进行单卡计算:
单卡forward#
参数量:
W 1 : d model × d ff W 2 : d ff × d model 总参数 : 8 d model 2 \begin{aligned}
W_1 &: d_{\text{model}} \times d_{\text{ff}} \\
W_2 &: d_{\text{ff}} \times d_{\text{model}} \\
\text{总参数} &: 8 d_{\text{model}}^2
\end{aligned} W 1 W 2 总参数 : d model × d ff : d ff × d model : 8 d model 2
计算量:
两个矩阵乘均贡献 2 × B × S × d model × d ff F L O P S = 16 × B × S × d model 2 \begin{aligned}
&\text{两个矩阵乘均贡献 } 2 \times B \times S \times d_{\text{model}} \times d_{\text{ff}} \\
&\mathrm{FLOPS} = 16 \times B \times S \times d_{\text{model}}^2
\end{aligned} 两个矩阵乘均贡献 2 × B × S × d model × d ff FLOPS = 16 × B × S × d model 2
激活量:在backward里考虑。
单卡backward#
首先对dropout反向:
∂ L ∂ Z = ∂ L ∂ O u t ⊙ m a s k 1 − p \frac{\partial L}{\partial Z} = \frac{\partial L}{\partial \mathrm{Out}} \odot \frac{\mathrm{mask}}{1 - p} ∂ Z ∂ L = ∂ Out ∂ L ⊙ 1 − p mask
这一步的FLOPS差一个数量级,可忽略不计,另外使用 L X LX L X 表示 ∂ L / ∂ X \partial L / \partial X ∂ L / ∂ X 。
然后对 W 2 W_2 W 2 进行反向:
L W 2 = A T ⋅ L Z L A = L Z ⋅ W 2 \begin{aligned}
LW_2 &= A^T \cdot LZ \\
LA &= LZ \cdot W_2
\end{aligned} L W 2 L A = A T ⋅ L Z = L Z ⋅ W 2
其中 A = G e L U ( ⋯ ) A = \mathrm{GeLU}(\cdots) A = GeLU ( ⋯ ) 。
这一步的FLOPS为 2 × ( 2 × B S d model d ff ) = 16 B S d model 2 2 \times (2 \times B S \, d_{\text{model}} \, d_{\text{ff}}) = 16 \, B S \, d_{\text{model}}^2 2 × ( 2 × BS d model d ff ) = 16 BS d model 2
GeLU的FLOPS几乎也可以忽略不计。
然后对 W 1 W_1 W 1 进行反向,几乎与 W 2 W_2 W 2 相同。
因此整个过程的FLOPS为 32 B S d model 2 32 \, B S \, d_{\text{model}}^2 32 BS d model 2 ,为前向传播的两倍。
然后我们从激活值占用角度分析,在没有梯度检查点的情况下,我们有:
X (BS, d_model) use for compute L_W1
H = XW_1 (BS, d_ff) use for compute the gelu
A = GeLU(H) (BS, d_ff) use for compute L_W2
Dropout mask(BS, d_model) use for Dropout plaintext
TP forward#
我们进行这样的切分方式:
W_1 -> (W_11, W_12, W_13, ... W_1n) # W_1i : (d_model, d_ff / n) python
这样我们在输入X的时候全部注入,然后得到:
H -> ( XW_11 , XW_12 , XW_13 , ... XW_1n ) # XW_1i : (B, S, d_ff / n) python
值得注意的是,我们选择按列切分 W 1 W_1 W 1 使得我们得到的结果是可以独立通过gelu的,省去了这一步通信的麻烦。
之后考虑 W 2 W_2 W 2
我们选择将 W 2 W_2 W 2 进行这样的切分:
W_2 -> [
W_21,
W_22,
W_23,
...
W_2n
] python
之后,显然我们现在可以每张卡计算XW_11 @ W_21,而且他的形状就是最后矩阵的形状,
因此,我们算出来然后最后采用all reduce就可以得到最后结果啦。
ok,我们现在对这整个过程进行分析:
参数量:
显然,我们现在把所有参数分散到了多卡上,而且分散均匀,
W 1 : d model × d ff W 2 : d ff × d model 总参数 : 8 d model 2 / n \begin{aligned}
W_1 &: d_{\text{model}} \times d_{\text{ff}} \\
W_2 &: d_{\text{ff}} \times d_{\text{model}} \\
\text{总参数} &: 8 d_{\text{model}}^2 / n
\end{aligned} W 1 W 2 总参数 : d model × d ff : d ff × d model : 8 d model 2 / n
两个矩阵乘均贡献 2 × B × S × d model × d ff / n F L O P S = 16 × B × S × d model 2 / n \begin{aligned}
&\text{两个矩阵乘均贡献 } 2 \times B \times S \times d_{\text{model}} \times d_{\text{ff}} / n \\
&\mathrm{FLOPS} = 16 \times B \times S \times d_{\text{model}}^2 / n
\end{aligned} 两个矩阵乘均贡献 2 × B × S × d model × d ff / n FLOPS = 16 × B × S × d model 2 / n
但是这里还要考虑一个问题,就是最后reduce-all操作还要对所有激活值进行累加,但是这部分数量级过小,可忽略。
TP backward#
在每张卡上的前向是:
Z i = G e L U ( X W 1 i ) W 2 i , AllReduce → Z = ∑ i Z i Z^i = \mathrm{GeLU}(X W_1^i) W_2^i, \quad \text{AllReduce} \to Z = \sum_i Z^i Z i = GeLU ( X W 1 i ) W 2 i , AllReduce → Z = i ∑ Z i
由于AllReduce之后每张卡上的Z完全相同,所以上游传回的梯度也完全一样,不需要额外通信。
此后,每一步的计算基本上与单张卡相同,但是要除以 N N N 。
因此,每张卡的反向FLOPS为:
F L O P S = 32 × B S d model 2 / N \mathrm{FLOPS} = 32 \times B S \, d_{\text{model}}^2 / N FLOPS = 32 × BS d model 2 / N
然后,之后需要注意的是我们在反向传播的最后仍然需要一步all-reduce,因为我们此前计算的都是独立的梯度。
激活值的占用:我们有:
X (BS, d_model) use for compute L_W1
H = XW_1 (BS, d_ff / N) use for compute the gelu
A = GeLU(H) (BS, d_ff / N) use for compute L_W2
Dropout mask(BS, d_model) use for Dropout plaintext
Attention#
单卡forward#
输入数据:
X (B, S, d_model)
X_h (B, S, h, d_head) W_Q (h, d_head, d_Q) W_K (h, d_head, d_K) W_V(h, d_head, d_V)
# 注意到在单卡情况下我们这一步计算Q, K, V通常不做维度划分,但可以这么理解,方便后续对TP的理解
Q (B, S, h, d_Q) K (B, S, h, d_K) V (B, S, h, d_V)
Q @ K.transpose -> S (B, h, S, S)
S @ V -> (B, h, S, d_V)
reshape -> (B, S, h * d_V)
# 接着引入一个W_O : (h * d_V, d_model)
O (B, S, d_model) python
ok,对整个过程清晰之后我们便可以分析其各个指标:
参数量:
W Q : h × d head × d Q W K : h × d head × d K W V : h × d head × d V \begin{aligned}
W_Q &: h \times d_{\text{head}} \times d_Q \\
W_K &: h \times d_{\text{head}} \times d_K \\
W_V &: h \times d_{\text{head}} \times d_V
\end{aligned} W Q W K W V : h × d head × d Q : h × d head × d K : h × d head × d V
通常来说这几个都相等,所以总参数为 4 d model 2 4 d_{\text{model}}^2 4 d model 2 。
计算量:
操作 形状 FLOPS X → Q , K , V X \to Q, K, V X → Q , K , V ( B S , d model ) × ( d model , d model ) (BS, d_{\text{model}}) \times (d_{\text{model}}, d_{\text{model}}) ( BS , d model ) × ( d model , d model ) 6 B S d model 2 6 \, BS \, d_{\text{model}}^2 6 BS d model 2 Q K T → S Q K^T \to S Q K T → S ( B , h , S , d head ) × ( B , h , d head , S ) (B, h, S, d_{\text{head}}) \times (B, h, d_{\text{head}}, S) ( B , h , S , d head ) × ( B , h , d head , S ) 2 B S 2 d model 2 \, BS^2 d_{\text{model}} 2 B S 2 d model S V → A V SV \to AV S V → A V ( B , h , S , S ) × ( B , h , S , d head ) (B, h, S, S) \times (B, h, S, d_{\text{head}}) ( B , h , S , S ) × ( B , h , S , d head ) 2 B S 2 d model 2 \, BS^2 d_{\text{model}} 2 B S 2 d model A V ⋅ W O AV \cdot W_O A V ⋅ W O ( B , S , d model ) × ( d model , d model ) (B, S, d_{\text{model}}) \times (d_{\text{model}}, d_{\text{model}}) ( B , S , d model ) × ( d model , d model ) 2 B S d model 2 2 \, BS \, d_{\text{model}}^2 2 BS d model 2
因此,总的FLOPS为:
8 B S d model 2 + 4 B S 2 d model 8 \, BS \, d_{\text{model}}^2 + 4 \, BS^2 d_{\text{model}} 8 BS d model 2 + 4 B S 2 d model
需保存的激活值:
Tensor shape num X X X ( B , S , d model ) (B, S, d_{\text{model}}) ( B , S , d model ) B S d model BS \, d_{\text{model}} BS d model Q , K , V Q, K, V Q , K , V ( B , S , d model ) (B, S, d_{\text{model}}) ( B , S , d model ) 3 B S d model 3 \, BS \, d_{\text{model}} 3 BS d model S S S ( B , h , S , S ) (B, h, S, S) ( B , h , S , S ) B h S 2 BhS^2 B h S 2 A V AV A V ( B , h , S , d head ) (B, h, S, d_{\text{head}}) ( B , h , S , d head ) B S d model BS \, d_{\text{model}} BS d model
单卡backward#
首先我们做 W O W_O W O 的反向:
L W O = ( A V ) T ⋅ L O ( d model , S , B ) × ( B , S , d model ) L A V = L O ⋅ W O T ( B , S , d model ) × ( d model , d model ) \begin{aligned}
LW_O &= (AV)^T \cdot LO \quad (d_{\text{model}}, S, B) \times (B, S, d_{\text{model}}) \\
LAV &= LO \cdot W_O^T \quad (B, S, d_{\text{model}}) \times (d_{\text{model}}, d_{\text{model}})
\end{aligned} L W O L A V = ( A V ) T ⋅ L O ( d model , S , B ) × ( B , S , d model ) = L O ⋅ W O T ( B , S , d model ) × ( d model , d model )
加起来是 4 B S d model 2 4 \, BS \, d_{\text{model}}^2 4 BS d model 2
然后我们回到 A V AV A V :
L S = L A V ⋅ V T ( B , h , S , d head ) × ( B , h , d head , S ) L V = S T ⋅ L A V ( B , h , S , S ) × ( B , h , S , d head ) \begin{aligned}
LS &= LAV \cdot V^T \quad (B, h, S, d_{\text{head}}) \times (B, h, d_{\text{head}}, S) \\
LV &= S^T \cdot LAV \quad (B, h, S, S) \times (B, h, S, d_{\text{head}})
\end{aligned} L S L V = L A V ⋅ V T ( B , h , S , d head ) × ( B , h , d head , S ) = S T ⋅ L A V ( B , h , S , S ) × ( B , h , S , d head )
这一步的FLOPS为 4 B S 2 d model 4 \, BS^2 d_{\text{model}} 4 B S 2 d model 。
然后经过Softmax反向,FLOPS可以忽略,然后计算Q, K,FLOPS为 4 B S 2 d model 4 \, BS^2 d_{\text{model}} 4 B S 2 d model 。
之后对 W W W 做反向,权重梯度和输入梯度均为 2 B S d model 2 2 \, BS \, d_{\text{model}}^2 2 BS d model 2 ,共计为12。
因此,总反向FLOPS为:
16 B S d model 2 + 8 B S 2 d model 16 \, BS \, d_{\text{model}}^2 + 8 \, BS^2 d_{\text{model}} 16 BS d model 2 + 8 B S 2 d model
为前向的两倍,所以我们在这里也可以认为反向传播的FLOPS为前向的两倍。
显然这时我们就可以完全将head分到多张卡上,所有的几乎均乘上一个 1 N \frac{1}{N} N 1 即可。
但此时仍然需要注意的是,我们得到O之后仍然需要all-reduce,这与mlp是一样的。
先写到这里.