大模型推理框架 vLLM 源碼解析(一)

1. Quick Start

創建如下代碼,命名爲 run.py

from vllm import LLM, SamplingParams

prompts = [
	"Have you followed marsggbo in Zhihu?",
	"你一鍵三連了嗎?"
] # 輸入prompts
sampling_params = SamplingParams(temperature=0.8, top_k=50) # 採樣策略
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2) # 初始化 LLM
outputs = llm.generate(prompts, sampling_params) # 完成推理
for output in outputs:
	prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

執行命令:python run.py。該腳本會自動將模型以張量並行的方式在兩個 GPU 上進行推理計算。

整個推理過程大大致流程如下圖所示,即
1 給定一定數量的 prompts(字符串數組)
2. vllm 會使用 Scheduler 模塊自動對需要推理句子進行調度
3. 根據調度的結果,使用 tokenizer 將字符串轉換成 prompt id,然後餵給 model 進行計算得到 logits 預測結果
4. 根據 logits 預測結果和提前設置好的採樣策略對結果進行採樣得到新的 token id
5. 將採樣結果保存到 output

inferencce pipeline

2. 整體核心模塊

vllm 核心模塊結構
上圖給出了 vLLM 核心模塊之間的結構關係。接下來我們從簡單的模塊(即輸入、採樣和輸出)開始介紹,最後詳細介紹 LLM 模塊。

3. Sequence

句子模塊
如上圖我們可以看到 vLLM 爲輸入的句子設計了很多子模塊,這些模塊的用處各不相同,但是有彼此之間有關係,下面分別詳細介紹一下。

3.1 SequenceStatus

首先看到 SequenceStatus,其源代碼如下:

class SequenceStatus(enum.Enum):
    """Status of a sequence."""
    WAITING = enum.auto() # 等待中,句子還沒開始推理,或者推理還未結束
    RUNNING = enum.auto() # 運行中
    SWAPPED = enum.auto() # 已交換
    FINISHED_STOPPED = enum.auto() # 已停止
    FINISHED_LENGTH_CAPPED = enum.auto() # 已長度限制
    FINISHED_ABORTED = enum.auto() # 已中止
    FINISHED_IGNORED = enum.auto() # 已忽略

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        # 判斷狀態是否爲已停止、已長度限制、已中止或已忽略
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
            SequenceStatus.FINISHED_ABORTED,
            SequenceStatus.FINISHED_IGNORED,
        ]

3.2 SequenceData

SequenceData 用於存儲與序列相關的數據。這個類有三個屬性:prompt_token_ids(提示詞的標記ID)、output_token_ids(生成文本的標記ID)和cumulative_logprob(累計對數概率)。

class SequenceData:
    def __init__(
        self,
        prompt_token_ids: List[int],
    ) -> None:
        self.prompt_token_ids = prompt_token_ids
        self.output_token_ids: List[int] = []
        self.cumulative_logprob = 0.0

3.3 Sequence

Sequence 用於存儲序列的數據、狀態和塊信息,且每個序列有唯一標識,即seq_id。注意看下面的代碼:

  • 數據其實是通過上面的 SequenceData 保存的
  • 默認初始化狀態,所有句子序列的狀態都是 SequenceStatus.WAITING
  • 所謂塊信息,其實就是 vLLM 會在初始化階段預留出一定數量的CPU 和 GPU 內存,一般是以 token 爲單位的,例如在初始化的時候會使用值全爲 0,大小爲 (256, 128)的 prompt_ids做 warm up。每個序列會按照實際大小申請 block 來記錄內存使用情況,即序列 token 數越多,屬性logical_token_blocks包含的 block 個數也就越多。
class Sequence:
    def __init__(
        self,
        seq_id: int,
        prompt: str,
        prompt_token_ids: List[int],
        block_size: int,
    ) -> None:
        self.seq_id = seq_id
        self.prompt = prompt
        self.block_size = block_size

        self.data = SequenceData(prompt_token_ids) # 數據

        self.logical_token_blocks: List[LogicalTokenBlock] = []
        # Initialize the logical token blocks with the prompt token ids.
        self._append_tokens_to_blocks(prompt_token_ids) # 塊信息
        self.status = SequenceStatus.WAITING # 狀態
		...

3.3 SequenceGroup

Sequence只是單個序列的表示方式,seq_id是它的唯一標識。SequenceGroup則是爲了表示多個序列,request_id是它的唯一標識,表示是第幾個請求。

具體而言,可以看到__init__函數有個參數是 seqs: List[Sequence],它表示由一個或多個 Sequence 組成的列表,然後會通過self.seqs_dict = {seq.seq_id: seq for seq in seqs}轉化成字典方便管理,這個字典的 key 是每個 Sequence 的唯一標識seq_id

class SequenceGroup:
    def __init__(
        self,
        request_id: str,
        seqs: List[Sequence],
        sampling_params: SamplingParams,
        arrival_time: float,
        lora_request: Optional[LoRARequest] = None,
        prefix: Optional[Prefix] = None,
    ) -> None:
        self.request_id = request_id
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
        self.sampling_params = sampling_params
        self.arrival_time = arrival_time
		...

下面是 vLLm 中 LLMEngine 使用 Sequence 和 SequenceGroup 的場景示例:

class LLMEngine:
    def add_request(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        prefix_pos: Optional[int] = None,
    ) -> None:
        prompt_token_ids = self.encode_request(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request) # 將字符串序列轉換成 id

        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
                       lora_request)

        # Create the sequence group.
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
                                  arrival_time)

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

可以看到SequenceGroupseqs參數在最初階段其實只是單個序列 ,即[seq]。但是我們知道其實一個 prompt 可以有多個輸出結果,所以SequenceGroup的目的是管理一個輸入 prompt的多個生成序列信息。如果我們設置SamplingParams.n=2(第 4 節會介紹),那麼在推理過程中,SequenceGroup會新增一個 Sequence,這個新增的 Sequence 的 seq_id 和原來的那個 Sequence 不一樣,具體的代碼細節會在下一篇文章中介紹。

3.5 SequenceGroupMetadata

class SequenceGroupMetadata:
    def __init__(
        self,
        request_id: str,
        is_prompt: bool,
        seq_data: Dict[int, SequenceData],
        sampling_params: SamplingParams,
        block_tables: Dict[int, List[int]],
    ) -> None:
        self.request_id = request_id
        self.is_prompt = is_prompt
        self.seq_data = seq_data
        self.sampling_params = sampling_params
        self.block_tables = block_tables
		...

SequenceGroupMetadata 記錄了一些元信息,下面代碼展示了 Scheduler 模塊是如何生成這些信息的:

  • request_id 就是 SequenceGroup的 request_id
  • seq_data 是一個字典,key 是每個 Sequence的 seq_id,value 則是對應的 data (即 SequenceData)
  • block_tables也是一個字典,key 也是每個 Sequence的 seq_id,value 這是對應 Sequence 申請的 block
class Scheduler:
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
        scheduler_outputs = self._schedule()

        # Create input data structures.
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
        for seq_group in scheduler_outputs.scheduled_seq_groups:
            seq_data: Dict[int, SequenceData] = {}
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                seq_id = seq.seq_id
                seq_data[seq_id] = seq.data # 單個 SequenceData
                block_tables[seq_id] = self.block_manager.get_block_table(seq) # 對應Sequence的block信息

            seq_group_metadata = SequenceGroupMetadata(
                request_id=seq_group.request_id,
                is_prompt=scheduler_outputs.prompt_run,
                seq_data=seq_data,
                sampling_params=seq_group.sampling_params,
                block_tables=block_tables,
                lora_request=seq_group.lora_request,
                prefix=seq_group.prefix,
            )
            seq_group_metadata_list.append(seq_group_metadata)
        return seq_group_metadata_list, scheduler_outputs

3.6 SequenceOutput 和 SequenceGroupOutput

SequenceOutput 和 SequenceGroupOutput的關係就類似 Sequence 和 SequenceGroup。SequenceOutput其實就是記錄了上一個 輸入 token id 以及對應輸出的 token id。

class SequenceOutput:
    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
        logprobs: Dict[int, float],
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

class SequenceGroupOutput:
    def __init__(
        self,
        samples: List[SequenceOutput],
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
        self.prompt_logprobs = prompt_logprobs

4. SamplingParams

SamplingParams

SamplingParams 包含以下參數:

  • n:要生成的序列的數量,默認爲 1。
  • best_of:從多少個序列中選擇最佳序列,需要大於 n,默認等於 n。
  • temperature:用於控制生成結果的隨機性,較低的溫度會使生成結果更確定性,較高的溫度會使生成結果更隨機。
  • top_p:用於過濾掉生成詞彙表中概率低於給定閾值的詞彙,控制隨機性。
  • top_k:選擇前 k 個候選 token,控制多樣性。
  • presence_penalty:用於控制生成結果中特定詞彙的出現頻率。
  • frequency_penalty:用於控制生成結果中詞彙的頻率分佈。
  • repetition_penalty:用於控制生成結果中的詞彙重複程度。
  • use_beam_search:是否使用束搜索來生成序列。
  • length_penalty:用於控制生成結果的長度分佈。
  • early_stopping:是否在生成過程中提前停止。
  • stop:要停止生成的詞彙列表。
  • stop_token_ids:要停止生成的詞彙的ID列表。
  • include_stop_str_in_output:是否在輸出結果中包含停止字符串。
  • ignore_eos:在生成過程中是否忽略結束符號。
  • max_tokens:生成序列的最大長度。
  • logprobs:用於記錄生成過程的概率信息。
  • prompt_logprobs:用於記錄生成過程的概率信息,用於特定提示。
  • skip_special_tokens:是否跳過特殊符號。
  • spaces_between_special_tokens:是否在特殊符號之間添加空格。

這些參數的設置通常取決於具體需求和模型性能。以下是一些常見的設置指導方法:

  • temperature:較低的溫度(如0.2)會產生更確定性的結果,而較高的溫度(如0.8)會產生更隨機的結果。您可以根據您的需求進行調整。
  • presence_penalty、frequency_penalty 和 repetition_penalty:這些參數可以用於控制生成結果中的詞彙分佈和重複程度。您可以根據您的需求進行調整。
  • use_beam_search:束搜索通常用於生成更高質量的結果,但可能會降低生成速度。您可以根據您的需求進行調整。
  • length_penalty:這個參數可以用於控制生成結果的長度。較高的值會產生更長的結果,而較低的值會產生更短的結果。您可以根據您的需求進行調整。
  • early_stopping:如果您不希望生成過長的結果,可以設置此參數爲True。
  • stop 和 stop_token_ids:您可以使用這些參數來指定生成結果的結束條件。

5. Output 模塊

Output模塊

Output 主要用於表示語言模型(LLM)的生成結果,包含如下兩個模塊:

  • CompletionOutput
  • RequestOutput

通過上面的介紹我們知道一個 request 可能包含多個序列,CompletionOutput 用來表示一個 request 中某個序列的完整輸出的數據,其中下面的index就表示該序列在 request 中的索引位置

class CompletionOutput:
    def __init__(
        self,
        index: int, # 輸出結果在請求中的索引
        text: str, # 生成的文本
        token_ids: List[int], # 生成的文本對應的 token ID 列表
        cumulative_logprob: float,
        logprobs: Optional[SampleLogprobs],
        finish_reason: Optional[str] = None, # 序列完成的原因(SequenceStatus)
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        self.index = index
        self.text = text
        self.token_ids = token_ids
        self.finish_reason = finish_reason
		...

RequestOutput則表示 request 所有序列的輸出結果,有它的初始化函數可以看到它記錄了對應的 request_id

class RequestOutput:
    def __init__(
        self,
        request_id: str,
        prompt: str,
        prompt_token_ids: List[int],
        prompt_logprobs: Optional[PromptLogprobs],
        outputs: List[CompletionOutput],
        finished: bool,
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.outputs = outputs
        self.finished = finished
		...

我們看看RequestOutput的from_seq_group就能很好理解CompletionOutputRequestOutput是如何使用的了。爲方便理解,代碼有刪減,但是不影響最終結果:

class RequestOutput:
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
        # 1. Get the top-n sequences.
        n = seq_group.sampling_params.n # 每個序列返回的生成序列數量
        seqs = seq_group.get_seqs()
		# 根據累積 logprob 值來選擇出前 n 個生成序列
		sorting_key = lambda seq: seq.get_cumulative_logprob()
        sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
        top_n_seqs = sorted_seqs[:n]

        # 2. Create the outputs.
        outputs: List[CompletionOutput] = []
        for seq in top_n_seqs:
            logprobs = seq.output_logprobs
            finshed_reason = SequenceStatus.get_finished_reason(seq.status)
            output = CompletionOutput(seqs.index(seq), seq.output_text,
                                      seq.get_output_token_ids(),
                                      seq.get_cumulative_logprob(), logprobs,
                                      finshed_reason)
            outputs.append(output)

        # Every sequence in the sequence group should have the same prompt.
        prompt = seq_group.prompt
        prompt_token_ids = seq_group.prompt_token_ids
        prompt_logprobs = seq_group.prompt_logprobs
        finished = seq_group.is_finished()
        return cls(seq_group.request_id,
                   prompt,
                   prompt_token_ids,
                   prompt_logprobs,
                   outputs,
                   finished,
                   lora_request=seq_group.lora_request)

RequestOutput是通過對傳入的seq_group: SequenceGroup進行解析後得到的。解析過程主要有兩個階段:

  1. Get the top-n sequences:這一階段就是對生成序列按照 cumulative_logprob 進行排序,最後選擇出top-n 序列。
  2. Create the outputs:將所有top-n生成序列分別轉換成 CompletionOutput列表,並作爲RequestOutput的初始化參數。

微信公衆號:AutoML機器學習
MARSGGBO原創
如有意合作或學術討論歡迎私戳聯繫~
郵箱:[email protected]

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章