Skip to content

Commit 5244b7f

Browse files
russellblochuynh1412mmoskal
committed
Add Guidance backend to V0 structured output
This commit is based on the PR #10217. It is updated to be compatible with `main`. Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Loc Huynh <lohuynh@microsoft.com> Co-authored-by: Michal Moskal <michal@moskal.me>
1 parent b0746fa commit 5244b7f

File tree

8 files changed

+841
-7
lines changed

8 files changed

+841
-7
lines changed

benchmarks/benchmark_serving_structured_output.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ def main(args: argparse.Namespace):
992992
parser.add_argument(
993993
"--structured-output-backend",
994994
type=str,
995-
choices=["outlines", "lm-format-enforcer", "xgrammar", "json-unique"],
995+
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
996996
default="xgrammar",
997997
help="Backend to use for structured outputs")
998998

requirements/common.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pillow # Required for image processing
1717
prometheus-fastapi-instrumentator >= 7.0.0
1818
tiktoken >= 0.6.0 # Required for DBRX tokenizer
1919
lm-format-enforcer >= 0.10.11, < 0.11
20+
llguidance>=0.6.15
2021
outlines == 0.1.11
2122
lark == 1.2.2
2223
xgrammar == 0.1.11; platform_machine == "x86_64"
@@ -37,4 +38,4 @@ depyf==0.18.0 # required for profiling and debugging with compilation config
3738
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
3839
watchfiles # required for http server to monitor the updates of TLS files
3940
python-json-logger # Used by logging as per examples/other/logging_configuration.md
40-
scipy # Required for phi-4-multimodal-instruct
41+
scipy # Required for phi-4-multimodal-instruct

tests/model_executor/test_guided_processors.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from vllm.sampling_params import GuidedDecodingParams
1616

1717
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
18-
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
18+
GUIDED_DECODING_BACKENDS = [
19+
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
20+
]
1921
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
2022
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
2123

vllm/config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2751,7 +2751,9 @@ def compute_hash(self) -> str:
27512751
return hash_str
27522752

27532753
def __post_init__(self):
2754-
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
2754+
valid_guided_backends = [
2755+
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
2756+
]
27552757

27562758
backend = GuidedDecodingParams(
27572759
backend=self.guided_decoding_backend).backend_name

vllm/model_executor/guided_decoding/__init__.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,15 @@ async def get_guided_decoding_logits_processor(
130130
get_local_xgrammar_guided_decoding_logits_processor)
131131
return get_local_xgrammar_guided_decoding_logits_processor(
132132
guided_params, tokenizer, model_config, reasoner)
133-
133+
if guided_params.backend_name == 'guidance':
134+
from vllm.model_executor.guided_decoding.guidance_decoding import (
135+
get_local_guidance_guided_decoding_logits_processor)
136+
return get_local_guidance_guided_decoding_logits_processor(
137+
guided_params, tokenizer)
134138
raise ValueError(
135139
f"Unknown guided decoding backend '{guided_params.backend}'. "
136-
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
140+
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
141+
)
137142

138143

139144
def get_local_guided_decoding_logits_processor(
@@ -163,7 +168,13 @@ def get_local_guided_decoding_logits_processor(
163168
get_local_xgrammar_guided_decoding_logits_processor)
164169
return get_local_xgrammar_guided_decoding_logits_processor(
165170
guided_params, tokenizer, model_config, reasoner)
171+
if guided_params.backend_name == 'guidance':
172+
from vllm.model_executor.guided_decoding.guidance_decoding import (
173+
get_local_guidance_guided_decoding_logits_processor)
174+
return get_local_guidance_guided_decoding_logits_processor(
175+
guided_params, tokenizer)
166176

167177
raise ValueError(
168178
f"Unknown guided decoding backend '{guided_params.backend}'. "
169-
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
179+
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
180+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from enum import Enum
3+
from re import escape as regex_escape
4+
from typing import Union
5+
6+
from transformers import PreTrainedTokenizerBase
7+
8+
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
9+
GuidanceLogitsProcessor)
10+
from vllm.sampling_params import GuidedDecodingParams
11+
12+
13+
class GuidedDecodingMode(Enum):
14+
JSON = "json"
15+
REGEX = "regex"
16+
CHOICE = "choice"
17+
GRAMMAR = "grammar"
18+
19+
20+
def get_local_guidance_guided_decoding_logits_processor(
21+
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
22+
) -> Union[GuidanceLogitsProcessor, None]:
23+
"""
24+
Given an OpenAI-compatible request, check for guided decoding parameters
25+
and get the necessary logits processor for the given guide.
26+
We cache logit processors by (guide, tokenizer), and on cache hit
27+
we make a shallow copy to reuse the same underlying FSM.
28+
"""
29+
guide = None
30+
mode = None
31+
32+
if guided_params.json:
33+
guide = guided_params.json
34+
mode = GuidedDecodingMode.JSON.value
35+
elif guided_params.regex:
36+
guide = guided_params.regex
37+
mode = GuidedDecodingMode.REGEX.value
38+
elif guided_params.choice:
39+
# choice just uses regex
40+
choices = (regex_escape(str(choice))
41+
for choice in guided_params.choice)
42+
choices_regex = "(" + "|".join(choices) + ")"
43+
guide = choices_regex
44+
mode = GuidedDecodingMode.CHOICE.value
45+
elif guided_params.grammar:
46+
guide = guided_params.grammar
47+
mode = GuidedDecodingMode.GRAMMAR.value
48+
49+
if not guide or not mode:
50+
return None
51+
52+
return GuidanceLogitsProcessor(mode, guide, tokenizer,
53+
guided_params.whitespace_pattern)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import json
3+
import os
4+
from typing import Any, List, Type, Union
5+
6+
import llguidance # type: ignore[import-untyped]
7+
import llguidance.hf
8+
import numpy as np
9+
import torch
10+
from pydantic import BaseModel
11+
from transformers import PreTrainedTokenizerBase
12+
13+
from vllm.model_executor.guided_decoding.guidance_utils import (
14+
LLInterpreterResponse)
15+
16+
17+
class GuidanceLogitsProcessor:
18+
"""Base Guidance Logits Processor"""
19+
20+
cached_tokenizers: dict[str, Any] = {}
21+
22+
def __init__(
23+
self,
24+
mode: str,
25+
guide: Union[dict, Type[BaseModel], str],
26+
tokenizer: PreTrainedTokenizerBase,
27+
whitespace_pattern: Union[str, None] = None,
28+
) -> None:
29+
"""Base Guidance Logits Processor
30+
31+
Args:
32+
mode (str)
33+
guided generation mode.
34+
Must be one of "json", "regex", "choice", "grammar"
35+
guide (Union[dict, Type[BaseModel], str])
36+
guide for guided generation
37+
tokenizer (PreTrainedTokenizerBase)
38+
model's tokenizer
39+
whitespace_pattern (Union[str, None], optional)
40+
Json-string to indicate pattern to use \
41+
for JSON syntactic whitespace
42+
Example: '{"whitespace_flexible":true}'
43+
"""
44+
self.mode = mode
45+
self.guide = guide
46+
self.tokenizer = tokenizer
47+
self.tokenizer_name = tokenizer.name_or_path
48+
self.whitespace_pattern = whitespace_pattern
49+
50+
self.is_stopped = False
51+
self.pending_ff_tokens: list[int] = []
52+
self.new_sampling = False
53+
self.initialized = False
54+
55+
def _initialize(self):
56+
if self.initialized:
57+
return
58+
59+
if self.mode.lower() == "json":
60+
if isinstance(self.guide, dict):
61+
schema = json.dumps(self.guide)
62+
elif isinstance(self.guide, BaseModel):
63+
schema = json.dumps(self.guide.model_json_schema())
64+
else:
65+
schema = str(self.guide)
66+
67+
whitespaces_config = {}
68+
if isinstance(self.whitespace_pattern, str):
69+
whitespaces_config = json.loads(self.whitespace_pattern)
70+
71+
whitespace_flexible = whitespaces_config.get(
72+
"whitespace_flexible", False)
73+
compiler = llguidance.JsonCompiler(
74+
whitespace_flexible=whitespace_flexible)
75+
self.serialized_grammar = compiler.compile(schema)
76+
elif self.mode.lower() in ["regex", "choice"]:
77+
compiler = llguidance.RegexCompiler()
78+
self.serialized_grammar = compiler.compile(regex=self.guide)
79+
elif self.mode.lower() == "grammar":
80+
serialized_grammar = self.guide
81+
if isinstance(self.guide, dict):
82+
serialized_grammar = json.dumps(self.guide)
83+
self.serialized_grammar = serialized_grammar
84+
85+
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
86+
None)
87+
if ll_tokenizer is None:
88+
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
89+
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer
90+
self.ll_tokenizer = ll_tokenizer
91+
self.ll_interpreter = llguidance.LLInterpreter(
92+
self.ll_tokenizer,
93+
self.serialized_grammar,
94+
enable_backtrack=False,
95+
enable_ff_tokens=False,
96+
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
97+
)
98+
99+
self.initialized = True
100+
101+
def __call__(
102+
self,
103+
input_ids: List[int],
104+
scores: torch.Tensor,
105+
) -> torch.Tensor:
106+
# we initialize the guidance model here
107+
# to avoid pickling ll_tokenizer and ll_interpreter
108+
self._initialize()
109+
110+
if self.is_stopped:
111+
return scores
112+
113+
if self.new_sampling and len(input_ids) > 0:
114+
backtrack, ff_tokens = self.ll_interpreter.commit_token(
115+
input_ids[-1])
116+
if len(ff_tokens) > 0 and backtrack == 0:
117+
# first token is last generated token
118+
ff_tokens = ff_tokens[1:]
119+
self.pending_ff_tokens.extend(ff_tokens)
120+
self.new_sampling = False
121+
122+
if len(self.pending_ff_tokens) > 0:
123+
# if we have pending fast-forward tokens,
124+
# just return them immediately
125+
ff_token = self.pending_ff_tokens.pop(0)
126+
scores.add_(-scores)
127+
scores[ff_token] = 200.0
128+
return scores
129+
130+
mask, resp = self.ll_interpreter.compute_mask()
131+
r = LLInterpreterResponse.model_validate_json(resp)
132+
133+
if r.stop:
134+
mask = np.zeros(scores.shape[-1], dtype=np.uint8)
135+
if self.ll_tokenizer.eos_token is not None:
136+
mask[self.ll_tokenizer.eos_token] = 200
137+
self.is_stopped = True
138+
elif mask is None:
139+
# NOTE: mask should not be None unless r.stop is True
140+
# However, we are handling this case just in case
141+
# llguidance allows free-style generation
142+
mask = np.zeros(scores.shape[-1], dtype=np.uint8)
143+
else:
144+
mask = np.frombuffer(mask, dtype=np.uint8)
145+
146+
# Force all invalid tokens to have 0 value
147+
scores.add_(-torch.min(scores))
148+
zero_indices = np.where(mask == 0)[0]
149+
scores[zero_indices] = 0.0
150+
non_zero_indices = np.nonzero(mask)[0]
151+
scores[non_zero_indices] += 200.0
152+
# set special tokens not in vocab to 0
153+
scores[mask.shape[0]:] = 0.0
154+
self.new_sampling = True
155+
156+
return scores

0 commit comments

Comments
 (0)