• <li id="00i08"><input id="00i08"></input></li>
  • <sup id="00i08"><tbody id="00i08"></tbody></sup>
    <abbr id="00i08"></abbr>
  • 博客專欄

    EEPW首頁 > 博客 > 手撕大模型|KVCache 原理及代碼解析

    手撕大模型|KVCache 原理及代碼解析

    發布人:地平線開發者 時間:2025-09-13 來源:工程師 發布文章

    在大型語言模型(LLM)的推理過程中,KV Cache 是一項關鍵技術,它通過緩存中間計算結果顯著提升了模型的運行效率。本文將深入解析 KV Cache 的工作原理、實現方式,并通過代碼示例展示其在實際應用中的效果。

    一、為什么需要 KV Cache?

    在 Transformer 進行自回歸推理(如文本生成,每次生成一個 token 的時候需要結合前面所有的 token 做 attention 操作)時,計算注意力機制時需要存儲 Key(K) 和 Value(V),以便下一個時間步可以復用這些緩存,而不必重新計算整個序列。

    在標準 Transformer 解碼時,每次生成新 token 時:

    • 需要 重新計算所有之前 token 的 K 和 V,并與當前 token 進行注意力計算。

    • 計算復雜度是 O(n2)(對于長度為 n 的序列)。

    img

    而 KV Cache 通過存儲 K 和 V 的歷史值,避免重復計算:

    • 只需計算 新 token 的 K 和 V,然后將其與緩存的值結合使用。

    • 計算復雜度下降到 O(n)(每個 token 只與之前緩存的 token 計算注意力)。

    二、KV Cache 的工作原理

    KV Cache 的核心思想是緩存歷史計算中的鍵(Key)和值(Value)矩陣,避免重復計算。具體來說:

    1. 在生成第一個 token 時,模型計算并緩存所有輸入 token 的 K 和 V 矩陣

    2. 生成后續 token 時,只需要計算新 token 的查詢(Query)矩陣

    3. 將新的 Q 矩陣與緩存的 K、V 矩陣進行注意力計算,同時將新 token 的 K、V 追加到緩存中

    這個過程可以用偽代碼直觀展示:

    初始輸入: [t0, t1, t2]
    首次計算: K=[K0,K1,K2], V=[V0,V1,V2] → 生成t3
    緩存狀態: K=[K0,K1,K2], V=[V0,V1,V2]
    第二次計算: 新Q=Q3
    注意力計算: Attention(Q3, [K0,K1,K2]) → 生成t4
    更新緩存: K=[K0,K1,K2,K3], V=[V0,V1,V2,V3]
    第三次計算: 新Q=Q4
    注意力計算: Attention(Q4, [K0,K1,K2,K3]) → 生成t5
    更新緩存: K=[K0,K1,K2,K3,K4], V=[V0,V1,V2,V3,V4]
    ...

    通過這種方式,每次新生成 token 時,只需計算新的 Q 矩陣并與歷史 KV 矩陣進行注意力計算,將時間復雜度從 O (n2) 降低到 O (n),極大提升了長序列生成的效率。

    下面,我們結合示意圖進一步剖析一下 KV Cache 部分的邏輯。

    img

    img

    img

    img

    KV Cache 核心節約的時間有三大塊:

    1. 前面 n-1 次的 Q 的計算,當然這塊對于一次一個 token 的輸出本來也沒有用;

    2. 同理還有 Attention 計算時對角矩陣變為最后一行,和 b 是同理的,這樣 mask 矩陣也就沒有什么用了;

    3. 前面 n-1 次的 K 和 V 的計算,也就是上圖紫色部分,這部分是實打實被 Cache 過不需要再重新計算的部分。

    這里還有個 softmax 的問題,softmax 原本就是針對同一個 query 的所有 key 的計算,所以并不受影響。

    2.1 KV Cache 的技術細節
    1. 緩存結構

    KV Cache 通常為每個注意力頭維護獨立的緩存,結構如下:

    1. Key 緩存:形狀為 [batch_size, num_heads, seq_len, head_dim]

    2. Value 緩存:形狀為 [batch_size, num_heads, seq_len, head_dim]

    其中,seq_len 會隨著生成過程動態增長,直到達到模型最大序列長度限制。

    1. 內存與速度的權衡

    KV Cache 雖然提升了速度,但需要額外的內存存儲緩存數據。以 GPT-3 175B 模型為例,每個 token 的 KV 緩存約占用 20KB 內存,當生成 1000 個 token 時,單個樣本就需要約 20MB 內存。在批量處理時,內存消耗會線性增加。

    實際應用中需要根據硬件條件在以下方面進行權衡:

    1. 最大緩存長度(影響能處理的序列長度)

    2. 批量大?。ㄓ绊懖l處理能力)

    3. 精度選擇(FP16 比 FP32 節省一半內存)

    4. 滑動窗口機制

    當處理超長序列時,一些模型(如 Llama 2)采用滑動窗口機制,只保留最近的 N 個 token 的 KV 緩存,以控制內存占用。這種機制在犧牲少量上下文信息的情況下,保證了模型能處理更長的對話。

    四、代碼實現解析

    下面以 PyTorch 為例,展示 KV Cache 在自注意力計算中的實現方式。

    1. 基礎自注意力實現(無緩存)

    首先看一下標準的自注意力計算,沒有緩存機制:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    class SelfAttention(nn.Module):
        def __init__(self, embed_dim, num_heads):
            super().__init__()
            self.embed_dim = embed_dim
            self.num_heads = num_heads
            self.head_dim = embed_dim // num_heads
            
            # 定義Q、K、V投影矩陣
            self.q_proj = nn.Linear(embed_dim, embed_dim)
            self.k_proj = nn.Linear(embed_dim, embed_dim)
            self.v_proj = nn.Linear(embed_dim, embed_dim)
            self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        def forward(self, x):
            batch_size, seq_len, embed_dim = x.shape
            
            # 計算Q、K、V
            q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            
            # 計算注意力分數
            attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
            attn_probs = F.softmax(attn_scores, dim=-1)
            
            # 應用注意力權重
            output = attn_probs @ v
            output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
            
            return self.out_proj(output)
    1. 帶 KV Cache 的自注意力實現

    下面修改代碼,加入 KV Cache 機制:

    class CachedSelfAttention(nn.Module):
        def __init__(self, embed_dim, num_heads):
            super().__init__()
            self.embed_dim = embed_dim
            self.num_heads = num_heads
            self.head_dim = embed_dim // num_heads
            
            # 定義投影矩陣
            self.q_proj = nn.Linear(embed_dim, embed_dim)
            self.k_proj = nn.Linear(embed_dim, embed_dim)
            self.v_proj = nn.Linear(embed_dim, embed_dim)
            self.out_proj = nn.Linear(embed_dim, embed_dim)
            
            # 初始化緩存
            self.cache_k = None
            self.cache_v = None
        
        def forward(self, x, use_cache=False):
            batch_size, seq_len, embed_dim = x.shape
            
            # 計算Q、K、V
            q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            
            # 如果使用緩存且緩存存在,則拼接歷史KV
            if use_cache and self.cache_k is not None:
                k = torch.cat([self.cache_k, k], dim=-2)
                v = torch.cat([self.cache_v, v], dim=-2)
            
            # 如果使用緩存,更新緩存
            if use_cache:
                self.cache_k = k
                self.cache_v = v
            
            # 計算注意力分數(注意這里的k是包含歷史緩存的)
            attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
            attn_probs = F.softmax(attn_scores, dim=-1)
            
            # 應用注意力權重
            output = attn_probs @ v
            output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
            
            return self.out_proj(output)
        
        def reset_cache(self):
            """重置緩存,用于新序列的生成"""
            self.cache_k = None
            self.cache_v = None
    1. 生成過程中的緩存使用

    在文本生成時,我們可以這樣使用帶緩存的注意力機制:

    def generate_text(model, input_ids, max_length=50):
        # 初始化模型緩存
        model.reset_cache()
        
        # 處理初始輸入
        output = model(input_ids, use_cache=True)
        next_token = torch.argmax(output[:, -1, :], dim=-1, keepdim=True)
        generated = [next_token]
        
        # 生成后續token
        for _ in range(max_length - 1):
            # 只輸入新生成的token
            output = model(next_token, use_cache=True)
            next_token = torch.argmax(output[:, -1, :], dim=-1, keepdim=True)
            generated.append(next_token)
            
            # 如果生成結束符則停止
            if next_token.item() == 102:  # 假設102是[SEP]的id
                break
        
        return torch.cat(generated, dim=1)
    五、KV Cache 的優化策略

    在實際部署中,為了進一步提升 KV Cache 的效率,還會采用以下優化策略:

    1. 分頁 KV Cache(Paged KV Cache):借鑒內存分頁機制,將連續的 KV 緩存分割成固定大小的塊,提高內存利用率,代表實現有 vLLM。

    2. 動態緩存管理:根據輸入序列長度動態調整緩存大小,在批量處理時優化內存分配。

    3. 量化緩存:使用 INT8 或 INT4 等低精度格式存儲 KV 緩存,在犧牲少量精度的情況下大幅減少內存占用。

    4. 選擇性緩存:對于一些不重要的層或注意力頭,選擇性地不進行緩存,平衡速度和內存。

    六、總結

    KV Cache 通過緩存中間計算結果,有效解決了 Transformer 模型在生成式任務中的效率問題,是大模型能夠實現實時交互的關鍵技術之一。理解 KV Cache 的工作原理和實現方式,對于優化大模型推理性能、解決實際部署中的挑戰具有重要意義。

    七、參考鏈接

    https://zhuanlan.zhihu.com/p/670515231

    https://zhuanlan.zhihu.com/p/714288577

    https://zhuanlan.zhihu.com/p/715921106https://zhuanlan.zhihu.com/p/19489285169

    https://medium.com/@joaolages/kv-caching-explained-276520203249


    *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



    相關推薦

    技術專區

    關閉
    主站蜘蛛池模板: 托克逊县| 武宣县| 崇文区| 烟台市| 阳春市| 东莞市| 沙田区| 拉孜县| 安义县| 迭部县| 扶余县| 汝州市| 景德镇市| 定南县| 夏邑县| 民乐县| 江津市| 高要市| 洛宁县| 崇左市| 灵丘县| 华宁县| 集安市| 岱山县| 嵩明县| 且末县| 清远市| 红安县| 晋州市| 怀宁县| 嘉善县| 祁门县| 旌德县| 呼和浩特市| 西平县| 静安区| 南澳县| 通河县| 福鼎市| 盈江县| 安西县|