Skip to content

Commit 1dc816d

Browse files
feat: add function call in SGLang (#1666)
1 parent b3d73bd commit 1dc816d

File tree

4 files changed

+238
-15
lines changed

4 files changed

+238
-15
lines changed

camel/configs/sglang_config.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
1414
from __future__ import annotations
1515

16-
from typing import Sequence, Union
16+
from typing import Any, Dict, List, Optional, Sequence, Union
1717

1818
from camel.configs.base_config import BaseConfig
1919
from camel.types import NOT_GIVEN, NotGiven
@@ -56,10 +56,11 @@ class SGLangConfig(BaseConfig):
5656
in the chat completion. The total length of input tokens and
5757
generated tokens is limited by the model's context length.
5858
(default: :obj:`None`)
59-
tools (list[FunctionTool], optional): A list of tools the model may
60-
call. Currently, only functions are supported as a tool. Use this
61-
to provide a list of functions the model may generate JSON inputs
62-
for. A max of 128 functions are supported.
59+
tools (list[Dict[str, Any]], optional): A list of tool definitions
60+
that the model can dynamically invoke. Each tool should be
61+
defined as a dictionary following OpenAI's function calling
62+
specification format. For more details, refer to the OpenAI
63+
documentation.
6364
"""
6465

6566
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
@@ -70,6 +71,7 @@ class SGLangConfig(BaseConfig):
7071
presence_penalty: float = 0.0
7172
stream: bool = False
7273
max_tokens: Union[int, NotGiven] = NOT_GIVEN
74+
tools: Optional[Union[List[Dict[str, Any]]]] = None
7375

7476

7577
SGLANG_API_PARAMS = {param for param in SGLangConfig.model_fields.keys()}

camel/models/sglang_model.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,16 @@ def __init__(
9797
def _start_server(self) -> None:
9898
try:
9999
if not self._url:
100+
tool_call_flag = self.model_config_dict.get("tools")
101+
tool_call_arg = (
102+
f"--tool-call-parser {self._api_key} "
103+
if tool_call_flag
104+
else ""
105+
)
100106
cmd = (
101107
f"python -m sglang.launch_server "
102108
f"--model-path {self.model_type} "
109+
f"{tool_call_arg}"
103110
f"--port 30000 "
104111
f"--host 0.0.0.0"
105112
)
@@ -265,6 +272,19 @@ def stream(self) -> bool:
265272
"""
266273
return self.model_config_dict.get('stream', False)
267274

275+
def __del__(self):
276+
r"""Properly clean up resources when the model is destroyed."""
277+
self.cleanup()
278+
279+
def cleanup(self):
280+
r"""Terminate the server process and clean up resources."""
281+
with self._lock:
282+
if self.server_process:
283+
_terminate_process(self.server_process)
284+
self.server_process = None
285+
self._client = None
286+
logging.info("Server process terminated during cleanup.")
287+
268288

269289
# Below are helper functions from sglang.utils
270290
def _terminate_process(process):
@@ -326,21 +346,25 @@ def _execute_shell_command(command: str) -> subprocess.Popen:
326346
return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT)
327347

328348

329-
def _wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
349+
def _wait_for_server(base_url: str, timeout: Optional[int] = 30) -> None:
330350
r"""Wait for the server to be ready by polling the /v1/models endpoint.
331351
332352
Args:
333353
base_url: The base URL of the server
334-
timeout: Maximum time to wait in seconds. None means wait forever.
354+
timeout: Maximum time to wait in seconds. Default is 30 seconds.
335355
"""
336356
import requests
337357

358+
# Set a default value if timeout is None
359+
actual_timeout = 30 if timeout is None else timeout
360+
338361
start_time = time.time()
339362
while True:
340363
try:
341364
response = requests.get(
342365
f"{base_url}/v1/models",
343366
headers={"Authorization": "Bearer None"},
367+
timeout=5, # Add a timeout for the request itself
344368
)
345369
if response.status_code == 200:
346370
time.sleep(5)
@@ -356,9 +380,15 @@ def _wait_for_server(base_url: str, timeout: Optional[int] = None) -> None:
356380
)
357381
break
358382

359-
if timeout and time.time() - start_time > timeout:
383+
if time.time() - start_time > actual_timeout:
384+
raise TimeoutError(
385+
f"Server did not become ready within "
386+
f"{actual_timeout} seconds"
387+
)
388+
except (requests.exceptions.RequestException, TimeoutError) as e:
389+
if time.time() - start_time > actual_timeout:
360390
raise TimeoutError(
361-
"Server did not become ready within timeout period"
391+
f"Server did not become ready within "
392+
f"{actual_timeout} seconds: {e}"
362393
)
363-
except requests.exceptions.RequestException:
364394
time.sleep(1)

examples/models/sglang_model_example.py

+79-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,84 @@
5151

5252
"""
5353
===============================================================================
54-
CAMEL AI ReferentialAction
54+
Hello CAMEL AI. How can I assist you today?
55+
===============================================================================
56+
"""
57+
58+
weather_tool = [
59+
{
60+
"type": "function",
61+
"function": {
62+
"name": "get_current_weather",
63+
"description": "Get the current weather in a given location",
64+
"parameters": {
65+
"type": "object",
66+
"properties": {
67+
"city": {
68+
"type": "string",
69+
"description": "The city to find the weather for,\n"
70+
"e.g. 'San Francisco'",
71+
},
72+
"state": {
73+
"type": "string",
74+
"description": "The two-letter abbreviation for,\n"
75+
"the state (e.g., 'CA'), e.g. CA for California",
76+
},
77+
"unit": {
78+
"type": "string",
79+
"description": "Temperature unit (celsius/fahrenheit)",
80+
"enum": ["celsius", "fahrenheit"],
81+
},
82+
},
83+
"required": ["city", "state", "unit"],
84+
},
85+
},
86+
}
87+
]
88+
89+
90+
r"""
91+
Note that api_key defines the parser used to interpret responses.
92+
Currently supported parsers include:
93+
llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct,
94+
meta-llama/Llama-3.2-1B-Instruct).
95+
mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3,
96+
mistralai/Mistral-Nemo-Instruct-2407,
97+
mistralai/ Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).
98+
qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct).
99+
"""
100+
sglang_model_with_tool = ModelFactory.create(
101+
model_platform=ModelPlatformType.SGLANG,
102+
model_type="meta-llama/Llama-3.2-1B-Instruct",
103+
model_config_dict={"temperature": 0.0, "tools": weather_tool},
104+
api_key="llama3",
105+
)
106+
107+
assistant_sys_msg = (
108+
"You are a helpful assistant.\n"
109+
"Use the get_current_weather tool when asked about weather."
110+
)
111+
agent_with_tool = ChatAgent(
112+
assistant_sys_msg,
113+
model=sglang_model_with_tool,
114+
token_limit=4096,
115+
external_tools=weather_tool,
116+
)
117+
user_msg = "What's the weather in Boston today?"
118+
119+
assistant_response = agent_with_tool.step(user_msg)
120+
external_tool_call = assistant_response.info.get('external_tool_call_request')
121+
if external_tool_call:
122+
print(f"Detected external tool call: {external_tool_call.tool_name}")
123+
print(f"Arguments: {external_tool_call.args}")
124+
print(f"Tool Call ID: {external_tool_call.tool_call_id}")
125+
else:
126+
print("No external tool call detected")
127+
128+
"""
129+
===============================================================================
130+
Detected external tool call: get_current_weather
131+
Arguments: {'city': 'Boston', 'state': 'MA', 'unit': 'celsius'}
132+
Tool Call ID: 0
55133
===============================================================================
56134
"""

test/models/test_sglang_model.py

+117-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# limitations under the License.
1313
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
1414
import re
15+
from unittest.mock import MagicMock, patch
1516

1617
import pytest
1718

@@ -21,6 +22,47 @@
2122
from camel.utils import OpenAITokenCounter
2223

2324

25+
@pytest.fixture
26+
def sglang_model_cleanup():
27+
r"""Fixture to ensure SGLang model is cleaned up after each test."""
28+
models = []
29+
30+
# Mock the server-related functions to avoid actual server startup
31+
with (
32+
patch(
33+
'camel.models.sglang_model._execute_shell_command'
34+
) as mock_execute,
35+
patch('camel.models.sglang_model._wait_for_server') as mock_wait,
36+
patch('camel.models.sglang_model.OpenAI') as mock_client,
37+
patch('camel.models.sglang_model.AsyncOpenAI') as mock_async_client,
38+
):
39+
# Configure mocks
40+
mock_execute.return_value = MagicMock()
41+
mock_wait.return_value = None
42+
mock_client.return_value = MagicMock()
43+
mock_async_client.return_value = MagicMock()
44+
45+
def _create_model(
46+
model_type, model_config_dict=None, api_key="sglang"
47+
):
48+
model = SGLangModel(model_type, model_config_dict, api_key=api_key)
49+
# Set up the model to use our mocks
50+
model._url = "http://mock-server:30000/v1"
51+
model._client = mock_client.return_value
52+
model._async_client = mock_async_client.return_value
53+
models.append(model)
54+
return model
55+
56+
yield _create_model
57+
58+
# Clean up all models after test
59+
for model in models:
60+
try:
61+
model.cleanup()
62+
except Exception as e:
63+
print(f"Error during cleanup: {e}")
64+
65+
2466
@pytest.mark.model_backend
2567
@pytest.mark.parametrize(
2668
"model_type",
@@ -31,8 +73,8 @@
3173
ModelType.GPT_4O_MINI,
3274
],
3375
)
34-
def test_sglang_model(model_type: ModelType):
35-
model = SGLangModel(model_type, api_key="sglang")
76+
def test_sglang_model(model_type: ModelType, sglang_model_cleanup):
77+
model = sglang_model_cleanup(model_type)
3678
assert model.model_type == model_type
3779
assert model.model_config_dict == SGLangConfig().as_dict()
3880
assert isinstance(model.token_counter, OpenAITokenCounter)
@@ -41,7 +83,7 @@ def test_sglang_model(model_type: ModelType):
4183

4284

4385
@pytest.mark.model_backend
44-
def test_sglang_model_unexpected_argument():
86+
def test_sglang_model_unexpected_argument(sglang_model_cleanup):
4587
model_type = ModelType.GPT_4
4688
model_config_dict = {"model_path": "vicuna-7b-v1.5"}
4789

@@ -54,4 +96,75 @@ def test_sglang_model_unexpected_argument():
5496
)
5597
),
5698
):
57-
_ = SGLangModel(model_type, model_config_dict, api_key="sglang")
99+
_ = sglang_model_cleanup(model_type, model_config_dict)
100+
101+
102+
@pytest.mark.model_backend
103+
def test_sglang_function_call(sglang_model_cleanup):
104+
test_tool = {
105+
"type": "function",
106+
"function": {
107+
"name": "test_tool",
108+
"description": "Test function",
109+
"parameters": {"type": "object", "properties": {}},
110+
},
111+
}
112+
113+
model = sglang_model_cleanup(
114+
ModelType.GPT_4,
115+
model_config_dict={"tools": [test_tool]},
116+
)
117+
118+
# Create a mock response object
119+
from camel.types import (
120+
ChatCompletion,
121+
ChatCompletionMessage,
122+
ChatCompletionMessageToolCall,
123+
Choice,
124+
CompletionUsage,
125+
)
126+
127+
# create mock response
128+
mock_response = ChatCompletion(
129+
id="mock_id",
130+
object="chat.completion",
131+
created=1234567890,
132+
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
133+
choices=[
134+
Choice(
135+
index=0,
136+
message=ChatCompletionMessage(
137+
role="assistant",
138+
content=None,
139+
tool_calls=[
140+
ChatCompletionMessageToolCall(
141+
id="0",
142+
type="function",
143+
function={"name": "test_tool", "arguments": "{}"},
144+
)
145+
],
146+
),
147+
finish_reason="tool_calls",
148+
)
149+
],
150+
usage=CompletionUsage(
151+
prompt_tokens=10,
152+
completion_tokens=20,
153+
total_tokens=30,
154+
),
155+
)
156+
157+
# Patch the run method to return our mock response
158+
with patch.object(model, '_run', return_value=mock_response):
159+
messages = [
160+
{
161+
"role": "user",
162+
"content": "Use test_tool and respond with result",
163+
}
164+
]
165+
166+
response = model.run(messages=messages)
167+
168+
assert len(response.choices[0].message.tool_calls) > 0
169+
tool_call = response.choices[0].message.tool_calls[0]
170+
assert tool_call.function.name == "test_tool"

0 commit comments

Comments
 (0)