Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core] Add support for V1 Engine #295

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

shen-shanshan
Copy link
Collaborator

What this PR does / why we need it?

Add support for V1 Engine.

Please note that this is just the initial version, and there may be some places need to be fixed or optimized in the future, feel free to leave some comments to us.

Does this PR introduce any user-facing change?

To use V1 Engine on NPU device, you need to set the env variable shown below:

export VLLM_USE_V1=1
export VLLM_WORKER_MULTIPROC_METHOD=spawn

How was this patch tested?

I have tested the online serving with Qwen2.5-7B-Instruct using this command:

vllm serve Qwen/Qwen2.5-7B-Instruct --max_model_len 26240

Query the model with input prompts:

curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "Qwen/Qwen2.5-7B-Instruct",
        "prompt": "The future of AI is",
        "max_tokens": 7,
        "temperature": 0
    }'

The test logs are shown below:

INFO 03-11 06:18:03 [__init__.py:30] Available plugins for group vllm.platform_plugins:
INFO 03-11 06:18:03 [__init__.py:32] name=ascend, value=vllm_ascend:register
INFO 03-11 06:18:03 [__init__.py:34] all available plugins for group vllm.platform_plugins will be loaded.
INFO 03-11 06:18:03 [__init__.py:36] set environment variable VLLM_PLUGINS to control which plugins to load.
INFO 03-11 06:18:03 [__init__.py:44] plugin ascend loaded.
INFO 03-11 06:18:03 [__init__.py:247] Platform plugin ascend is activated
INFO 03-11 06:18:06 [core.py:51] Initializing a V1 LLM engine (v0.7.4.dev360+gc91b64f7) with config: model='Qwen/Qwen2.5-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=26240, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=npu, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=Qwen/Qwen2.5-7B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":0,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}

...

INFO 03-11 06:18:30 [loader.py:429] Loading weights took 3.40 seconds
INFO 03-11 06:18:32 [kv_cache_utils.py:537] GPU KV cache size: 744,048 tokens
INFO 03-11 06:18:32 [kv_cache_utils.py:540] Maximum concurrency for 26,240 tokens per request: 28.36x
npu not support graph capture. current compilation level : CompilationLevel.NO_COMPILATION
INFO 03-11 06:18:32 [core.py:120] init engine (profile, create kv cache, warmup model) took 2.06 seconds

...

INFO 03-11 06:18:37 [api_server.py:958] Starting vLLM API server on http://0.0.0.0:8000
INFO:     127.0.0.1:46928 - "POST /v1/completions HTTP/1.1" 200 OK

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Copy link
Collaborator

@wangxiyuan wangxiyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice you copied much code from vllm cuda v1. Please check if you can just import the code instaed of rewrite.

Another question:
Do we really need v1 module? How about just move the file to the right place and rename to something like model_runner_v1.py?

return int(physical_device_id)
else:
return device_id
# def _device_id_to_physical_device_id(device_id: int) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not comment code, please remove it if it's useless, otherwise add a note here to explain why the code is comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just move the file to the right place and rename to something like model_runner_v1.py?

I think it's a good idea.

@@ -74,8 +77,9 @@ def get_device_capability(cls, device_id: int = 0):

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = _device_id_to_physical_device_id(device_id)
return torch.npu.get_device_name(physical_device_id)
# physical_device_id = _device_id_to_physical_device_id(device_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -103,17 +107,35 @@ def mem_get_info(cls) -> Tuple[int, int]:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
compilation_config = vllm_config.compilation_config
if compilation_config.level != CompilationLevel.NO_COMPILATION:
logger.info("[NPU] Forcing NO_COMPILATION compilation level")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print current compilation_config.level to log as well. and change to warning level.

logger.warning("[V1][NPU] Disable prefix caching")
cache_config.enable_prefix_caching = False

assert not vllm_config.speculative_config, (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speculative decoding works now for 0.7.3. we don't need this in main IMO.


import torch

try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to try catch, just import torch_npu is fine

logger = init_logger(__name__)


class NPUModelRunner(LoRAModelRunnerMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why based from LORA?

Copy link
Collaborator Author

@shen-shanshan shen-shanshan Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In vLLM V1, LoRAModelRunnerMixin is a base class for GPUModelRunner, but I see TPUModelRunner doesn't extends this base class, so I think NPUModelRunner may also don't neet to extends this base class. I will modify this soon.

vocab_size=model_config.get_vocab_size(),
)

self.use_cuda_graph = (self.vllm_config.compilation_config.level
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove cuda related code

def init_device(self):
if self.device_config.device.type == "npu":
# # This env var set by Ray causes exceptions with graph building.
# os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove the NCCL related comments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK👌

Signed-off-by: shen-shanshan <467638484@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants