Skip to content

Commit

Permalink
feat(ali provider): Add support for OpenAI-compatible API mode (#490)
Browse files Browse the repository at this point in the history
- Introduce new plugin option to use OpenAI-compatible API for Alibaba Cloud
- Modify AliProvider to support direct OpenAI API calls when enabled
- Update base configuration to use compatible mode endpoint
- Add stream handling with JSON escaping for OpenAI-style responses
  • Loading branch information
MartialBE committed Feb 11, 2025
1 parent c701ffb commit 4307251
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 12 deletions.
39 changes: 33 additions & 6 deletions providers/ali/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,56 @@ import (
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/providers/openai"
"one-api/types"
)

// 定义供应商工厂
type AliProviderFactory struct{}

type AliProvider struct {
base.BaseProvider
openai.OpenAIProvider

UseOpenaiAPI bool
}

// 创建 AliProvider
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
func (f AliProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
useOpenaiAPI := false

if channel.Plugin != nil {
plugin := channel.Plugin.Data()
if pOpenAI, ok := plugin["use_openai_api"]; ok {
if enable, ok := pOpenAI["enable"].(bool); ok && enable {
useOpenaiAPI = true
}
}
}

return &AliProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
OpenAIProvider: openai.OpenAIProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(useOpenaiAPI),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
},
StreamEscapeJSON: true,
SupportStreamOptions: true,
},
UseOpenaiAPI: useOpenaiAPI,
}
}

func getConfig() base.ProviderConfig {
func getConfig(useOpenaiAPI bool) base.ProviderConfig {
if useOpenaiAPI {
return base.ProviderConfig{
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
}
}

return base.ProviderConfig{
BaseURL: "https://dashscope.aliyuncs.com",
ChatCompletions: "/api/v1/services/aigc/text-generation/generation",
Expand Down
8 changes: 8 additions & 0 deletions providers/ali/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ type aliStreamHandler struct {
}

func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
if p.UseOpenaiAPI {
return p.OpenAIProvider.CreateChatCompletion(request)
}

req, errWithCode := p.getAliChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
Expand All @@ -35,6 +39,10 @@ func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest)
}

func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
if p.UseOpenaiAPI {
return p.OpenAIProvider.CreateChatCompletionStream(request)
}

req, errWithCode := p.getAliChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
Expand Down
1 change: 1 addition & 0 deletions providers/openai/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type OpenAIProvider struct {
IsAzure bool
BalanceAction bool
SupportStreamOptions bool
StreamEscapeJSON bool
}

// 创建 OpenAIProvider
Expand Down
20 changes: 14 additions & 6 deletions providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import (
)

type OpenAIStreamHandler struct {
Usage *types.Usage
ModelName string
isAzure bool
Usage *types.Usage
ModelName string
isAzure bool
EscapeJSON bool
}

func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
Expand Down Expand Up @@ -94,9 +95,10 @@ func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletio
}

chatHandler := OpenAIStreamHandler{
Usage: p.Usage,
ModelName: request.Model,
isAzure: p.IsAzure,
Usage: p.Usage,
ModelName: request.Model,
isAzure: p.IsAzure,
EscapeJSON: p.StreamEscapeJSON,
}

return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
Expand Down Expand Up @@ -157,5 +159,11 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan s
}
}

if h.EscapeJSON {
if data, err := json.Marshal(openaiResponse.ChatCompletionStreamResponse); err == nil {
dataChan <- string(data)
return
}
}
dataChan <- string(*rawLine)
}
12 changes: 12 additions & 0 deletions web/src/views/Channel/type/Plugin.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@
"required": true
}
}
},
"use_openai_api": {
"name": "使用OpenAI API",
"description": "使用OpenAI API",
"params": {
"enable": {
"name": "启用",
"description": "是否启用使用OpenAI API, 开启用直接使用ali官方兼容OpenAI的API,不再做类型转换, 且上述插件无效",
"type": "bool",
"required": true
}
}
}
},
"24": {
Expand Down

0 comments on commit 4307251

Please sign in to comment.