FlashDecoding++: Faster Large Language Model Inference on GPUs: Asynchronized Softmax with Unified

cover
15 Feb 2024

This paper is available on arxiv under CC BY-NC-SA 4.0 DEED license.

Authors:

(1) Ke Hong, Tsinghua University & Infinigence-AI;

(2) Guohao Dai, Shanghai Jiao Tong University & Infinigence-AI;

(3) Jiaming Xu, Shanghai Jiao Tong University & Infinigence-AI;

(4) Qiuli Mao, Tsinghua University & Infinigence-AI;

(5) Xiuhong Li, Peking University;

(6) Jun Liu, Shanghai Jiao Tong University & Infinigence-AI;

(7) Kangdi Chen, Infinigence-AI;

(8) Yuhan Dong, Tsinghua University;

(9) Yu Wang, Tsinghua University.

3 Asynchronized Softmax with Unified Maximum Value

Motivation. The partial softmax operation requires synchronization among different partial vectors, leading to ∼20% overheads of the attention operation. As is shown in Figure 2, the synchronization is required after the maximum value of the partial vector is calculated. The maximum value is used to update previous partial softmax (i.e., recompute previous attention) results. Thus, to reduce synchronization overheads, the key problem to be solved is how to compute each partial softmax result without requiring results from other partial softmax computation.

Figure 5: The statistical distribution of xi (elements in the input vectors of softmax) in typical LLMs with differentinputs.

Challenge. The reason that synchronization is required lies in that the maximum value of each partial vector is different. The maximum value is used to avoid overflow of the exponent operation (f(x) in Figure 4(a)), and exponents are summed (l(x) in Figure 4(a)) as the denominator of the softmax operation. Such a non-linear operation on each partial maximum value makes the synchronization among each partial softmax computation unavoidable.

Analysis and Insights. According to the formula of softmax computation, the maximum value is used as the scaling factor for both the numerator and the denominator (i.e., f(x) and l(x) in Figure 4(a)). Our key insight is, the scaling factor can be an arbitrary number rather than using the maximum value mathematically, shown in Equation (3). When we set ϕ = 0, it becomes the original softmax computation [30].

However, the scaling factor cannot be an arbitrary number considering the overflowing of the exponent computation. For the case where xi ≫ ϕ, e xi−ϕ overflows and cannot be represented using a fix-width floating point number (e.g., float32 for exponent results in current LLM engines). For another case where xi ≪ ϕ, e xi−ϕ → 0, leading to precision loss. Thus, a proper scaling factor ϕ should be carefully selected to avoid the two cases above. Figure 5 shows the statistical distribution of xi (elements in the input vectors of softmax) in typical LLMs with different inputs [31]. Our key insight is, > 99.99% xi are within a certain range. Specifically, for Llama2-7B, we have −16.8 < xi < 6.5 for > 99.99% xi . Because e b−a and e a−b can be represented by a float32 format, we can set ϕ = a in Equation (3). For OPT-6.7B, we do not apply the technique in this section because of the large range in Figure 5.

Approach: Asynchronization. Based on the insights above, each partial softmax computation shares a unified maximum value, ϕ. After the softmax operation, an inner product operation is executed between the softmax result and a column of V (i.e., v). Assume that the input vector x can be divided into p partial vectors, x = [x (1), ..., x(p) ] (v = [v (1), ..., v(p) ] correspondingly), we have:

The inner accumulation in both the numerator and the denominator only take the partial vectors x (j) and v (j) as input, thus they can be processed asynchronously and individually. The outer accumulation is only processed after all partial vectors are processed. As we can see in Figure 4(c), each f(x (j) ) is calculated individually, and sof tmax(x) is calculated after all x (j) is calculated.

Figure 6: Example of asynchronized partial softmax computation. (a) Each partial softmax result is process individuallywithout the synchronized update. (b) The recomputation process for all parital softmax computation is required when

Approach: Recomputation. Without loss of generality, we assume a < xi − ϕ < b for each xi to ensure precision and avoid overflow. Then, the partial softmax operation is processed individually. However, when xi − ϕ ≤ a or xi − ϕ ≥ b, the asynchronized partial softmax computation is terminated for the vector x where xi belongs to. The softmax is then recomputed using the synchronized partial softmax scheme (used in FlashAttention [18, 19] and FlashDecoding [13]) shown in Figure 4(b). Such a recomputation scheme avoids overflow while introducing negligible overheads based on the statistical data shown in Figure 5.