vLLM v1 Speculative Decoding 模块超深度分析

分析对象:vllm/vllm/v1/spec_decode
代码规模:11 Python 文件,4,548 行有效代码


第一章 模块定位与架构概览

在这里插入图片描述

一、模块定位

1. 业务职责

vLLM v1 Speculative Decoding 模块是推测解码(Speculative Decoding)的提议端实现,负责在目标模型(Target Model)正式解码之前,快速生成"草稿"token序列供目标模型批量验证。核心职责包括:

  1. 多种推测策略:支持7种推测方法(Eagle、DraftModel、Medusa、DFlash、Ngram、NgramGPU、SuffixDecoding)
  2. 草稿token生成:每个请求生成1~k个draft tokens,与目标模型的验证结果对比
  3. 隐藏状态提取:从目标模型中间层提取hidden_states,供Eagle/DFlash/Medusa等需要的方法使用
  4. 性能指标采集:推测接受率、吞吐量、逐位置接受率等关键指标
  5. KV Cache管理:草稿token的slot mapping更新、attention metadata构建
2. 在系统中的位置
Scheduler → Worker → ModelRunner
                       ├── TargetModel.forward() → hidden_states + logits
                       ├── [Proposer.propose()] → draft_tokens
                       ├── TargetModel.forward(draft+base) → verification logits
                       ├── Sampler → verify/reject draft tokens
                       └── [SpecDecodeMetrics] → 统计上报

Proposer位于目标模型前向传播与采样之间

  • 上游:ModelRunner调用propose()获取draft tokens
  • 下游:Sampler验证draft tokens,接受/拒绝后更新KV cache
3. 核心业务价值
价值 说明
推理加速 接受k个draft + 1个bonus = 1步产生k+1个token → 2-3x加速
无损质量 目标模型验证确保接受token的概率分布等价于自回归
灵活策略 从无模型的n-gram到强模型的Eagle,适应不同场景
资源效率 n-gram/suffix无需额外模型 → 零额外显存开销

二、模块文件全景

文件 行数 核心类/函数
__init__.py 0 空包标识
metadata.py 66 SpecDecodeMetadata dataclass
medusa.py 78 MedusaProposer
draft_model.py 88 DraftModelProposer
suffix_decoding.py 101 SuffixDecodingProposer
metrics.py 215 SpecDecodingStats/Logging/Prom
dflash.py 282 DFlashProposer
ngram_proposer.py 285 NgramProposer + Numba LPS
extract_hidden_states.py 382 ExtractHiddenStatesProposer
utils.py 596 Triton kernels + 辅助函数
ngram_proposer_gpu.py 662 NgramProposerGPU + torch.compile
eagle.py 1,793 SpecDecodeBaseProposer + EagleProposer

第二章 metadata.py — 推测解码元数据(66行)

一、模块定位

SpecDecodeMetadata 是推测解码的核心索引数据结构,告诉 Sampler 如何从展平的 logits 中提取每个draft token和bonus token对应的logits。

二、逐行深度解析

@dataclass
class SpecDecodeMetadata:
    # [num_tokens]
    draft_token_ids: torch.Tensor
  • draft_token_ids:所有请求的draft token ID展平拼接,形状 [num_tokens]
  • num_tokens = sum(num_draft_tokens_per_request)
  • 设计意图:展平存储 → 避免不规则2D tensor → GPU友好
    # [batch_size]
    num_draft_tokens: list[int]
  • 每个请求的draft token数量,长度=批次大小
  • 注意:用 list[int] 而非Tensor → CPU侧标量操作更方便
  • 不同请求可有不同数量的draft(特别是SuffixDecoding动态长度)
    cu_num_draft_tokens: torch.Tensor  # [batch_size]
  • draft token数量的前缀和(cumulative sum),GPU int32
  • 例:num_draft_tokens=[5,3,4]cu_num=[5,8,12]
  • 用途:从展平的 draft_token_ids 中按请求切片
    cu_num_sampled_tokens: torch.Tensor  # [batch_size]
  • 采样token数量(draft+1)的前缀和
  • num_sampled = num_draft + 1(每个draft后多采一个bonus)
  • 例:[5,3,4] → sampled=[6,4,5] → cu=[6,10,15]
    target_logits_indices: torch.Tensor  # [num_tokens]
  • 指向target model logits中每个draft token前一个位置的索引
  • 推理:draft[i]的验证需要对比 target 在位置i的预测 → target_logits的索引i
  • 这些索引从target model前向的输出logits中选取对应位置
    bonus_logits_indices: torch.Tensor  # [batch_size]
  • 每个请求的bonus token对应的logits索引
  • bonus token:验证后额外采样的1个token(接受最后一个draft后自然产生的token)
    logits_indices: torch.Tensor  # [num_tokens + batch_size]
  • target_logits_indices + bonus_logits_indices 的合并索引
  • 设计意图:Sampler一次采样需要所有位置的logits → 合并后一次gather
    def __post_init__(self):
        self.max_spec_len = max(self.num_draft_tokens)
  • 自动计算最大推测长度
  • 用途:Tree Attention需要知道最大的draft树深度
make_dummy 类方法
    @classmethod
    def make_dummy(cls, draft_token_ids: list[list[int]], device) -> "SpecDecodeMetadata":
  • 设计意图:构建全零占位metadata,用于CUDA Graph捕获
  • CUDA Graph要求固定的tensor形状 → make_dummy提供一致形状
  • 所有索引字段填零(因为CUDA Graph捕获时不执行实际验证逻辑)

逐行计算流程

  1. batch_size = len(draft_token_ids) — 请求数量
  2. num_draft_tokens = [len(ids) for ids in draft_token_ids] — 每请求draft数
  3. num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids] — 每请求采样数
  4. flattened_draft_token_ids = sum(draft_token_ids, []) — 展平
  5. cumsum计算 → GPU tensor
  6. 全零索引tensor → 返回

第三章 medusa.py — Medusa多头推测(78行)

逐行解析

Medusa是单步多头部并行推测方法:在target model的hidden_states上,添加多个独立的线性头(Medusa heads),每个头独立预测下一个token。

class MedusaProposer:
    def __init__(self, vllm_config, device):
        self.spec_config = vllm_config.speculative_config
        self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
  • hidden_size:从draft_model_config获取 → Medusa头的输入维度需匹配target hidden_size
    def propose(self, target_hidden_states, sampling_metadata, slot_mappings=None):
        blocks = self.model(target_hidden_states)
        logits = self.model.compute_logits(blocks)
        draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
        return draft_tokens

逐行解析

  1. self.model(target_hidden_states) — Medusa头前向传播,输入target的hidden_states
  2. compute_logits(blocks) — 每个头计算vocab维度的logits
  3. [logit.argmax(dim=-1) for logit in logits] — 每个头取argmax → 贪心解码
  4. torch.stack(..., dim=1) — 形状 [batch, num_heads]
  • 设计意图:单次前向 → 多头并行 → 无自回归 → 低延迟但精度较低
  • 与Eagle的区别:Medusa各头独立,无attention → 更快但接受率更低
    def load_model(self, target_model):
        with set_model_tag("medusa_head"):
            self.model = get_model(vllm_config=self.vllm_config, model_config=self.spec_config.draft_model_config)
  • set_model_tag("medusa_head") — 标记模型标签,用于torch.compile缓存隔离
  • EPLB断言:MoE + EPLB不支持Medusa

第四章 draft_model.py — 独立草稿模型推测(88行)

逐行解析

DraftModelProposer使用一个完全独立的小模型作为草稿生成器,自回归地生成draft tokens。

class DraftModelProposer(SpecDecodeBaseProposer):
    def __init__(self, vllm_config, device, runner=None):
        super().__init__(pass_hidden_states_to_model=False)
  • pass_hidden_states_to_model=False — 草稿模型不使用target的hidden_states
  • 原因:独立模型有自己的embedding层和transformer → 不需要target的中间表示
    def _raise_if_vocab_size_mismatch(self):
        self.speculative_config.verify_equal_vocab_size_if_draft_model()
  • 词汇表大小必须匹配 → 否则draft token ID在target模型中无效
    def _raise_if_draft_tp_mismatch(self):
        tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
        draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
        if draft_tp != tgt_tp:
            raise ValueError(...)
  • TP必须匹配的原因:不同TP rank编译draft model时,torch compile缓存会被覆盖损坏
  • 需要类似PR#5414的rank-aware compile缓存机制
    def _create_draft_vllm_config(self):
        return replace(base, quant_config=None, parallel_config=..., model_config=spec.draft_model_config)
  • quant_config=None — 草稿模型不量化(保持精度)
  • 使用replace()创建新配置 → 不可变配置模式
    def _maybe_share_embeddings(self, target_language_model): pass
    def _maybe_share_lm_head(self, target_language_model): pass
  • 空实现:独立草稿模型不与target共享embedding/lm_head
  • 对比Eagle:Eagle共享embedding和lm_head → 更高效

第五章 suffix_decoding.py — 后缀树推测(101行)

逐行解析

SuffixDecodingProposer使用后缀树模式匹配进行推测,无需任何额外模型。

class SuffixDecodingProposer:
    def __init__(self, vllm_config):
        self.num_speculative_tokens = config.num_speculative_tokens
        self.max_tree_depth = config.suffix_decoding_max_tree_depth
        self.max_spec_factor = config.suffix_decoding_max_spec_factor
        self.min_token_prob = config.suffix_decoding_min_token_prob
  • max_tree_depth:后缀树最大深度 → 限制搜索范围
  • max_spec_factor:推测token数上限因子 → 防止过长推测
  • min_token_prob:最小token概率阈值 → 过滤低置信度draft
        from arctic_inference.suffix_decoding import SuffixDecodingCache
        self.suffix_cache = SuffixDecodingCache(
            max_tree_depth=..., max_cached_requests=...
        )
  • 延迟导入:arctic_inference可能未安装 → 仅使用时才报错
  • SuffixDecodingCache:Snowflake Arctic Inference的官方实现
    def propose(self, input_batch, sampled_token_ids, slot_mappings=None):

逐行解析核心逻辑

  1. 遍历每个请求的sampled_token_ids
  2. 跳过partial prefillif not sampled_ids: continue — 未完成预填充的请求不推测
  3. 跳过长序列if num_tokens >= max_model_len: continue
  4. 请求生命周期管理
    • 新请求:start_request(req_id, prompt_token_ids) → 构建prompt的后缀树
    • 已缓存请求:evict_cached_response(req_id) → 重置后重新开始
    • 活跃请求:add_active_response(req_id, sampled_ids) → 追加新token到缓存
  5. 模式提取pattern = token_ids_cpu[start:num_tokens] — 取最近max_tree_depth个token
  6. 推测suffix_cache.speculate(req_id, pattern, max_spec_tokens, ...) → 返回draft
  7. 清理:不在当前batch的活跃请求 → stop_request(req_id)

关键设计:动态推测长度 — 每个请求的draft数量不同 → SpecDecodeMetadata.num_draft_tokens 是list[int]


第六章 metrics.py — 推测解码指标体系(215行)

在这里插入图片描述

逐行解析

6.1 SpecDecodingStats — 逐步统计
@dataclass
class SpecDecodingStats:
    num_spec_tokens: int
    num_drafts: int = 0           # 推测次数
    num_draft_tokens: int = 0     # 总draft token数
    num_accepted_tokens: int = 0  # 总接受token数
    num_accepted_tokens_per_pos: list[int] = field(default_factory=list)  # 逐位置

逐位置统计的设计意图

  • 位置0的接受率通常最高(第一个draft最容易正确)
  • 位置越后接受率越低 → 帮助调优num_speculative_tokens
  • 例:5位置统计 [0.85, 0.72, 0.60, 0.48, 0.35] → 前3位置有价值,后2位置可能不划算
    def observe_draft(self, num_draft_tokens, num_accepted_tokens):
        self.num_drafts += 1
        self.num_draft_tokens += num_draft_tokens
        self.num_accepted_tokens += num_accepted_tokens
        for i in range(num_accepted_tokens):
            self.num_accepted_tokens_per_pos[i] += 1
  • assert num_accepted_tokens <= self.num_spec_tokens — 接受数不超过最大推测长度
  • 逐位置递增:前N个位置各+1 → N=接受数
6.2 SpecDecodingLogging — 时间窗口聚合
    def log(self, log_fn=logger.info):
        mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
  • 关键公式1 + accepted/drafts — 包含bonus token
  • 直觉:每次推测至少产生1个token(bonus),加上接受的draft → 平均长度 >= 2
        acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
  • 2D矩阵每列求和/总draft数 → 每个位置的独立接受率
6.3 SpecDecodingProm — Prometheus导出
  • 4个Counter:drafts、draft_tokens、accepted_tokens、accepted_per_pos
  • per_pos使用position label → PromQL向量查询
  • PromQL示例rate(accepted_total[5m]) / rate(draft_total[5m]) = 接受率时序

第七章 ngram_proposer.py — CPU N-gram推测(285行)

在这里插入图片描述

核心算法:LPS (Longest Proper Prefix which is also Suffix)

这是KMP字符串匹配算法的变体,在翻转的token序列上寻找最长n-gram匹配。

7.1 NgramProposer类
class NgramProposer:
    def __init__(self, vllm_config):
        self.min_n = config.prompt_lookup_min  # 最小n-gram长度
        self.max_n = config.prompt_lookup_max  # 最大n-gram长度
        self.k = config.num_speculative_tokens # 提取的draft数量

Numba线程控制

        self.num_numba_thread_available = min(1, (cpu_count // 2)) // tp_size
  • 当前硬编码为1线程 → TODO(ekagra-ranjan)提升到8
  • 除以tp_size:所有TP rank都会运行此代码 → 每rank分得部分线程

JIT预热

        self.propose([[]] * 1024, np.zeros(1024), np.zeros((1024, max_model_len)))
  • 首次调用触发Numba JIT编译 → 后续调用直接执行编译后的机器码
7.2 batch_propose — 批量推测
    def batch_propose(self, num_requests, valid_ngram_requests, num_tokens_no_spec, token_ids_cpu):
        if num_ngram_requests := len(valid_ngram_requests):
            original_num_numba_threads = get_num_threads()
            total_tokens = np.sum(num_tokens_no_spec)
            if total_tokens >= self.num_tokens_threshold:
                set_num_threads(max(1, min(self.num_numba_thread_available, num_ngram_requests)))
            else:
                set_num_threads(1)  # 小batch不开多线程(开销大于收益)
            batch_propose_numba(...)
            set_num_threads(original_num_numba_threads)  # 恢复

自适应线程数

  • num_tokens_threshold = 8192 — 总token数低于此值时单线程
  • 多线程有启动开销 → 小batch反而更慢
7.3 _find_longest_matched_ngram_and_propose_tokens — LPS核心
@jit(nopython=True)
def _find_longest_matched_ngram_and_propose_tokens(origin_tokens, min_ngram, max_ngram, max_model_len, k):
    # 翻转tokens → 找最长前缀(= 原序列的最长后缀)
    tokens = origin_tokens[::-1]
    
    lps = np.zeros(max_ngram, dtype=np.int32)
    longest_ngram = 0
    position = 0
    prev_lps = 0
    i = 1
    while i < total_token:
        if tokens[prev_lps] == tokens[i]:
            prev_lps += 1
            if prev_lps >= longest_ngram:
                longest_ngram = prev_lps
                position = i
            if i < max_ngram:
                lps[i] = prev_lps
            if prev_lps == max_ngram:
                prev_lps = lps[max_ngram - 1]  # 截断搜索
            i += 1
        elif prev_lps != 0:
            prev_lps = lps[prev_lps - 1]  # 回退到次长匹配
        else:
            i += 1

算法详解

  1. 翻转:将"找最长后缀匹配"转化为"找最长前缀匹配" → 经典KMP场景
  2. LPS数组lps[i] = tokens[0:i+1]的最长真前缀也是其后缀的长度
  3. prev_lps指针:当前正在比较的前缀长度
  4. 匹配时prev_lps += 1,更新longest_ngram和position
  5. 不匹配时:回退到lps[prev_lps-1](次长匹配)→ 避免从头扫描
  6. 截断prev_lps == max_ngram时截断 → 限制n-gram长度不超过max_ngram
  7. 翻转回原位置start = total - 1 - position + longest_ngram
  8. 提取draftorigin[start : start + k]

复杂度:O(n)单次扫描,Numba prange并行多序列。


第八章 ngram_proposer_gpu.py — GPU N-gram推测(662行)

核心设计:全向量化PyTorch实现

与CPU版本不同,GPU版本使用torch.unfold + argmax实现完全向量化的n-gram搜索。

@support_torch_compile()
class NgramGPUKernel(nn.Module):
    def _find_first_and_extract_all_n_parallel(self, token_ids, seq_lengths, min_ngram_len, max_ngram_len, num_draft_tokens):
        # 滑动窗口: O(1) view
        search_windows = token_ids.unfold(1, ngram_len, 1)
        
        # 提取尾部n-gram
        suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0))
        
        # 批量匹配
        matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)
        
        # 找最早匹配位置
        first_match = matches.float().argmax(dim=1)

关键设计决策

  1. unfold — PyTorch的滑动窗口操作,返回view而非copy → 零额外内存
  2. 向量化匹配 — 所有序列并行比较 → GPU并行度高
  3. torch.compile@support_torch_compile() 装饰器 → 自动编译优化
  4. 多n-gram长度 — 尝试min_n到max_n所有长度 → 取最长匹配

与CPU版本的区别

维度 CPU Ngram GPU Ngram
计算设备 CPU (Numba) GPU (PyTorch)
算法 KMP/LPS O(n) unfold暴力 O(n*w)
并行度 prange多序列 全batch向量化
延迟 ~1ms ~0.1ms
CUDA Graph 不支持 支持
内存 O(max_ngram) O(batchmax_seqngram)

第九章 dflash.py — DFlash并行推测(282行)

逐行解析

DFlashProposer是并行推测方法,基于DFlash(Draft Flash Attention)的cross-attention机制。

class DFlashProposer(SpecDecodeBaseProposer):
    def __init__(self, vllm_config, device, runner=None):
        super().__init__(pass_hidden_states_to_model=True)
  • pass_hidden_states_to_model=True — DFlash使用target的hidden_states作为context K/V

关键设计

        self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens)
        self.max_positions = self.max_num_tokens + self.max_query_tokens
  • DFlash将token分为context(target hidden states)和query(draft tokens)
  • Context作为K/V → 不参与自回归 → 一次写入
  • Query = bonus_token + mask_tokens → 单次forward生成所有draft
set_inputs_first_pass — 核心输入准备
    def set_inputs_first_pass(self, target_token_ids, next_token_ids, ...):
        batch_size = cad.batch_size()
        num_context = target_token_ids.shape[0]
        num_query_per_req = 1 + self.num_speculative_tokens
  • 分离context和query:context slot mapping和query slot mapping使用不同buffer
  • causal=False — DFlash使用非因果注意力(cross-attention)
  • Fused Triton kernelcopy_and_expand_dflash_inputs_kernel 一次完成input_ids/positions/slot_mapping/indices
        effective_seq_lens = cad.seq_lens
        if has_num_rejected:
            effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu
  • 被拒绝的token不计入有效序列长度 → attention只看有效前缀

第十章 extract_hidden_states.py — 隐藏状态提取(382行)

核心设计

ExtractHiddenStatesProposer不做推测,而是将target model的中间层hidden_states缓存到KV cache中,供后续KV transfer使用。

class ExtractHiddenStatesProposer:
    def __init__(self, vllm_config, device):
        assert vllm_config.speculative_config.num_speculative_tokens == 1
  • 固定num_spec=1 — 不做推测,只缓存1步的hidden_states
        layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
        self.num_hidden_states = len(layer_ids)
        self.hidden_states = torch.zeros(
            (self.max_num_tokens, self.num_hidden_states, self.hidden_size),
            dtype=self.dtype, device=device,
        )
  • 从多层提取hidden_states → 堆叠为3D tensor [tokens, num_layers, hidden]
  • 用于Eagle3等需要多层辅助hidden_states的方法
propose — 核心方法
    def propose(self, sampled_token_ids, target_hidden_states, common_attn_metadata, ...):
        stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
        self.hidden_states[:num_tokens] = stacked_hidden_states
        
        attn_metadata = self.attn_metadata_builder.build_for_drafting(...)
        with set_forward_context(per_layer_attn_metadata, ...):
            self.model(hidden_states=self.hidden_states[:num_input_tokens])
        
        return sampled_token_ids[:, :1]  # 返回target采样的token作为"draft"

关键逻辑

  1. Stack多个层的hidden_states → 3D buffer
  2. 调用ExtractHiddenStatesModel → 将hidden_states写入KV cache(cache-only attention)
  3. 返回target采样的token → 始终验证通过(draft=target采样结果)
  4. 实际不做推测 — 目的是缓存hidden_states供后续使用

第十一章 utils.py — Triton内核与辅助函数(596行)

核心Triton Kernels

Kernel 功能 调用者
eagle_step_slot_mapping_metadata_kernel 融合:position+1 / slot_mapping / seq_lens更新 EagleProposer自回归步
eagle_prepare_inputs_padded_kernel 构建Eagle第一遍输入:input_ids/positions/slot_mapping EagleProposer首遍
eagle_prepare_next_token_padded_kernel 准备下一步token的input_ids和positions EagleProposer单步
copy_and_expand_eagle_inputs_kernel 并行推测:扩展input_ids到所有draft位置 EagleProposer并行模式
copy_and_expand_dflash_inputs_kernel DFlash专用:context/query分离 DFlashProposer
next_power_of_2
def next_power_of_2(n: int) -> int:
    if n <= 0: return 1
    n -= 1
    n |= n >> 1; n |= n >> 2; n |= n >> 4; n |= n >> 8; n |= n >> 16; n |= n >> 32
    return n + 1
  • 经典位运算技巧:将n以下最高位置1的所有低位也置1 → +1得到下一个2的幂
  • 用途:Triton kernel的BLOCK_SIZE必须是2的幂
compute_new_slot_mapping
def compute_new_slot_mapping(block_size, seq_len, block_table, ...):
  • 从block_table和position计算slot_mapping
  • slot = block_table[position // block_size] * block_size + position % block_size

第十二章 eagle.py — Eagle推测核心(1793行)

在这里插入图片描述

模块定位

eagle.py 是整个spec_decode模块的核心文件,包含:

  • SpecDecodeBaseProposer — 所有模型类Proposer的抽象基类
  • EagleProposer — EAGLE/EAGLE3自回归推测器
  • 大量辅助方法:slot mapping管理、attention metadata构建、CUDA Graph支持

SpecDecodeBaseProposer — 抽象基类

class SpecDecodeBaseProposer:
    def __init__(self, vllm_config, device, pass_hidden_states_to_model, runner=None):
        self.parallel_drafting = self.speculative_config.parallel_drafting
        self.extra_slots_per_request = (
            1 if not self.parallel_drafting else self.num_speculative_tokens
        )
        self.net_num_new_slots_per_request = self.extra_slots_per_request - (
            1 if (self.pass_hidden_states_to_model and self.method != "dflash") else 0
        )

关键计算

  • extra_slots_per_request:每个请求额外需要的KV cache slot数
    • 自回归模式=1(仅bonus token) / 并行模式=k(所有draft + bonus)
  • net_num_new_slots_per_request:实际新增slot数
    • Eagle自回归:1 - 1 = 0(hidden_states传入占1个slot → 无额外slot)
    • DraftModel:1 - 0 = 1(不传hidden_states → 需额外1 slot)
    • DFlash并行:k - 0 = k

核心方法列表

方法 功能
propose() 生成draft tokens(子类实现)
load_model() 加载推测模型
set_inputs_first_pass() 构建第一遍forward输入
build_model_inputs_first_pass() 构建model forward参数
set_inputs_subsequent_pass() 自回归后续步输入
_get_slot_mapping() 构建slot mapping字典
_determine_batch_execution_and_padding() CUDA Graph批大小决策
_get_eagle3_use_aux_hidden_state_from_config() Eagle3多层hidden states配置

EagleProposer — EAGLE推测器

两种模式

  1. 自回归模式 (parallel_drafting=False):每步生成1个draft token,循环k步
  2. 并行推测模式 (parallel_drafting=True):单次forward生成所有draft tokens

Eagle模型架构

target_hidden_states [1, h] → linear → [1, h]
                                  ↓ + prev_token_embedding
                              attention (1 layer)
                                  ↓
                              linear → logits → sample → draft_token

Eagle3增强:使用多层aux hidden_states拼接 → 更丰富的draft信息

关键创新

  • Token shift:Eagle将上一步的token embedding拼接到hidden_states → 类似GPT的causal shift
  • Single-layer attention:仅1层attention → 极轻量 → 推测速度远快于target
  • 共享embedding/lm_head:与target模型共享 → 节省显存 + 词汇一致性

第十三章 全书总结

13.1 核心设计模式

# 模式 应用 说明
1 策略模式 SpecDecodeBaseProposer → 7种Proposer 统一接口,运行时选择推测策略
2 模板方法 BaseProposer.propose() → 子类实现 通用slot/metadata逻辑在基类
3 Fused Kernel Triton kernels 融合多操作 → 减少kernel launch开销
4 LPS/KMP NgramProposer O(n)最长n-gram匹配
5 向量化解码 NgramGPUKernel.unfold GPU并行批量n-gram搜索
6 延迟导入 SuffixDecoding(arctic_inference) 可选依赖不强制
7 零拷贝View unfold / tensor切片 避免GPU内存拷贝
8 CUDA Graph友好 make_dummy + 固定shape buffer 推测步骤可被CUDA Graph捕获
9 自适应线程 Numba set_num_threads 小batch单线程 / 大batch多线程
10 非因果注意力 DFlash causal=False cross-attention context→query
11 Token Shift Eagle prev_token_emb 拼接前一步token embedding
12 逐位置指标 accepted_per_pos 精确调优num_speculative_tokens

13.2 七种推测方法对比

方法 额外模型 推测策略 隐状态 并行draft TP限制 典型加速
Eagle Eagle头 自回归/并行 需要 可选 同TP 2-3x
Eagle3 Eagle3头 自回归/并行 多层 可选 同TP 2.5-3.5x
DraftModel 独立小模型 自回归 不需要 同TP 1.5-2x
Medusa Medusa多头 单步并行 需要 1.5-2x
DFlash DFlash头 并行cross-attn 需要 2-3x
Ngram 无(CPU) LPS查表 不需要 1.3-1.8x
NgramGPU 无(GPU) unfold查表 不需要 1.3-1.8x
Suffix 无(CPU) 后缀树 不需要 1.3-1.8x
ExtractHS Cache-only 不推测 多层 辅助

13.3 验证-拒绝机制

在这里插入图片描述

所有方法共享统一的验证流程:

  1. Proposer生成draft tokens → SpecDecodeMetadata
  2. Target model前向传播(base + draft) → 全部logits
  3. Sampler用target_logits_indices + bonus_logits_indices提取关键位置logits
  4. 逐位置验证:draft[i] == target_sample[i+1]?
  5. 首个不匹配处截断 → 丢弃后续draft → 保留bonus token
  6. 数学保证:接受的token分布等价于自回归采样 → 无损质量

13.4 性能关键路径

热路径(每步推理)

target_forward → hidden_states → proposer.propose() → draft_tokens
→ target_forward(draft) → sampler(verify) → accept/reject → update_kv

瓶颈分析

  • Eagle自回归:k次draft forward → 推测延迟 = k * eagle_forward_time
  • 并行模式:1次draft forward → 但attention复杂度O(k*n)
  • N-gram:CPU LPS ~1ms → 几乎无开销,但接受率低
  • DFlash:单次forward + cross-attn → 最优延迟/精度平衡

总结:vLLM v1 Speculative Decoding 模块是一个设计精良的多策略推测解码框架,通过统一的SpecDecodeBaseProposer接口支持7种推测方法。核心创新包括:Eagle的token-shift单层attention、DFlash的非因果cross-attention并行推测、N-gram的LPS算法GPU化、Suffix的后缀树动态推测。所有方法共享验证-拒绝机制,数学保证等价于自回归采样。模块通过Triton fused kernel、CUDA Graph支持、自适应线程等优化,确保推测开销最小化。

Logo

一站式 AI 云服务平台

更多推荐