一、基本情況
vLLM 部署大模型 官方網址: https://vllm.ai github 地址:https://github.com/vllm-project/vllm
vLLM 是一個快速且易於使用的庫,用於進行大型語言模型(LLM)的推理和服務。
它具有以下特點:
- 速度快: 在每個請求需要 3 個並行輸出完成時的服務吞吐量。vLLM 比 HuggingFace Transformers(HF)的吞吐量高出 8.5 倍-15 倍,比 HuggingFace 文本生成推理(TGI)的吞吐量高出 3.3 倍-3.5 倍
- 優化的 CUDA 內核
- 靈活且易於使用:
- 與流行的 Hugging Face 模型無縫集成。
- 高吞吐量服務,支持多種解碼算法,包括並行抽樣、束搜索等。
- 支持張量並行處理,實現分佈式推理。
- 支持流式輸出。
- 兼容 OpenAI API 服務器。
支持的模型
vLLM 無縫支持多個 Hugging Face 模型,包括 Aquila、Baichuan、BLOOM、Falcon、GPT-2、GPT BigCode、GPT-J、GPT-NeoX、InternLM、LLaMA、Mistral、MPT、OPT、Qwen 等不同架構的模型。(https://vllm.readthedocs.io/en/latest/models/supported_models.html)
目前,glm3和llama3都分別自己提供了openai樣式的服務,現在看一看vLLM有哪些不同?
二、初步實驗
安裝:
pip install vllm
下載:
import torch from modelscope import snapshot_download, AutoModel, AutoTokenizer import os model_dir = snapshot_download('LLM-Research/Meta-Llama-3-8B-Instruct', cache_dir='/root/autodl-tmp', revision='master')
運行以上代碼。
調用:
python -m vllm.entrypoints.openai.api_server --model /root/autodl-tmp/LLM-Research/Meta-Llama-3-8B-Instruct --trust-remote-code --port 6006
資源佔用:
嘗試通過postman進行調用:
curl http://localhost:6006/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "/root/autodl-tmp/LLM-Research/Meta-Llama-3-8B-Instruct", "max_tokens":60, "messages": [ { "role": "user", "content": "你是誰?" } ] }'
這種結果獲得方式、以及速度都不是最好的:
於此相對,採用系統自帶服務,顯存佔用更少。
單次測試代碼可以直接運行,並且能夠很好地和現有代碼進行融合。
import requests import json def get_completion(prompt): headers = {'Content-Type': 'application/json'} data = {"prompt": prompt} response = requests.post(url='http://127.0.0.1:6006', headers=headers, data=json.dumps(data)) return response.json()['response'] if __name__ == '__main__': print(get_completion('1+1=?'))
三、雙卡實驗(略寫)
分佈式推理
vLLM 支持分佈式張量並行推理和服務,使用 Ray 管理分佈式運行時,請使用以下命令安裝 Ray:
pip install ray
分佈式推理實驗,要運行多 GPU 服務,請在啓動服務器時傳入 --tensor-parallel-size 參數。
例如,要在 2 個 GPU 上運行 API 服務器:
python -m vllm.entrypoints.openai.api_server --model /root/autodl-tmp/Yi-6B-Chat --dtype auto --api-key token-agiclass --trust-remote-code --port 6006 --tensor-parallel-size 2
多卡調用一定是關鍵的能力,但是現在我還沒有足夠的動機來研究相關問題。
四、小結提煉
通過初步閱讀理解相關代碼,vLLM在openai的調用這塊採用了類似的方法;但是可能是爲了並行,導致它的體量比較大,並且出現了不兼容現象。
目前主要觀點,仍然是基於現有的體系來進行應用編寫。非常關鍵的一點是要懂原理,這樣的話才能夠應對各種情況。而對原理的探索能力一定是核心要素。
https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
import asyncio import importlib import inspect import os from contextlib import asynccontextmanager from http import HTTPStatus import fastapi import uvicorn from fastapi import Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app import vllm from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext TIMEOUT_KEEP_ALIVE = 5 # seconds openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion logger = init_logger(__name__) @asynccontextmanager async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: await asyncio.sleep(10) await engine.do_log_stats() if not engine_args.disable_log_stats: asyncio.create_task(_force_log()) yield app = fastapi.FastAPI(lifespan=lifespan) def parse_args(): parser = make_arg_parser() return parser.parse_args() # Add prometheus asgi middleware to route /metrics requests metrics_app = make_asgi_app() app.mount("/metrics", metrics_app) @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): err = openai_serving_chat.create_error_response(message=str(exc)) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @app.get("/health") async def health() -> Response: """Health check.""" await openai_serving_chat.engine.check_health() return Response(status_code=200) @app.get("/v1/models") async def show_available_models(): models = await openai_serving_chat.show_available_models() return JSONResponse(content=models.model_dump()) @app.get("/version") async def show_version(): ver = {"version": vllm.__version__} return JSONResponse(content=ver) @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): generator = await openai_serving_chat.create_chat_completion( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) if request.stream: return StreamingResponse(content=generator, media_type="text/event-stream") else: assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) @app.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): generator = await openai_serving_completion.create_completion( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) if request.stream: return StreamingResponse(content=generator, media_type="text/event-stream") else: return JSONResponse(content=generator.model_dump()) if __name__ == "__main__": args = parse_args() app.add_middleware( CORSMiddleware, allow_origins=args.allowed_origins, allow_credentials=args.allow_credentials, allow_methods=args.allowed_methods, allow_headers=args.allowed_headers, ) if token := os.environ.get("VLLM_API_KEY") or args.api_key: @app.middleware("http") async def authentication(request: Request, call_next): root_path = "" if args.root_path is None else args.root_path if not request.url.path.startswith(f"{root_path}/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: return JSONResponse(content={"error": "Unauthorized"}, status_code=401) return await call_next(request) for middleware in args.middleware: module_path, object_name = middleware.rsplit(".", 1) imported = getattr(importlib.import_module(module_path), object_name) if inspect.isclass(imported): app.add_middleware(imported) elif inspect.iscoroutinefunction(imported): app.middleware("http")(imported) else: raise ValueError(f"Invalid middleware {middleware}. " f"Must be a function or a class.") logger.info(f"vLLM API server version {vllm.__version__}") logger.info(f"args: {args}") if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) openai_serving_chat = OpenAIServingChat(engine, served_model_names, args.response_role, args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( engine, served_model_names, args.lora_modules) app.root_path = args.root_path uvicorn.run(app, host=args.host, port=args.port, log_level=args.uvicorn_log_level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs)