Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llm] ray.llm support custom accelerators #51359

Merged

Conversation

liuxsh9
Copy link
Contributor

@liuxsh9 liuxsh9 commented Mar 14, 2025

Why are these changes needed?

Support the usage of ray.data.llm and ray.serve.llm on custom accelerators beyond just GPU.

Related issue number

Roadmap for Data and Serve LLM APIs

This PR should contribute to advancing the implementation of TPU support as outlined in the Roadmap.

Usage Example

Users should install the corresponding vllm platform plugin. For example, the installation process for vllm-ascend is as follows:

# Install vllm main branch according:
git clone --depth 1 https://github.com/vllm-project/vllm.git
cd vllm
pip install -r requirements/build.txt
VLLM_TARGET_DEVICE=empty pip install .

# Install vllm-ascend main branch
git clone https://github.com/vllm-project/vllm-ascend.git
cd vllm-ascend
pip install -e .

Then users can utilize the NPU (and other accelerators) through ray.data.llm,

from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor

config = vLLMEngineProcessorConfig(
    model="unsloth/Llama-3.1-8B-Instruct",
    engine_kwargs={},
    concurrency=1,
    batch_size=64,
    # an additional line of code compare to default GPU
    accelerator_name="NPU",
)

and similarly for ray.serve.llm

from ray import serve
from ray.serve.llm.configs import LLMConfig
from ray.serve.llm.deployments import VLLMService, LLMRouter

llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-0.5b",
        model_source="Qwen/Qwen2.5-0.5B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    # An additional line of code compare to default GPU.
    accelerator_name="NPU",
    # If you want to specify more precise resource types.
    accelerator_type="910B4",
    # You can customize the engine arguments (e.g. vLLM engine kwargs)
    engine_kwargs=dict(),
)

# Deploy the application
deployment = VLLMService.as_deployment(llm_config.get_serve_options(name_prefix="VLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app)

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@liuxsh9 liuxsh9 requested a review from a team as a code owner March 14, 2025 04:31
@liuxsh9 liuxsh9 changed the title Llm support custom accelerators [llm] ray.llm support custom accelerators Mar 14, 2025
@kouroshHakha
Copy link
Contributor

kouroshHakha commented Mar 14, 2025

Hi @liuxsh9,

Thanks for your contribution. accelerator_name is not a general enough thing we can do. It would be better if we use custom resources from ray core to do this in a more generic way. So for above usecases it would become something like:

data:

from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor

config = vLLMEngineProcessorConfig(
    model="unsloth/Llama-3.1-8B-Instruct",
    engine_kwargs={},
    concurrency=1,
    batch_size=64,
    # an additional line of code compare to default GPU
    resources_per_worker={"NPU": 1}
)

serve:

from ray import serve
from ray.serve.llm.configs import LLMConfig
from ray.serve.llm.deployments import VLLMService, LLMRouter

llm_config = LLMConfig(
    model_loading_config=dict(
        model_id="qwen-0.5b",
        model_source="Qwen/Qwen2.5-0.5B-Instruct",
    ),
    deployment_config=dict(
        autoscaling_config=dict(
            min_replicas=1, max_replicas=2,
        )
    ),
    # If you want to specify more precise resource types.
    resources_per_worker={"NPU-910B4": 1}
    # You can customize the engine arguments (e.g. vLLM engine kwargs)
    engine_kwargs=dict(),
)

# Deploy the application
deployment = VLLMService.as_deployment(llm_config.get_serve_options(name_prefix="VLLM:")).bind(llm_config)
llm_app = LLMRouter.as_deployment().bind([deployment])
serve.run(llm_app)

This way user has full control on how to setup their cluster labeling and how to pass in the resource requirements.
I think we should change this PR to do this instead. Also for data you should add this to the base config and not just vllm processor config.

Also I am not sure if NPU is an officially supported accelerator by Ray. cc @jjyao to chime in here.

@liuxsh9
Copy link
Contributor Author

liuxsh9 commented Mar 14, 2025

Thanks! @kouroshHakha First, both Ray #41256 and vLLM now support NPU , meeting the basic adaptation requirements. Your observation is valid - while we considered using the generic resources field, this would require explicitly specifying accelerator nums. However, since the engine_kwargs parameter already determines required accelerator nums, having two nums-related config might increase complexity. Our proposal is to maintain the GPU-like pattern where nums is handled by engine_kwargs, while simply adding a new field for accelerator category selection. These are our design considerations, and welcome community feedback on this approach!

@kouroshHakha
Copy link
Contributor

Ray and vllm now support NPU

👍

Your observation is valid - while we considered using the generic resources field, this would require explicitly specifying accelerator nums. However, since the engine_kwargs parameter already determines required accelerator nums, having two nums-related config might increase complexity. Our proposal is to maintain the GPU-like pattern where nums is handled by engine_kwargs, while simply adding a new field for accelerator category selection. These are our design considerations, and welcome community feedback on this approach!

The number you'd set on resources_per_worker will be logical resources for each worker while the stuff set in engine_kwargs are llm specific params. Regardless of TP or PP you'd always want to set the bundle to be {"resource": {"NPU": 1}} which is not sth you'd have to think about as the end user. This is the pattern that is used in other ray libraries as well. For example resources_per_worker in ray train is similar. Or in ray data you can directly pass in the actor remote args which includes resource field for custom resource specification.

@liuxsh9 liuxsh9 force-pushed the llm-support-custom-accelerators branch 2 times, most recently from 0e224fa to f5ce9c5 Compare March 18, 2025 02:21
Signed-off-by: liuxsh9 <liuxiaoshuang4@huawei.com>
@liuxsh9 liuxsh9 force-pushed the llm-support-custom-accelerators branch from f5ce9c5 to 6e1a87a Compare March 18, 2025 02:25
liuxsh9 added 2 commits March 18, 2025 10:32
Signed-off-by: liuxsh9 <liuxiaoshuang4@huawei.com>
Signed-off-by: liuxsh9 <liuxiaoshuang4@huawei.com>
@liuxsh9
Copy link
Contributor Author

liuxsh9 commented Mar 18, 2025

Ray and vllm now support NPU

👍

Your observation is valid - while we considered using the generic resources field, this would require explicitly specifying accelerator nums. However, since the engine_kwargs parameter already determines required accelerator nums, having two nums-related config might increase complexity. Our proposal is to maintain the GPU-like pattern where nums is handled by engine_kwargs, while simply adding a new field for accelerator category selection. These are our design considerations, and welcome community feedback on this approach!

The number you'd set on resources_per_worker will be logical resources for each worker while the stuff set in engine_kwargs are llm specific params. Regardless of TP or PP you'd always want to set the bundle to be {"resource": {"NPU": 1}} which is not sth you'd have to think about as the end user. This is the pattern that is used in other ray libraries as well. For example resources_per_worker in ray train is similar. Or in ray data you can directly pass in the actor remote args which includes resource field for custom resource specification.

Fine-tuned! Thanks for the feedback. More welcome! @kouroshHakha

Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some more stuff, I'd be happy to push over your PR and make these last mile changes. Let me know what you think:

if not resources_per_worker:
map_batches_kwargs["num_gpus"] = num_mp_workers
else:
ray_remote_args["resources"] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize now that there is some naming confusion. resource_per_worker in the top level api is referring to the resource required per worker within the replica while it might also be interpreted as resources per replica. Say you want to do tp=2 and pp=2 on NPUs. Then is resource_per_worker={"NPU": 4} the correct value or is resource_per_worker={"NPU": 1} the right thing. Worker could mean num workers seen from Ray's perspective. Inside this function however, resource_per_worker seems to be referring to resource_per_vllm_worker which is number of workers from vllm's perspective. We need to find a consistent naming to differentiate them. Here is my suggested implementation. I can push over your changes if that's ok?

  1. Change the ray_scheduling_strategy_fn to be more explicit about the meaning of these items and what can be None and what cannot be None (accelerator_type can be None and also should be ignored when custom resources are passed in).
def _ray_scheduling_strategy_fn(
    num_bundles_per_replica: int, 
    accelerator_type: Optional[str] = None,
    resources_per_bundle: Optional[Dict[str, float]] = None,
):
    """Create a Ray scheduling strategy for the engine.

    Args:
        num_bundles_per_replica: The number of device bundles per 
            engine replica.
        accelerator_type: The accelerator type. If None, the 
            accelerator_type label will not be set.
        resources_per_bundle: The custom resources per bundle. 
            If None, we default to 1xGPU + 1xCPU bundle.

    Returns:
        The Ray scheduling strategy.
    """

    def _get_bundle() -> Dict[str, float]:

        # Custom resources
        if resources_per_bundle:
            return resources_per_bundle
        
        # GPU bundles
        bundle = {"GPU": 1, "CPU": 1}
        if accelerator_type:
            bundle[f"accelerator_type:{accelerator_type}"] = 0.001
        return bundle

    pg = ray.util.placement_group(
        [_get_bundle()] * num_bundles_per_replica,
        strategy="STRICT_PACK",
    )
    return dict(
        scheduling_strategy=PlacementGroupSchedulingStrategy(
            pg, placement_group_capture_child_tasks=True
        )
    )
  1. Change the stage postinit implementation to reflect the new names and consistently use ray_remote_args in case the ray_remote_args_fn condition does not get exercised.
class vLLMEngineStage(StatefulStage):
    """
    A stage that runs vLLM engine.
    """

    fn: Type[StatefulStageUDF] = vLLMEngineStageUDF

    @root_validator(pre=True)
    def post_init(cls, values):
        """Post-initialize the stage. Specifically,
        this function determines the num_gpus and Ray remote args
        for the .map_batches() call in this stage.

        Args:
            values: The raw stage values.
        Returns:
            The updated values.
        """
        map_batches_kwargs = values["map_batches_kwargs"]
        resources_per_bundle = map_batches_kwargs.get("resources_per_bundle")
        accelerator_type = map_batches_kwargs.get("accelerator_type", "")
        fn_constructor_kwargs = values["fn_constructor_kwargs"]
        engine_kwargs = fn_constructor_kwargs.get("engine_kwargs", {})

        ray_remote_args = {}
        if accelerator_type:
            ray_remote_args["accelerator_type"] = accelerator_type

        # Setup num_workers required per vLLM engine.
        tp_size = engine_kwargs.get("tensor_parallel_size", 1)
        pp_size = engine_kwargs.get("pipeline_parallel_size", 1)
        num_bundles_per_replica = tp_size * pp_size

        # Use the MP backend by default.
        engine_kwargs.setdefault("distributed_executor_backend", "mp")
        executor_backend = engine_kwargs.get("distributed_executor_backend")

        # When Ray is used in the vLLM engine, we set num_devices to 0 so that
        # Ray Data won't reserve GPUs in advance. Instead, we specify scheduling
        # strategy in .map_batches() arguments and let vLLM Ray executor to
        # create placement groups for each TP/PP worker.
        if executor_backend == "ray" and num_bundles_per_replica > 1:
            # Note that we have to use partial() to pass a function
            # instead of an object.
            map_batches_kwargs["ray_remote_args_fn"] = partial(
                _ray_scheduling_strategy_fn,
                num_bundles_per_replica,
                accelerator_type,
                resources_per_bundle,
            )

        if not resources_per_bundle:
            # Default to GPUs per bundle if custom resources are not specified.
            ray_remote_args["num_gpus"] = num_bundles_per_replica
        else:
            ray_remote_args["resources"] = {
                resource_key: resource_count * num_bundles_per_replica
                for resource_key, resource_count in resources_per_bundle.items()
            }

        map_batches_kwargs.update(ray_remote_args)
        return values
  1. Reflect the name resource_per_bundle to public to save the user from the confusion.

@@ -134,7 +140,10 @@ def placement_strategy(self) -> str:

@property
def placement_bundles(self) -> List[Dict[str, float]]:
bundle = {"GPU": 1}
if not self.resources_per_worker:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar change should happen here as well.

@liuxsh9
Copy link
Contributor Author

liuxsh9 commented Mar 19, 2025

Differentiating between Ray and vLLM in resource allocation makes perfect sense. Please feel free to push your changes! @kouroshHakha

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
@kouroshHakha kouroshHakha requested a review from a team as a code owner March 20, 2025 18:04
@kouroshHakha kouroshHakha added the go add ONLY when ready to merge, run all tests label Mar 20, 2025
@kouroshHakha
Copy link
Contributor

kouroshHakha commented Mar 20, 2025

Running release tests: https://buildkite.com/ray-project/release/builds/36492

The failure case should be fixed on master #51528

@kouroshHakha
Copy link
Contributor

@GeneDer and @comaniac Please review as a second eye. TY

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
@kouroshHakha
Copy link
Contributor

@GeneDer I added the unittest for serve. For data, I'll take a rain check because that is a bit different to test. I don't want to do it in this PR.

Copy link
Contributor

@GeneDer GeneDer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test, LGTM!

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
@kouroshHakha kouroshHakha enabled auto-merge (squash) March 21, 2025 05:18
@kouroshHakha kouroshHakha disabled auto-merge March 21, 2025 05:18
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
@kouroshHakha kouroshHakha enabled auto-merge (squash) March 21, 2025 05:31
@kouroshHakha kouroshHakha self-assigned this Mar 21, 2025
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
@github-actions github-actions bot disabled auto-merge March 21, 2025 15:46
@kouroshHakha kouroshHakha enabled auto-merge (squash) March 21, 2025 15:58
@kouroshHakha kouroshHakha merged commit 360ede3 into ray-project:master Mar 21, 2025
6 checks passed
dhakshin32 pushed a commit to dhakshin32/ray that referenced this pull request Mar 27, 2025
Signed-off-by: liuxsh9 <liuxiaoshuang4@huawei.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Co-authored-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Dhakshin Suriakannu <d_suriakannu@apple.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests llm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants