Skip to content

[Upstream] fix: support aws cross region inferences #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 18, 2025
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ COPY --from=builder2 /build/one-api /

EXPOSE 3000
WORKDIR /data
ENTRYPOINT ["/one-api"]
ENTRYPOINT ["/one-api"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ graph LR
+ 例子:`NODE_TYPE=slave`
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+例子:`CHANNEL_TEST_FREQUENCY=1440`
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
Expand Down
3 changes: 3 additions & 0 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,6 @@ var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)

var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
var TestPrompt = env.String("TEST_PROMPT", "Output only your specific model name with no additional text.")

// OpenrouterProviderSort is used to determine the order of the providers in the openrouter
var OpenrouterProviderSort = env.String("OPENROUTER_PROVIDER_SORT", "")
7 changes: 5 additions & 2 deletions common/conv/any.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package conv

func AsString(v any) string {
str, _ := v.(string)
return str
if str, ok := v.(string); ok {
return str
}

return ""
}
6 changes: 3 additions & 3 deletions relay/adaptor/anthropic/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me

// https://x.com/alexalbert__/status/1812921642143900036
// claude-3-5-sonnet can support 8k context
if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") {
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
if strings.HasPrefix(meta.ActualModelName, "claude-3-7-sonnet") {
req.Header.Set("anthropic-beta", "output-128k-2025-02-19")
}

return nil
Expand All @@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
return ConvertRequest(c, *request)
}

func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
Expand Down
3 changes: 2 additions & 1 deletion relay/adaptor/anthropic/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package anthropic
var ModelList = []string{
"claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-5-haiku-latest",
"claude-3-5-haiku-20241022",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-5-sonnet-latest",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest",
Expand Down
83 changes: 77 additions & 6 deletions relay/adaptor/anthropic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@ package anthropic

import (
"bufio"
"context"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"math"
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
Expand All @@ -36,7 +39,16 @@ func stopReasonClaude2OpenAI(reason *string) string {
}
}

func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
// isModelSupportThinking is used to check if the model supports extended thinking
func isModelSupportThinking(model string) bool {
if strings.Contains(model, "claude-3-7-sonnet") {
return true
}

return false
}

func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools))

for _, tool := range textRequest.Tools {
Expand All @@ -61,7 +73,27 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
TopK: textRequest.TopK,
Stream: textRequest.Stream,
Tools: claudeTools,
Thinking: textRequest.Thinking,
}

if isModelSupportThinking(textRequest.Model) &&
c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil {
claudeRequest.Thinking = &model.Thinking{
Type: "enabled",
BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))),
}
}

if isModelSupportThinking(textRequest.Model) &&
claudeRequest.Thinking != nil {
if claudeRequest.MaxTokens <= 1024 {
return nil, errors.New("max_tokens must be greater than 1024 when using extended thinking")
}

// top_p must be nil when using extended thinking
claudeRequest.TopP = nil
}

if len(claudeTools) > 0 {
claudeToolChoice := struct {
Type string `json:"type"`
Expand Down Expand Up @@ -142,13 +174,14 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeMessage.Content = contents
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
}
return &claudeRequest
return &claudeRequest, nil
}

// https://docs.anthropic.com/claude/reference/messages-streaming
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var reasoningText string
var stopReason string
tools := make([]model.Tool, 0)

Expand All @@ -158,6 +191,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
case "content_block_start":
if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text
if claudeResponse.ContentBlock.Thinking != nil {
reasoningText = *claudeResponse.ContentBlock.Thinking
}

if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, model.Tool{
Id: claudeResponse.ContentBlock.Id,
Expand All @@ -172,6 +209,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
case "content_block_delta":
if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text
if claudeResponse.Delta.Thinking != nil {
reasoningText = *claudeResponse.Delta.Thinking
}

if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, model.Tool{
Function: model.Function{
Expand All @@ -189,9 +230,20 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
stopReason = *claudeResponse.Delta.StopReason
}
case "thinking_delta":
if claudeResponse.Delta != nil && claudeResponse.Delta.Thinking != nil {
reasoningText = *claudeResponse.Delta.Thinking
}
case "ping",
"message_stop",
"content_block_stop":
default:
logger.SysErrorf("unknown stream response type %q", claudeResponse.Type)
}

var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
choice.Delta.Reasoning = &reasoningText
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
Expand All @@ -209,11 +261,23 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo

func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
var reasoningText string

tools := make([]model.Tool, 0)
for _, v := range claudeResponse.Content {
switch v.Type {
case "thinking":
if v.Thinking != nil {
reasoningText += *v.Thinking
} else {
logger.Errorf(context.Background(), "thinking is nil in response")
}
case "text":
responseText += v.Text
default:
logger.Warnf(context.Background(), "unknown response type %q", v.Type)
}

if v.Type == "tool_use" {
args, _ := json.Marshal(v.Input)
tools = append(tools, model.Tool{
Expand All @@ -226,11 +290,13 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
})
}
}

choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: responseText,
Reasoning: &reasoningText,
Name: nil,
ToolCalls: tools,
},
Expand Down Expand Up @@ -277,6 +343,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
data = strings.TrimPrefix(data, "data:")
data = strings.TrimSpace(data)

logger.Debugf(c.Request.Context(), "stream <- %q\n", data)

var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
Expand Down Expand Up @@ -344,6 +412,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}

logger.Debugf(c.Request.Context(), "response <- %s\n", string(responseBody))

var claudeResponse Response
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions relay/adaptor/anthropic/model.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package anthropic

import "github.com/songquanpeng/one-api/relay/model"

// https://docs.anthropic.com/claude/reference/messages_post

type Metadata struct {
Expand All @@ -22,6 +24,9 @@ type Content struct {
Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#implementing-extended-thinking
Thinking *string `json:"thinking,omitempty"`
Signature *string `json:"signature,omitempty"`
}

type Message struct {
Expand Down Expand Up @@ -54,6 +59,7 @@ type Request struct {
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
//Metadata `json:"metadata,omitempty"`
Thinking *model.Thinking `json:"thinking,omitempty"`
}

type Usage struct {
Expand Down Expand Up @@ -84,6 +90,8 @@ type Delta struct {
PartialJson string `json:"partial_json,omitempty"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Thinking *string `json:"thinking,omitempty"`
Signature *string `json:"signature,omitempty"`
}

type StreamResponse struct {
Expand Down
6 changes: 5 additions & 1 deletion relay/adaptor/aws/claude/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}

claudeReq := anthropic.ConvertRequest(*request)
claudeReq, err := anthropic.ConvertRequest(c, *request)
if err != nil {
return nil, errors.Wrap(err, "convert request")
}

c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil
Expand Down
10 changes: 6 additions & 4 deletions relay/adaptor/aws/claude/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ func awsModelID(requestModel string) (string, error) {
}

func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}

awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
Expand Down Expand Up @@ -102,13 +103,14 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*

func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}

awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
Expand Down
6 changes: 5 additions & 1 deletion relay/adaptor/aws/claude/model.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package aws

import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
import (
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/model"
)

// Request is the request to AWS Claude
//
Expand All @@ -17,4 +20,5 @@ type Request struct {
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *model.Thinking `json:"thinking,omitempty"`
}
10 changes: 6 additions & 4 deletions relay/adaptor/aws/llama3/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
}

func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}

awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
Expand Down Expand Up @@ -140,13 +141,14 @@ func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse {

func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel))
awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}

awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
Expand Down
Loading