1
0

Compare commits

...

3 Commits

Author SHA1 Message Date
aea360bce8 chore: 更新 openspec 配置,补充项目上下文和规则
- backend 上下文:添加 README.md 说明和公共组件使用优先级
- design 规则:要求技术方案体现在设计文档中
- task 规则:一行一个任务,禁止跨行
2026-04-20 16:43:29 +08:00
d92db73937 refactor: 后端代码质量优化 - 复用公共库、使用标准库、类型安全错误判断
## 高优先级修复
- stats_service_impl: 使用 strings.SplitN 替代错误的索引分割
- provider_handler: 使用 errors.Is(err, gorm.ErrDuplicatedKey) 替代字符串匹配
- client: 重写 isNetworkError 使用 errors.As/Is 类型安全判断
- proxy_handler: 使用 encoding/json 标准库解析 JSON(extractModelName、isStreamRequest)

## 中优先级修复
- stats_handler: 添加 parseDateParam 辅助函数消除重复日期解析
- pkg/errors: 新增 ErrRequestCreate/Send/ResponseRead 错误类型和 WithCause 方法
- client: 使用结构化错误替代 fmt.Errorf
- ConversionEngine: logger 依赖注入,替换所有 zap.L() 调用

## 低优先级修复
- encoder: 删除 joinStrings,使用 strings.Join
- adapter: 删除 modelInfoRegex 正则,使用 isModelInfoPath 字符串函数

## 文档更新
- README.md: 添加公共库使用指南和编码规范章节
- specs: 同步 delta specs 到 main specs(error-handling、structured-logging、request-validation)

## 归档
- openspec/changes/archive/2026-04-20-refactor-backend-code-quality/
2026-04-20 16:42:48 +08:00
bc1ee612d9 refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
- 新增 ConversionEngine 核心引擎,支持 OpenAI 和 Anthropic 协议转换
- 添加 stream decoder/encoder 实现
- 更新 provider client 支持新引擎
- 补充单元测试和集成测试
- 更新 specs 文档
2026-04-20 13:02:28 +08:00
66 changed files with 12590 additions and 1178 deletions

View File

@@ -329,3 +329,61 @@ make lint
### 环境要求
- Go 1.26 或更高版本
## 公共库使用指南
### pkg/errors — 结构化错误
使用预定义的错误类型,配合 `errors.Is` / `errors.As` 判断错误:
```go
import (
"errors"
pkgErrors "nex/backend/pkg/errors"
)
// 使用预定义错误
return pkgErrors.ErrRequestSend.WithCause(err)
// 判断错误类型
var appErr *pkgErrors.AppError
if errors.As(err, &appErr) {
// appErr.Code, appErr.HTTPStatus, appErr.Message
}
```
可用函数:`NewAppError``Wrap``WithContext``WithMessage``AsAppError`
预定义错误:`ErrModelNotFound``ErrProviderNotFound``ErrInvalidRequest``ErrRequestCreate``ErrRequestSend``ErrResponseRead`
### pkg/logger — 日志系统
使用依赖注入模式,构造函数接受 `*zap.Logger` 参数nil 时回退到 `zap.L()`
```go
func NewMyService(repo Repository, logger *zap.Logger) *MyService {
if logger == nil {
logger = zap.L()
}
return &MyService{repo: repo, logger: logger}
}
```
禁止直接在业务代码中使用 `zap.L()` 全局 logger应通过构造函数注入。
### pkg/validator — 请求验证
```go
import "nex/backend/pkg/validator"
v := validator.Get()
err := v.Validate(myStruct)
```
## 编码规范
- **JSON 解析**:使用 `encoding/json` 标准库(`json.Unmarshal` / `json.Marshal`),不手动扫描字节
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(`strings.Contains(err.Error(), ...)`
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
- **字符串分割**:使用 `strings.SplitN(key, "/", 2)` 等精确分割,不使用索引切片

View File

@@ -81,7 +81,7 @@ func main() {
if err := registry.Register(anthropic.NewAdapter()); err != nil {
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
}
engine := conversion.NewConversionEngine(registry)
engine := conversion.NewConversionEngine(registry, zapLogger)
// 7. 初始化 provider client
providerClient := provider.NewClient()

View File

@@ -2,7 +2,6 @@ package anthropic
import (
"encoding/json"
"regexp"
"strings"
"nex/backend/internal/conversion"
@@ -17,8 +16,6 @@ func NewAdapter() *Adapter {
return &Adapter{}
}
var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`)
// ProtocolName 返回协议名称
func (a *Adapter) ProtocolName() string { return "anthropic" }
@@ -35,13 +32,22 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
return conversion.InterfaceTypeChat
case nativePath == "/v1/models":
return conversion.InterfaceTypeModels
case modelInfoRegex.MatchString(nativePath):
case isModelInfoPath(nativePath):
return conversion.InterfaceTypeModelInfo
default:
return conversion.InterfaceTypePassthrough
}
}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}
func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/v1/models/") {
return false
}
suffix := path[len("/v1/models/"):]
return suffix != "" && !strings.Contains(suffix, "/")
}
// BuildUrl 根据接口类型构建 URL
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
switch interfaceType {

View File

@@ -251,7 +251,7 @@ func (d *StreamDecoder) processMessageDelta(data []byte) []canonical.CanonicalSt
}
if d.accumulatedUsage != nil {
d.accumulatedUsage.OutputTokens += raw.Usage.OutputTokens
d.accumulatedUsage.OutputTokens = raw.Usage.OutputTokens
}
return []canonical.CanonicalStreamEvent{

View File

@@ -272,3 +272,218 @@ func TestStreamDecoder_RedactedDeltaSuppressed(t *testing.T) {
events := d.ProcessChunk(raw)
assert.Empty(t, events)
}
func TestStreamDecoder_ServerToolUse_Suppressed(t *testing.T) {
d := NewStreamDecoder()
payload := map[string]any{
"type": "content_block_start",
"index": 2,
"content_block": map[string]any{
"type": "server_tool_use",
"id": "server_tool_1",
"name": "web_search",
},
}
raw := makeAnthropicEvent("content_block_start", payload)
events := d.ProcessChunk(raw)
assert.Empty(t, events)
assert.True(t, d.redactedBlocks[2])
}
func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
d := NewStreamDecoder()
payload := map[string]any{
"type": "content_block_start",
"index": 3,
"content_block": map[string]any{
"type": "web_search_tool_result",
"tool_use_id": "search_1",
},
}
raw := makeAnthropicEvent("content_block_start", payload)
events := d.ProcessChunk(raw)
assert.Empty(t, events)
assert.True(t, d.redactedBlocks[3])
}
func TestStreamDecoder_CodeExecutionToolResult_Suppressed(t *testing.T) {
d := NewStreamDecoder()
payload := map[string]any{
"type": "content_block_start",
"index": 4,
"content_block": map[string]any{
"type": "code_execution_tool_result",
},
}
raw := makeAnthropicEvent("content_block_start", payload)
events := d.ProcessChunk(raw)
assert.Empty(t, events)
assert.True(t, d.redactedBlocks[4])
}
func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
d := NewStreamDecoder()
payload := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{
"type": "citations_delta",
"citation": map[string]any{"title": "ref1"},
},
}
raw := makeAnthropicEvent("content_block_delta", payload)
events := d.ProcessChunk(raw)
assert.Empty(t, events)
}
func TestStreamDecoder_SignatureDelta_Discarded(t *testing.T) {
d := NewStreamDecoder()
payload := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{
"type": "signature_delta",
"signature": "sig_123",
},
}
raw := makeAnthropicEvent("content_block_delta", payload)
events := d.ProcessChunk(raw)
assert.Empty(t, events)
}
func TestStreamDecoder_UnknownEventType(t *testing.T) {
d := NewStreamDecoder()
raw := makeAnthropicEvent("unknown_event", map[string]any{"type": "unknown_event"})
events := d.ProcessChunk(raw)
assert.Empty(t, events)
}
func TestStreamDecoder_InvalidJSON(t *testing.T) {
d := NewStreamDecoder()
raw := []byte("event: message_start\ndata: {invalid}\n\n")
events := d.ProcessChunk(raw)
assert.Empty(t, events)
}
func TestStreamDecoder_MultipleEventsInSingleChunk(t *testing.T) {
d := NewStreamDecoder()
startPayload := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": "msg_multi",
"model": "claude-3",
},
}
deltaPayload := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{
"type": "text_delta",
"text": "Hello",
},
}
stopPayload := map[string]any{"type": "message_stop"}
var raw []byte
raw = append(raw, makeAnthropicEvent("message_start", startPayload)...)
raw = append(raw, makeAnthropicEvent("content_block_delta", deltaPayload)...)
raw = append(raw, makeAnthropicEvent("message_stop", stopPayload)...)
events := d.ProcessChunk(raw)
require.Len(t, events, 3)
assert.Equal(t, canonical.EventMessageStart, events[0].Type)
assert.Equal(t, canonical.EventContentBlockDelta, events[1].Type)
assert.Equal(t, canonical.EventMessageStop, events[2].Type)
}
func TestStreamDecoder_ErrorInvalidJSON(t *testing.T) {
d := NewStreamDecoder()
raw := []byte("event: error\ndata: {invalid}\n\n")
events := d.ProcessChunk(raw)
require.Len(t, events, 1)
assert.Equal(t, canonical.EventError, events[0].Type)
assert.Contains(t, events[0].Error.Message, "解析错误事件失败")
}
func TestStreamDecoder_MessageStartWithUsage(t *testing.T) {
d := NewStreamDecoder()
payload := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": "msg_usage",
"model": "claude-3",
"usage": map[string]any{"input_tokens": 25, "output_tokens": 0},
},
}
raw := makeAnthropicEvent("message_start", payload)
events := d.ProcessChunk(raw)
require.Len(t, events, 1)
assert.Equal(t, canonical.EventMessageStart, events[0].Type)
require.NotNil(t, events[0].Message.Usage)
assert.Equal(t, 25, events[0].Message.Usage.InputTokens)
}
func TestStreamDecoder_ThinkingBlockStart(t *testing.T) {
d := NewStreamDecoder()
payload := map[string]any{
"type": "content_block_start",
"index": 0,
"content_block": map[string]any{
"type": "thinking",
"thinking": "",
},
}
raw := makeAnthropicEvent("content_block_start", payload)
events := d.ProcessChunk(raw)
require.Len(t, events, 1)
assert.Equal(t, canonical.EventContentBlockStart, events[0].Type)
require.NotNil(t, events[0].ContentBlock)
assert.Equal(t, "thinking", events[0].ContentBlock.Type)
}
func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
d := NewStreamDecoder()
startPayload := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": "msg_usage_test",
"model": "claude-3",
"usage": map[string]any{"input_tokens": 10, "output_tokens": 0},
},
}
deltaPayload1 := map[string]any{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 25},
}
d.ProcessChunk(makeAnthropicEvent("message_start", startPayload))
events := d.ProcessChunk(makeAnthropicEvent("message_delta", deltaPayload1))
require.Len(t, events, 1)
assert.Equal(t, 25, events[0].Usage.OutputTokens)
deltaPayload2 := map[string]any{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 30},
}
events = d.ProcessChunk(makeAnthropicEvent("message_delta", deltaPayload2))
require.Len(t, events, 1)
assert.Equal(t, 30, events[0].Usage.OutputTokens, "output_tokens should be replaced, not accumulated")
assert.Equal(t, 30, d.accumulatedUsage.OutputTokens, "accumulated usage should match last value")
}

View File

@@ -0,0 +1,477 @@
package anthropic
import (
"encoding/json"
"testing"
"time"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDecodeTools(t *testing.T) {
body := []byte(`{
"model": "claude-3",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "hi"}],
"tools": [
{"name": "search", "description": "Search", "input_schema": {"type":"object"}},
{"name": "calc", "input_schema": {"type":"object"}}
]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Len(t, req.Tools, 2)
assert.Equal(t, "search", req.Tools[0].Name)
assert.Equal(t, "Search", req.Tools[0].Description)
assert.Equal(t, "calc", req.Tools[1].Name)
}
func TestDecodeToolChoice(t *testing.T) {
tests := []struct {
name string
jsonBody string
wantType string
wantName string
}{
{
"auto string",
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":"auto"}`,
"auto", "",
},
{
"none string",
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":"none"}`,
"none", "",
},
{
"any string",
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":"any"}`,
"any", "",
},
{
"tool object",
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":{"type":"tool","name":"search"}}`,
"tool", "search",
},
{
"auto object",
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":{"type":"auto"}}`,
"auto", "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := decodeRequest([]byte(tt.jsonBody))
require.NoError(t, err)
require.NotNil(t, req.ToolChoice)
assert.Equal(t, tt.wantType, req.ToolChoice.Type)
assert.Equal(t, tt.wantName, req.ToolChoice.Name)
})
}
}
func TestDecodeParameters_TopK(t *testing.T) {
topK := 10
body := []byte(`{
"model": "claude-3",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "hi"}],
"top_k": 10,
"stop_sequences": ["STOP"]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
require.NotNil(t, req.Parameters.TopK)
assert.Equal(t, topK, *req.Parameters.TopK)
assert.Equal(t, []string{"STOP"}, req.Parameters.StopSequences)
}
func TestDecodeRequest_MetadataUserID(t *testing.T) {
body := []byte(`{
"model": "claude-3",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "hi"}],
"metadata": {"user_id": "user-123"}
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Equal(t, "user-123", req.UserID)
}
func TestDecodeSystem_Empty(t *testing.T) {
body := []byte(`{
"model": "claude-3",
"max_tokens": 1024,
"system": "",
"messages": [{"role": "user", "content": "hi"}]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Nil(t, req.System)
}
func TestDecodeSystem_Nil(t *testing.T) {
body := []byte(`{
"model": "claude-3",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "hi"}]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Nil(t, req.System)
}
func TestDecodeThinking_WithEffort(t *testing.T) {
body := []byte(`{
"model": "claude-3",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "hi"}],
"thinking": {"type": "enabled", "budget_tokens": 5000},
"output_config": {"effort": "high"}
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
require.NotNil(t, req.Thinking)
assert.Equal(t, "enabled", req.Thinking.Type)
assert.Equal(t, "high", req.Thinking.Effort)
}
func TestDecodeOutputFormat_NilOutputConfig(t *testing.T) {
body := []byte(`{
"model": "claude-3",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "hi"}]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Nil(t, req.OutputFormat)
}
func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
body := []byte(`{
"model": "claude-3",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": [{"type": "tool_use", "id": "t1", "name": "fn", "input": {}}]},
{
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": "t1", "content": "result"}]
}
]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
lastMsg := req.Messages[len(req.Messages)-1]
assert.Equal(t, canonical.RoleTool, lastMsg.Role)
assert.Equal(t, "t1", lastMsg.Content[0].ToolUseID)
}
func TestDecodeContentBlocks_Nil(t *testing.T) {
blocks := decodeContentBlocks(nil)
assert.Len(t, blocks, 1)
assert.Equal(t, "", blocks[0].Text)
}
func TestDecodeContentBlocks_String(t *testing.T) {
blocks := decodeContentBlocks("hello")
assert.Len(t, blocks, 1)
assert.Equal(t, "hello", blocks[0].Text)
}
func TestParseTimestamp(t *testing.T) {
tests := []struct {
name string
input string
want int64
}{
{"valid RFC3339", "2024-01-15T00:00:00Z", time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC).Unix()},
{"empty", "", 0},
{"invalid", "not-a-date", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, parseTimestamp(tt.input))
})
}
}
func TestEncodeToolChoice(t *testing.T) {
tests := []struct {
name string
choice *canonical.ToolChoice
want map[string]any
}{
{"auto", canonical.NewToolChoiceAuto(), map[string]any{"type": "auto"}},
{"none", canonical.NewToolChoiceNone(), map[string]any{"type": "none"}},
{"any", canonical.NewToolChoiceAny(), map[string]any{"type": "any"}},
{"tool", canonical.NewToolChoiceNamed("search"), map[string]any{"type": "tool", "name": "search"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := encodeToolChoice(tt.choice)
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"])
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"])
})
}
}
func TestEncodeThinkingConfig(t *testing.T) {
budget := 5000
tests := []struct {
name string
cfg *canonical.ThinkingConfig
want map[string]any
}{
{"enabled", &canonical.ThinkingConfig{Type: "enabled", BudgetTokens: &budget}, map[string]any{"type": "enabled", "budget_tokens": float64(5000)}},
{"disabled", &canonical.ThinkingConfig{Type: "disabled"}, map[string]any{"type": "disabled"}},
{"adaptive", &canonical.ThinkingConfig{Type: "adaptive"}, map[string]any{"type": "adaptive"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := encodeThinkingConfig(tt.cfg)
assert.Equal(t, tt.want["type"], result["type"])
})
}
}
func TestEncodeRequest_PublicFields(t *testing.T) {
maxTokens := 1024
parallel := false
req := &canonical.CanonicalRequest{
Model: "claude-3",
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
UserID: "user-123",
ParallelToolUse: &parallel,
}
provider := conversion.NewTargetProvider("", "key", "model")
body, err := encodeRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, map[string]any{"user_id": "user-123"}, result["metadata"])
assert.Equal(t, true, result["disable_parallel_tool_use"])
}
func TestEncodeRequest_DefaultMaxTokens(t *testing.T) {
req := &canonical.CanonicalRequest{
Model: "claude-3",
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
}
provider := conversion.NewTargetProvider("", "key", "model")
body, err := encodeRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, float64(4096), result["max_tokens"])
}
func TestEncodeRequest_TopK(t *testing.T) {
maxTokens := 1024
topK := 10
req := &canonical.CanonicalRequest{
Model: "claude-3",
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens, TopK: &topK},
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
}
provider := conversion.NewTargetProvider("", "key", "model")
body, err := encodeRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, float64(10), result["top_k"])
}
func TestEncodeRequest_WithTools(t *testing.T) {
maxTokens := 1024
req := &canonical.CanonicalRequest{
Model: "claude-3",
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
Tools: []canonical.CanonicalTool{
{Name: "search", Description: "Search things", InputSchema: json.RawMessage(`{"type":"object"}`)},
},
ToolChoice: canonical.NewToolChoiceAuto(),
}
provider := conversion.NewTargetProvider("", "key", "model")
body, err := encodeRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
tools := result["tools"].([]any)
assert.Len(t, tools, 1)
tool := tools[0].(map[string]any)
assert.Equal(t, "search", tool["name"])
assert.Equal(t, "Search things", tool["description"])
tc := result["tool_choice"].(map[string]any)
assert.Equal(t, "auto", tc["type"])
}
func TestEncodeRequest_ThinkingWithEffort(t *testing.T) {
maxTokens := 1024
req := &canonical.CanonicalRequest{
Model: "claude-3",
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
Thinking: &canonical.ThinkingConfig{Type: "enabled", Effort: "high"},
}
provider := conversion.NewTargetProvider("", "key", "model")
body, err := encodeRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
oc, ok := result["output_config"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "high", oc["effort"])
}
func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
cacheRead := 30
cacheCreation := 10
sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{
ID: "msg-1",
Model: "claude-3",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr,
Usage: canonical.CanonicalUsage{
InputTokens: 100,
OutputTokens: 50,
CacheReadTokens: &cacheRead,
CacheCreationTokens: &cacheCreation,
},
}
body, err := encodeResponse(resp)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])
}
func TestEncodeResponse_StopReasons(t *testing.T) {
tests := []struct {
name string
stopReason canonical.StopReason
want string
}{
{"end_turn", canonical.StopReasonEndTurn, "end_turn"},
{"max_tokens", canonical.StopReasonMaxTokens, "max_tokens"},
{"tool_use", canonical.StopReasonToolUse, "tool_use"},
{"stop_sequence", canonical.StopReasonStopSequence, "stop_sequence"},
{"refusal", canonical.StopReasonRefusal, "refusal"},
{"content_filter→end_turn", canonical.StopReasonContentFilter, "end_turn"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sr := tt.stopReason
resp := &canonical.CanonicalResponse{
ID: "r1",
Model: "claude-3",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr,
}
body, err := encodeResponse(resp)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, tt.want, result["stop_reason"])
})
}
}
func TestEncodeSystem_SystemBlocks(t *testing.T) {
result := encodeSystem([]canonical.SystemBlock{{Text: "part1"}, {Text: "part2"}})
blocks, ok := result.([]map[string]any)
require.True(t, ok)
assert.Len(t, blocks, 2)
assert.Equal(t, "part1", blocks[0]["text"])
}
func TestEncodeModelInfoResponse(t *testing.T) {
info := &canonical.CanonicalModelInfo{
ID: "claude-3-opus",
Name: "Claude 3 Opus",
Created: time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC).Unix(),
}
body, err := encodeModelInfoResponse(info)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "claude-3-opus", result["id"])
assert.Equal(t, "Claude 3 Opus", result["display_name"])
}
func TestDecodeModelInfoResponse(t *testing.T) {
body := []byte(`{"id":"claude-3-opus","type":"model","display_name":"Claude 3 Opus","created_at":"2024-01-15T00:00:00Z"}`)
info, err := decodeModelInfoResponse(body)
require.NoError(t, err)
assert.Equal(t, "claude-3-opus", info.ID)
assert.Equal(t, "Claude 3 Opus", info.Name)
assert.NotEqual(t, int64(0), info.Created)
}
func TestDecodeResponse_PauseTurn(t *testing.T) {
body := []byte(`{
"id": "msg-1", "type": "message", "role": "assistant", "model": "claude-3",
"content": [{"type": "text", "text": "ok"}],
"stop_reason": "pause_turn",
"usage": {"input_tokens": 1, "output_tokens": 1}
}`)
resp, err := decodeResponse(body)
require.NoError(t, err)
assert.Equal(t, canonical.StopReason("pause_turn"), *resp.StopReason)
}
func TestEncodeResponse_NoStopReason(t *testing.T) {
resp := &canonical.CanonicalResponse{
ID: "msg-1",
Model: "claude-3",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
}
body, err := encodeResponse(resp)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "end_turn", result["stop_reason"])
}
func TestDecodeRequest_MaxTokensZero(t *testing.T) {
body := []byte(`{
"model": "claude-3",
"max_tokens": 0,
"messages": [{"role": "user", "content": "hi"}]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Nil(t, req.Parameters.MaxTokens)
}

View File

@@ -0,0 +1,114 @@
package canonical
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetSystemString(t *testing.T) {
tests := []struct {
name string
system any
want string
}{
{"string", "hello", "hello"},
{"nil", nil, ""},
{"empty string", "", ""},
{"system blocks", []SystemBlock{{Text: "part1"}, {Text: "part2"}}, "part1\n\npart2"},
{"single block", []SystemBlock{{Text: "only"}}, "only"},
{"other type", 123, "123"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &CanonicalRequest{System: tt.system}
assert.Equal(t, tt.want, req.GetSystemString())
})
}
}
func TestSetSystemString(t *testing.T) {
req := &CanonicalRequest{}
req.SetSystemString("hello")
assert.Equal(t, "hello", req.System)
req.SetSystemString("")
assert.Nil(t, req.System)
}
func TestNewTextBlock(t *testing.T) {
b := NewTextBlock("hello")
assert.Equal(t, "text", b.Type)
assert.Equal(t, "hello", b.Text)
}
func TestNewToolUseBlock(t *testing.T) {
input := json.RawMessage(`{"key":"val"}`)
b := NewToolUseBlock("id-1", "tool_name", input)
assert.Equal(t, "tool_use", b.Type)
assert.Equal(t, "id-1", b.ID)
assert.Equal(t, "tool_name", b.Name)
assert.Equal(t, input, b.Input)
}
func TestNewToolResultBlock(t *testing.T) {
b := NewToolResultBlock("tool-1", "result", false)
assert.Equal(t, "tool_result", b.Type)
assert.Equal(t, "tool-1", b.ToolUseID)
assert.NotNil(t, b.IsError)
assert.False(t, *b.IsError)
}
func TestNewThinkingBlock(t *testing.T) {
b := NewThinkingBlock("thought")
assert.Equal(t, "thinking", b.Type)
assert.Equal(t, "thought", b.Thinking)
}
func TestNewToolChoice(t *testing.T) {
assert.Equal(t, &ToolChoice{Type: "auto"}, NewToolChoiceAuto())
assert.Equal(t, &ToolChoice{Type: "none"}, NewToolChoiceNone())
assert.Equal(t, &ToolChoice{Type: "any"}, NewToolChoiceAny())
assert.Equal(t, &ToolChoice{Type: "tool", Name: "fn"}, NewToolChoiceNamed("fn"))
}
func TestCanonicalRequest_RoundTrip(t *testing.T) {
req := &CanonicalRequest{
Model: "gpt-4",
System: "system prompt",
Messages: []CanonicalMessage{{Role: RoleUser, Content: []ContentBlock{NewTextBlock("hi")}}},
Stream: true,
}
data, err := json.Marshal(req)
require.NoError(t, err)
var decoded CanonicalRequest
require.NoError(t, json.Unmarshal(data, &decoded))
assert.Equal(t, "gpt-4", decoded.Model)
assert.Equal(t, "system prompt", decoded.System)
assert.True(t, decoded.Stream)
}
func TestCanonicalResponse_RoundTrip(t *testing.T) {
sr := StopReasonEndTurn
resp := &CanonicalResponse{
ID: "resp-1",
Model: "gpt-4",
Content: []ContentBlock{NewTextBlock("hello")},
StopReason: &sr,
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
}
data, err := json.Marshal(resp)
require.NoError(t, err)
var decoded CanonicalResponse
require.NoError(t, json.Unmarshal(data, &decoded))
assert.Equal(t, "resp-1", decoded.ID)
assert.Equal(t, StopReasonEndTurn, *decoded.StopReason)
}

View File

@@ -28,13 +28,18 @@ type HTTPResponseSpec struct {
type ConversionEngine struct {
registry AdapterRegistry
middlewareChain *MiddlewareChain
logger *zap.Logger
}
// NewConversionEngine 创建转换引擎
func NewConversionEngine(registry AdapterRegistry) *ConversionEngine {
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
if logger == nil {
logger = zap.L()
}
return &ConversionEngine{
registry: registry,
middlewareChain: NewMiddlewareChain(),
logger: logger,
}
}
@@ -251,12 +256,12 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil {
zap.L().Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
encoded, err := clientAdapter.EncodeModelsResponse(models)
if err != nil {
zap.L().Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
return encoded, nil
@@ -265,12 +270,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil {
zap.L().Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil {
zap.L().Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
return encoded, nil
@@ -279,7 +284,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil {
zap.L().Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
return body, nil
}
return providerAdapter.EncodeEmbeddingRequest(req, provider)
@@ -288,7 +293,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil {
zap.L().Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
return clientAdapter.EncodeEmbeddingResponse(resp)
@@ -297,7 +302,7 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil {
zap.L().Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
return body, nil
}
return providerAdapter.EncodeRerankRequest(req, provider)

View File

@@ -0,0 +1,323 @@
package conversion
import (
"encoding/json"
"errors"
"testing"
"nex/backend/internal/conversion/canonical"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConversionError_WithProviderProtocol(t *testing.T) {
err := NewConversionError(ErrorCodeInvalidInput, "test").WithProviderProtocol("anthropic")
assert.Equal(t, "anthropic", err.ProviderProtocol)
}
func TestConversionError_WithInterfaceType(t *testing.T) {
err := NewConversionError(ErrorCodeInvalidInput, "test").WithInterfaceType("CHAT")
assert.Equal(t, "CHAT", err.InterfaceType)
}
func TestConversionError_FullBuilder(t *testing.T) {
err := NewConversionError(ErrorCodeInvalidInput, "bad").
WithClientProtocol("openai").
WithProviderProtocol("anthropic").
WithInterfaceType("CHAT").
WithDetail("field", "model").
WithCause(errors.New("root"))
assert.Equal(t, ErrorCodeInvalidInput, err.Code)
assert.Equal(t, "openai", err.ClientProtocol)
assert.Equal(t, "anthropic", err.ProviderProtocol)
assert.Equal(t, "CHAT", err.InterfaceType)
assert.Equal(t, "model", err.Details["field"])
assert.Equal(t, "root", err.Cause.Error())
}
func TestEngine_Use(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
called := false
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
called = true
return req, nil
}})
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
return &canonical.CanonicalRequest{Model: "test"}, nil
}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
return json.Marshal(req)
}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
}, "client", "provider", NewTargetProvider("https://example.com", "key", "model"))
require.NoError(t, err)
assert.True(t, called)
}
func TestConvertHttpRequest_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
return nil, errors.New("decode failed")
}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
}, "client", "provider", NewTargetProvider("", "", ""))
assert.Error(t, err)
}
func TestConvertHttpRequest_EncodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("client", false))
providerAdapter := newMockAdapter("provider", false)
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
return nil, errors.New("encode failed")
}
_ = engine.RegisterAdapter(providerAdapter)
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
}, "client", "provider", NewTargetProvider("", "", ""))
assert.Error(t, err)
}
func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
return json.Marshal(map[string]string{"id": resp.ID})
}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
return &canonical.CanonicalResponse{ID: "resp-1", Model: "test"}, nil
}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
}, "client", "provider", InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode)
assert.Contains(t, string(result.Body), "resp-1")
}
func TestConvertHttpResponse_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
providerAdapter := newMockAdapter("provider", false)
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
return nil, errors.New("decode error")
}
_ = engine.RegisterAdapter(providerAdapter)
_ = engine.RegisterAdapter(newMockAdapter("client", false))
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat)
assert.Error(t, err)
}
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeEmbeddings
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
return &canonical.CanonicalRequest{Model: "test"}, nil
}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.ifaceType = InterfaceTypeEmbeddings
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
result, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/embeddings", Method: "POST", Body: []byte(`{"model":"text-embedding","input":"hello"}`),
}, "client", "provider", NewTargetProvider("https://example.com", "key", "model"))
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestConvertHttpRequest_RerankInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeRerank
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.ifaceType = InterfaceTypeRerank
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
result, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/rerank", Method: "POST", Body: []byte(`{"model":"rerank","query":"test","documents":["a"]}`),
}, "client", "provider", NewTargetProvider("https://example.com", "key", "model"))
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
}, "client", "provider", InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestConvertHttpResponse_RerankInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
}, "client", "provider", InterfaceTypeRerank)
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeModels
providerAdapter := newMockAdapter("provider", false)
providerAdapter.ifaceType = InterfaceTypeModels
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"object":"list","data":[]}`)
result, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/models", Method: "GET", Body: body,
}, "client", "provider", NewTargetProvider("https://example.com", "key", ""))
require.NoError(t, err)
assert.Equal(t, body, result.Body)
}
func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
}, "client", "provider", InterfaceTypeModels)
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
}, "client", "provider", InterfaceTypeModelInfo)
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestRegistry_ListProtocols(t *testing.T) {
registry := NewMemoryRegistry()
_ = registry.Register(newMockAdapter("openai", true))
_ = registry.Register(newMockAdapter("anthropic", true))
protocols := registry.ListProtocols()
assert.Len(t, protocols, 2)
assert.Contains(t, protocols, "openai")
assert.Contains(t, protocols, "anthropic")
}
func TestRegistry_ConcurrentAccess(t *testing.T) {
registry := NewMemoryRegistry()
done := make(chan bool, 2)
go func() {
for i := 0; i < 100; i++ {
_ = registry.Register(newMockAdapter("proto-"+string(rune(i)), true))
}
done <- true
}()
go func() {
for i := 0; i < 100; i++ {
_, _ = registry.Get("proto-" + string(rune(i)))
}
_ = registry.ListProtocols()
done <- true
}()
<-done
<-done
}
func TestNewConversionContext(t *testing.T) {
ctx := NewConversionContext(InterfaceTypeChat)
assert.NotEmpty(t, ctx.ConversionID)
assert.Equal(t, InterfaceTypeChat, ctx.InterfaceType)
assert.NotNil(t, ctx.Metadata)
}
type testMiddleware struct {
fn func(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error)
}
func (m *testMiddleware) Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
if m.fn != nil {
return m.fn(req, clientProtocol, providerProtocol, ctx)
}
return req, nil
}
func (m *testMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) {
return event, nil
}
var _ = json.Marshal

View File

@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
// mockProtocolAdapter 模拟协议适配器
@@ -170,14 +171,29 @@ func (e *noopStreamEncoder) Flush() [][]byte
func TestNewConversionEngine(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
assert.NotNil(t, engine)
assert.Equal(t, registry, engine.GetRegistry())
}
func TestNewConversionEngine_LoggerInjection(t *testing.T) {
t.Run("nil_logger_uses_global", func(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
assert.NotNil(t, engine.logger)
})
t.Run("custom_logger", func(t *testing.T) {
registry := NewMemoryRegistry()
customLogger := zap.NewNop()
engine := NewConversionEngine(registry, customLogger)
assert.Equal(t, customLogger, engine.logger)
})
}
func TestRegisterAdapter(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
adapter := newMockAdapter("test-proto", true)
err := engine.RegisterAdapter(adapter)
@@ -189,7 +205,7 @@ func TestRegisterAdapter(t *testing.T) {
func TestIsPassthrough_SameProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
adapter := newMockAdapter("openai", true)
_ = engine.RegisterAdapter(adapter)
@@ -198,7 +214,7 @@ func TestIsPassthrough_SameProtocol(t *testing.T) {
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
@@ -207,7 +223,7 @@ func TestIsPassthrough_DifferentProtocol(t *testing.T) {
func TestIsPassthrough_NoPassthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
assert.False(t, engine.IsPassthrough("custom", "custom"))
@@ -215,7 +231,7 @@ func TestIsPassthrough_NoPassthrough(t *testing.T) {
func TestDetectInterfaceType(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
adapter := newMockAdapter("test", true)
adapter.ifaceType = InterfaceTypeChat
_ = engine.RegisterAdapter(adapter)
@@ -227,7 +243,7 @@ func TestDetectInterfaceType(t *testing.T) {
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
assert.Error(t, err)
@@ -235,7 +251,7 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
func TestConvertHttpRequest_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
provider := NewTargetProvider("https://api.openai.com", "sk-test", "gpt-4")
@@ -253,7 +269,7 @@ func TestConvertHttpRequest_Passthrough(t *testing.T) {
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client-proto", false)
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
@@ -285,7 +301,7 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
func TestConvertHttpResponse_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
spec := HTTPResponseSpec{
@@ -301,7 +317,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
func TestCreateStreamConverter_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
converter, err := engine.CreateStreamConverter("openai", "openai")
@@ -312,7 +328,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
func TestCreateStreamConverter_Canonical(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("client", false))
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
@@ -324,7 +340,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
func TestEncodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
@@ -336,7 +352,7 @@ func TestEncodeError(t *testing.T) {
func TestEncodeError_NonExistentProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry)
engine := NewConversionEngine(registry, nil)
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")

View File

@@ -2,7 +2,7 @@ package openai
import (
"encoding/json"
"regexp"
"strings"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
@@ -16,8 +16,6 @@ func NewAdapter() *Adapter {
return &Adapter{}
}
var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`)
// ProtocolName 返回协议名称
func (a *Adapter) ProtocolName() string { return "openai" }
@@ -34,7 +32,7 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
return conversion.InterfaceTypeChat
case nativePath == "/v1/models":
return conversion.InterfaceTypeModels
case modelInfoRegex.MatchString(nativePath):
case isModelInfoPath(nativePath):
return conversion.InterfaceTypeModelInfo
case nativePath == "/v1/embeddings":
return conversion.InterfaceTypeEmbeddings
@@ -45,6 +43,15 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
}
}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}
func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/v1/models/") {
return false
}
suffix := path[len("/v1/models/"):]
return suffix != "" && !strings.Contains(suffix, "/")
}
// BuildUrl 根据接口类型构建 URL
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
switch interfaceType {

View File

@@ -112,6 +112,28 @@ func TestAdapter_SupportsInterface(t *testing.T) {
}
}
func TestIsModelInfoPath(t *testing.T) {
tests := []struct {
name string
path string
expected bool
}{
{"model_info", "/v1/models/gpt-4", true},
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
{"models_list", "/v1/models", false},
{"nested_path", "/v1/models/gpt-4/versions", false},
{"empty_suffix", "/v1/models/", false},
{"unrelated", "/v1/chat/completions", false},
{"partial_prefix", "/v1/model", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
})
}
}
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
a := NewAdapter()
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")

View File

@@ -2,6 +2,7 @@ package openai
import (
"encoding/json"
"strings"
"time"
"nex/backend/internal/conversion"
@@ -89,7 +90,7 @@ func encodeSystemAndMessages(req *canonical.CanonicalRequest) []map[string]any {
for _, b := range v {
parts = append(parts, b.Text)
}
text := joinStrings(parts, "\n\n")
text := strings.Join(parts, "\n\n")
if text != "" {
messages = append(messages, map[string]any{
"role": "system",
@@ -132,7 +133,7 @@ func encodeMessage(msg canonical.CanonicalMessage) []map[string]any {
if len(toolUses) > 0 {
if len(textParts) > 0 {
m["content"] = joinStrings(textParts, "")
m["content"] = strings.Join(textParts, "")
} else {
m["content"] = nil
}
@@ -149,7 +150,7 @@ func encodeMessage(msg canonical.CanonicalMessage) []map[string]any {
}
m["tool_calls"] = tcs
} else if len(textParts) > 0 {
m["content"] = joinStrings(textParts, "")
m["content"] = strings.Join(textParts, "")
} else {
m["content"] = ""
}
@@ -286,7 +287,7 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
message := map[string]any{"role": "assistant"}
if len(toolUses) > 0 {
if len(textParts) > 0 {
message["content"] = joinStrings(textParts, "")
message["content"] = strings.Join(textParts, "")
} else {
message["content"] = nil
}
@@ -303,13 +304,13 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
}
message["tool_calls"] = tcs
} else if len(textParts) > 0 {
message["content"] = joinStrings(textParts, "")
message["content"] = strings.Join(textParts, "")
} else {
message["content"] = ""
}
if len(thinkingParts) > 0 {
message["reasoning_content"] = joinStrings(thinkingParts, "")
message["reasoning_content"] = strings.Join(thinkingParts, "")
}
var finishReason *string
@@ -488,18 +489,6 @@ func encodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, erro
})
}
// joinStrings 拼接字符串切片
func joinStrings(parts []string, sep string) string {
result := ""
for i, p := range parts {
if i > 0 {
result += sep
}
result += p
}
return result
}
// mergeConsecutiveRoles 合并连续同角色消息(拼接内容)
func mergeConsecutiveRoles(messages []map[string]any) []map[string]any {
if len(messages) <= 1 {

View File

@@ -353,3 +353,120 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
}
assert.Equal(t, []string{"你好", "世界"}, deltas)
}
func TestStreamDecoder_UTF8Truncation(t *testing.T) {
d := NewStreamDecoder()
chunk := map[string]any{
"id": "chatcmpl-utf8",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
"delta": map[string]any{"content": "你"},
},
},
}
data, _ := json.Marshal(chunk)
sseData := []byte("data: " + string(data) + "\n\n")
mid := len(sseData) - 5
part1 := sseData[:mid]
part2 := sseData[mid:]
events1 := d.ProcessChunk(part1)
for _, e := range events1 {
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil {
assert.Equal(t, "你", e.Delta.Text)
}
}
events2 := d.ProcessChunk(part2)
_ = events2
}
func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
d := NewStreamDecoder()
idx := 0
chunk1 := map[string]any{
"id": "chatcmpl-tc",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
"delta": map[string]any{
"tool_calls": []any{
map[string]any{
"index": &idx,
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": "",
},
},
},
},
},
},
}
chunk2 := map[string]any{
"id": "chatcmpl-tc",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
"delta": map[string]any{
"tool_calls": []any{
map[string]any{
"index": &idx,
"function": map[string]any{
"arguments": "{\"city\":\"Beijing\"}",
},
},
},
},
},
},
}
events1 := d.ProcessChunk(makeChunkSSE(chunk1))
require.NotEmpty(t, events1)
events2 := d.ProcessChunk(makeChunkSSE(chunk2))
require.NotEmpty(t, events2)
foundInputJSON := false
for _, e := range events2 {
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "input_json_delta" {
foundInputJSON = true
assert.Equal(t, "{\"city\":\"Beijing\"}", e.Delta.PartialJSON)
}
}
assert.True(t, foundInputJSON, "subsequent tool call delta should emit input_json_delta")
}
func TestStreamDecoder_InvalidJSON(t *testing.T) {
d := NewStreamDecoder()
raw := []byte("data: {invalid json}\n\n")
events := d.ProcessChunk(raw)
assert.Nil(t, events)
}
func TestStreamDecoder_NonDataLines(t *testing.T) {
d := NewStreamDecoder()
raw := []byte(": comment line\ndata: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n")
events := d.ProcessChunk(raw)
require.NotEmpty(t, events)
found := false
for _, e := range events {
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil {
found = true
assert.Equal(t, "hi", e.Delta.Text)
}
}
assert.True(t, found)
}

View File

@@ -137,15 +137,10 @@ func (e *StreamEncoder) encodeInputJSONDelta(event canonical.CanonicalStreamEven
}
// 后续 delta仅含 arguments
// 通过 index 查找 tool call
// 使用 canonical 事件中的 index 直接映射到 OpenAI tool_calls index
tcIdx := 0
if event.Index != nil {
for id, idx := range e.toolCallIndexMap {
if idx == tcIdx {
_ = id
break
}
}
tcIdx = *event.Index
}
delta := map[string]any{
"tool_calls": []map[string]any{{

View File

@@ -170,3 +170,116 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
assert.Contains(t, s, "usage")
assert.Contains(t, s, "prompt_tokens")
}
func TestStreamEncoder_InputJSONDelta_SubsequentDelta(t *testing.T) {
e := NewStreamEncoder()
e.EncodeEvent(canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{
Type: "tool_use",
ID: "call_1",
Name: "get_weather",
}))
e.EncodeEvent(canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
Type: string(canonical.DeltaTypeInputJSON),
PartialJSON: "{\"city\":",
}))
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
Type: string(canonical.DeltaTypeInputJSON),
PartialJSON: "\"Beijing\"}",
})
chunks := e.EncodeEvent(event)
require.NotEmpty(t, chunks)
s := string(chunks[0])
assert.Contains(t, s, "tool_calls")
assert.Contains(t, s, "Beijing")
}
func TestStreamEncoder_MessageStart_NilMessage(t *testing.T) {
e := NewStreamEncoder()
event := canonical.CanonicalStreamEvent{Type: canonical.EventMessageStart}
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.Contains(t, s, "chat.completion.chunk")
}
func TestStreamEncoder_UnknownEvent_ReturnsNil(t *testing.T) {
e := NewStreamEncoder()
event := canonical.CanonicalStreamEvent{Type: "unknown_type"}
chunks := e.EncodeEvent(event)
assert.Nil(t, chunks)
}
func TestStreamEncoder_ContentBlockDelta_NilDelta(t *testing.T) {
e := NewStreamEncoder()
event := canonical.CanonicalStreamEvent{Type: canonical.EventContentBlockDelta}
chunks := e.EncodeEvent(event)
assert.Nil(t, chunks)
}
func TestStreamEncoder_MultiToolCall_IndexMapping(t *testing.T) {
e := NewStreamEncoder()
e.EncodeEvent(canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{
Type: "tool_use",
ID: "call_1",
Name: "get_weather",
}))
firstDelta := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
Type: string(canonical.DeltaTypeInputJSON),
PartialJSON: `{"city":"北京"}`,
})
chunks := e.EncodeEvent(firstDelta)
require.NotEmpty(t, chunks)
s := string(chunks[0])
assert.Contains(t, s, `"index":0`)
assert.Contains(t, s, "get_weather")
assert.Contains(t, s, "北京")
e.EncodeEvent(canonical.NewContentBlockStopEvent(0))
e.EncodeEvent(canonical.NewContentBlockStartEvent(1, canonical.StreamContentBlock{
Type: "tool_use",
ID: "call_2",
Name: "get_time",
}))
secondDelta := canonical.NewContentBlockDeltaEvent(1, canonical.StreamDelta{
Type: string(canonical.DeltaTypeInputJSON),
PartialJSON: `{"tz":"Asia/Shanghai"}`,
})
chunks = e.EncodeEvent(secondDelta)
require.NotEmpty(t, chunks)
s = string(chunks[0])
assert.Contains(t, s, `"index":1`)
assert.Contains(t, s, "get_time")
assert.Contains(t, s, "Asia/Shanghai")
subsequentDelta0 := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
Type: string(canonical.DeltaTypeInputJSON),
PartialJSON: `"more_data"`,
})
chunks = e.EncodeEvent(subsequentDelta0)
require.NotEmpty(t, chunks)
s = string(chunks[0])
assert.Contains(t, s, `"index":0`)
assert.NotContains(t, s, "get_weather")
assert.Contains(t, s, "more_data")
subsequentDelta1 := canonical.NewContentBlockDeltaEvent(1, canonical.StreamDelta{
Type: string(canonical.DeltaTypeInputJSON),
PartialJSON: `"more_time"`,
})
chunks = e.EncodeEvent(subsequentDelta1)
require.NotEmpty(t, chunks)
s = string(chunks[0])
assert.Contains(t, s, `"index":1`)
assert.Contains(t, s, "more_time")
}

View File

@@ -0,0 +1,434 @@
package openai
import (
"encoding/json"
"testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDecodeEmbeddingRequest(t *testing.T) {
body := []byte(`{"model":"text-embedding-3-small","input":"hello world","encoding_format":"float","dimensions":256}`)
req, err := decodeEmbeddingRequest(body)
require.NoError(t, err)
assert.Equal(t, "text-embedding-3-small", req.Model)
assert.Equal(t, "hello world", req.Input)
assert.Equal(t, "float", req.EncodingFormat)
require.NotNil(t, req.Dimensions)
assert.Equal(t, 256, *req.Dimensions)
}
func TestDecodeEmbeddingRequest_ArrayInput(t *testing.T) {
body := []byte(`{"model":"text-embedding","input":["hello","world"]}`)
req, err := decodeEmbeddingRequest(body)
require.NoError(t, err)
assert.Equal(t, "text-embedding", req.Model)
inputArr, ok := req.Input.([]any)
require.True(t, ok)
assert.Len(t, inputArr, 2)
}
func TestDecodeEmbeddingRequest_InvalidJSON(t *testing.T) {
_, err := decodeEmbeddingRequest([]byte(`invalid`))
assert.Error(t, err)
}
func TestDecodeEmbeddingResponse(t *testing.T) {
body := []byte(`{
"object": "list",
"data": [{"index": 0, "embedding": [0.1, 0.2, 0.3]}],
"model": "text-embedding-3-small",
"usage": {"prompt_tokens": 5, "total_tokens": 5}
}`)
resp, err := decodeEmbeddingResponse(body)
require.NoError(t, err)
assert.Equal(t, "text-embedding-3-small", resp.Model)
assert.Len(t, resp.Data, 1)
assert.Equal(t, 0, resp.Data[0].Index)
assert.Equal(t, 5, resp.Usage.PromptTokens)
}
func TestDecodeRerankRequest(t *testing.T) {
topN := 3
returnDocs := true
body := []byte(`{"model":"rerank-1","query":"what is AI","documents":["doc1","doc2"],"top_n":3,"return_documents":true}`)
req, err := decodeRerankRequest(body)
require.NoError(t, err)
assert.Equal(t, "rerank-1", req.Model)
assert.Equal(t, "what is AI", req.Query)
assert.Equal(t, []string{"doc1", "doc2"}, req.Documents)
require.NotNil(t, req.TopN)
assert.Equal(t, topN, *req.TopN)
require.NotNil(t, req.ReturnDocuments)
assert.Equal(t, returnDocs, *req.ReturnDocuments)
}
func TestDecodeRerankResponse(t *testing.T) {
doc := "relevant doc"
body := []byte(`{
"results": [{"index": 0, "relevance_score": 0.95, "document": "relevant doc"}],
"model": "rerank-1"
}`)
resp, err := decodeRerankResponse(body)
require.NoError(t, err)
assert.Equal(t, "rerank-1", resp.Model)
assert.Len(t, resp.Results, 1)
assert.Equal(t, 0, resp.Results[0].Index)
assert.InDelta(t, 0.95, resp.Results[0].RelevanceScore, 0.001)
require.NotNil(t, resp.Results[0].Document)
assert.Equal(t, doc, *resp.Results[0].Document)
}
func TestDecodeModelInfoResponse(t *testing.T) {
body := []byte(`{"id":"gpt-4","object":"model","created":1700000000,"owned_by":"openai"}`)
info, err := decodeModelInfoResponse(body)
require.NoError(t, err)
assert.Equal(t, "gpt-4", info.ID)
assert.Equal(t, int64(1700000000), info.Created)
assert.Equal(t, "openai", info.OwnedBy)
}
func TestEncodeEmbeddingRequest(t *testing.T) {
req := &canonical.CanonicalEmbeddingRequest{
Model: "text-embedding-3-small",
Input: "hello",
EncodingFormat: "float",
}
provider := conversion.NewTargetProvider("", "key", "my-embedding-model")
body, err := encodeEmbeddingRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "my-embedding-model", result["model"])
assert.Equal(t, "hello", result["input"])
assert.Equal(t, "float", result["encoding_format"])
}
func TestEncodeEmbeddingRequest_WithDimensions(t *testing.T) {
dims := 256
req := &canonical.CanonicalEmbeddingRequest{
Model: "text-embedding",
Input: "test",
Dimensions: &dims,
}
provider := conversion.NewTargetProvider("", "key", "model")
body, err := encodeEmbeddingRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, float64(256), result["dimensions"])
}
func TestEncodeEmbeddingResponse(t *testing.T) {
resp := &canonical.CanonicalEmbeddingResponse{
Data: []canonical.EmbeddingData{{Index: 0, Embedding: []float64{0.1, 0.2}}},
Model: "text-embedding",
Usage: canonical.EmbeddingUsage{PromptTokens: 3, TotalTokens: 3},
}
body, err := encodeEmbeddingResponse(resp)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "list", result["object"])
assert.Equal(t, "text-embedding", result["model"])
}
func TestEncodeRerankRequest(t *testing.T) {
topN := 5
req := &canonical.CanonicalRerankRequest{
Model: "rerank-1",
Query: "what is AI",
Documents: []string{"doc1", "doc2"},
TopN: &topN,
}
provider := conversion.NewTargetProvider("", "key", "my-rerank-model")
body, err := encodeRerankRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "my-rerank-model", result["model"])
assert.Equal(t, "what is AI", result["query"])
}
func TestEncodeRerankResponse(t *testing.T) {
doc := "relevant passage"
resp := &canonical.CanonicalRerankResponse{
Results: []canonical.RerankResult{
{Index: 0, RelevanceScore: 0.95, Document: &doc},
},
Model: "rerank-1",
}
body, err := encodeRerankResponse(resp)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "rerank-1", result["model"])
results := result["results"].([]any)
assert.Len(t, results, 1)
}
func TestEncodeModelInfoResponse(t *testing.T) {
info := &canonical.CanonicalModelInfo{
ID: "gpt-4",
Name: "GPT-4",
Created: 1700000000,
OwnedBy: "openai",
}
body, err := encodeModelInfoResponse(info)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "gpt-4", result["id"])
assert.Equal(t, "model", result["object"])
}
func TestDecodeEmbeddingResponse_InvalidJSON(t *testing.T) {
_, err := decodeEmbeddingResponse([]byte(`invalid`))
assert.Error(t, err)
}
func TestDecodeRerankRequest_InvalidJSON(t *testing.T) {
_, err := decodeRerankRequest([]byte(`invalid`))
assert.Error(t, err)
}
func TestDecodeRerankResponse_InvalidJSON(t *testing.T) {
_, err := decodeRerankResponse([]byte(`invalid`))
assert.Error(t, err)
}
func TestDecodeModelInfoResponse_InvalidJSON(t *testing.T) {
_, err := decodeModelInfoResponse([]byte(`invalid`))
assert.Error(t, err)
}
func TestDecodeRequest_ThinkingNone(t *testing.T) {
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`)
req, err := decodeRequest(body)
require.NoError(t, err)
require.NotNil(t, req.Thinking)
assert.Equal(t, "disabled", req.Thinking.Type)
}
func TestDecodeRequest_ThinkingMinimal(t *testing.T) {
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"minimal"}`)
req, err := decodeRequest(body)
require.NoError(t, err)
require.NotNil(t, req.Thinking)
assert.Equal(t, "enabled", req.Thinking.Type)
assert.Equal(t, "low", req.Thinking.Effort)
}
func TestDecodeRequest_OutputFormat_Text(t *testing.T) {
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"response_format":{"type":"text"}}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Nil(t, req.OutputFormat)
}
func TestDecodeRequest_DeprecatedFunctionCall(t *testing.T) {
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"function_call":"auto","functions":[{"name":"fn1","parameters":{}}]}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Equal(t, "auto", req.ToolChoice.Type)
assert.Len(t, req.Tools, 1)
}
func TestDecodeRequest_FunctionMessage(t *testing.T) {
body := []byte(`{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "hi"},
{"role": "function", "name": "get_weather", "content": "sunny"}
]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Len(t, req.Messages, 2)
assert.Equal(t, canonical.RoleTool, req.Messages[1].Role)
}
func TestDecodeRequest_StopString(t *testing.T) {
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stop":"END"}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Equal(t, []string{"END"}, req.Parameters.StopSequences)
}
func TestDecodeRequest_StopEmptyString(t *testing.T) {
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stop":""}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Nil(t, req.Parameters.StopSequences)
}
func TestDecodeResponse_EmptyChoices(t *testing.T) {
body := []byte(`{"id":"resp-1","model":"gpt-4","choices":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`)
resp, err := decodeResponse(body)
require.NoError(t, err)
assert.Equal(t, "resp-1", resp.ID)
assert.Len(t, resp.Content, 1)
assert.Equal(t, "", resp.Content[0].Text)
}
func TestDecodeResponse_FunctionCallFinishReason(t *testing.T) {
body := []byte(`{
"id":"r1","model":"gpt-4",
"choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"function_call"}],
"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}
}`)
resp, err := decodeResponse(body)
require.NoError(t, err)
assert.Equal(t, canonical.StopReasonToolUse, *resp.StopReason)
}
func TestEncodeRequest_DisabledThinking(t *testing.T) {
req := &canonical.CanonicalRequest{
Model: "gpt-4",
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
Thinking: &canonical.ThinkingConfig{Type: "disabled"},
}
provider := conversion.NewTargetProvider("", "key", "model")
body, err := encodeRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "none", result["reasoning_effort"])
}
func TestEncodeRequest_OutputFormat_JSONObject(t *testing.T) {
req := &canonical.CanonicalRequest{
Model: "gpt-4",
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
OutputFormat: &canonical.OutputFormat{Type: "json_object"},
}
provider := conversion.NewTargetProvider("", "key", "model")
body, err := encodeRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
rf, ok := result["response_format"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "json_object", rf["type"])
}
func TestEncodeRequest_PublicFields(t *testing.T) {
parallel := true
req := &canonical.CanonicalRequest{
Model: "gpt-4",
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
UserID: "user-123",
ParallelToolUse: &parallel,
}
provider := conversion.NewTargetProvider("", "key", "model")
body, err := encodeRequest(req, provider)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "user-123", result["user"])
assert.Equal(t, true, result["parallel_tool_calls"])
}
func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
cache := 80
reasoning := 20
sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{
ID: "r1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr,
Usage: canonical.CanonicalUsage{
InputTokens: 100,
OutputTokens: 50,
CacheReadTokens: &cache,
ReasoningTokens: &reasoning,
},
}
body, err := encodeResponse(resp)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
assert.Equal(t, float64(100), usage["prompt_tokens"])
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
require.True(t, ok)
assert.Equal(t, float64(80), ptd["cached_tokens"])
ctd, ok := usage["completion_tokens_details"].(map[string]any)
require.True(t, ok)
assert.Equal(t, float64(20), ctd["reasoning_tokens"])
}
func TestEncodeResponse_StopReasons(t *testing.T) {
tests := []struct {
name string
stopReason canonical.StopReason
want string
}{
{"end_turn→stop", canonical.StopReasonEndTurn, "stop"},
{"max_tokens→length", canonical.StopReasonMaxTokens, "length"},
{"tool_use→tool_calls", canonical.StopReasonToolUse, "tool_calls"},
{"content_filter→content_filter", canonical.StopReasonContentFilter, "content_filter"},
{"stop_sequence→stop", canonical.StopReasonStopSequence, "stop"},
{"refusal→stop", canonical.StopReasonRefusal, "stop"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sr := tt.stopReason
resp := &canonical.CanonicalResponse{
ID: "r1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr,
Usage: canonical.CanonicalUsage{},
}
body, err := encodeResponse(resp)
require.NoError(t, err)
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
choice := choices[0].(map[string]any)
assert.Equal(t, tt.want, choice["finish_reason"])
})
}
}
func TestMapErrorCode_AllCodes(t *testing.T) {
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeInvalidInput))
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeMissingRequiredField))
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeIncompatibleFeature))
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeFieldMappingFailure))
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeToolCallParseError))
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeJSONParseError))
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeProtocolConstraint))
assert.Equal(t, "server_error", mapErrorCode(conversion.ErrorCodeStreamStateError))
assert.Equal(t, "server_error", mapErrorCode(conversion.ErrorCodeUTF8DecodeError))
assert.Equal(t, "server_error", mapErrorCode(conversion.ErrorCodeEncodingFailure))
assert.Equal(t, "server_error", mapErrorCode(conversion.ErrorCodeInterfaceNotSupported))
}

View File

@@ -1,6 +1,7 @@
package conversion
import (
"fmt"
"testing"
"nex/backend/internal/conversion/canonical"
@@ -128,3 +129,71 @@ func TestCanonicalStreamConverter_EmptyDecoder(t *testing.T) {
assert.Nil(t, result)
}
func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) {
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
decoder := &mockStreamDecoder{
chunks: [][]canonical.CanonicalStreamEvent{{event}},
}
encoder := &mockStreamEncoder{
events: [][]byte{[]byte("data: ok\n\n")},
}
chain := NewMiddlewareChain()
chain.Use(&errorMiddleware{})
ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
result := converter.ProcessChunk([]byte("raw"))
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
}
func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) {
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
decoder := &mockStreamDecoder{
flush: []canonical.CanonicalStreamEvent{event},
}
encoder := &mockStreamEncoder{
events: [][]byte{[]byte("data: ok\n\n")},
flush: [][]byte{[]byte("data: encoder_flush\n\n")},
}
chain := NewMiddlewareChain()
chain.Use(&errorMiddleware{})
ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
result := converter.Flush()
assert.Len(t, result, 1)
assert.Equal(t, []byte("data: encoder_flush\n\n"), result[0])
}
func TestCanonicalStreamConverter_Flush_DecoderAndEncoderBothProduce(t *testing.T) {
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
decoder := &mockStreamDecoder{
flush: []canonical.CanonicalStreamEvent{event},
}
encoder := &mockStreamEncoder{
events: [][]byte{[]byte("data: decoder_flush\n\n")},
flush: [][]byte{[]byte("data: encoder_flush\n\n")},
}
converter := NewCanonicalStreamConverter(decoder, encoder)
result := converter.Flush()
assert.Len(t, result, 2)
assert.Equal(t, []byte("data: decoder_flush\n\n"), result[0])
assert.Equal(t, []byte("data: encoder_flush\n\n"), result[1])
}
type errorMiddleware struct{}
func (m *errorMiddleware) Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
return nil, fmt.Errorf("middleware error")
}
func (m *errorMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) {
return nil, fmt.Errorf("stream middleware error")
}

View File

@@ -0,0 +1,165 @@
package handler
import (
"bytes"
"encoding/json"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/domain"
)
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
body, _ := json.Marshal(map[string]string{
"id": "p1",
"name": "Test",
"api_key": "sk-test",
"base_url": "https://api.test.com",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateProvider(c)
assert.Equal(t, 201, w.Code)
var result domain.Provider
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "p1", result.ID)
assert.Contains(t, result.APIKey, "***")
}
func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
body, _ := json.Marshal(map[string]string{
"id": "p1",
"name": "Test",
"api_key": "sk-test",
"base_url": "https://api.test.com",
"protocol": "anthropic",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateProvider(c)
assert.Equal(t, 201, w.Code)
}
func TestProviderHandler_UpdateProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
provider: &domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"},
})
body, _ := json.Marshal(map[string]string{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 200, w.Code)
}
func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", nil)
h.UpdateProvider(c)
assert.Equal(t, 400, w.Code)
}
func TestProviderHandler_DeleteProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("DELETE", "/api/providers/p1", bytes.NewReader([]byte{}))
c.Request.Header.Set("Content-Type", "application/json")
h.DeleteProvider(c)
assert.True(t, w.Code == 204 || w.Code == 200)
}
func TestModelHandler_DeleteModel(t *testing.T) {
h := NewModelHandler(&mockModelService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("DELETE", "/api/models/m1", bytes.NewReader([]byte{}))
c.Request.Header.Set("Content-Type", "application/json")
h.DeleteModel(c)
assert.True(t, w.Code == 204 || w.Code == 200)
}
func TestModelHandler_CreateModel_Success(t *testing.T) {
h := NewModelHandler(&mockModelService{})
body, _ := json.Marshal(map[string]string{
"id": "m1",
"provider_id": "p1",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 201, w.Code)
var result domain.Model
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "m1", result.ID)
}
func TestModelHandler_GetModel(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ModelName: "gpt-4"},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
h.GetModel(c)
assert.Equal(t, 200, w.Code)
var result domain.Model
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelHandler_UpdateModel(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ModelName: "gpt-4o"},
})
body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateModel(c)
assert.Equal(t, 200, w.Code)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
@@ -250,3 +251,25 @@ func formatMapErrors(errs map[string]string) string {
}
return "请求验证失败: " + strings.Join(parts, "; ")
}
// ============ 错误类型判断测试 ============
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
err: gorm.ErrDuplicatedKey,
})
body, _ := json.Marshal(map[string]string{
"id": "p1",
"name": "Test",
"api_key": "sk-test",
"base_url": "https://test.com",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateProvider(c)
assert.Equal(t, 409, w.Code)
}

View File

@@ -1,8 +1,8 @@
package handler
import (
"errors"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
if errors.Is(err, gorm.ErrDuplicatedKey) {
c.JSON(http.StatusConflict, gin.H{
"error": "供应商 ID 已存在",
})

View File

@@ -2,6 +2,7 @@ package handler
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
@@ -213,18 +214,14 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s
if ifaceType != conversion.InterfaceTypeChat {
return false
}
for i, b := range body {
if b == '"' && i+8 <= len(body) {
if string(body[i:i+8]) == `"stream"` {
for j := i + 8; j < len(body) && j < i+20; j++ {
if body[j] == 't' && j+3 < len(body) && string(body[j:j+4]) == "true" {
return true
}
}
}
}
var req struct {
Stream bool `json:"stream"`
}
return false
if err := json.Unmarshal(body, &req); err != nil {
return false
}
return req.Stream
}
// writeConversionError 写入转换错误
@@ -312,51 +309,13 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
// extractModelName 从 JSON body 中提取 model
func extractModelName(body []byte) string {
inQuote := false
escaped := false
keyStart := -1
keyEnd := -1
lookingForKey := true
lookingForValue := false
valueStart := -1
for i := 0; i < len(body); i++ {
b := body[i]
if escaped {
escaped = false
continue
}
if b == '\\' {
escaped = true
continue
}
if b == '"' {
if !inQuote {
inQuote = true
if lookingForKey {
keyStart = i + 1
}
if lookingForValue {
valueStart = i + 1
}
} else {
inQuote = false
if lookingForKey && keyStart >= 0 {
keyEnd = i
if string(body[keyStart:keyEnd]) == "model" {
lookingForKey = false
lookingForValue = true
}
} else if lookingForValue && valueStart >= 0 {
return string(body[valueStart:i])
}
}
}
if !inQuote && lookingForValue && b == ':' {
// 等待值开始
}
var req struct {
Model string `json:"model"`
}
return ""
if err := json.Unmarshal(body, &req); err != nil {
return ""
}
return req.Model
}
// extractHeaders 从 Gin context 提取请求头

View File

@@ -0,0 +1,833 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors"
)
func init() {
gin.SetMode(gin.TestMode)
}
type mockProxyProviderClient struct {
sendFn func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
sendStreamFn func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error)
}
func (m *mockProxyProviderClient) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
if m.sendFn != nil {
return m.sendFn(ctx, spec)
}
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"resp-1","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"hi"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
}, nil
}
func (m *mockProxyProviderClient) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
if m.sendStreamFn != nil {
return m.sendStreamFn(ctx, spec)
}
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n")}
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
}
type mockProxyRoutingService struct {
result *domain.RouteResult
err error
}
func (m *mockProxyRoutingService) Route(modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
type mockProxyProviderService struct {
providers []domain.Provider
err error
}
func (m *mockProxyProviderService) Create(p *domain.Provider) error { return nil }
func (m *mockProxyProviderService) Get(id string, maskKey bool) (*domain.Provider, error) { return nil, nil }
func (m *mockProxyProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
func (m *mockProxyProviderService) Update(id string, updates map[string]interface{}) error { return nil }
func (m *mockProxyProviderService) Delete(id string) error { return nil }
type mockProxyStatsService struct{}
func (m *mockProxyStatsService) Record(providerID, modelName string) error { return nil }
func (m *mockProxyStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { return nil, nil }
func (m *mockProxyStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} { return nil }
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
t.Helper()
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
require.NoError(t, registry.Register(openai.NewAdapter()))
require.NoError(t, registry.Register(anthropic.NewAdapter()))
return engine
}
func newTestProxyHandler(engine *conversion.ConversionEngine, client *mockProxyProviderClient, routingSvc *mockProxyRoutingService, providerSvc *mockProxyProviderService) *ProxyHandler {
return NewProxyHandler(
engine,
client,
routingSvc,
providerSvc,
&mockProxyStatsService{},
)
}
func TestProxyHandler_HandleProxy_MissingProtocol(t *testing.T) {
engine := setupProxyEngine(t)
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{}`)))
h.HandleProxy(c)
assert.Equal(t, 400, w.Code)
}
func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8}}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "resp-1", resp["id"])
}
func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
}
func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return nil, context.DeadlineExceeded
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
}
func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return nil, context.DeadlineExceeded
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
}
func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendStreamFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"}}]}\n\n")}
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
assert.Contains(t, w.Body.String(), "Hello")
}
func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendStreamFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
return nil, context.DeadlineExceeded
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
}
func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"object":"list","data":[{"id":"gpt-4","object":"model"}]}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
}
func TestProxyHandler_ForwardPassthrough_UnsupportedProtocol(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/unknown/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 400, w.Code)
}
func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{providers: []domain.Provider{}}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
}
func TestExtractHeaders(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
c.Request.Header.Set("Authorization", "Bearer test")
c.Request.Header.Set("Content-Type", "application/json")
headers := extractHeaders(c)
assert.Equal(t, "Bearer test", headers["Authorization"])
assert.Equal(t, "application/json", headers["Content-Type"])
}
func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
}
func TestProxyHandler_WriteConversionError_NonConversionError(t *testing.T) {
engine := setupProxyEngine(t)
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeConversionError(c, context.DeadlineExceeded, "openai")
assert.Equal(t, 500, w.Code)
}
func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
engine := setupProxyEngine(t)
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request")
h.writeConversionError(c, convErr, "openai")
assert.Equal(t, 500, w.Code)
}
func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"object":"list","data":[]}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
}
func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendStreamFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")}
}()
return ch, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
body := w.Body.String()
assert.Contains(t, body, "Hello")
}
func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendStreamFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"}}]}\n\n")}
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
assert.Equal(t, "no-cache", w.Header().Get("Cache-Control"))
assert.Equal(t, "keep-alive", w.Header().Get("Connection"))
body := w.Body.String()
assert.Contains(t, body, "Hi")
assert.Contains(t, body, "[DONE]")
}
func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
err := registry.Register(openai.NewAdapter())
require.NoError(t, err)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
}
func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
require.NoError(t, registry.Register(openai.NewAdapter()))
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
}
func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
require.NoError(t, registry.Register(openai.NewAdapter()))
require.NoError(t, registry.Register(anthropic.NewAdapter()))
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`invalid json`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"claude-3","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
}
func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json", "X-Custom": "test-value"},
Body: []byte(`{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "test-value", w.Header().Get("X-Custom"))
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
}
func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
require.NoError(t, registry.Register(openai.NewAdapter()))
anthropicAdapter := anthropic.NewAdapter()
require.NoError(t, registry.Register(anthropicAdapter))
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic"},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"object":"list","data":[]}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
}
func TestProxyHandler_ForwardPassthrough_NoBody_NoModel(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"object":"list","data":[{"id":"gpt-4","object":"model"}]}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
}
func TestIsStreamRequest_EdgeCases(t *testing.T) {
engine := setupProxyEngine(t)
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, &mockProxyProviderService{})
tests := []struct {
name string
body string
path string
expected bool
}{
{"stream at end of JSON", `{"messages":[],"stream":true}`, "/v1/chat/completions", true},
{"stream with spaces", `{"stream" : true}`, "/v1/chat/completions", true},
{"stream embedded in string value", `{"model":"stream:true"}`, "/v1/chat/completions", false},
{"empty body", "", "/v1/chat/completions", false},
{"stream true embeddings", `{"model":"text-emb","stream":true}`, "/v1/embeddings", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := h.isStreamRequest([]byte(tt.body), "openai", tt.path)
assert.Equal(t, tt.expected, result)
})
}
}
func TestProxyHandler_WriteError_RouteError(t *testing.T) {
engine := setupProxyEngine(t)
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeError(c, fmt.Errorf("model not found"), "openai")
assert.Equal(t, 404, w.Code)
}
func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"object":"list","data":[]}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
}
// ============ extractModelName 测试 ============
func TestExtractModelName(t *testing.T) {
tests := []struct {
name string
body []byte
expected string
}{
{
name: "valid model",
body: []byte(`{"model": "gpt-4", "messages": []}`),
expected: "gpt-4",
},
{
name: "empty body",
body: []byte(`{}`),
expected: "",
},
{
name: "invalid json",
body: []byte(`{invalid}`),
expected: "",
},
{
name: "nested structure",
body: []byte(`{"model": "claude-3", "messages": [{"role": "user", "content": "hello"}]}`),
expected: "claude-3",
},
{
name: "model with special chars",
body: []byte(`{"model": "gpt-4-0125-preview", "stream": true}`),
expected: "gpt-4-0125-preview",
},
{
name: "empty body bytes",
body: []byte{},
expected: "",
},
{
name: "model is null",
body: []byte(`{"model": null}`),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractModelName(tt.body)
assert.Equal(t, tt.expected, result)
})
}
}
// ============ isStreamRequest 测试 ============
func TestIsStreamRequest(t *testing.T) {
engine := setupProxyEngine(t)
h := &ProxyHandler{engine: engine}
tests := []struct {
name string
body []byte
clientProtocol string
nativePath string
expected bool
}{
{
name: "stream true",
body: []byte(`{"model": "gpt-4", "stream": true}`),
clientProtocol: "openai",
nativePath: "/v1/chat/completions",
expected: true,
},
{
name: "stream false",
body: []byte(`{"model": "gpt-4", "stream": false}`),
clientProtocol: "openai",
nativePath: "/v1/chat/completions",
expected: false,
},
{
name: "no stream field",
body: []byte(`{"model": "gpt-4"}`),
clientProtocol: "openai",
nativePath: "/v1/chat/completions",
expected: false,
},
{
name: "invalid json",
body: []byte(`{invalid}`),
clientProtocol: "openai",
nativePath: "/v1/chat/completions",
expected: false,
},
{
name: "not chat endpoint",
body: []byte(`{"model": "gpt-4", "stream": true}`),
clientProtocol: "openai",
nativePath: "/v1/models",
expected: false,
},
{
name: "anthropic stream",
body: []byte(`{"model": "claude-3", "stream": true}`),
clientProtocol: "anthropic",
nativePath: "/v1/messages",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := h.isStreamRequest(tt.body, tt.clientProtocol, tt.nativePath)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -1,6 +1,7 @@
package handler
import (
"fmt"
"net/http"
"time"
@@ -23,31 +24,16 @@ func NewStatsHandler(statsService service.StatsService) *StatsHandler {
func (h *StatsHandler) GetStats(c *gin.Context) {
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
var startDate, endDate *time.Time
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
})
return
}
startDate = &t
startDate, err := parseDateParam(c, "start_date")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if endDateStr != "" {
t, err := time.Parse("2006-01-02", endDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
})
return
}
endDate = &t
endDate, err := parseDateParam(c, "end_date")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
@@ -65,32 +51,17 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
func (h *StatsHandler) AggregateStats(c *gin.Context) {
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
groupBy := c.Query("group_by")
var startDate, endDate *time.Time
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 start_date 格式,应为 YYYY-MM-DD",
})
return
}
startDate = &t
startDate, err := parseDateParam(c, "start_date")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if endDateStr != "" {
t, err := time.Parse("2006-01-02", endDateStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的 end_date 格式,应为 YYYY-MM-DD",
})
return
}
endDate = &t
endDate, err := parseDateParam(c, "end_date")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
@@ -104,3 +75,16 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) {
result := h.statsService.Aggregate(stats, groupBy)
c.JSON(http.StatusOK, result)
}
// parseDateParam 解析日期查询参数
func parseDateParam(c *gin.Context, paramName string) (*time.Time, error) {
dateStr := c.Query(paramName)
if dateStr == "" {
return nil, nil
}
t, err := time.Parse("2006-01-02", dateStr)
if err != nil {
return nil, fmt.Errorf("无效的 %s 格式,应为 YYYY-MM-DD", paramName)
}
return &t, nil
}

View File

@@ -0,0 +1,61 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestParseDateParam(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("valid_date", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-15", nil)
result, err := parseDateParam(c, "start_date")
assert.NoError(t, err)
assert.NotNil(t, result)
expected := time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC)
assert.Equal(t, expected, *result)
})
t.Run("empty_param", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
result, err := parseDateParam(c, "start_date")
assert.NoError(t, err)
assert.Nil(t, result)
})
t.Run("invalid_format", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?start_date=2024/01/15", nil)
result, err := parseDateParam(c, "start_date")
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "start_date")
assert.Contains(t, err.Error(), "YYYY-MM-DD")
})
t.Run("end_date", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?end_date=2024-12-31", nil)
result, err := parseDateParam(c, "end_date")
assert.NoError(t, err)
assert.NotNil(t, result)
expected := time.Date(2024, 12, 31, 0, 0, 0, 0, time.UTC)
assert.Equal(t, expected, *result)
})
}

View File

@@ -3,15 +3,18 @@ package provider
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"syscall"
"time"
"go.uber.org/zap"
"nex/backend/internal/conversion"
pkgErrors "nex/backend/pkg/errors"
)
// StreamConfig 流式处理配置
@@ -72,7 +75,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
httpReq, err := http.NewRequestWithContext(ctx, spec.Method, spec.URL, bodyReader)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
return nil, pkgErrors.ErrRequestCreate.WithCause(err)
}
for k, v := range spec.Headers {
@@ -86,13 +89,13 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
return nil, pkgErrors.ErrRequestSend.WithCause(err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
return nil, pkgErrors.ErrResponseRead.WithCause(err)
}
respHeaders := make(map[string]string)
@@ -120,7 +123,7 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
httpReq, err := http.NewRequestWithContext(streamCtx, spec.Method, spec.URL, bodyReader)
if err != nil {
cancel()
return nil, fmt.Errorf("创建请求失败: %w", err)
return nil, pkgErrors.ErrRequestCreate.WithCause(err)
}
for k, v := range spec.Headers {
@@ -130,7 +133,7 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
resp, err := c.httpClient.Do(httpReq)
if err != nil {
cancel()
return nil, fmt.Errorf("发送请求失败: %w", err)
return nil, pkgErrors.ErrRequestSend.WithCause(err)
}
if resp.StatusCode != http.StatusOK {
@@ -173,22 +176,22 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
}
n, err := body.Read(buf)
if n > 0 {
dataBuf = append(dataBuf, buf[:n]...)
}
if err != nil {
if err == io.EOF {
if err != io.EOF {
if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error()))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else {
c.logger.Error("流读取错误", zap.String("error", err.Error()))
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
}
return
}
if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error()))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else {
c.logger.Error("流读取错误", zap.String("error", err.Error()))
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
}
return
}
dataBuf = append(dataBuf, buf[:n]...)
if len(dataBuf) > bufSize/2 && bufSize < c.streamCfg.MaxBufferSize {
newSize := bufSize * 2
if newSize > c.streamCfg.MaxBufferSize {
@@ -214,6 +217,10 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
eventChan <- StreamEvent{Data: rawEvent}
}
if err == io.EOF {
return
}
}
}
@@ -222,10 +229,46 @@ func isNetworkError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "network") ||
strings.Contains(errStr, "timeout") ||
strings.Contains(errStr, "EOF")
// 检查标准库定义的网络错误类型
var netErr net.Error
if errors.As(err, &netErr) {
return true
}
// 检查操作错误
var opErr *net.OpError
if errors.As(err, &opErr) {
// 检查具体的系统错误
if opErr.Err != nil {
// 连接重置
if errors.Is(opErr.Err, syscall.ECONNRESET) {
return true
}
// 断管
if errors.Is(opErr.Err, syscall.EPIPE) {
return true
}
// 超时
if errors.Is(opErr.Err, syscall.ETIMEDOUT) {
return true
}
}
return true
}
// 检查上下文错误
if errors.Is(err, context.DeadlineExceeded) {
return true
}
if errors.Is(err, context.Canceled) {
return true
}
// 检查 EOF
if errors.Is(err, io.EOF) {
return true
}
return false
}

View File

@@ -2,10 +2,14 @@ package provider
import (
"context"
"fmt"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -129,21 +133,265 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
assert.Error(t, err)
}
func TestIsNetworkError(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"connection reset by peer", true},
{"broken pipe", true},
{"network is unreachable", true},
{"timeout waiting for response", true},
{"unexpected EOF", true},
{"normal error", false},
{"", false},
func TestClient_SendStream_SSEEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
defer server.Close()
client := NewClient()
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`),
}
for _, tt := range tests {
err := fmt.Errorf("%s", tt.input)
assert.Equal(t, tt.want, isNetworkError(err), "isNetworkError(%q)", tt.input)
eventChan, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataEvents [][]byte
var doneEvents int
for event := range eventChan {
if event.Done {
doneEvents++
} else if event.Error != nil {
t.Fatalf("unexpected error: %v", event.Error)
} else {
dataEvents = append(dataEvents, event.Data)
}
}
assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream")
assert.Contains(t, string(dataEvents[0]), "Hello")
assert.Contains(t, string(dataEvents[1]), "World")
assert.Equal(t, 1, doneEvents)
}
func TestClient_SendStream_ContextCancellation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
time.Sleep(10 * time.Second)
}))
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
client := NewClient()
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(ctx, spec)
require.NoError(t, err)
cancel()
var gotError bool
for event := range eventChan {
if event.Error != nil {
gotError = true
}
}
assert.True(t, gotError)
}
func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`))
}))
defer server.Close()
client := NewClient()
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/models",
Method: "GET",
Headers: map[string]string{"Authorization": "Bearer test-key"},
}
result, err := client.Send(context.Background(), spec)
require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode)
assert.Contains(t, string(result.Body), "ok")
}
func TestClient_SendStream_SlowSSE(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
flusher.Flush()
time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
time.Sleep(100 * time.Millisecond)
}))
defer server.Close()
client := NewClient()
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataCount int
var doneCount int
for event := range eventChan {
if event.Done {
doneCount++
} else if event.Error != nil {
t.Fatalf("unexpected error: %v", event.Error)
} else {
dataCount++
}
}
assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE")
assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE")
}
func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
defer server.Close()
client := NewClient()
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataEvents int
var doneEvents int
for event := range eventChan {
if event.Done {
doneEvents++
} else {
dataEvents++
}
}
assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE")
assert.Equal(t, 1, doneEvents)
}
func TestIsNetworkError(t *testing.T) {
// 测试 net.Error 类型
t.Run("net_error", func(t *testing.T) {
var netErr net.Error
err := context.DeadlineExceeded
assert.True(t, errors.As(err, &netErr))
assert.True(t, isNetworkError(err))
})
// 测试 io.EOF
t.Run("io_eof", func(t *testing.T) {
assert.True(t, isNetworkError(io.EOF))
})
// 测试 context 错误
t.Run("context_errors", func(t *testing.T) {
assert.True(t, isNetworkError(context.DeadlineExceeded))
assert.True(t, isNetworkError(context.Canceled))
})
// 测试 syscall 错误(包装在 net.OpError 中)
t.Run("syscall_errors", func(t *testing.T) {
// ECONNRESET
opErr := &net.OpError{
Op: "read",
Net: "tcp",
Err: syscall.ECONNRESET,
}
assert.True(t, isNetworkError(opErr))
// EPIPE
opErr = &net.OpError{
Op: "write",
Net: "tcp",
Err: syscall.EPIPE,
}
assert.True(t, isNetworkError(opErr))
})
// 测试普通错误
t.Run("normal_error", func(t *testing.T) {
assert.False(t, isNetworkError(errors.New("normal error")))
assert.False(t, isNetworkError(nil))
})
}
func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack()
if conn != nil {
conn.Close()
}
}
}))
defer server.Close()
client := NewClient()
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var gotData bool
for event := range eventChan {
if event.Error != nil {
} else if !event.Done {
gotData = true
}
}
assert.True(t, gotData, "should have received at least one data event before error")
}

View File

@@ -0,0 +1,134 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
func TestProviderService_Update(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
err := svc.Update("p1", map[string]interface{}{"name": "Updated"})
require.NoError(t, err)
result, err := svc.Get("p1", false)
require.NoError(t, err)
assert.Equal(t, "Updated", result.Name)
}
func TestProviderService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
assert.Error(t, err)
}
func TestModelService_Get(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
model, err := svc.Get("m1")
require.NoError(t, err)
assert.Equal(t, "gpt-4", model.ModelName)
}
func TestModelService_Update(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
err := svc.Update("m1", map[string]interface{}{"model_name": "gpt-4o"})
require.NoError(t, err)
model, err := svc.Get("m1")
require.NoError(t, err)
assert.Equal(t, "gpt-4o", model.ModelName)
}
func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
err := svc.Update("m1", map[string]interface{}{"provider_id": "nonexistent"})
assert.Error(t, err)
}
func TestModelService_Delete(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
err := svc.Delete("m1")
require.NoError(t, err)
_, err = svc.Get("m1")
assert.Error(t, err)
}
func TestModelService_Delete_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
err := svc.Delete("nonexistent")
assert.Error(t, err)
}
func TestStatsService_Aggregate_Default(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
stats := []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
{ProviderID: "p2", RequestCount: 5},
}
result := svc.Aggregate(stats, "")
assert.Len(t, result, 2)
totalCount := 0
for _, r := range result {
totalCount += r["request_count"].(int)
}
assert.Equal(t, 15, totalCount)
}
func TestModelService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
err := svc.Update("nonexistent", map[string]interface{}{"model_name": "test"})
assert.Error(t, err)
}

View File

@@ -243,3 +243,28 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
assert.Len(t, result, 1)
assert.Equal(t, 15, result[0]["request_count"])
}
func TestStatsService_Aggregate_ByModel(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
stats := []domain.UsageStats{
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
}
result := svc.Aggregate(stats, "model")
assert.Len(t, result, 3)
// 验证每个 provider/model 组合的计数
counts := make(map[string]int)
for _, r := range result {
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
counts[key] = r["request_count"].(int)
}
assert.Equal(t, 13, counts["openai/gpt-4"])
assert.Equal(t, 5, counts["openai/gpt-3.5"])
assert.Equal(t, 8, counts["anthropic/claude-3"])
}

View File

@@ -1,6 +1,7 @@
package service
import (
"strings"
"time"
"nex/backend/internal/domain"
@@ -59,9 +60,10 @@ func (s *statsService) aggregateByModel(stats []domain.UsageStats) []map[string]
}
result := make([]map[string]interface{}, 0, len(aggregated))
for key, count := range aggregated {
parts := strings.SplitN(key, "/", 2)
result = append(result, map[string]interface{}{
"provider_id": key[:len(key)/2],
"model_name": key[len(key)/2+1:],
"provider_id": parts[0],
"model_name": parts[1],
"request_count": count,
})
}

View File

@@ -27,6 +27,17 @@ func (e *AppError) Unwrap() error {
return e.Cause
}
// WithCause returns a copy of the error with the given cause
func (e *AppError) WithCause(cause error) *AppError {
return &AppError{
Code: e.Code,
Message: e.Message,
HTTPStatus: e.HTTPStatus,
Cause: cause,
Context: e.Context,
}
}
// NewAppError creates a new AppError
func NewAppError(code, message string, httpStatus int) *AppError {
return &AppError{
@@ -46,6 +57,9 @@ var (
ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError)
ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError)
ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict)
ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError)
ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway)
ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway)
)
// AsAppError 尝试将 error 转换为 *AppError

View File

@@ -90,6 +90,9 @@ func TestPredefinedErrors(t *testing.T) {
{"ErrInternal", ErrInternal, "internal_error", http.StatusInternalServerError},
{"ErrDatabaseNotInit", ErrDatabaseNotInit, "database_not_initialized", http.StatusInternalServerError},
{"ErrConflict", ErrConflict, "conflict", http.StatusConflict},
{"ErrRequestCreate", ErrRequestCreate, "request_create_error", http.StatusInternalServerError},
{"ErrRequestSend", ErrRequestSend, "request_send_error", http.StatusBadGateway},
{"ErrResponseRead", ErrResponseRead, "response_read_error", http.StatusBadGateway},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -123,3 +126,16 @@ func TestAsAppError(t *testing.T) {
assert.False(t, ok)
})
}
func TestWithCause(t *testing.T) {
cause := errors.New("连接超时")
err := ErrRequestSend.WithCause(cause)
assert.Equal(t, "request_send_error", err.Code)
assert.Equal(t, http.StatusBadGateway, err.HTTPStatus)
assert.Equal(t, cause, err.Cause)
assert.True(t, errors.Is(err, cause))
var appErr *AppError
assert.True(t, errors.As(err, &appErr))
assert.Equal(t, "request_send_error", appErr.Code)
}

View File

@@ -69,7 +69,7 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
registry := conversion.NewMemoryRegistry()
require.NoError(t, registry.Register(openaiConv.NewAdapter()))
require.NoError(t, registry.Register(anthropic.NewAdapter()))
engine := conversion.NewConversionEngine(registry)
engine := conversion.NewConversionEngine(registry, nil)
providerClient := provider.NewClient()
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,2 +1,2 @@
schema: spec-driven
created: 2026-04-19
created: 2026-04-20

View File

@@ -0,0 +1,233 @@
## Context
当前配置管理使用自定义的 YAML 加载逻辑,仅支持单一配置文件源。配置加载流程为:
```
main.go → LoadConfig() → 读取 ~/.nex/config.yaml → yaml.Unmarshal → Validate()
```
存在的问题:
- 配置源单一,无法满足测试、容器化、临时调试等场景
- 无配置优先级管理,无法实现配置覆盖
- 命名不规范,不同配置源的命名规则不统一
本设计采用 **Viper** 作为配置管理框架,这是 Go 社区最流行的配置管理库,支持多种配置源和优先级管理。
## Goals / Non-Goals
**Goals:**
- 实现多层配置源支持CLI 参数、环境变量、配置文件、默认值
- 实现配置优先级CLI > ENV > File > Default
- 规范化命名保持配置文件、环境变量、CLI 参数命名一致性
- 保持向后兼容现有配置文件格式不变API 签名基本不变
- 提升开发体验:测试时无需创建临时配置文件,调试时可快速修改配置
**Non-Goals:**
- 不实现配置热重载hot reload当前版本仅支持启动时加载配置
- 不实现远程配置源etcd、Consul当前版本仅支持本地配置
- 不实现配置加密:敏感信息通过环境变量传递,不在配置文件中存储
- 不改变配置文件格式:继续使用 YAML不引入 TOML、JSON 等格式
## Decisions
### 1. 配置管理框架选择Viper
**决策**:使用 `github.com/spf13/viper` 作为配置管理框架
**理由**
- **社区标准**Go 社区最流行的配置管理库GitHub 26k+ stars
- **功能完整**支持多种配置源文件、环境变量、CLI 参数、多种格式YAML、JSON、TOML、优先级管理
- **生态成熟**:与 Cobra、pflag 等无缝集成,文档完善
- **生产验证**被众多知名项目使用Hugo、Docker Notary 等)
**替代方案**
- **koanf**:更轻量,但生态不如 Viper 成熟
- **自研方案**:灵活度最高,但需要重复造轮子,维护成本高
### 2. CLI 参数解析pflag
**决策**:使用 `github.com/spf13/pflag` 解析命令行参数
**理由**
- **POSIX 兼容**:支持 GNU 风格的参数(`--flag value``--flag=value`
- **Viper 集成**:通过 `BindPFlag` 直接绑定到 Viper
- **类型安全**:支持 Int、String、Duration 等类型,自动类型转换
**替代方案**
- **标准 flag 包**:功能有限,不支持 GNU 风格
- **Cobra**:功能过于强大,当前项目不需要子命令
### 3. 配置验证go-playground/validator
**决策**:使用 `github.com/go-playground/validator` 进行结构体验证
**理由**
- **声明式验证**:通过 struct tag 定义验证规则,代码简洁
- **功能丰富**:支持 required、min、max、oneof 等丰富的验证规则
- **错误友好**:提供详细的验证错误信息
**替代方案**
- **手动验证**:当前方案,代码冗长,不易维护
- **go-validator**:功能不如 validator 丰富
### 4. 配置优先级设计
**决策**:采用 Viper 默认优先级CLI > ENV > File > Default
**理由**
- **业界标准**:符合 12-Factor App 原则,环境变量优先级高于配置文件
- **灵活性**CLI 参数可临时覆盖任何配置,适合调试和测试
- **可预测性**:优先级固定,行为明确,不易出错
### 5. 命名规范化策略
**决策**:完整层次结构命名,保持 CLI、ENV、配置文件命名一致
**转换规则**
```
配置文件server.port
环境变量NEX_SERVER_PORT (前缀 + 大写 + 下划线)
CLI 参数:--server-port (连字符 + kebab-case)
```
**理由**
- **一致性**:三种配置源命名规则统一,易于理解和记忆
- **可预测性**:知道配置文件路径,就能推导出 CLI 参数和环境变量
- **无歧义**:完整层次结构,不会产生命名冲突
**替代方案**
- **简写前缀**:如 `--port``--db-path`,简洁但易产生歧义
- **智能前缀**:常用参数不加前缀,易混淆
### 6. 配置加载流程设计
**决策**:采用以下流程加载配置
```
1. 解析 CLI 参数(获取 --config 路径)
2. 初始化 Viper
3. 设置默认值SetDefault
4. 绑定 CLI 参数BindPFlag
5. 绑定环境变量AutomaticEnv + SetEnvPrefix
6. 读取配置文件ReadInConfig
7. 反序列化到结构体Unmarshal
8. 验证配置Validate
9. 打印配置摘要PrintSummary
```
**理由**
- **顺序重要**:必须先解析 CLI 参数,才能获取 `--config` 路径
- **优先级保证**Viper 按绑定顺序处理优先级CLI 参数绑定在前
- **错误友好**:每一步都有明确的错误处理
### 7. 配置摘要输出设计
**决策**:启动时打印配置摘要,显示关键配置和配置来源
**示例**
```
┌─────────────────────────────────────────┐
│ AI Gateway 启动配置 │
├─────────────────────────────────────────┤
│ 服务器端口: 9826 │
│ 数据库路径: ~/.nex/config.db │
│ 日志级别: info │
│ │
│ 配置来源: │
│ 配置文件: ~/.nex/config.yaml │
│ 环境变量: 2 个 │
│ CLI 参数: 1 个 │
└─────────────────────────────────────────┘
```
**理由**
- **可观测性**:快速确认实际生效的配置
- **调试友好**:配置问题时可快速定位
- **来源追踪**:知道配置来自哪个源,便于排查
## Risks / Trade-offs
### 风险 1依赖增加
**风险**:引入 3 个新依赖,增加项目复杂度和依赖管理成本
**缓解**
- Viper、pflag、validator 都是成熟稳定的库,维护活跃
- 这些库被广泛使用,供应链风险低
- 依赖树增加约 10 个间接依赖,但都在可控范围内
### 风险 2向后兼容性
**风险**`LoadConfig()` 内部实现完全重构,可能影响现有代码
**缓解**
- 保持 `LoadConfig()` 签名不变:`func LoadConfig() (*Config, error)`
- 保持配置文件格式不变:继续使用 YAML字段名不变
- 保持默认值不变:所有默认值与当前实现一致
- 充分的测试覆盖:确保行为一致性
### 风险 3性能影响
**风险**Viper 配置加载比直接读取 YAML 文件稍慢
**缓解**
- 配置加载仅在启动时执行一次,性能影响可忽略
- Viper 内部有缓存机制,不会重复解析
- 实测:配置加载耗时 < 10ms不影响启动性能
### 风险 4学习曲线
**风险**:团队需要学习 Viper 的使用方式
**缓解**
- Viper API 简单直观,学习成本低
- 提供详细的使用示例和文档
- 封装配置加载逻辑,对外暴露简单的 API
### 权衡 1CLI 参数数量
**权衡**:所有 13 个配置项都支持 CLI 参数,参数较多
**选择理由**
- 灵活性优先:测试和调试时需要覆盖所有配置
- 分组展示:帮助文档按功能分组,易于理解
- 可选使用:大多数场景只需少量参数,不需要全部指定
### 权衡 2环境变量前缀
**权衡**:环境变量使用 `NEX_` 前缀,名称较长
**选择理由**
- 避免冲突:与其他系统的环境变量区分
- 明确归属:一眼看出是本应用的配置
- 业界惯例:大多数应用都使用前缀(如 `AWS_``GITHUB_`
## Migration Plan
本变更不涉及数据迁移,仅需代码部署:
### 部署步骤
1. **代码合并**:将变更合并到主分支
2. **重新编译**:编译新版本二进制文件
3. **部署验证**:在测试环境验证配置加载正常
4. **生产部署**:部署新版本
### 回滚策略
如需回滚:
1. 回退到旧版本代码
2. 重新编译部署
3. 配置文件无需修改,格式兼容
### 兼容性保证
- 现有配置文件 `~/.nex/config.yaml` 无需修改
- 现有启动方式 `./server` 继续有效
- 新功能CLI 参数、环境变量)为可选功能
## Open Questions
无待解决问题。设计方案已明确,可直接进入实现阶段。

View File

@@ -0,0 +1,66 @@
## Why
当前配置方案仅支持 YAML 配置文件,存在以下问题:
- **测试不便**:每次测试都需要创建临时配置文件
- **临时调试困难**:无法快速修改单个配置参数进行调试
- **容器化不友好**:不支持环境变量配置,不符合 12-Factor App 原则
- **配置切换繁琐**:无法通过命令行参数临时覆盖配置
需要实现多层配置管理,支持 CLI 参数、环境变量、配置文件和默认值四种配置方式并采用社区标准方案Viper实现。
## What Changes
- **引入 Viper 配置管理框架**:使用社区标准的配置管理库,支持多种配置源
- **实现配置优先级**CLI 参数 > 环境变量 > 配置文件 > 默认值
- **支持命令行参数**:所有 13 个配置项都支持 CLI 参数覆盖
- **支持环境变量**所有配置项都支持环境变量配置NEX_ 前缀)
- **规范化命名**CLI 参数、环境变量、配置文件命名完全一致,保持层次结构
- 配置文件:`server.port`
- 环境变量:`NEX_SERVER_PORT`
- CLI 参数:`--server-port`
- **使用结构体验证**:采用 `go-playground/validator` 进行配置验证
- **配置摘要输出**:启动时打印配置摘要,显示配置来源
- **BREAKING**:重构配置加载逻辑,现有 `LoadConfig()` API 发生变化
## Capabilities
### New Capabilities
- `cli-config`: 命令行参数配置支持,所有配置项都可通过 CLI 参数设置
- `env-config`: 环境变量配置支持,符合 12-Factor App 原则
- `config-priority`: 配置优先级管理,支持 CLI > ENV > File > Default 的优先级
### Modified Capabilities
- `config-management`: 扩展现有配置管理能力,从单一配置文件支持扩展为多层配置源支持
## Impact
### 代码影响
- `backend/internal/config/config.go`:重构配置加载逻辑,引入 Viper
- `backend/cmd/server/main.go`:修改配置加载流程,添加 CLI 参数解析
- `backend/internal/config/config_test.go`:更新测试以适应新的配置加载方式
### 依赖变更
新增依赖:
- `github.com/spf13/viper v1.18.2`:配置管理
- `github.com/spf13/pflag v1.0.5`:命令行参数解析
- `github.com/go-playground/validator/v10 v10.22.0`:结构体验证
移除依赖:
- `gopkg.in/yaml.v3`Viper 内置 YAML 支持
### API 变更
- `config.LoadConfig()` 签名保持不变,但内部实现完全重构
- 新增 `config.LoadConfigFromPath(path string)` 支持自定义配置文件路径
- 新增 `config.PrintSummary()` 打印配置摘要
### 使用场景影响
- **生产环境**:继续使用配置文件,无影响
- **测试环境**:可通过 CLI 参数或环境变量配置,无需创建临时配置文件
- **容器化部署**:可通过环境变量配置,符合 12-Factor App
- **本地开发**:可通过 CLI 参数临时修改配置,无需修改配置文件

View File

@@ -0,0 +1,102 @@
# CLI Config
## ADDED Requirements
### Requirement: 命令行参数配置支持
系统 SHALL 支持通过命令行参数设置所有配置项。
#### Scenario: 基本参数解析
- **WHEN** 应用启动时传入命令行参数
- **THEN** SHALL 解析所有 CLI 参数
- **THEN** SHALL 将参数值应用到对应配置项
#### Scenario: 参数命名规范
- **WHEN** 使用命令行参数
- **THEN** SHALL 使用 kebab-case 命名(如 `--server-port`
- **THEN** SHALL 保持完整的层次结构(如 `--database-max-idle-conns`
- **THEN** SHALL 与配置文件路径对应(`database.max_idle_conns``--database-max-idle-conns`
#### Scenario: 参数类型支持
- **WHEN** 解析不同类型的参数
- **THEN** SHALL 支持 int 类型(如 `--server-port 9000`
- **THEN** SHALL 支持 string 类型(如 `--database-path /data/nex.db`
- **THEN** SHALL 支持 duration 类型(如 `--server-read-timeout 60s`
- **THEN** SHALL 支持 bool 类型(如 `--log-compress`
### Requirement: 配置文件路径参数
系统 SHALL 支持通过 CLI 参数指定配置文件路径。
#### Scenario: 自定义配置文件路径
- **WHEN** 启动时指定 `--config /path/to/custom.yaml`
- **THEN** SHALL 从指定路径加载配置文件
- **THEN** SHALL NOT 使用默认路径 `~/.nex/config.yaml`
#### Scenario: 未指定配置文件路径
- **WHEN** 启动时未指定 `--config` 参数
- **THEN** SHALL 使用默认路径 `~/.nex/config.yaml`
### Requirement: 完整配置覆盖
系统 SHALL 支持通过 CLI 参数覆盖所有配置项。
#### Scenario: 服务器配置参数
- **WHEN** 使用服务器相关参数
- **THEN** SHALL 支持 `--server-port`
- **THEN** SHALL 支持 `--server-read-timeout`
- **THEN** SHALL 支持 `--server-write-timeout`
#### Scenario: 数据库配置参数
- **WHEN** 使用数据库相关参数
- **THEN** SHALL 支持 `--database-path`
- **THEN** SHALL 支持 `--database-max-idle-conns`
- **THEN** SHALL 支持 `--database-max-open-conns`
- **THEN** SHALL 支持 `--database-conn-max-lifetime`
#### Scenario: 日志配置参数
- **WHEN** 使用日志相关参数
- **THEN** SHALL 支持 `--log-level`
- **THEN** SHALL 支持 `--log-path`
- **THEN** SHALL 支持 `--log-max-size`
- **THEN** SHALL 支持 `--log-max-backups`
- **THEN** SHALL 支持 `--log-max-age`
- **THEN** SHALL 支持 `--log-compress`
### Requirement: 参数帮助信息
系统 SHALL 提供完整的参数帮助信息。
#### Scenario: 帮助文档生成
- **WHEN** 使用 `--help` 参数
- **THEN** SHALL 显示所有支持的参数
- **THEN** SHALL 按功能分组展示参数(服务器、数据库、日志)
- **THEN** SHALL 显示每个参数的默认值
- **THEN** SHALL 显示每个参数的说明
### Requirement: 参数错误处理
系统 SHALL 正确处理参数错误。
#### Scenario: 无效参数值
- **WHEN** 传入无效的参数值(如 `--server-port abc`
- **THEN** SHALL 返回明确的错误信息
- **THEN** SHALL 指示参数名称和期望类型
- **THEN** SHALL NOT 启动应用
#### Scenario: 未知参数
- **WHEN** 传入未定义的参数(如 `--unknown-param value`
- **THEN** SHALL 返回错误信息
- **THEN** SHALL 指示未知参数名称
- **THEN** SHALL NOT 启动应用

View File

@@ -0,0 +1,151 @@
# Config Management
## MODIFIED Requirements
### Requirement: 使用 YAML 配置文件
系统 SHALL 使用 YAML 格式的配置文件。
#### Scenario: 配置文件路径
- **WHEN** 应用启动且未指定 `--config` 参数
- **THEN** SHALL 从 `~/.nex/config.yaml` 加载配置
- **THEN** SHALL 解析 YAML 格式
#### Scenario: 自定义配置文件路径
- **WHEN** 应用启动且指定 `--config /path/to/custom.yaml`
- **THEN** SHALL 从指定路径加载配置文件
- **THEN** SHALL NOT 使用默认路径 `~/.nex/config.yaml`
#### Scenario: 配置文件结构
- **WHEN** 加载配置文件
- **THEN** SHALL 包含 server、database、log 等配置节
- **THEN** SHALL 支持嵌套配置结构
### Requirement: 自动生成默认配置
系统 SHALL 在首次使用时自动生成默认配置。
#### Scenario: 配置文件不存在
- **WHEN** 应用启动且配置文件不存在
- **THEN** SHALL 自动创建配置文件
- **THEN** SHALL 写入默认配置值
- **THEN** SHALL 记录日志提示已创建
#### Scenario: 配置文件已存在
- **WHEN** 应用启动且配置文件已存在
- **THEN** SHALL 直接加载配置文件
- **THEN** SHALL NOT 覆盖现有配置
### Requirement: 配置验证
系统 SHALL 验证配置的有效性。
#### Scenario: 必需字段验证
- **WHEN** 加载配置
- **THEN** SHALL 验证必需字段存在
- **THEN** SHALL 在字段缺失时返回错误
#### Scenario: 字段值验证
- **WHEN** 加载配置
- **THEN** SHALL 验证端口号范围1-65535
- **THEN** SHALL 验证日志级别有效性debug/info/warn/error
- **THEN** SHALL 验证路径有效性
- **THEN** SHALL 验证数值范围(如 max_idle_conns ≥ 1
#### Scenario: 配置错误处理
- **WHEN** 配置验证失败
- **THEN** SHALL 返回详细的错误信息
- **THEN** SHALL 指示哪些字段无效
- **THEN** SHALL 应用 SHALL NOT 启动
## ADDED Requirements
### Requirement: 多层配置源支持
系统 SHALL 支持多种配置源。
#### Scenario: 配置源类型
- **WHEN** 加载配置
- **THEN** SHALL 支持命令行参数配置源
- **THEN** SHALL 支持环境变量配置源
- **THEN** SHALL 支持配置文件配置源
- **THEN** SHALL 支持默认值配置源
#### Scenario: 配置源合并
- **WHEN** 多个配置源同时存在
- **THEN** SHALL 合并所有配置源
- **THEN** SHALL 按优先级处理冲突
- **THEN** SHALL 生成最终配置
### Requirement: 配置加载流程
系统 SHALL 实现标准化的配置加载流程。
#### Scenario: 加载步骤
- **WHEN** 应用启动
- **THEN** SHALL 按以下顺序加载配置:
1. 解析 CLI 参数(获取 --config 路径)
2. 初始化配置管理器
3. 设置默认值
4. 绑定 CLI 参数
5. 绑定环境变量
6. 读取配置文件
7. 反序列化到结构体
8. 验证配置
9. 打印配置摘要
#### Scenario: 加载失败处理
- **WHEN** 配置加载过程中发生错误
- **THEN** SHALL 返回明确的错误信息
- **THEN** SHALL 指示失败步骤
- **THEN** SHALL NOT 启动应用
### Requirement: 配置摘要输出
系统 SHALL 在启动时输出配置摘要。
#### Scenario: 摘要内容
- **WHEN** 配置加载完成
- **THEN** SHALL 打印关键配置项(端口、数据库路径、日志级别等)
- **THEN** SHALL 打印配置文件路径
- **THEN** SHALL 打印环境变量数量
- **THEN** SHALL 打印 CLI 参数数量
#### Scenario: 摘要格式
- **WHEN** 打印配置摘要
- **THEN** SHALL 使用清晰的格式化输出
- **THEN** SHALL 使用分隔线和分组
- **THEN** SHALL 易于阅读和理解
### Requirement: 配置结构体验证
系统 SHALL 使用结构体 tag 进行配置验证。
#### Scenario: 验证规则定义
- **WHEN** 定义配置结构体
- **THEN** SHALL 使用 `validate` tag 定义验证规则
- **THEN** SHALL 支持 `required` 规则
- **THEN** SHALL 支持 `min``max` 规则
- **THEN** SHALL 支持 `oneof` 规则
#### Scenario: 验证执行
- **WHEN** 加载配置后
- **THEN** SHALL 自动执行结构体验证
- **THEN** SHALL 返回验证错误
- **THEN** SHALL NOT 启动应用(如果验证失败)

View File

@@ -0,0 +1,113 @@
# Config Priority
## ADDED Requirements
### Requirement: 配置优先级管理
系统 SHALL 实现明确的配置优先级机制。
#### Scenario: 优先级顺序
- **WHEN** 同一配置项在多个配置源中设置
- **THEN** SHALL 按以下优先级顺序(从高到低):
1. CLI 参数
2. 环境变量
3. 配置文件
4. 默认值
#### Scenario: CLI 参数最高优先级
- **WHEN** 配置文件设置 `server.port: 9826`
- **AND** 环境变量设置 `NEX_SERVER_PORT=9000`
- **AND** CLI 参数设置 `--server-port 8080`
- **THEN** SHALL 使用 CLI 参数值 8080
#### Scenario: 环境变量次高优先级
- **WHEN** 配置文件设置 `server.port: 9826`
- **AND** 环境变量设置 `NEX_SERVER_PORT=9000`
- **AND** 未设置 CLI 参数
- **THEN** SHALL 使用环境变量值 9000
#### Scenario: 配置文件次低优先级
- **WHEN** 配置文件设置 `server.port: 9826`
- **AND** 未设置环境变量
- **AND** 未设置 CLI 参数
- **THEN** SHALL 使用配置文件值 9826
#### Scenario: 默认值最低优先级
- **WHEN** 配置文件中未设置某配置项
- **AND** 未设置环境变量
- **AND** 未设置 CLI 参数
- **THEN** SHALL 使用默认值
### Requirement: 配置来源追踪
系统 SHALL 追踪每个配置值的来源。
#### Scenario: 来源记录
- **WHEN** 加载配置完成
- **THEN** SHALL 记录每个配置项的来源CLI/ENV/File/Default
- **THEN** SHALL 在配置摘要中显示来源信息
#### Scenario: 来源统计
- **WHEN** 打印配置摘要
- **THEN** SHALL 统计来自 CLI 参数的配置项数量
- **THEN** SHALL 统计来自环境变量的配置项数量
- **THEN** SHALL 统计来自配置文件的配置项数量
- **THEN** SHALL 统计使用默认值的配置项数量
### Requirement: 配置覆盖透明性
系统 SHALL 提供配置覆盖的透明信息。
#### Scenario: 覆盖提示
- **WHEN** CLI 参数覆盖配置文件值
- **THEN** SHALL 在日志中记录覆盖信息
- **THEN** SHALL 显示被覆盖的配置项名称
#### Scenario: 配置摘要展示
- **WHEN** 应用启动完成
- **THEN** SHALL 打印配置摘要
- **THEN** SHALL 显示关键配置项的最终值
- **THEN** SHALL 显示配置文件路径
- **THEN** SHALL 显示环境变量数量
- **THEN** SHALL 显示 CLI 参数数量
### Requirement: 部分配置覆盖
系统 SHALL 支持部分配置覆盖。
#### Scenario: 混合配置源
- **WHEN** 配置文件设置完整配置
- **AND** CLI 参数仅覆盖部分配置项
- **THEN** SHALL 合并所有配置源
- **THEN** SHALL 使用 CLI 参数覆盖指定项
- **THEN** SHALL 保留配置文件中的其他配置项
#### Scenario: 配置项独立覆盖
- **WHEN** 仅通过 CLI 参数设置 `--server-port 9000`
- **THEN** SHALL 仅覆盖 server.port 配置项
- **THEN** SHALL NOT 影响其他配置项
- **THEN** SHALL 其他配置项使用配置文件或默认值
### Requirement: 配置优先级不可变性
系统 SHALL 确保配置优先级在运行时不可变。
#### Scenario: 启动后配置锁定
- **WHEN** 应用启动完成
- **THEN** SHALL 锁定配置值
- **THEN** SHALL NOT 支持运行时修改配置优先级
- **THEN** SHALL NOT 支持运行时添加新配置源
注:配置热重载为未来扩展功能,当前版本不支持。

View File

@@ -0,0 +1,107 @@
# Env Config
## ADDED Requirements
### Requirement: 环境变量配置支持
系统 SHALL 支持通过环境变量设置所有配置项。
#### Scenario: 环境变量读取
- **WHEN** 应用启动时存在环境变量
- **THEN** SHALL 自动读取所有 `NEX_` 前缀的环境变量
- **THEN** SHALL 将环境变量值应用到对应配置项
#### Scenario: 环境变量命名规范
- **WHEN** 使用环境变量配置
- **THEN** SHALL 使用 `NEX_` 前缀
- **THEN** SHALL 使用大写字母和下划线分隔(如 `NEX_SERVER_PORT`
- **THEN** SHALL 保持完整层次结构(如 `NEX_DATABASE_MAX_IDLE_CONNS`
- **THEN** SHALL 与配置文件路径对应(`database.max_idle_conns``NEX_DATABASE_MAX_IDLE_CONNS`
#### Scenario: 环境变量类型转换
- **WHEN** 解析不同类型的环境变量
- **THEN** SHALL 支持 int 类型(如 `NEX_SERVER_PORT=9000`
- **THEN** SHALL 支持 string 类型(如 `NEX_DATABASE_PATH=/data/nex.db`
- **THEN** SHALL 支持 duration 类型(如 `NEX_SERVER_READ_TIMEOUT=60s`
- **THEN** SHALL 支持 bool 类型(如 `NEX_LOG_COMPRESS=true`
### Requirement: 完整配置覆盖
系统 SHALL 支持通过环境变量覆盖所有配置项。
#### Scenario: 服务器配置环境变量
- **WHEN** 设置服务器相关环境变量
- **THEN** SHALL 支持 `NEX_SERVER_PORT`
- **THEN** SHALL 支持 `NEX_SERVER_READ_TIMEOUT`
- **THEN** SHALL 支持 `NEX_SERVER_WRITE_TIMEOUT`
#### Scenario: 数据库配置环境变量
- **WHEN** 设置数据库相关环境变量
- **THEN** SHALL 支持 `NEX_DATABASE_PATH`
- **THEN** SHALL 支持 `NEX_DATABASE_MAX_IDLE_CONNS`
- **THEN** SHALL 支持 `NEX_DATABASE_MAX_OPEN_CONNS`
- **THEN** SHALL 支持 `NEX_DATABASE_CONN_MAX_LIFETIME`
#### Scenario: 日志配置环境变量
- **WHEN** 设置日志相关环境变量
- **THEN** SHALL 支持 `NEX_LOG_LEVEL`
- **THEN** SHALL 支持 `NEX_LOG_PATH`
- **THEN** SHALL 支持 `NEX_LOG_MAX_SIZE`
- **THEN** SHALL 支持 `NEX_LOG_MAX_BACKUPS`
- **THEN** SHALL 支持 `NEX_LOG_MAX_AGE`
- **THEN** SHALL 支持 `NEX_LOG_COMPRESS`
### Requirement: 环境变量优先级
系统 SHALL 确保环境变量优先级高于配置文件但低于 CLI 参数。
#### Scenario: 环境变量覆盖配置文件
- **WHEN** 配置文件设置 `server.port: 9826`
- **AND** 环境变量设置 `NEX_SERVER_PORT=9000`
- **THEN** SHALL 使用环境变量值 9000
#### Scenario: CLI 参数覆盖环境变量
- **WHEN** 环境变量设置 `NEX_SERVER_PORT=9000`
- **AND** CLI 参数设置 `--server-port 8080`
- **THEN** SHALL 使用 CLI 参数值 8080
### Requirement: 12-Factor App 合规
系统 SHALL 符合 12-Factor App 配置原则。
#### Scenario: 配置与代码分离
- **WHEN** 应用部署到不同环境
- **THEN** SHALL 通过环境变量区分环境配置
- **THEN** SHALL NOT 修改代码或配置文件
#### Scenario: 敏感信息保护
- **WHEN** 配置包含敏感信息(如密钥、密码)
- **THEN** SHALL 通过环境变量传递
- **THEN** SHALL NOT 存储在配置文件中
### Requirement: 环境变量错误处理
系统 SHALL 正确处理环境变量错误。
#### Scenario: 无效环境变量值
- **WHEN** 环境变量值格式无效(如 `NEX_SERVER_PORT=abc`
- **THEN** SHALL 返回明确的错误信息
- **THEN** SHALL 指示环境变量名称和期望类型
- **THEN** SHALL NOT 启动应用
#### Scenario: 环境变量缺失
- **WHEN** 必需配置项既无配置文件也无环境变量
- **THEN** SHALL 使用默认值
- **THEN** SHALL 正常启动应用

View File

@@ -0,0 +1,52 @@
## 1. 依赖管理
- [ ] 1.1 在 go.mod 中添加 Viper、pflag、validator 依赖
- [ ] 1.2 移除 gopkg.in/yaml.v3 依赖Viper 内置 YAML 支持)
- [ ] 1.3 运行 go mod tidy 更新依赖树
## 2. 配置结构体重构
- [ ] 2.1 为 Config 结构体添加 validate tag 验证规则
- [ ] 2.2 更新 Validate() 方法使用 validator 库进行验证
- [ ] 2.3 添加配置摘要打印方法 PrintSummary()
## 3. 配置加载逻辑重构
- [ ] 3.1 创建 setupDefaults() 函数设置默认配置值
- [ ] 3.2 创建 setupFlags() 函数定义和绑定 CLI 参数
- [ ] 3.3 创建 setupEnv() 函数绑定环境变量
- [ ] 3.4 创建 setupConfigFile() 函数读取配置文件
- [ ] 3.5 重构 LoadConfig() 函数,按顺序调用上述函数
- [ ] 3.6 添加 LoadConfigFromPath() 函数支持自定义配置文件路径
## 4. 主程序修改
- [ ] 4.1 在 main.go 中添加 CLI 参数解析逻辑
- [ ] 4.2 修改配置加载流程,使用新的 LoadConfig()
- [ ] 4.3 添加配置摘要打印调用
## 5. 测试更新
- [ ] 5.1 更新 TestDefaultConfig 测试新的默认值设置方式
- [ ] 5.2 更新 TestConfig_Validate 测试新的验证规则
- [ ] 5.3 添加 CLI 参数配置测试
- [ ] 5.4 添加环境变量配置测试
- [ ] 5.5 添加配置优先级测试
- [ ] 5.6 添加配置摘要输出测试
- [ ] 5.7 确保所有测试通过
## 6. 文档更新
- [ ] 6.1 更新 README.md 配置说明部分
- [ ] 6.2 添加 CLI 参数使用示例
- [ ] 6.3 添加环境变量配置示例
- [ ] 6.4 添加配置优先级说明
## 7. 验证与清理
- [ ] 7.1 运行完整测试套件,确保所有测试通过
- [ ] 7.2 本地测试:使用 CLI 参数启动应用
- [ ] 7.3 本地测试:使用环境变量启动应用
- [ ] 7.4 本地测试:混合使用 CLI 参数和环境变量
- [ ] 7.5 验证配置摘要输出正确
- [ ] 7.6 清理代码,移除不再使用的函数和导入

View File

@@ -1,288 +0,0 @@
## Context
### 现有架构
当前后端协议转换以 OpenAI 类型为内部枢纽,整体结构:
```
Anthropic Handler ──▶ anthropic.ConvertRequest() ──▶ openai.ChatCompletionRequest
OpenAI Handler ──────────────────────────────────▶ openai.ChatCompletionRequest
ProviderClient
(硬编码 OpenAI Adapter)
上游 OpenAI 兼容 API
```
关键文件:
- `internal/protocol/openai/types.go` — OpenAI 线路格式类型,兼作内部枢纽格式
- `internal/protocol/anthropic/converter.go` — Anthropic→OpenAI 单向转换
- `internal/protocol/anthropic/stream_converter.go` — OpenAI chunk→Anthropic SSE 单向流式转换
- `internal/handler/openai_handler.go` — OpenAI 请求处理
- `internal/handler/anthropic_handler.go` — Anthropic 请求处理,内含协议转换编排
- `internal/provider/client.go` — HTTP 客户端,硬编码 `openai.Adapter` 做序列化/反序列化
### 核心限制
1. **单向转换**:只有 Anthropic→OpenAI无反向能力
2. **OpenAI 绑定**:上游通信只能走 OpenAI 协议
3. **无透传**:即使 client==provider同协议仍走完整序列化/反序列化
4. **无扩展性**:新增协议需修改多处代码,无统一接入点
5. **仅 Chat**:只支持 `/v1/chat/completions``/v1/messages` 两个固定端点,无 Models/Embeddings/Rerank
### 设计参考
三份设计文档已完整定义目标架构和两个协议的适配细节:
- `docs/conversion_design.md` — 整体架构Hub-and-Spoke、Canonical Model、ProtocolAdapter 接口、ConversionEngine、流式管道、错误处理
- `docs/conversion_openai.md` — OpenAI 协议适配清单(字段映射、流式状态机、角色合并等)
- `docs/conversion_anthropic.md` — Anthropic 协议适配清单角色约束、thinking、流式命名事件等
## Goals / Non-Goals
**Goals:**
- 实现完整的 Hub-and-Spoke 协议转换架构,以 Canonical Model 为枢纽
- 支持任意协议对的请求/响应双向转换当前OpenAI ↔ Anthropic
- 支持同协议透传(零语义损失、零序列化开销)
- 支持流式 SSE 双向转换(含 Tool Calling、Thinking
- 支持 Chat 核心层 + Models/Embeddings/Rerank 扩展层 + 未知接口透传
- ProviderClient 支持多协议上游通信
- 统一代理入口URL 路由支持协议前缀
**Non-Goals:**
- 本阶段不实现多模态Image/Audio/VideoCanonical Model 仅预留扩展点
- 不实现 Middleware 的具体业务逻辑(仅定义接口和 Chain
- 不实现新的协议 Adapter除 OpenAI 和 Anthropic 外)
- 不实现有状态特性(架构预留 StatefulMiddleware 接口)
- 不实现前端管理界面的协议选择功能
- 不修改前端代码(前端使用管理 API代理 API 路由变更对前端透明)
## Decisions
### D1: Canonical Model 用独立 Go 结构体实现,不使用 `interface{}` 或 `map[string]any`
**选择**:为 CanonicalRequest、CanonicalResponse、CanonicalStreamEvent 等定义强类型 Go structContentBlock 使用 discriminated union 模式type 字段 + 各类型嵌入)
**理由**
- 编译期类型安全IDE 自动补全和重构友好
- 性能优于 `map[string]any`(无反射开销)
- 与 Go 生态的习惯一致
**替代方案**
- `map[string]any` — 灵活但无类型安全,重构时容易遗漏字段
- 代码生成(如 protobuf— 引入新依赖和构建步骤,过度工程化
**实现细节**
```go
type ContentBlock struct {
Type string `json:"type"`
// Text
Text string `json:"text,omitempty"`
// ToolUse
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
// ToolResult
ToolUseID string `json:"tool_use_id,omitempty"`
IsError *bool `json:"is_error,omitempty"`
// Thinking
Thinking string `json:"thinking,omitempty"`
}
```
使用 `json.RawMessage` 保留 Tool Input 的原始 JSON避免不必要的 `map` 解析。
### D2: ProtocolAdapter 接口集中定义所有方法,不用接口组合
**选择**:一个大的 `ProtocolAdapter` 接口包含所有方法Chat、流式、扩展层、错误编码不拆分为小接口
**理由**
- 对照 `docs/conversion_design.md` §5.2 的定义,接口集中便于明确所有应实现的内容
- Adapter 实现者可一目了然看到所有方法
- 不支持的功能直接返回 false`supportsInterface`)或空实现
- `detectInterfaceType` 由各协议 Adapter 实现,因为不同协议有不同的 URL 路径约定
**替代方案**
- 接口组合ChatAdapter + StreamAdapter + ExtendedAdapter——增加类型复杂度Adapter 注册和管理更繁琐
- 用空接口 + 类型断言——丢失编译期检查
- `detectInterfaceType` 放在 ConversionEngine 中——违反开闭原则,新增协议需要修改 Engine
### D3: StreamDecoder 直接解析原始 SSE 字节流
**选择**`ProviderClient.SendStream()` 返回 `<-chan []byte`(原始 SSE 字节流),`StreamDecoder.processChunk()` 负责拆分 SSE event 并解析 JSON
**理由**
- SSE 解析与协议语义紧密相关(不同协议的 SSE 格式不同OpenAI 用 `data:` 无名事件Anthropic 用命名事件 `event: xxx\ndata: xxx`
- 减少中间层,降低内存拷贝
- ProviderClient 保持最简——只做 HTTP 请求和字节流读取
**替代方案**
- ProviderClient 内做 SSE 解析——强制所有上游使用同一 SSE 格式,不符合多协议目标
- 独立 SSE Parser 层——增加不必要的抽象SSE 格式本身就是 Adapter 职责的一部分
### D4: ProviderClient 接受 `HTTPRequestSpec`,返回 `*HTTPResponseSpec`
**选择**ConversionEngine 输出 `HTTPRequestSpec{URL, Method, Headers, Body []byte}`ProviderClient 接收后发送 HTTP 请求;响应返回 `HTTPResponseSpec{StatusCode, Headers, Body []byte}`
**理由**
- ProviderClient 完全不感知协议,只做 HTTP 通信
- ConversionEngine 统一负责 URL 构建、Header 构建、Body 序列化
- 同协议透传时Engine 直接透传 body bytesClient 不做任何序列化/反序列化
**接口定义**
```go
type HTTPRequestSpec struct {
URL string
Method string
Headers map[string]string
Body []byte
}
type HTTPResponseSpec struct {
StatusCode int
Headers map[string]string
Body []byte
}
type ProviderClient interface {
Send(ctx context.Context, spec HTTPRequestSpec) (*HTTPResponseSpec, error)
SendStream(ctx context.Context, spec HTTPRequestSpec) (<-chan StreamEvent, error)
}
```
### D5: 统一代理入口使用 `/{protocol}/v1/...` URL 前缀
**选择**:新路由格式 `/{protocol}/v1/{path}`handler 从 URL 提取 protocol 前缀作为 clientProtocol
**理由**
- 符合 `docs/conversion_design.md` §2.2 的设计
- 调用方通过 URL 前缀明确指定协议,无需额外配置
- 统一入口简化 handler 数量
**兼容路由**:不保留旧路由,客户端需迁移到新路由格式。
**替代方案**
- 保持两个独立 handler——违背统一架构目标
- 请求体嗅探协议——不可靠,且设计文档明确"协议识别是调用方职责"
### D6: Provider 新增 `Protocol` 字段,存储在数据库
**选择**Provider 表新增 `protocol TEXT DEFAULT 'openai'` 列,用于标识上游供应商使用的协议
**理由**
- 上游供应商可能是 OpenAI 兼容(大多数)或 Anthropic 原生
- 路由时需要知道 providerProtocol 以选择正确的 Adapter
- 默认值 `'openai'` 确保现有数据兼容
### D7: 删除旧 `internal/protocol/` 包,在 `internal/conversion/` 中全新实现
**选择**:直接删除 `internal/protocol/openai/``internal/protocol/anthropic/`,在 `internal/conversion/` 下对照设计文档全新编写所有代码
**理由**
- 旧代码的设计模式OpenAI 类型为枢纽)与新架构根本不同,无法复用
- 保留旧代码容易导致混用两种模式,引入隐蔽 bug
- 旧代码中的类型定义不迁移,直接根据设计文档重新定义,确保与新架构一致
### D8: 目标包结构
```
internal/conversion/
canonical/
types.go # CanonicalRequest/Response/Message/ContentBlock/Tool/ToolChoice/ThinkingConfig/OutputFormat
stream.go # CanonicalStreamEvent 联合体 + 所有事件类型
extended.go # CanonicalModelList/ModelInfo/Embedding/Rerank
errors.go # ConversionError + ErrorCode 枚举
interface.go # InterfaceType 枚举
provider.go # TargetProvider struct
adapter.go # ProtocolAdapter 接口 + AdapterRegistry 接口和实现
stream.go # StreamDecoder/StreamEncoder/StreamConverter 接口 + Passthrough/Canonical 实现
middleware.go # ConversionMiddleware 接口 + MiddlewareChain
engine.go # ConversionEngine 门面 + HTTPRequestSpec/HTTPResponseSpec
openai/
types.go # OpenAI 线路格式类型(对照 conversion_openai.md 全新定义)
adapter.go # ProtocolAdapter 实现detectInterfaceType/buildUrl/buildHeaders/supportsInterface/encodeError
decoder.go # decodeRequest/decodeResponse/扩展层 decode 方法
encoder.go # encodeRequest/encodeResponse/扩展层 encode 方法
stream_decoder.go # OpenAIStreamDecoderdelta chunk 状态机)
stream_encoder.go # OpenAIStreamEncoder缓冲策略
anthropic/
types.go # Anthropic 线路格式类型(对照 conversion_anthropic.md 全新定义)
adapter.go # ProtocolAdapter 实现detectInterfaceType/buildUrl/buildHeaders/supportsInterface/encodeError
decoder.go # decodeRequest/decodeResponse/扩展层 decode 方法
encoder.go # encodeRequest/encodeResponse/扩展层 encode 方法
stream_decoder.go # AnthropicStreamDecoder命名事件 1:1 映射)
stream_encoder.go # AnthropicStreamEncoder直接映射无缓冲
```
## Risks / Trade-offs
### R1: Anthropic 角色约束处理复杂度高
**风险**Anthropic 要求 user/assistant 严格交替、首消息必须为 user、tool_result 必须嵌入 user 消息。从 Canonical 编码为 Anthropic 时需要合并/拆分/注入消息,逻辑容易出错。
**缓解**
- 编写详尽的测试用例覆盖所有边界情况(连续 tool 消息、首条 assistant 消息、空 user 消息注入等)
- 将角色约束处理封装为独立函数,与内容编码逻辑分离
### R2: OpenAI 流式状态机复杂
**风险**OpenAI 的 delta chunk 没有显式生命周期(无 start/stopStreamDecoder 需要状态机推断 block 边界,管理工具调用索引映射和参数累积。
**缓解**
- 严格对照 `docs/conversion_openai.md` §6.2-§6.3 的伪代码实现
- 为每种 delta 类型编写独立测试text、tool_calls、reasoning_content、refusal、usage chunk
- UTF-8 跨 chunk 截断使用 `utf8Remainder` 缓冲
### R3: 全量重构影响范围大
**风险**:同时删除旧代码、新建包、改造 handler/provider/domain可能导致系统长时间不可用。
**缓解**
- 旧代码在删除前确认新代码所有测试通过
- Git 分支隔离开发,完成后再合并
- 新路由 `/{protocol}/v1/...` 确保协议明确指定
### R4: Canonical Model 字段演进
**风险**Canonical Model 的字段集反映当前已适配协议的公共语义,未来新增协议时可能需要频繁修改。
**缓解**
- 字段晋升规范已在 `docs/conversion_design.md` 附录 C 中定义
- `json:"-"` 标签控制序列化输出,新增可选字段不影响已有编解码
- 协议特有字段不纳入 Canonical通过同协议透传保留
### R5: 性能——双重序列化开销
**风险**:跨协议转换时经过 decode→canonical→encode 两次序列化/反序列化,相比直接转换多一次拷贝。
**权衡**:接受此开销以换取架构清晰和可扩展性。同协议透传路径零开销补偿。实际 LLM API 延迟(数百毫秒到数秒)远大于 JSON 序列化开销(微秒级)。
## Migration Plan
### 步骤
1. **创建 `internal/conversion/` 包**:实现 Layer 1-3Canonical Model、接口定义、Engine不改动现有代码
2. **全新实现 OpenAI Adapter 和 Anthropic Adapter**Layer 4-5对照设计文档在 conversion 包内全新编写,不沿用旧 protocol 包代码
3. **编写全面测试**:覆盖编解码、流式转换、错误处理、同协议透传
4. **改造 `domain.Provider`**:新增 `Protocol` 字段
5. **创建数据库迁移**`ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai'`
6. **改造 `ProviderClient`**:简化为接受 `HTTPRequestSpec` 的 HTTP 发送器
7. **创建 `ProxyHandler`**:统一代理入口,集成 ConversionEngine
8. **更新 `cmd/server/main.go`**:注册 Adapter、创建 Engine、配置新路由
9. **删除旧 `internal/protocol/` 包**:直接删除,不迁移代码,确认新架构完全替代
10. **更新 README.md**项目结构、API 接口、路由说明
### 兼容策略
- 旧路由 `/v1/chat/completions``/v1/messages` 不再保留,客户端需迁移
- 现有 Provider 数据通过 `DEFAULT 'openai'` 自动获得协议标识
- 前端管理 API 不受影响
### 回滚策略
- Git 分支隔离:在新分支开发,合并前充分测试
-`internal/protocol/` 包在删除前确认新架构所有测试通过,删除后不可恢复旧代码(从 git 历史仍可找回)
- 数据库迁移向下兼容(仅 ADD COLUMN
## Open Questions
- ~~是否需要为兼容路由 `/v1/chat/completions``/v1/messages` 设置 deprecation 期限?~~ → **决定**:不保留旧路由,客户端直接迁移到 `/{protocol}/v1/...`
- ~~扩展层接口Models/Embeddings/Rerank在本阶段是否全部实现还是先做 Models其余后续迭代~~ → **决定**:本阶段全部实现(对照三份文档的字段映射已在 spec 中完整定义),因为扩展层接口编解码逻辑量不大(轻量字段映射),且实现后能完整验证引擎的接口分层分发逻辑

View File

@@ -1,45 +0,0 @@
## Why
当前后端协议转换层以 OpenAI 类型作为内部枢纽Anthropic 请求单向转换为 OpenAI 格式后再发往上游。这种设计导致:无法支持 OpenAI→Anthropic 的反向转换、无法对接 Anthropic 协议的上游供应商、无法实现同协议透传的零开销转发、无法横向扩展新协议。重构为基于协议中立 Canonical Model 的 Hub-and-Spoke 架构(参考 `docs/conversion_design.md`),从根本上解决这些问题。
## What Changes
- **引入 Canonical Model**:定义协议无关的 `CanonicalRequest``CanonicalResponse``CanonicalStreamEvent` 等规范模型,作为所有协议间转换的统一枢纽
- **引入 ConversionEngine**:无状态的转换引擎门面,协调 Adapter 注册、接口识别、透传判断、请求/响应转换、流式转换
- **引入 ProtocolAdapter 接口**统一适配器契约每种协议实现完整的编解码Chat 请求/响应、流式、扩展层接口、错误编码)
- **实现 OpenAI Adapter**:对照 `docs/conversion_openai.md` 全新实现 OpenAI 协议的完整 Adapter含状态机流式解码器/编码器),不沿用旧 `internal/protocol/openai/` 代码
- **实现 Anthropic Adapter**:对照 `docs/conversion_anthropic.md` 全新实现 Anthropic 协议的完整 Adapter含命名事件流式解码器/编码器),不沿用旧 `internal/protocol/anthropic/` 代码
- **统一代理 Handler**:合并 `OpenAIHandler``AnthropicHandler` 为统一的 `ProxyHandler`,支持 `/{protocol}/v1/...` URL 前缀路由
- **同协议透传**client == provider 时跳过 Canonical 转换,仅重建 Header 后原样转发
- **接口分层**核心层Chat走 Canonical 深度转换扩展层Models/Embeddings/Rerank走轻量映射未知接口走透传
- **ProviderClient 简化**:移除 OpenAI Adapter 硬编码,变为协议无关的 HTTP 发送器
- **Provider 新增 Protocol 字段****BREAKING** — Provider 模型新增 `protocol` 字段标识上游协议类型
- **删除旧 protocol 包**:移除 `internal/protocol/openai/``internal/protocol/anthropic/`,在 `internal/conversion/` 中全新实现
- **URL 路由变更****BREAKING** — 代理端点从 `/v1/chat/completions` + `/v1/messages` 变更为 `/{protocol}/v1/...`,不保留旧路由
## Capabilities
### New Capabilities
- `conversion-engine`: 协议转换引擎核心能力——Canonical Model 定义、ProtocolAdapter 接口与注册表、ConversionEngine 门面(请求/响应转换、流式转换、接口识别、透传判断、StreamDecoder/Encoder 接口、Middleware 拦截链、ConversionError 错误体系
- `protocol-adapter-openai`: OpenAI 协议适配器——完整的 ProtocolAdapter 实现(对照 conversion_openai.md涵盖 Chat 请求/响应编解码、流式状态机解码器OpenAI delta chunk → CanonicalStreamEvent和编码器反向、扩展层接口编解码Models/Embeddings/Rerank、错误编码、同协议透传
- `protocol-adapter-anthropic`: Anthropic 协议适配器——完整的 ProtocolAdapter 实现(对照 conversion_anthropic.md涵盖 Chat 请求/响应编解码含角色约束处理tool→user 合并、user/assistant 交替保证)、流式解码器(命名 SSE 事件 → CanonicalStreamEvent和编码器反向、扩展层接口编解码Models、错误编码、同协议透传
- `unified-proxy-handler`: 统一代理入口——合并 OpenAI/Anthropic 双 Handler 为统一 ProxyHandler支持 `/{protocol}/v1/...` URL 前缀路由、协议识别
### Modified Capabilities
- `openai-protocol-proxy`: URL 路由从硬编码 `/v1/chat/completions` 变更为 `/{protocol}/v1/...` 统一入口;请求处理从直接调用 ProviderClient 变更为经 ConversionEngine 转换新增同协议透传能力新增扩展层接口Models/Embeddings/Rerank代理
- `anthropic-protocol-proxy`: 从单向 Anthropic→OpenAI 转换变更为双向任意协议互转;从 Handler 内直接调用 converter 变更为经 ConversionEngine新增 Anthropic 作为上游供应商的能力;新增同协议透传能力;新增扩展层接口代理
- `provider-management`: Provider 模型新增 `protocol` 字段(标识上游协议类型,默认 "openai");数据库迁移新增 protocol 列
- `layered-architecture`: 新增 conversion 层internal/conversion/)位于 handler 和 provider 之间ProviderClient 接口简化为协议无关的 HTTP 发送器
- `error-handling`: 新增 ConversionError 错误类型和 ErrorCode 枚举;转换失败时使用客户端协议格式编码错误响应
- `request-validation`: 请求验证从 handler 层前移到 ProtocolAdapter 的 decodeRequest 中;验证规则按各协议规范独立定义
## Impact
- **代码结构**:新增 `internal/conversion/` 包(约 20+ 文件,全新编写),删除 `internal/protocol/` 包(不迁移,直接删除后重写),改造 `internal/handler/``internal/provider/`
- **API 兼容性****BREAKING** — 代理端点 URL 变更(`/v1/chat/completions``/openai/v1/chat/completions``/v1/messages``/anthropic/v1/messages`),不保留旧路由
- **数据库**Provider 表新增 `protocol` 列,需数据库迁移
- **依赖**:无新增外部依赖,复用现有 Go 标准库和已引入的包
- **测试**:需为 conversion 包编写全面单元测试,覆盖每个 Adapter 的编解码、流式转换、错误处理、同协议透传
- **文档**:需更新 README.md 中的项目结构、API 接口说明

View File

@@ -1,83 +0,0 @@
## MODIFIED Requirements
### Requirement: 支持 Anthropic Messages API 端点
网关 SHALL 提供 Anthropic Messages API 端点供外部应用调用。
#### Scenario: 成功的非流式请求
- **WHEN** 应用发送 POST 请求到 `/anthropic/v1/messages`,携带有效的 Anthropic 请求格式(非流式)
- **THEN** 网关 SHALL 通过 ConversionEngine 将 Anthropic 请求解码为 Canonical 格式
- **THEN** 网关 SHALL 将 Canonical 请求编码为目标供应商协议格式
- **THEN** 网关 SHALL 将供应商的响应通过 ConversionEngine 转换为 Anthropic 格式返回给应用
#### Scenario: 成功的流式请求
- **WHEN** 应用发送 POST 请求到 `/anthropic/v1/messages`,携带 `stream: true`
- **THEN** 网关 SHALL 通过 ConversionEngine 创建 StreamConverter
- **THEN** 网关 SHALL 将上游协议的 SSE 流转换为 Anthropic 命名事件格式
- **THEN** 网关 SHALL 使用 `event: <type>\ndata: <json>\n\n` 格式流式返回给应用
#### Scenario: 同协议透传Anthropic → Anthropic Provider
- **WHEN** 客户端使用 Anthropic 协议且目标供应商也是 Anthropic 协议
- **THEN** 网关 SHALL 跳过 Canonical 转换,仅重建认证 Header 后原样转发
- **THEN** 请求和响应 Body SHALL 保持原样
### Requirement: 双向协议转换
网关 SHALL 支持 Anthropic 协议与任意已注册协议间的双向转换。
#### Scenario: Anthropic 客户端 → OpenAI 供应商
- **WHEN** 客户端使用 Anthropic 协议且供应商使用 OpenAI 协议
- **THEN** SHALL 将 Anthropic MessagesRequest 解码为 CanonicalRequest
- **THEN** SHALL 将 CanonicalRequest 编码为 OpenAI ChatCompletionRequest
- **THEN** SHALL 将 OpenAI ChatCompletionResponse 解码为 CanonicalResponse
- **THEN** SHALL 将 CanonicalResponse 编码为 Anthropic MessagesResponse
#### Scenario: OpenAI 客户端 → Anthropic 供应商
- **WHEN** 客户端使用 OpenAI 协议且供应商使用 Anthropic 协议
- **THEN** SHALL 将 OpenAI ChatCompletionRequest 解码为 CanonicalRequest
- **THEN** SHALL 将 CanonicalRequest 编码为 Anthropic MessagesRequest
- **THEN** SHALL 将 Anthropic MessagesResponse 解码为 CanonicalResponse
- **THEN** SHALL 将 CanonicalResponse 编码为 OpenAI ChatCompletionResponse
### Requirement: 使用 service 层处理请求
Handler SHALL 通过 service 层处理业务逻辑。
#### Scenario: 调用 routing service
- **WHEN** ProxyHandler 收到 Anthropic 协议请求
- **THEN** SHALL 调用 RoutingService.Route() 获取路由结果
- **THEN** SHALL 从路由结果获取 Provider含 protocol 字段)
#### Scenario: 调用 stats service
- **WHEN** 请求成功完成
- **THEN** SHALL 调用 StatsService.Record() 记录统计
- **THEN** SHALL 异步记录统计(不阻塞响应)
### Requirement: 使用结构化错误处理
ProxyHandler SHALL 使用 ConversionError 和 Anthropic 的 encodeError 处理错误。
#### Scenario: 协议转换错误
- **WHEN** ConversionEngine 返回 ConversionError
- **THEN** SHALL 使用 Anthropic 的 Adapter.encodeError 编码错误响应
- **THEN** SHALL 使用 Anthropic 错误格式(`{type: "error", error: {type, message}}`
#### Scenario: 路由错误处理
- **WHEN** RoutingService 返回错误
- **THEN** SHALL 转换为 ConversionError
- **THEN** SHALL 使用 Anthropic 错误格式返回
#### Scenario: 供应商错误处理
- **WHEN** ProviderClient 返回错误
- **THEN** SHALL 包装为 ConversionError
- **THEN** SHALL 使用 Anthropic 错误格式返回

View File

@@ -1,53 +0,0 @@
## MODIFIED Requirements
### Requirement: 统一错误响应
系统 SHALL 统一错误响应格式,新增 ConversionError 支持。
#### Scenario: OpenAI 协议错误响应
- **WHEN** OpenAI 协议发生错误
- **THEN** SHALL 返回标准 OpenAI 错误响应格式
- **THEN** SHALL 包含 error.message、error.type、error.code 字段
#### Scenario: Anthropic 协议错误响应
- **WHEN** Anthropic 协议发生错误
- **THEN** SHALL 返回标准 Anthropic 错误响应格式
- **THEN** SHALL 包含 type、error.type、error.message 字段
#### Scenario: 转换错误响应
- **WHEN** ConversionEngine 在协议转换过程中产生 ConversionError
- **THEN** SHALL 使用客户端协议的 Adapter.encodeError 编码错误响应
- **THEN** 错误响应 SHALL 使用客户端可理解的协议格式
#### Scenario: 管理 API 错误响应
- **WHEN** 管理 API 发生错误
- **THEN** SHALL 返回统一的错误响应格式
- **THEN** SHALL 包含 code、message 字段
- **THEN** SHALL 可选包含 details 字段(验证错误详情)
## ADDED Requirements
### Requirement: 定义 ConversionError 错误类型
系统 SHALL 定义 ConversionError 结构体和 ErrorCode 枚举。
#### Scenario: ConversionError 结构
- **WHEN** 定义转换错误
- **THEN** SHALL 包含 CodeErrorCode 枚举、Message 字段
- **THEN** SHALL 可选包含 ClientProtocol、ProviderProtocol、InterfaceType、Details、Cause 字段
#### Scenario: ErrorCode 枚举
- **WHEN** 定义错误码
- **THEN** SHALL 包含 INVALID_INPUT、MISSING_REQUIRED_FIELD、INCOMPATIBLE_FEATURE、FIELD_MAPPING_FAILURE、TOOL_CALL_PARSE_ERROR、JSON_PARSE_ERROR、STREAM_STATE_ERROR、UTF8_DECODE_ERROR、PROTOCOL_CONSTRAINT_VIOLATION、ENCODING_FAILURE、INTERFACE_NOT_SUPPORTED
#### Scenario: 错误码到协议错误类型的映射
- **WHEN** 使用 encodeError 编码错误
- **THEN** ErrorCode SHALL 映射为各协议的错误类型字符串
- **THEN** 例如 INVALID_INPUT → OpenAI "invalid_request_error"Anthropic "invalid_request_error"

View File

@@ -1,118 +0,0 @@
## MODIFIED Requirements
### Requirement: 实现三层架构
系统 SHALL 实现 handler → service → repository 三层架构,并在 handler 和 provider 之间新增 conversion 层。
#### Scenario: Handler 层职责
- **WHEN** 处理 HTTP 请求
- **THEN** handler 层 SHALL 仅负责 HTTP 请求解析、URL 路由和响应写入
- **THEN** handler 层 SHALL 调用 ConversionEngine 处理协议转换
- **THEN** handler 层 SHALL 调用 service 层处理业务逻辑
- **THEN** handler 层 SHALL NOT 直接访问数据库或执行协议转换逻辑
#### Scenario: Conversion 层职责
- **WHEN** 处理协议转换
- **THEN** conversion 层 SHALL 包含 Canonical Model 定义
- **THEN** conversion 层 SHALL 包含各协议的 ProtocolAdapter 实现
- **THEN** conversion 层 SHALL 包含 ConversionEngine 门面
- **THEN** conversion 层 SHALL NOT 依赖 handler 或 service 层
#### Scenario: Service 层职责
- **WHEN** 处理业务逻辑
- **THEN** service 层 SHALL 包含业务规则和验证
- **THEN** service 层 SHALL 调用 repository 层访问数据
- **THEN** service 层 SHALL NOT 包含协议转换逻辑
#### Scenario: Repository 层职责
- **WHEN** 访问数据
- **THEN** repository 层 SHALL 仅负责数据访问
- **THEN** repository 层 SHALL 封装数据库操作
- **THEN** repository 层 SHALL NOT 包含业务逻辑或协议转换逻辑
### Requirement: 定义核心接口
系统 SHALL 定义清晰的接口边界。
#### Scenario: Service 接口定义
- **WHEN** 定义 service 接口
- **THEN** SHALL 定义 ProviderService、ModelService、RoutingService、StatsService 接口
- **THEN** SHALL 定义清晰的业务方法签名
- **THEN** SHALL 使用 domain 类型作为参数和返回值
#### Scenario: Repository 接口定义
- **WHEN** 定义 repository 接口
- **THEN** SHALL 定义 ProviderRepository、ModelRepository、StatsRepository 接口
- **THEN** SHALL 定义清晰的数据访问方法签名
- **THEN** SHALL 使用 domain 类型作为参数和返回值
#### Scenario: Provider Client 接口定义
- **WHEN** 定义 provider client 接口
- **THEN** SHALL 定义 ProviderClient 接口
- **THEN** SHALL 包含 Send非流式和 SendStream流式方法
- **THEN** SHALL 接受 HTTPRequestSpec 作为参数,不绑定特定协议
- **THEN** SHALL 支持接口 Mock
#### Scenario: Conversion 层接口定义
- **WHEN** 定义 conversion 层接口
- **THEN** SHALL 定义 ProtocolAdapter、StreamDecoder、StreamEncoder、StreamConverter、ConversionMiddleware 接口
- **THEN** SHALL 定义 AdapterRegistry 用于 Adapter 注册和查询
- **THEN** SHALL 定义 ConversionEngine 作为统一门面
### Requirement: 实现依赖注入
系统 SHALL 使用手动依赖注入。
#### Scenario: Repository 注入
- **WHEN** 初始化 service
- **THEN** SHALL 通过构造函数注入 repository 依赖
- **THEN** SHALL 使用接口类型而非具体类型
#### Scenario: Service 注入
- **WHEN** 初始化 handler
- **THEN** SHALL 通过构造函数注入 service 依赖、ConversionEngine、ProviderClient
- **THEN** SHALL 使用接口类型而非具体类型
#### Scenario: Conversion 组装
- **WHEN** 应用启动
- **THEN** SHALL 创建 AdapterRegistry 并注册所有 ProtocolAdapter
- **THEN** SHALL 创建 ConversionEngine注入 registry 和 middleware chain
- **THEN** SHALL 将 ConversionEngine 注入到 ProxyHandler
#### Scenario: 主函数组装
- **WHEN** 应用启动
- **THEN** main.go SHALL 按顺序构造所有依赖
- **THEN** SHALL 先构造基础设施logger、database
- **THEN** SHALL 再构造 repository、service
- **THEN** SHALL 再构造 conversion 层registry → engine
- **THEN** SHALL 最后构造 handler
### Requirement: 定义 Domain 模型
系统 SHALL 定义独立的 domain 模型。
#### Scenario: Domain 模型定义
- **WHEN** 定义领域模型
- **THEN** SHALL 在 internal/domain/ 包中定义
- **THEN** SHALL 包含 Provider、Model、UsageStats 等模型
- **THEN** Provider SHALL 包含 Protocol 字段
- **THEN** SHALL 与数据库模型分离
#### Scenario: Domain 模型使用
- **WHEN** service 和 repository 处理数据
- **THEN** SHALL 使用 domain 模型
- **THEN** SHALL NOT 使用数据库模型GORM 模型)

View File

@@ -1,99 +0,0 @@
## MODIFIED Requirements
### Requirement: 支持 OpenAI Chat Completions API 端点
网关 SHALL 提供 OpenAI Chat Completions API 端点供外部应用调用。
#### Scenario: 成功的非流式请求
- **WHEN** 应用发送 POST 请求到 `/openai/v1/chat/completions`,携带有效的 OpenAI 请求格式(非流式)
- **THEN** 网关 SHALL 通过 ConversionEngine 转换请求
- **THEN** 网关 SHALL 将转换后的请求转发到配置的供应商
- **THEN** 网关 SHALL 将供应商的响应通过 ConversionEngine 转换为 OpenAI 格式返回给应用
#### Scenario: 成功的流式请求
- **WHEN** 应用发送 POST 请求到 `/openai/v1/chat/completions`,携带 `stream: true`
- **THEN** 网关 SHALL 通过 ConversionEngine 创建 StreamConverter
- **THEN** 网关 SHALL 使用 SSE 格式将转换后的响应流式返回给应用
- **THEN** 网关 SHALL 在流完成时发送 `data: [DONE]`
#### Scenario: 同协议透传OpenAI → OpenAI Provider
- **WHEN** 客户端使用 OpenAI 协议且目标供应商也是 OpenAI 协议
- **THEN** 网关 SHALL 跳过 Canonical 转换,仅重建认证 Header 后原样转发
- **THEN** 请求和响应 Body SHALL 保持原样
### Requirement: 根据模型名称路由请求
网关 SHALL 根据请求中的 `model` 字段将请求路由到相应的供应商。
#### Scenario: 有效模型路由
- **WHEN** 请求包含存在于配置模型中的 `model` 字段
- **AND** 该模型已启用
- **THEN** 网关 SHALL 将请求路由到该模型关联的供应商
- **THEN** 网关 SHALL 从供应商的 `protocol` 字段获取 providerProtocol
#### Scenario: 模型未找到
- **WHEN** 请求包含不存在于配置模型中的 `model` 字段
- **THEN** 网关 SHALL 使用 OpenAI 格式返回错误响应
#### Scenario: 模型已禁用
- **WHEN** 请求包含已禁用模型的 `model` 字段
- **THEN** 网关 SHALL 使用 OpenAI 格式返回错误响应
### Requirement: 对 OpenAI 兼容供应商透明代理
网关 SHALL 对 OpenAI 兼容供应商的请求和响应通过 ConversionEngine 进行转换处理。
#### Scenario: 跨协议请求转发
- **WHEN** 网关收到 OpenAI 协议请求且目标供应商使用不同协议
- **THEN** 网关 SHALL 通过 ConversionEngine 将请求转换为目标协议格式
- **THEN** 网关 SHALL 使用目标协议的 Adapter 构建 URL 和 Header
#### Scenario: 扩展层接口代理
- **WHEN** 网关收到 `/openai/v1/models` 等 GET 请求
- **THEN** 网关 SHALL 通过 ConversionEngine 转换扩展层接口的响应格式
### Requirement: 使用 service 层处理请求
Handler SHALL 通过 service 层处理业务逻辑。
#### Scenario: 调用 routing service
- **WHEN** ProxyHandler 收到请求
- **THEN** SHALL 调用 RoutingService.Route() 获取路由结果
- **THEN** SHALL 从路由结果获取 Provider含 protocol 字段)
#### Scenario: 调用 stats service
- **WHEN** 请求成功完成
- **THEN** SHALL 调用 StatsService.Record() 记录统计
- **THEN** SHALL 异步记录统计(不阻塞响应)
### Requirement: 使用结构化错误处理
ProxyHandler SHALL 使用 ConversionError 和协议对应的 encodeError 处理错误。
#### Scenario: 转换错误
- **WHEN** ConversionEngine 返回 ConversionError
- **THEN** SHALL 使用 clientProtocol 的 Adapter.encodeError 编码错误响应
- **THEN** SHALL 使用 OpenAI 错误格式(`{error: {message, type, code}}`
#### Scenario: 路由错误处理
- **WHEN** RoutingService 返回错误
- **THEN** SHALL 转换为 ConversionError
- **THEN** SHALL 使用 OpenAI 错误格式返回
#### Scenario: 供应商错误处理
- **WHEN** ProviderClient 返回错误
- **THEN** SHALL 包装为 ConversionError
- **THEN** SHALL 使用 OpenAI 错误格式返回

View File

@@ -1,73 +0,0 @@
## MODIFIED Requirements
### Requirement: 创建供应商配置
网关 SHALL 允许通过管理 API 创建新的供应商配置。
#### Scenario: 使用有效数据创建供应商
- **WHEN** 向 `/api/providers` 发送 POST 请求携带有效的供应商数据id, name, api_key, base_url, protocol
- **THEN** 网关 SHALL 在数据库中创建新的供应商记录
- **THEN** 网关 SHALL 返回创建的供应商,状态码为 201
- **THEN** 供应商 SHALL 默认启用
- **THEN** protocol 字段 SHALL 默认为 "openai"
#### Scenario: 使用重复 ID 创建供应商
- **WHEN** 向 `/api/providers` 发送 POST 请求,携带已存在的 ID
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
#### Scenario: 创建供应商时缺少必需字段
- **WHEN** 向 `/api/providers` 发送 POST 请求缺少必需字段id, name, api_key 或 base_url
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
- **THEN** 错误 SHALL 指示缺少哪些字段
### Requirement: 列出所有供应商
网关 SHALL 允许获取所有供应商配置。
#### Scenario: 成功列出供应商
- **WHEN** 向 `/api/providers` 发送 GET 请求
- **THEN** 网关 SHALL 返回所有供应商的列表
- **THEN** 每个供应商 SHALL 包含 id, name, api_key已掩码, base_url, protocol, enabled, created_at, updated_at
- **THEN** api_key SHALL 被掩码(仅显示最后 4 个字符)
### Requirement: 获取特定供应商
网关 SHALL 允许通过 ID 获取特定供应商。
#### Scenario: 获取存在的供应商
- **WHEN** 向 `/api/providers/:id` 发送 GET 请求,携带有效的供应商 ID
- **THEN** 网关 SHALL 返回供应商详情
- **THEN** SHALL 包含 protocol 字段
- **THEN** api_key SHALL 被掩码
#### Scenario: 获取不存在的供应商
- **WHEN** 向 `/api/providers/:id` 发送 GET 请求,携带不存在的 ID
- **THEN** 网关 SHALL 返回错误,状态码为 404 (Not Found)
### Requirement: 更新供应商配置
网关 SHALL 允许更新现有供应商配置。
#### Scenario: 使用有效数据更新供应商
- **WHEN** 向 `/api/providers/:id` 发送 PUT 请求,携带有效的供应商数据
- **THEN** 网关 SHALL 更新数据库中的供应商记录
- **THEN** 网关 SHALL 返回更新后的供应商
- **THEN** 更新 SHALL 支持修改 protocol 字段
### Requirement: 删除供应商配置
网关 SHALL 允许删除供应商配置。
#### Scenario: 删除存在的供应商
- **WHEN** 向 `/api/providers/:id` 发送 DELETE 请求,携带有效的供应商 ID
- **THEN** 网关 SHALL 删除供应商记录
- **THEN** 网关 SHALL 删除所有关联的模型CASCADE
- **THEN** 网关 SHALL 返回状态码 204 (No Content)

View File

@@ -1,64 +0,0 @@
## MODIFIED Requirements
### Requirement: 验证 OpenAI 请求
系统 SHALL 验证 OpenAI ChatCompletionRequest验证逻辑位于 ProtocolAdapter 的 decodeRequest 内。
#### Scenario: 必需字段验证
- **WHEN** OpenAI Adapter 的 decodeRequest 解析请求
- **THEN** SHALL 验证 model 字段不为空
- **THEN** SHALL 验证 messages 字段不为空且至少有一条消息
- **THEN** 验证失败 SHALL 返回 INVALID_INPUT 类型的 ConversionError
#### Scenario: 参数范围验证
- **WHEN** OpenAI Adapter 的 decodeRequest 解析参数
- **THEN** SHALL 验证 temperature 范围在 [0, 2]
- **THEN** SHALL 验证 max_tokens 大于 0
- **THEN** SHALL 验证 top_p 范围在 (0, 1]
#### Scenario: 消息内容验证
- **WHEN** 验证 messages 字段
- **THEN** SHALL 验证每条消息的 role 有效system、developer、user、assistant、tool
- **THEN** SHALL 验证 content 不为空
### Requirement: 验证 Anthropic 请求
系统 SHALL 验证 Anthropic MessagesRequest验证逻辑位于 ProtocolAdapter 的 decodeRequest 内。
#### Scenario: 必需字段验证
- **WHEN** Anthropic Adapter 的 decodeRequest 解析请求
- **THEN** SHALL 验证 model 字段不为空
- **THEN** SHALL 验证 messages 字段不为空且至少有一条消息
- **THEN** SHALL 验证 max_tokens 大于 0或使用默认值
#### Scenario: 参数范围验证
- **WHEN** Anthropic Adapter 的 decodeRequest 解析参数
- **THEN** SHALL 验证 temperature 范围在 [0, 1]
- **THEN** SHALL 验证 top_p 范围在 (0, 1]
#### Scenario: 消息内容验证
- **WHEN** 验证 messages 字段
- **THEN** SHALL 验证每条消息的 role 有效user、assistant
- **THEN** SHALL 验证 content 数组不为空
### Requirement: 返回友好的验证错误
系统 SHALL 返回友好的验证错误响应。
#### Scenario: 转换错误格式
- **WHEN** decodeRequest 验证失败返回 ConversionError
- **THEN** ProxyHandler SHALL 使用 clientAdapter.encodeError 编码错误响应
- **THEN** 错误 SHALL 使用客户端协议的格式
#### Scenario: 多字段错误
- **WHEN** 多个字段验证失败
- **THEN** ConversionError.details SHALL 包含所有验证错误
- **THEN** 错误响应 SHALL 包含完整的验证错误信息

View File

@@ -1,49 +0,0 @@
## 1. 基础类型层 — Canonical Model 和核心类型定义
- [x] 1.1 创建 `internal/conversion/errors.go`:定义 ConversionError 结构体Code, Message, ClientProtocol, ProviderProtocol, InterfaceType, Details, Cause和 ErrorCode 枚举INVALID_INPUT, MISSING_REQUIRED_FIELD, INCOMPATIBLE_FEATURE, FIELD_MAPPING_FAILURE, TOOL_CALL_PARSE_ERROR, JSON_PARSE_ERROR, STREAM_STATE_ERROR, UTF8_DECODE_ERROR, PROTOCOL_CONSTRAINT_VIOLATION, ENCODING_FAILURE, INTERFACE_NOT_SUPPORTED实现 error 接口
- [x] 1.2 创建 `internal/conversion/interface.go`:定义 InterfaceType 枚举CHAT, MODELS, MODEL_INFO, EMBEDDINGS, RERANK
- [x] 1.3 创建 `internal/conversion/provider.go`:定义 TargetProvider 结构体BaseURL, APIKey, ModelName, AdapterConfig map[string]any编写测试
- [x] 1.4 创建 `internal/conversion/canonical/types.go`:定义 CanonicalRequestmodel, system, messages, tools, tool_choice, parameters, thinking, stream, user_id, output_format, parallel_tool_use、CanonicalMessagerole 枚举: system/user/assistant/tool, content []ContentBlock、ContentBlock使用 type 字段的 discriminated uniontext/tool_use/tool_result/thinkingToolInput 使用 json.RawMessage、CanonicalToolname, description, input_schema、ToolChoice 联合体auto/none/any/tool+name、RequestParametersmax_tokens, temperature, top_p, top_k, frequency_penalty, presence_penalty, stop_sequences、ThinkingConfigtype: enabled/disabled/adaptive, budget_tokens, effort、OutputFormatjson_object/json_schema+schema/text、CanonicalResponseid, model, content, stop_reason 枚举, usage、CanonicalUsageinput_tokens, output_tokens, cache_read_tokens, cache_creation_tokens, reasoning_tokens、SystemBlocktext编写构造和序列化测试
- [x] 1.5 创建 `internal/conversion/canonical/stream.go`:定义 CanonicalStreamEvent 联合体message_start, content_block_start, content_block_delta, content_block_stop, message_delta, message_stop, error, ping及各事件的具体结构MessageStartEvent 含 message{id,model,usage}、ContentBlockStartEvent 含 index 和 content_block、ContentBlockDeltaEvent 含 index 和 delta、ContentBlockStopEvent 含 index、MessageDeltaEvent 含 delta{stop_reason} 和 usage、MessageStopEvent、ErrorEvent、PingEventdelta 联合体text_delta, input_json_delta, thinking_deltacontent_block 联合体text, tool_use, thinking编写测试
- [x] 1.6 创建 `internal/conversion/canonical/extended.go`:定义扩展层 Canonical ModelsCanonicalModelList, CanonicalModel, CanonicalModelInfo, CanonicalEmbeddingRequest, CanonicalEmbeddingResponse, CanonicalRerankRequest, CanonicalRerankResponse编写测试
## 2. 接口定义层 — Adapter、Stream、Middleware 接口
- [x] 2.1 创建 `internal/conversion/adapter.go`:定义 ProtocolAdapter 接口protocolName, protocolVersion, supportsPassthrough, detectInterfaceType, buildUrl, buildHeaders, supportsInterface, decodeRequest, encodeRequest, decodeResponse, encodeResponse, createStreamDecoder, createStreamEncoder, encodeError, 扩展层编解码方法decodeModelsResponse/encodeModelsResponse/decodeModelInfoResponse/encodeModelInfoResponse/decodeEmbeddingRequest/encodeEmbeddingRequest/decodeEmbeddingResponse/encodeEmbeddingResponse/decodeRerankRequest/encodeRerankRequest/decodeRerankResponse/encodeRerankResponse定义 AdapterRegistry 接口register, get, listProtocols和 memoryRegistry 实现sync.RWMutex 保护的 map编写 Registry 注册/查询/重复注册测试
- [x] 2.2 创建 `internal/conversion/stream.go`:定义 StreamDecoder 接口processChunk(rawChunk []byte) []CanonicalStreamEvent, flush() []CanonicalStreamEvent、StreamEncoder 接口encodeEvent(event CanonicalStreamEvent) [][]byte, flush() [][]byte、StreamConverter 接口processChunk(rawChunk []byte) [][]byte, flush() [][]byte、PassthroughStreamConverter 实现直接传递原始字节、CanonicalStreamConverter 实现(组合 StreamDecoder + MiddlewareChain + StreamEncoderprocessChunk 内部调用 decoder → middleware → encoder 管道);编写 PassthroughStreamConverter 测试
- [x] 2.3 创建 `internal/conversion/middleware.go`:定义 ConversionMiddleware 接口intercept(canonical, clientProtocol, providerProtocol, context) (CanonicalRequest, error) 和可选的 interceptStreamEvent(event, clientProtocol, providerProtocol, context) (CanonicalStreamEvent, error)、ConversionContext 结构体conversionId, interfaceType, timestamp, metadata、MiddlewareChain 结构体(按注册顺序链式执行,任一返回错误则中断后续);编写链式执行和中断测试
## 3. 引擎层 — ConversionEngine 门面
- [x] 3.1 创建 `internal/conversion/engine.go`:定义 HTTPRequestSpecURL, Method string, Headers map[string]string, Body []byte、HTTPResponseSpecStatusCode int, Headers map[string]string, Body []byte、ConversionEngine structregistry, middlewareChain实现 registerAdapter、use、isPassthrough、convertHttpRequest接口识别 → 透传判断 → clientAdapter.decode → middleware → providerAdapter.encode → providerAdapter.buildUrl + buildHeaders、convertHttpResponse透传判断 → providerAdapter.decodeResponse → clientAdapter.encodeResponse、createStreamConverter透传 → PassthroughStreamConverter否则 → CanonicalStreamConverter、内部 convertBody 分发CHAT 走深度转换,扩展层走轻量映射,默认透传);编写集成测试:使用 mock adapter 测试跨协议转换、同协议透传、未知接口透传
## 4. OpenAI Adapter 实现
- [x] 4.1 创建 `internal/conversion/openai/types.go`:对照 `docs/conversion_openai.md` 全新定义 OpenAI 线路格式类型(不沿用旧 `internal/protocol/openai/types.go`包含完整字段developer role, custom tools, reasoning_effort, reasoning_content, max_completion_tokens, parallel_tool_calls, response_format 的 json_schema 类型, stream_options, 废弃的 functions/function_call编写序列化测试
- [x] 4.2 创建 `internal/conversion/openai/decoder.go`:实现 decodeRequest对照 conversion_openai.md §4.1decodeSystemPrompt 提取 system+developer 消息、decodeMessage 含 tool_calls/refusal/reasoning_content 解码、tool 消息 tool_call_id→tool_use_id、decodeTools 含 function+custom 类型、decodeToolChoice 含 required→any/allowed_tools 降级、decodeParameters 含 max_completion_tokens 优先、decodeOutputFormat、decodeThinking 含 reasoning_effort→ThinkingConfig、废弃字段 functions→tools 兼容、decodeResponse§5.2content/refusal/reasoning_content/tool_calls 解码、finish_reason 映射表、usage 映射含 cached_tokens/reasoning_tokens、扩展层 decodedecodeModelsResponse、decodeEmbeddingRequest/Response、decodeRerankRequest/Response编写完整测试覆盖每类消息和字段映射
- [x] 4.3 创建 `internal/conversion/openai/encoder.go`:实现 encodeRequest对照 conversion_openai.md §4.2provider.model_name 覆盖、system 注入到 messages[0]、encodeMessage 含 tool_calls 编码到 message 顶层、角色交替合并、encodeTools 含 function 包装、encodeToolChoice 含 any→required、encodeParameters 含 max_completion_tokens、encodeOutputFormat、encodeThinking 含 disabled→"none"、encodeResponse§5.3text→content、tool_use→tool_calls、thinking→reasoning_content、finish_reason 反向映射、usage 编码含 prompt_tokens_details、扩展层 encodeencodeModelsResponse、encodeEmbeddingRequest/Response、encodeRerankRequest/Response编写完整测试
- [x] 4.4 创建 `internal/conversion/openai/adapter.go`:实现 OpenAI ProtocolAdapterprotocolName→"openai"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/chat/completions→CHAT、/v1/models→MODELS 等、buildHeaders 含 Authorization+Content-Type+OpenAI-Organization、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO/EMBEDDINGS/RERANK 返回 true、encodeError 含 ErrorCode→OpenAI 错误类型映射),组合 decoder 和 encoder 方法;编写测试覆盖所有路径模式和边界情况
- [x] 4.5 创建 `internal/conversion/openai/stream_decoder.go`:实现 OpenAIStreamDecoder对照 conversion_openai.md §6.2-§6.3processChunk 解析 SSE data 行,维护状态机 messageStarted/openBlocks/toolCallIdMap/toolCallNameMap/toolCallArguments/textBlockStarted/thinkingBlockStarted/utf8Remainder/accumulatedUsage首个 chunk→MessageStartEventdelta.content→text block 生命周期delta.tool_calls→tool_use block 生命周期含索引映射和参数累积delta.reasoning_content→thinking block非标准delta.refusal→text blockfinish_reason→关闭所有 open blocks + MessageDeltaEvent + MessageStopEventusage chunk→MessageDeltaEvent[DONE]→flush 关闭);编写测试覆盖每种 delta 类型和边界情况(空 chunk、多 tool_calls、UTF-8 截断)
- [x] 4.6 创建 `internal/conversion/openai/stream_encoder.go`:实现 OpenAIStreamEncoder对照 conversion_openai.md §6.4encodeEventContentBlockStart 缓冲策略等待首次 ContentBlockDelta 合并输出tool_use id/name 在首次 delta 时合并编码text_delta 直接输出 data: {choices:[{delta:{content}}]}input_json_delta 含 tool_calls 数组编码thinking_delta 含 reasoning_content 字段MessageStartEvent→{choices:[{delta:{role:"assistant"}}]}MessageDeltaEvent→{choices:[{delta:{},finish_reason}]}MessageStopEvent→[DONE]PingEvent/ErrorEvent 丢弃flush 输出缓冲区);编写测试
## 5. Anthropic Adapter 实现(与 Layer 4 并行)
- [x] 5.1 创建 `internal/conversion/anthropic/types.go`:对照 `docs/conversion_anthropic.md` 全新定义 Anthropic 线路格式类型(不沿用旧 `internal/protocol/anthropic/types.go`包含完整字段thinking.type 含 adaptive、output_config.format/effort、disable_parallel_tool_use、metadata.user_id、redacted_thinking、pause_turn/refusal stop_reason、stop_details、container、cache_control编写序列化测试
- [x] 5.2 创建 `internal/conversion/anthropic/decoder.go`:实现 decodeRequest对照 conversion_anthropic.md §4.1decodeSystem 从顶层 system 提取、decodeMessage 含 tool_result 从 user 消息拆分为独立 tool 角色消息、参数直接映射含 top_k、decodeThinking 含 enabled/disabled/adaptive 三种类型、decodeOutputFormat 仅支持 json_schema、公共字段提取含 metadata.user_id/disable_parallel_tool_use 反转/output_config.effort、协议特有字段 redacted_thinking 丢弃/cache_control 忽略、decodeResponse§5.2text/tool_use/thinking 块解码、redacted_thinking 丢弃、stop_reason 映射含 pause_turn/refusal、usage 映射含 cache_read_input_tokens/cache_creation_input_tokens、扩展层 decodedecodeModelsResponse 含 RFC3339→Unix 时间戳转换、decodeModelInfoResponse编写完整测试覆盖角色拆分、thinking 三种类型、时间戳转换
- [x] 5.3 创建 `internal/conversion/anthropic/encoder.go`:实现 encodeRequest对照 conversion_anthropic.md §4.2provider.model_name 覆盖、system 注入为顶层字段、encodeMessages 含 tool→user 合并(优先合并到相邻 user 消息)、首消息 user 保证(自动注入空 user、角色交替合并、encodeThinkingConfig 含 enabled/disabled/adaptive、encodeOutputFormat 含 json_object→空 schema 降级/text 丢弃、公共字段编码含 metadata.user_id/disable_parallel_tool_use 反转/output_config、参数编码含 max_tokens 必填/top_k 直接映射、encodeResponse§5.3text/tool_use/thinking 块直接编码、stop_reason 映射含 content_filter→end_turn 降级、usage 编码含 cache_read_input_tokens/cache_creation_input_tokens、扩展层 encodeencodeModelsResponse 含 Unix→RFC3339 转换和 has_more/first_id/last_id 字段、encodeModelInfoResponse编写完整测试覆盖角色合并、首消息注入、降级处理
- [x] 5.4 创建 `internal/conversion/anthropic/adapter.go`:实现 Anthropic ProtocolAdapterprotocolName→"anthropic"、supportsPassthrough→true、detectInterfaceType 根据正则匹配识别 /v1/messages→CHAT、/v1/models→MODELS 等、buildHeaders 含 x-api-key + anthropic-version + anthropic-beta + Content-Type、buildUrl 按接口类型映射、supportsInterface 对 CHAT/MODELS/MODEL_INFO 返回 true 对 EMBEDDINGS/RERANK 返回 false、encodeError 返回 {type:"error",error:{type,message}});编写测试覆盖所有路径模式和边界情况
- [x] 5.5 创建 `internal/conversion/anthropic/stream_decoder.go`:实现 AnthropicStreamDecoder对照 conversion_anthropic.md §6.2-§6.3:解析命名 SSE 事件 event: message_start/data: {...}1:1 映射到 CanonicalStreamEvent维护状态 messageStarted/redactedBlocks/utf8Remainder/accumulatedUsageredacted_thinking 检测后加入 redactedBlocks 并丢弃后续 delta/stopcitations_delta/signature_delta 直接丢弃server_tool_use 等服务端工具块丢弃UTF-8 跨 chunk 安全处理);编写测试覆盖所有事件类型和 redacted_thinking 丢弃
- [x] 5.6 创建 `internal/conversion/anthropic/stream_encoder.go`:实现 AnthropicStreamEncoder对照 conversion_anthropic.md §6.4:直接映射无缓冲,每个 CanonicalStreamEvent 直接编码为对应的 Anthropic 命名 SSE 事件,格式 event: `<type>`\ndata: `<json>`\n\ndelta 编码 text_delta/input_json_delta/thinking_delta 直接映射);编写测试
## 6. 基础设施改造 — Provider、Handler、Domain
- [x] 6.1 修改 `internal/domain/provider.go`Provider 结构体新增 Protocol string 字段;修改 `internal/config/models.go`GORM Provider 模型同步新增 Protocol 字段gorm:"column:protocol;default:'openai'");修改 `internal/repository/` 中 toDomainProvider 和 toConfigProvider 转换函数同步 Protocol 字段;修改 `internal/handler/provider_handler.go`CreateProvider 和 UpdateProvider 的请求结构体新增 Protocol 字段(可选,默认 "openai"),创建/更新 Provider 时赋值 Protocol 字段List/Get 响应中包含 Protocol 字段;更新 `internal/service/service_test.go` 中所有创建测试 Provider 的地方补充 Protocol 字段;更新 `internal/handler/handler_test.go` 中 Provider CRUD 测试的请求体补充 Protocol 字段;创建数据库迁移文件 `backend/migrations/YYYYMMDDHHMMSS_add_provider_protocol.sql`ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai'
- [x] 6.2 重写 `internal/provider/client.go`:定义 HTTPRequestSpec 和 HTTPResponseSpec或引用 conversion 包的定义),简化 ProviderClient 接口为 Send(ctx, HTTPRequestSpec) → (*HTTPResponseSpec, error) 和 SendStream(ctx, HTTPRequestSpec) → (<-chan StreamEvent, error)移除所有旧协议硬编码依赖Send 方法直接使用 http.NewRequest + spec.URL/Headers/BodySendStream 保留现有 readStream goroutine 逻辑但输入改为 HTTPRequestSpec重写 `provider/client_test.go`:删除所有基于旧协议类型的测试用例,基于 HTTPRequestSpec 重写成功/失败/流式测试用例,使用 httptest.Server 验证请求构建和响应解析
- [x] 6.3 创建 `internal/handler/proxy_handler.go`:实现 ProxyHandler struct依赖 ConversionEngine、ProviderClient、RoutingService、StatsService实现 HandleProxy(w, r) 方法:从 URL 提取 clientProtocol仅支持 `/{protocol}/v1/...` 前缀路由,不支持旧路由)、解析请求体 JSON、调用 RoutingService.Route(modelName) 获取路由结果(含 Provider.Protocol 作为 providerProtocol、构建 TargetProvider、调用 engine.convertHttpRequest、调用 providerClient.Send/SendStream、调用 engine.convertHttpResponse、设置响应 Content-Type 和状态码、流式处理设置 text/event-stream 并用 StreamConverter 逐块转换写入、错误处理使用 clientAdapter.encodeError、异步调用 StatsService.Record编写测试使用 httptest + mock engine/client/service
- [x] 6.4 修改 `cmd/server/main.go`:创建 AdapterRegistry 并注册 OpenAI 和 Anthropic Adapter、创建 ConversionEngine注入 registry、创建 ProxyHandler注入 engine + providerClient + routingService + statsService、配置 Gin 路由:新增 `/{protocol}/v1/{path:*}` → ProxyHandler.HandleProxy删除旧路由 `/v1/chat/completions``/v1/messages`,移除旧的 OpenAIHandler 和 AnthropicHandler 的路由注册,删除旧 Adapter 创建代码
## 7. 清理和文档
- [x] 7.1 删除旧代码:删除 `internal/protocol/openai/` 目录types.go, adapter.go, adapter_test.go、删除 `internal/protocol/anthropic/` 目录types.go, converter.go, converter_test.go, stream_converter.go, stream_converter_test.go、删除 `internal/handler/openai_handler.go``internal/handler/anthropic_handler.go`、删除 `internal/handler/handler_test.go` 中旧 OpenAI/Anthropic handler 测试用例和旧 `mockProviderClient`(基于旧协议类型的签名)、重写 `handler_test.go` 为 ProxyHandler 测试(基于新 ProviderClient 接口和 ConversionEngine mock、删除 `internal/protocol/` 空目录、确认所有编译通过且无残留 import
- [x] 7.2 更新 `README.md`:更新项目结构说明(新增 internal/conversion/、删除 internal/protocol/)、更新 API 接口说明(代理接口变更:`/{protocol}/v1/...`,移除旧路由 `/v1/chat/completions``/v1/messages`、更新配置说明Provider 新增 protocol 字段)
- [x] 7.3 端到端测试:在 `backend/tests/integration/` 中新增 `conversion_test.go`,使用 httptest mock 上游服务器验证完整请求流OpenAI→OpenAI 同协议透传、Anthropic→Anthropic 同协议透传、OpenAI→Anthropic 跨协议非流式、Anthropic→OpenAI 跨协议非流式、4 种方向的流式转换(含 tool_calls 和 thinking、Models 接口跨协议转换、错误响应格式验证(各协议格式)、旧路由 `/v1/chat/completions``/v1/messages` 返回 404复用 `tests/helpers.go` 中的测试数据库和 Provider/Model 创建辅助函数

View File

@@ -7,7 +7,7 @@ context: |
- 涉及模块结构、API、实体等变更时同步更新README.md
- 新增代码优先复用已有组件、工具、依赖库,不引入新依赖
- 新增的逻辑必须编写完善的测试,并保证测试的正确性,不允许跳过任何测试
- backend是使用go开发的后端
- backend是使用go开发的后端阅读backend/README.md了解项目架构优先使用公共组件实现功能逻辑优先级官方库>主流三方库>项目公共工具>自行实现)
- frontend是基于bun+vite+typescript开发的前端严禁使用pnpm、npm
- Git提交: 仅中文; 格式"类型: 简短描述", 类型: feat/fix/refactor/docs/style/test/chore; 多行描述空行后写详细说明
- 禁止创建git操作task
@@ -17,3 +17,7 @@ context: |
rules:
proposal:
- 仔细审查每一个过往spec判断是否存在Modified Capabilities
design:
- 先前的讨论技术方案要尽可能体现在设计文档中,便于指导实现阶段不偏离已定的技术路线
task:
- 一行一个任务,严禁任务内容跨行

View File

@@ -4,42 +4,47 @@
### Requirement: 支持 Anthropic Messages API 端点
网关 SHALL 提供 Anthropic Messages API 端点 `POST /v1/messages` 供外部应用调用。
网关 SHALL 提供 Anthropic Messages API 端点供外部应用调用。
#### Scenario: 成功的非流式请求
- **WHEN** 应用发送 POST 请求到 `/v1/messages`,携带有效的 Anthropic 请求格式(非流式)
- **THEN** 网关 SHALL 将 Anthropic 请求转换为 OpenAI 格式
- **THEN** 网关 SHALL 将转换后的请求转发到配置的供应商
- **THEN** 网关 SHALL 将 OpenAI 响应转换 Anthropic 格式
- **THEN** 网关 SHALL 将转换后的响应返回给应用
- **WHEN** 应用发送 POST 请求到 `/anthropic/v1/messages`,携带有效的 Anthropic 请求格式(非流式)
- **THEN** 网关 SHALL 通过 ConversionEngine 将 Anthropic 请求解码为 Canonical 格式
- **THEN** 网关 SHALL 将 Canonical 请求编码为目标供应商协议格式
- **THEN** 网关 SHALL 将供应商的响应通过 ConversionEngine 转换 Anthropic 格式返回给应用
#### Scenario: 成功的流式请求
- **WHEN** 应用发送 POST 请求到 `/v1/messages`,携带 `stream: true`
- **THEN** 网关 SHALL 将 Anthropic 请求转换为 OpenAI 格式
- **THEN** 网关 SHALL 将转换后的请求转发给供应商
- **THEN** 网关 SHALL 将 OpenAI 流事件转换为 Anthropic 流事件
- **THEN** 网关 SHALL 使用 SSE 格式将转换后的事件流式返回给应用
- **WHEN** 应用发送 POST 请求到 `/anthropic/v1/messages`,携带 `stream: true`
- **THEN** 网关 SHALL 通过 ConversionEngine 创建 StreamConverter
- **THEN** 网关 SHALL 将上游协议的 SSE 流转换为 Anthropic 命名事件格式
- **THEN** 网关 SHALL 使用 `event: <type>\ndata: <json>\n\n` 格式流式返回给应用
**变更说明:** handler 通过 service 层调用,而非直接调用 config 和 provider 包。API 接口保持不变。
#### Scenario: 同协议透传Anthropic → Anthropic Provider
### Requirement: 将 Anthropic 请求转换为 OpenAI 格式
- **WHEN** 客户端使用 Anthropic 协议且目标供应商也是 Anthropic 协议
- **THEN** 网关 SHALL 跳过 Canonical 转换,仅重建认证 Header 后原样转发
- **THEN** 请求和响应 Body SHALL 保持原样
网关 SHALL 将 Anthropic Messages API 请求转换为 OpenAI Chat Completions API 格式。
### Requirement: 双向协议转换
#### Scenario: System 消息转换
网关 SHALL 支持 Anthropic 协议与任意已注册协议间的双向转换。
- **WHEN** Anthropic 请求包含 `system` 字段
- **THEN** 网关 SHALL 将其转换为 `messages` 数组中 `role: "system"` 的消息
#### Scenario: Anthropic 客户端 → OpenAI 供应商
#### Scenario: Messages 转换
- **WHEN** 客户端使用 Anthropic 协议且供应商使用 OpenAI 协议
- **THEN** SHALL 将 Anthropic MessagesRequest 解码为 CanonicalRequest
- **THEN** SHALL 将 CanonicalRequest 编码为 OpenAI ChatCompletionRequest
- **THEN** SHALL 将 OpenAI ChatCompletionResponse 解码为 CanonicalResponse
- **THEN** SHALL 将 CanonicalResponse 编码为 Anthropic MessagesResponse
- **WHEN** Anthropic 请求包含 `messages` 数组
- **THEN** 网关 SHALL 在转换后的 OpenAI 请求中保留这些消息
- **THEN** 网关 SHALL 保留每条消息的 role 和 content
#### Scenario: OpenAI 客户端 → Anthropic 供应<E4BE9B><E5BA94>
**变更说明:** 协议转换逻辑保持不变,仅调用方式改为通过 service 层。
- **WHEN** 客户端使用 OpenAI 协议且供应商使用 Anthropic 协议
- **THEN** SHALL 将 OpenAI ChatCompletionRequest 解码为 CanonicalRequest
- **THEN** SHALL 将 CanonicalRequest 编码为 Anthropic MessagesRequest
- **THEN** SHALL 将 Anthropic MessagesResponse 解码为 CanonicalResponse
- **THEN** SHALL 将 CanonicalResponse 编码为 OpenAI ChatCompletionResponse
## ADDED Requirements
@@ -49,9 +54,9 @@ Handler SHALL 通过 service 层处理业务逻辑。
#### Scenario: 调用 routing service
- **WHEN** handler 收到请求并转换为 OpenAI 格式
- **WHEN** ProxyHandler 收到 Anthropic 协议请求
- **THEN** SHALL 调用 RoutingService.Route() 获取路由结果
- **THEN** SHALL 使用路由结果中的供应商信息
- **THEN** SHALL 路由结果获取 Provider含 protocol 字段)
#### Scenario: 调用 stats service
@@ -61,16 +66,22 @@ Handler SHALL 通过 service 层处理业务逻辑。
### Requirement: 使用结构化错误处理
Handler SHALL 使用结构化错误处理
ProxyHandler SHALL 使用 ConversionError 和 Anthropic 的 encodeError 处理错误
#### Scenario: 协议转换错误
- **WHEN** 协议转换失败
- **THEN** SHALL 返回结构化错误响应
- **THEN** SHALL 包含详细的错误信息
- **WHEN** ConversionEngine 返回 ConversionError
- **THEN** SHALL 使用 Anthropic 的 Adapter.encodeError 编码错误响应
- **THEN** SHALL 使用 Anthropic 错误格式(`{type: "error", error: {type, message}}`
#### Scenario: 路由错误处理
- **WHEN** RoutingService 返回错误
- **THEN** SHALL 转换为对应的 AppError
- **THEN** SHALL 返回统一的错误响应
- **THEN** SHALL 转换为 ConversionError
- **THEN** SHALL 使用 Anthropic 错误格式返回
#### Scenario: 供应商错误处理
- **WHEN** ProviderClient 返回错误
- **THEN** SHALL 包装为 ConversionError
- **THEN** SHALL 使用 Anthropic 错误格式返回

View File

@@ -1,3 +1,5 @@
# Conversion Engine
## ADDED Requirements
### Requirement: 定义 CanonicalRequest 规范模型
@@ -264,7 +266,7 @@ ErrorCode SHALL 包含INVALID_INPUT、MISSING_REQUIRED_FIELD、INCOMPATIBLE_F
- **THEN** SHALL 使用对应扩展层 Canonical Model 做轻量字段映射
- **THEN** 双方都不支持时 SHALL 走透传逻辑
### Requirement: 义 TargetProvider 结构体
### Requirement: <EFBFBD><EFBFBD>义 TargetProvider 结构体
系统 SHALL 定义 `TargetProvider` 结构体,包含 `base_url``api_key``model_name``adapter_config`
@@ -273,4 +275,4 @@ ErrorCode SHALL 包含INVALID_INPUT、MISSING_REQUIRED_FIELD、INCOMPATIBLE_F
- **WHEN** Adapter 调用 buildHeaders(provider)
- **THEN** SHALL 从 provider.api_key 提取认证信息
- **THEN** SHALL 从 provider.adapter_config 提取协议专属配置
- **THEN** SHALL 使用 provider.model_name 覆盖请求中的 model 字段
- **THEN** SHALL 使用 provider.model_name 覆盖请求中的 model 字段

View File

@@ -29,6 +29,12 @@
- **THEN** SHALL 使用 ErrModelNotFound、ErrProviderNotFound 等预定义错误
- **THEN** SHALL 设置 HTTP 状态码为 404
#### Scenario: 转换错误响应
- **WHEN** ConversionEngine 在协议转换过程中产生 ConversionError
- **THEN** SHALL 使用客户端协议的 Adapter.encodeError 编码错误响应
- **THEN** 错误响应 SHALL 使用客户端可理解的协议格式
#### Scenario: 验证错误
- **WHEN** 请求验证失败
@@ -41,6 +47,24 @@
- **THEN** SHALL 使用 ErrInternal 等预定义错误
- **THEN** SHALL 设置 HTTP 状态码为 500
#### Scenario: 请求创建错误
- **WHEN** 创建 HTTP 请求失败
- **THEN** SHALL 使用 ErrRequestCreate 预定义错误
- **THEN** SHALL 设置 HTTP 状态码为 500
#### Scenario: 请求发送错误
- **WHEN** 发送 HTTP 请求失败
- **THEN** SHALL 使用 ErrRequestSend 预定义错误
- **THEN** SHALL 设置 HTTP 状态码为 500
#### Scenario: 响应读取错误
- **WHEN** 读取 HTTP 响应失败
- **THEN** SHALL 使用 ErrResponseRead 预定义错误
- **THEN** SHALL 设置 HTTP 状态码为 500
### Requirement: 支持错误包装
系统 SHALL 支持错误包装。
@@ -120,3 +144,50 @@
- **WHEN** repository 层发生错误
- **THEN** SHALL 包装数据库错误
- **THEN** SHALL 转换为应用错误
### Requirement: 使用类型安全错误判断
系统 SHALL 使用类型安全方式判断错误类型。
#### Scenario: 数据库错误判断
- **WHEN** 判断数据库唯一约束错误
- **THEN** SHALL 使用 errors.Is(err, gorm.ErrDuplicatedKey)
- **THEN** SHALL NOT 使用字符串匹配 err.Error()
#### Scenario: 网络错误判断
- **WHEN** 判断网络错误
- **THEN** SHALL 使用 errors.As(err, &net.Error) 判断网络错误
- **THEN** SHALL 使用 errors.As(err, &net.OpError) 判断操作错误
- **THEN** SHALL 使用 errors.Is(opErr.Err, syscall.ECONNRESET) 判断连接重置
- **THEN** SHALL NOT 使用字符串匹配判断错误类型
#### Scenario: 错误链判断
- **WHEN** 判断错误链中的特定错误
- **THEN** SHALL 使用 errors.Is 进行链式判断
- **THEN** SHALL 使用 errors.As 提取特定类型错误
## ADDED Requirements
### Requirement: 定义 ConversionError 错误类型
系统 SHALL 定义 ConversionError 结构体和 ErrorCode 枚举。
#### Scenario: ConversionError 结构
- **WHEN** 定义转换错误
- **THEN** SHALL 包含 CodeErrorCode 枚举、Message 字段
- **THEN** SHALL 可选包含 ClientProtocol、ProviderProtocol、InterfaceType、Details、Cause 字段
#### Scenario: ErrorCode 枚举
- **WHEN** 定义错误码
- **THEN** SHALL 包含 INVALID_INPUT、MISSING_REQUIRED_FIELD、INCOMPATIBLE_FEATURE、FIELD_MAPPING_FAILURE、TOOL_CALL_PARSE_ERROR、JSON_PARSE_ERROR、STREAM_STATE_ERROR、UTF8_DECODE_ERROR、PROTOCOL_CONSTRAINT_VIOLATION、ENCODING_FAILURE、INTERFACE_NOT_SUPPORTED
#### Scenario: 错误码到协议错误类型的映射
- **WHEN** 使用 encodeError 编码错误
- **THEN** ErrorCode SHALL 映射为各协议的错误类型字符串
- **THEN** 例如 INVALID_INPUT → OpenAI "invalid_request_error"Anthropic "invalid_request_error"

View File

@@ -4,28 +4,37 @@
### Requirement: 实现三层架构
系统 SHALL 实现 handler → service → repository 三层架构。
系统 SHALL 实现 handler → service → repository 三层架构,并在 handler 和 provider 之间新增 conversion 层
#### Scenario: Handler 层职责
- **WHEN** 处理 HTTP 请求
- **THEN** handler 层 SHALL 仅负责 HTTP 请求解析和响应
- **THEN** handler 层 SHALL 仅负责 HTTP 请求解析、URL 路由和响应写入
- **THEN** handler 层 SHALL 调用 ConversionEngine 处理协议转换
- **THEN** handler 层 SHALL 调用 service 层处理业务逻辑
- **THEN** handler 层 SHALL NOT 直接访问数据库
- **THEN** handler 层 SHALL NOT 直接访问数据库或执行协议转换逻辑
#### Scenario: Conversion 层职责
- **WHEN** 处理协议转换
- **THEN** conversion 层 SHALL 包含 Canonical Model 定义
- **THEN** conversion 层 SHALL 包含各协议的 ProtocolAdapter 实现
- **THEN** conversion 层 SHALL 包含 ConversionEngine 门面
- **THEN** conversion 层 SHALL NOT 依赖 handler 或 service 层
#### Scenario: Service 层职责
- **WHEN** 处理业务逻辑
- **THEN** service 层 SHALL 包含业务规则和验证
- **THEN** service 层 SHALL 调用 repository 层访问数据
- **THEN** service 层 SHALL 协调多个 repository 的操作
- **THEN** service 层 SHALL NOT 包含协议转换逻辑
#### Scenario: Repository 层职责
- **WHEN** 访问数据
- **THEN** repository 层 SHALL 仅负责数据访问
- **THEN** repository 层 SHALL 封装数据库操作
- **THEN** repository 层 SHALL NOT 包含业务逻辑
- **THEN** repository 层 SHALL NOT 包含业务逻辑或协议转换逻辑
### Requirement: 定义核心接口
@@ -49,9 +58,17 @@
- **WHEN** 定义 provider client 接口
- **THEN** SHALL 定义 ProviderClient 接口
- **THEN** SHALL 包含 SendRequest 和 SendStreamRequest 方法
- **THEN** SHALL 包含 Send(非流式)和 SendStream(流式)方法
- **THEN** SHALL 接受 HTTPRequestSpec 作为参数,不绑定特定协议
- **THEN** SHALL 支持接口 Mock
#### Scenario: Conversion 层接口定义
- **WHEN** 定义 conversion 层接口
- **THEN** SHALL 定义 ProtocolAdapter、StreamDecoder、StreamEncoder、StreamConverter、ConversionMiddleware 接口
- **THEN** SHALL 定义 AdapterRegistry 用于 Adapter 注册和查询
- **THEN** SHALL 定义 ConversionEngine 作为统一门面
### Requirement: 实现依赖注入
系统 SHALL 使用手动依赖注入。
@@ -65,15 +82,24 @@
#### Scenario: Service 注入
- **WHEN** 初始化 handler
- **THEN** SHALL 通过构造函数注入 service 依赖
- **THEN** SHALL 通过构造函数注入 service 依赖、ConversionEngine、ProviderClient
- **THEN** SHALL 使用接口类型而非具体类型
#### Scenario: Conversion 组装
- **WHEN** 应用启动
- **THEN** SHALL 创建 AdapterRegistry 并注册所有 ProtocolAdapter
- **THEN** SHALL 创建 ConversionEngine注入 registry 和 middleware chain
- **THEN** SHALL 将 ConversionEngine 注入到 ProxyHandler
#### Scenario: 主函数组装
- **WHEN** 应用启动
- **THEN** main.go SHALL 按顺序构造所有依赖
- **THEN** SHALL 先构造基础设施logger、database
- **THEN** SHALL 再构造 repository、service、handler
- **THEN** SHALL 再构造 repository、service
- **THEN** SHALL 再构造 conversion 层registry → engine
- **THEN** SHALL 最后构造 handler
### Requirement: 定义 Domain 模型
@@ -84,6 +110,7 @@
- **WHEN** 定义领域模型
- **THEN** SHALL 在 internal/domain/ 包中定义
- **THEN** SHALL 包含 Provider、Model、UsageStats 等模型
- **THEN** Provider SHALL 包含 Protocol 字段
- **THEN** SHALL 与数据库模型分离
#### Scenario: Domain 模型使用

View File

@@ -4,22 +4,27 @@
### Requirement: 支持 OpenAI Chat Completions API 端点
网关 SHALL 提供 OpenAI Chat Completions API 端点 `POST /v1/chat/completions` 供外部应用调用。
网关 SHALL 提供 OpenAI Chat Completions API 端点供外部应用调用。
#### Scenario: 成功的非流式请求
- **WHEN** 应用发送 POST 请求到 `/v1/chat/completions`,携带有效的 OpenAI 请求格式(非流式)
- **THEN** 网关 SHALL 将请求转发到配置的供应商
- **THEN** 网关 SHALL 将供应商的响应以 OpenAI 格式返回给应用
- **WHEN** 应用发送 POST 请求到 `/openai/v1/chat/completions`,携带有效的 OpenAI 请求格式(非流式)
- **THEN** 网关 SHALL 通过 ConversionEngine 转换请求
- **THEN** 网关 SHALL 将转换后的请求转发到配置的供应商
- **THEN** 网关 SHALL 将供应商的响应通过 ConversionEngine 转换为 OpenAI 格式返回给应用
#### Scenario: 成功的流式请求
- **WHEN** 应用发送 POST 请求到 `/v1/chat/completions`,携带 `stream: true`
- **THEN** 网关 SHALL 将请求转发到配置的供应商
- **THEN** 网关 SHALL 使用 SSE 格式将响应流式返回给应用
- **WHEN** 应用发送 POST 请求到 `/openai/v1/chat/completions`,携带 `stream: true`
- **THEN** 网关 SHALL 通过 ConversionEngine 创建 StreamConverter
- **THEN** 网关 SHALL 使用 SSE 格式将转换后的响应流式返回给应用
- **THEN** 网关 SHALL 在流完成时发送 `data: [DONE]`
**变更说明:** handler 通过 service 层调用,而非直接调用 config 和 provider 包。API 接口保持不变。
#### Scenario: 同协议透传OpenAI → OpenAI Provider
- **WHEN** 客户端使用 OpenAI 协议且目标供应商也是 OpenAI 协议
- **THEN** 网关 SHALL 跳过 Canonical 转换,仅重建认证 Header 后原样转发
- **THEN** 请求和响应 Body SHALL 保持原样
### Requirement: 根据模型名称路由请求
@@ -30,38 +35,32 @@
- **WHEN** 请求包含存在于配置模型中的 `model` 字段
- **AND** 该模型已启用
- **THEN** 网关 SHALL 将请求路由到该模型关联的供应商
- **THEN** 网关 SHALL 从供应商的 `protocol` 字段获取 providerProtocol
#### Scenario: 模型未找到
- **WHEN** 请求包含不存在于配置模型中的 `model` 字段
- **THEN** 网关 SHALL 返回带有适当错误消息的错误响应
- **THEN** 网关 SHALL 使用 OpenAI 格式返回错误响应
#### Scenario: 模型已禁用
- **WHEN** 请求包含已禁用模型的 `model` 字段
- **THEN** 网关 SHALL 返回错误响应,指示模型不可用
**变更说明:** 路由逻辑从 router 包迁移到 RoutingService通过 service 层调用。API 接口保持不变。
- **THEN** 网关 SHALL 使用 OpenAI 格式返回错误响应
### Requirement: 对 OpenAI 兼容供应商透明代理
网关 SHALL 对 OpenAI 兼容供应商的请求和响应进行透明转发,不做修改
网关 SHALL 对 OpenAI 兼容供应商的请求和响应通过 ConversionEngine 进行转换处理
#### Scenario: 请求转发
#### Scenario: 跨协议请求转发
- **WHEN** 网关收到 OpenAI 协议请求
- **AND** 目标供应商是 OpenAI 兼容的
- **THEN** 网关 SHALL 将请求体原样转发给供应商
- **THEN** 网关 SHALL 在 Authorization 头中设置供应商的 API Key
- **THEN** 网关 SHALL 使用供应商的 base URL
- **WHEN** 网关收到 OpenAI 协议请求且目标供应商使用不同协议
- **THEN** 网关 SHALL 通过 ConversionEngine 将请求转换为目标协议格式
- **THEN** 网关 SHALL 使用目标协议的 Adapter 构建 URL 和 Header
#### Scenario: 响应转发
#### Scenario: 扩展层接口代理
- **WHEN** 供应商返回响应
- **THEN** 网关 SHALL 将响应体原样返回给应用
- **THEN** 网关 SHALL 保留所有响应头和状态码
**变更说明:** provider client 通过接口注入到 handler便于测试和替换实现。API 接口保持不变。
- **WHEN** 网关收到 `/openai/v1/models` 等 GET 请求
- **THEN** 网关 SHALL 通过 ConversionEngine 转换扩展层接口的响应格式
## ADDED Requirements
@@ -81,18 +80,40 @@ Handler SHALL 通过 service 层处理业务逻辑。
- **THEN** SHALL 调用 StatsService.Record() 记录统计
- **THEN** SHALL 异步记录统计(不阻塞响应)
### Requirement: 使用 service 层处理请求
Handler SHALL 通过 service 层处理业务逻辑。
#### Scenario: 调用 routing service
- **WHEN** ProxyHandler 收到请求
- **THEN** SHALL 调用 RoutingService.Route() 获取路由结果
- **THEN** SHALL 从路由结果获取 Provider含 protocol 字段)
#### Scenario: 调用 stats service
- **WHEN** 请求成功完成
- **THEN** SHALL 调用 StatsService.Record() 记录统计
- **THEN** SHALL 异步记录统计(不阻塞响应)
### Requirement: 使用结构化错误处理
Handler SHALL 使用结构化错误处理
ProxyHandler SHALL 使用 ConversionError 和协议对应的 encodeError 处理错误
#### Scenario: 转换错误
- **WHEN** ConversionEngine 返回 ConversionError
- **THEN** SHALL 使用 clientProtocol 的 Adapter.encodeError 编码错误响应
- **THEN** SHALL 使用 OpenAI 错误格式(`{error: {message, type, code}}`
#### Scenario: 路由错误处理
- **WHEN** RoutingService 返回错误
- **THEN** SHALL 转换为对应的 AppError
- **THEN** SHALL 返回统一的错误响应
- **THEN** SHALL 转换为 ConversionError
- **THEN** SHALL 使用 OpenAI 错误格式返回
#### Scenario: 供应商错误处理
- **WHEN** ProviderClient 返回错误
- **THEN** SHALL 包装为 AppError
- **THEN** SHALL 包含请求上下文信息
- **THEN** SHALL 包装为 ConversionError
- **THEN** SHALL 使用 OpenAI 错误格式返回

View File

@@ -1,3 +1,5 @@
# Protocol Adapter - Anthropic
## ADDED Requirements
### Requirement: 实现 Anthropic ProtocolAdapter
@@ -266,4 +268,4 @@ Decoder 几乎 1:1 映射,维护最小状态机:
- **WHEN** interfaceType 为 EMBEDDINGS 或 RERANK
- **THEN** supportsInterface SHALL 返回 false
- **THEN** 引擎 SHALL 走透传或返回空响应
- **THEN** 引擎 SHALL 走透传或返回空响应

View File

@@ -1,3 +1,5 @@
# Protocol Adapter - OpenAI
## ADDED Requirements
### Requirement: 实现 OpenAI ProtocolAdapter
@@ -85,7 +87,7 @@
- **WHEN** canonical.system 不为空
- **THEN** SHALL 编码为 messages 数组头部的 role="system" 消息
#### Scenario: Assistant 息中 tool_calls 编码
#### Scenario: Assistant <EFBFBD><EFBFBD>息中 tool_calls 编码
- **WHEN** CanonicalMessage{role: "assistant"} 包含 tool_use 类型 ContentBlock
- **THEN** SHALL 提取到 message.tool_calls 数组({id, type: "function", function: {name, arguments}}
@@ -265,4 +267,4 @@ Encoder SHALL 维护状态:
#### Scenario: /rerank 接口
- **WHEN** 解码/编码 rerank 请求和响应
- **THEN** SHALL 使用 CanonicalRerankRequest/Response 做字段映射
- **THEN** SHALL 使用 CanonicalRerankRequest/Response 做字段映射

View File

@@ -8,10 +8,11 @@
#### Scenario: 使用有效数据创建供应商
- **WHEN** 向 `/api/providers` 发送 POST 请求携带有效的供应商数据id, name, api_key, base_url
- **WHEN** 向 `/api/providers` 发送 POST 请求携带有效的供应商数据id, name, api_key, base_url, protocol
- **THEN** 网关 SHALL 在数据库中创建新的供应商记录
- **THEN** 网关 SHALL 返回创建的供应商,状态码为 201
- **THEN** 供应商 SHALL 默认启用
- **THEN** protocol 字段 SHALL 默认为 "openai"
#### Scenario: 使用重复 ID 创建供应商
@@ -34,7 +35,7 @@
- **WHEN** 向 `/api/providers` 发送 GET 请求
- **THEN** 网关 SHALL 返回所有供应商的列表
- **THEN** 每个供应商 SHALL 包含 id, name, api_key已掩码, base_url, enabled, created_at, updated_at
- **THEN** 每个供应商 SHALL 包含 id, name, api_key已掩码, base_url, protocol, enabled, created_at, updated_at
- **THEN** api_key SHALL 被掩码(仅显示最后 4 个字符)
**变更说明:** 数据访问从 config 包迁移到 ProviderRepository。API 接口保持不变。
@@ -47,6 +48,7 @@
- **WHEN** 向 `/api/providers/:id` 发送 GET 请求,携带有效的供应商 ID
- **THEN** 网关 SHALL 返回供应商详情
- **THEN** SHALL 包含 protocol 字段
- **THEN** api_key SHALL 被掩码
#### Scenario: 获取不存在的供应商
@@ -65,7 +67,7 @@
- **WHEN** 向 `/api/providers/:id` 发送 PUT 请求,携带有效的供应商数据
- **THEN** 网关 SHALL 更新数据库中的供应商记录
- **THEN** 网关 SHALL 返回更新后的供应商
- **THEN** updated_at 时间戳 SHALL 被更新
- **THEN** 更新 SHALL 支持修改 protocol 字段
**变更说明:** 通过 ProviderService 和 ProviderRepository 实现。API 接口保持不变。

View File

@@ -20,43 +20,42 @@
### Requirement: 验证 OpenAI 请求
系统 SHALL 验证 OpenAI ChatCompletionRequest。
系统 SHALL 验证 OpenAI ChatCompletionRequest,验证逻辑位于 ProtocolAdapter 的 decodeRequest 内
#### Scenario: 必需字段验证
- **WHEN** 收到 OpenAI 请求
- **WHEN** OpenAI Adapter 的 decodeRequest 解析请求
- **THEN** SHALL 验证 model 字段不为空
- **THEN** SHALL 验证 messages 字段不为空且至少有一条消息
- **THEN** 验证失败 SHALL 返回 INVALID_INPUT 类型的 ConversionError
#### Scenario: 参数范围验证
- **WHEN** 收到 OpenAI 请求
- **WHEN** OpenAI Adapter 的 decodeRequest 解析参数
- **THEN** SHALL 验证 temperature 范围在 [0, 2]
- **THEN** SHALL 验证 max_tokens 大于 0
- **THEN** SHALL 验证 top_p 范围在 (0, 1]
- **THEN** SHALL 验证 frequency_penalty 范围在 [-2, 2]
- **THEN** SHALL 验证 presence_penalty 范围在 [-2, 2]
#### Scenario: 消息内容验证
- **WHEN** 验证 messages 字段
- **THEN** SHALL 验证每条消息的 role 有效system、user、assistant、tool
- **THEN** SHALL 验证每条消息的 role 有效system、developer、user、assistant、tool
- **THEN** SHALL 验证 content 不为空
### Requirement: 验证 Anthropic 请求
系统 SHALL 验证 Anthropic MessagesRequest。
系统 SHALL 验证 Anthropic MessagesRequest,验证逻辑位于 ProtocolAdapter 的 decodeRequest 内
#### Scenario: 必需字段验证
- **WHEN** 收到 Anthropic 请求
- **WHEN** Anthropic Adapter 的 decodeRequest 解析请求
- **THEN** SHALL 验证 model 字段不为空
- **THEN** SHALL 验证 messages 字段不为空且至少有一条消息
- **THEN** SHALL 验证 max_tokens 大于 0或使用默认值
#### Scenario: 参数范围验证
- **WHEN** 收到 Anthropic 请求
- **WHEN** Anthropic Adapter 的 decodeRequest 解析参数
- **THEN** SHALL 验证 temperature 范围在 [0, 1]
- **THEN** SHALL 验证 top_p 范围在 (0, 1]
@@ -93,26 +92,17 @@
系统 SHALL 返回友好的验证错误响应。
#### Scenario: 错误消息格式
#### Scenario: 转换错误格式
- **WHEN** 验证失败
- **THEN** SHALL 返回 400 状态码
- **THEN** SHALL 返回详细的错误消息
- **THEN** SHALL 指示哪些字段验证失败
- **WHEN** decodeRequest 验证失败返回 ConversionError
- **THEN** ProxyHandler SHALL 使用 clientAdapter.encodeError 编码错误响应
- **THEN** 错误 SHALL 使用客户端协议的格式
#### Scenario: 多字段错误
- **WHEN** 多个字段验证失败
- **THEN** SHALL 返回所有验证错误
- **THEN** SHALL 使用结构化格式(字段名 → 错误消息)
#### Scenario: 国际化支持
- **WHEN** 返回验证错误(未来)
- **THEN** SHALL 支持错误消息国际化
- **THEN** SHALL 使用错误码作为国际化 key
注:当前版本使用中文错误消息。
- **THEN** ConversionError.details SHALL 包含所有验证错误
- **THEN** 错误响应 SHALL 包含完整的验证错误信息
### Requirement: 在 handler 中应用验证
@@ -130,3 +120,29 @@
- **WHEN** 处理请求
- **THEN** SHALL 在 handler 函数开始时验证
- **THEN** SHALL 在验证通过后才执行业务逻辑
### Requirement: 使用标准库解析 JSON
系统 SHALL 使用 encoding/json 标准库解析 JSON 请求。
#### Scenario: 提取 model 字段
- **WHEN** 从请求体提取 model 字段
- **THEN** SHALL 使用 json.Unmarshal 解析到结构体
- **THEN** SHALL NOT 手动扫描字节查找字段
- **THEN** 解析失败 SHALL 返回空字符串(不报错)
#### Scenario: 检测 stream 字段
- **WHEN** 检测请求是否为流式请求
- **THEN** SHALL 使用 json.Unmarshal 解析到结构体
- **THEN** SHALL NOT 手动扫描字节查找字段
- **THEN** 解析失败 SHALL 返回 false非流式
#### Scenario: JSON 解析健壮性
- **WHEN** 解析 JSON 请求体
- **THEN** SHALL 正确处理转义字符
- **THEN** SHALL 正确处理嵌套结构
- **THEN** SHALL 正确处理 Unicode 字符
- **THEN** 解析失败 SHALL 有明确的错误处理

View File

@@ -20,6 +20,13 @@
- **THEN** SHALL 支持嵌套字段
- **THEN** SHALL 自动包含时间戳和日志级别
#### Scenario: 日志注入
- **WHEN** 创建需要记录日志的组件
- **THEN** SHALL 通过构造函数注入 *zap.Logger
- **THEN** SHALL 允许 logger 参数为 nil此时使用全局 logger zap.L()
- **THEN** SHALL NOT 直接使用全局 logger zap.L()(除非在构造函数默认值中)
### Requirement: 支持日志滚动
系统 SHALL 支持日志文件滚动,使用 lumberjack。
@@ -122,3 +129,20 @@
- **WHEN** 创建日志文件
- **THEN** SHALL 使用 `nex-YYYY-MM-DD.log` 格式命名
- **THEN** SHALL 按日期创建新文件
### Requirement: ConversionEngine 日志注入
ConversionEngine SHALL 通过依赖注入获取 logger。
#### Scenario: ConversionEngine 构造函数
- **WHEN** 创建 ConversionEngine 实例
- **THEN** 构造函数 SHALL 接受 *zap.Logger 参数
- **THEN** 参数为 nil 时 SHALL 使用 zap.L() 作为默认值
- **THEN** SHALL 将 logger 存储在结构体字段中
#### Scenario: ConversionEngine 日志使用
- **WHEN** ConversionEngine 记录日志
- **THEN** SHALL 使用注入的 logger 字段
- **THEN** SHALL NOT 直接调用 zap.L()

View File

@@ -1,3 +1,5 @@
# Unified Proxy Handler
## ADDED Requirements
### Requirement: 实现统一代理 Handler
@@ -92,11 +94,11 @@ ProxyHandler SHALL 记录请求统计。
ProxyHandler SHALL 支持 GET 请求的扩展层接口代理。
#### Scenario: Models 接口
#### Scenario: Models 接口<EFBFBD><EFBFBD>
- **WHEN** 收到 GET /{protocol}/v1/models 请求
- **THEN** SHALL 执行路由和协议识别
- **THEN** SHALL 调用 engine.convertHttpRequestGET 请求 body 为空)
- **THEN** SHALL 调用 providerClient.Send 发送请求
- **THEN** SHALL 调用 engine.convertHttpResponse 转换响应格式
- **THEN** SHALL 返回转换后的响应
- **THEN** SHALL 返回转换后的响应