diff --git a/backend/README.md b/backend/README.md index 666912e..860923b 100644 --- a/backend/README.md +++ b/backend/README.md @@ -329,3 +329,61 @@ make lint ### 环境要求 - Go 1.26 或更高版本 + +## 公共库使用指南 + +### pkg/errors — 结构化错误 + +使用预定义的错误类型,配合 `errors.Is` / `errors.As` 判断错误: + +```go +import ( + "errors" + pkgErrors "nex/backend/pkg/errors" +) + +// 使用预定义错误 +return pkgErrors.ErrRequestSend.WithCause(err) + +// 判断错误类型 +var appErr *pkgErrors.AppError +if errors.As(err, &appErr) { + // appErr.Code, appErr.HTTPStatus, appErr.Message +} +``` + +可用函数:`NewAppError`、`Wrap`、`WithContext`、`WithMessage`、`AsAppError` + +预定义错误:`ErrModelNotFound`、`ErrProviderNotFound`、`ErrInvalidRequest`、`ErrRequestCreate`、`ErrRequestSend`、`ErrResponseRead` 等 + +### pkg/logger — 日志系统 + +使用依赖注入模式,构造函数接受 `*zap.Logger` 参数,nil 时回退到 `zap.L()`: + +```go +func NewMyService(repo Repository, logger *zap.Logger) *MyService { + if logger == nil { + logger = zap.L() + } + return &MyService{repo: repo, logger: logger} +} +``` + +禁止直接在业务代码中使用 `zap.L()` 全局 logger,应通过构造函数注入。 + +### pkg/validator — 请求验证 + +```go +import "nex/backend/pkg/validator" + +v := validator.Get() +err := v.Validate(myStruct) +``` + +## 编码规范 + +- **JSON 解析**:使用 `encoding/json` 标准库(`json.Unmarshal` / `json.Marshal`),不手动扫描字节 +- **字符串拼接**:使用 `strings.Join`,不手写循环拼接 +- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(`strings.Contains(err.Error(), ...)`) +- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()` +- **字符串分割**:使用 `strings.SplitN(key, "/", 2)` 等精确分割,不使用索引切片 diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 16bb983..e50e47d 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -81,7 +81,7 @@ func main() { if err := registry.Register(anthropic.NewAdapter()); err != nil { zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error())) } - engine := conversion.NewConversionEngine(registry) + engine := conversion.NewConversionEngine(registry, zapLogger) // 7. 初始化 provider client providerClient := provider.NewClient() diff --git a/backend/internal/conversion/anthropic/adapter.go b/backend/internal/conversion/anthropic/adapter.go index 350f981..9ae6c17 100644 --- a/backend/internal/conversion/anthropic/adapter.go +++ b/backend/internal/conversion/anthropic/adapter.go @@ -2,7 +2,6 @@ package anthropic import ( "encoding/json" - "regexp" "strings" "nex/backend/internal/conversion" @@ -17,8 +16,6 @@ func NewAdapter() *Adapter { return &Adapter{} } -var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`) - // ProtocolName 返回协议名称 func (a *Adapter) ProtocolName() string { return "anthropic" } @@ -35,13 +32,22 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp return conversion.InterfaceTypeChat case nativePath == "/v1/models": return conversion.InterfaceTypeModels - case modelInfoRegex.MatchString(nativePath): + case isModelInfoPath(nativePath): return conversion.InterfaceTypeModelInfo default: return conversion.InterfaceTypePassthrough } } +// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}) +func isModelInfoPath(path string) bool { + if !strings.HasPrefix(path, "/v1/models/") { + return false + } + suffix := path[len("/v1/models/"):] + return suffix != "" && !strings.Contains(suffix, "/") +} + // BuildUrl 根据接口类型构建 URL func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string { switch interfaceType { diff --git a/backend/internal/conversion/engine.go b/backend/internal/conversion/engine.go index 4e6bfa0..2301705 100644 --- a/backend/internal/conversion/engine.go +++ b/backend/internal/conversion/engine.go @@ -28,13 +28,18 @@ type HTTPResponseSpec struct { type ConversionEngine struct { registry AdapterRegistry middlewareChain *MiddlewareChain + logger *zap.Logger } // NewConversionEngine 创建转换引擎 -func NewConversionEngine(registry AdapterRegistry) *ConversionEngine { +func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine { + if logger == nil { + logger = zap.L() + } return &ConversionEngine{ registry: registry, middlewareChain: NewMiddlewareChain(), + logger: logger, } } @@ -251,12 +256,12 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { models, err := providerAdapter.DecodeModelsResponse(body) if err != nil { - zap.L().Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) + e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) return body, nil } encoded, err := clientAdapter.EncodeModelsResponse(models) if err != nil { - zap.L().Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) + e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) return body, nil } return encoded, nil @@ -265,12 +270,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { info, err := providerAdapter.DecodeModelInfoResponse(body) if err != nil { - zap.L().Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) + e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) return body, nil } encoded, err := clientAdapter.EncodeModelInfoResponse(info) if err != nil { - zap.L().Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) + e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) return body, nil } return encoded, nil @@ -279,7 +284,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { req, err := clientAdapter.DecodeEmbeddingRequest(body) if err != nil { - zap.L().Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error())) + e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error())) return body, nil } return providerAdapter.EncodeEmbeddingRequest(req, provider) @@ -288,7 +293,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { resp, err := providerAdapter.DecodeEmbeddingResponse(body) if err != nil { - zap.L().Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) + e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) return body, nil } return clientAdapter.EncodeEmbeddingResponse(resp) @@ -297,7 +302,7 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { req, err := clientAdapter.DecodeRerankRequest(body) if err != nil { - zap.L().Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error())) + e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error())) return body, nil } return providerAdapter.EncodeRerankRequest(req, provider) diff --git a/backend/internal/conversion/engine_supplemental_test.go b/backend/internal/conversion/engine_supplemental_test.go index 9f84224..b9c12c8 100644 --- a/backend/internal/conversion/engine_supplemental_test.go +++ b/backend/internal/conversion/engine_supplemental_test.go @@ -39,7 +39,7 @@ func TestConversionError_FullBuilder(t *testing.T) { func TestEngine_Use(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) called := false engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) { called = true @@ -66,7 +66,7 @@ func TestEngine_Use(t *testing.T) { func TestConvertHttpRequest_DecodeError(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client", false) clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) { return nil, errors.New("decode failed") @@ -82,7 +82,7 @@ func TestConvertHttpRequest_DecodeError(t *testing.T) { func TestConvertHttpRequest_EncodeError(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("client", false)) providerAdapter := newMockAdapter("provider", false) providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) { @@ -98,7 +98,7 @@ func TestConvertHttpRequest_EncodeError(t *testing.T) { func TestConvertHttpResponse_CrossProtocol(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client", false) clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) { @@ -121,7 +121,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) { func TestConvertHttpResponse_DecodeError(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) providerAdapter := newMockAdapter("provider", false) providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) { return nil, errors.New("decode error") @@ -135,7 +135,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) { func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client", false) clientAdapter.ifaceType = InterfaceTypeEmbeddings @@ -158,7 +158,7 @@ func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) { func TestConvertHttpRequest_RerankInterface(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client", false) clientAdapter.ifaceType = InterfaceTypeRerank @@ -178,7 +178,7 @@ func TestConvertHttpRequest_RerankInterface(t *testing.T) { func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client", false) clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true} @@ -196,7 +196,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) { func TestConvertHttpResponse_RerankInterface(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client", false) clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true} @@ -214,7 +214,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) { func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client", false) clientAdapter.ifaceType = InterfaceTypeModels providerAdapter := newMockAdapter("provider", false) @@ -232,7 +232,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) { func TestConvertHttpResponse_ModelsInterface(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client", false) clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true} providerAdapter := newMockAdapter("provider", false) @@ -249,7 +249,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) { func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client", false) clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true} providerAdapter := newMockAdapter("provider", false) diff --git a/backend/internal/conversion/engine_test.go b/backend/internal/conversion/engine_test.go index 3f13c07..5c1b38e 100644 --- a/backend/internal/conversion/engine_test.go +++ b/backend/internal/conversion/engine_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap" ) // mockProtocolAdapter 模拟协议适配器 @@ -170,14 +171,29 @@ func (e *noopStreamEncoder) Flush() [][]byte func TestNewConversionEngine(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) assert.NotNil(t, engine) assert.Equal(t, registry, engine.GetRegistry()) } +func TestNewConversionEngine_LoggerInjection(t *testing.T) { + t.Run("nil_logger_uses_global", func(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry, nil) + assert.NotNil(t, engine.logger) + }) + + t.Run("custom_logger", func(t *testing.T) { + registry := NewMemoryRegistry() + customLogger := zap.NewNop() + engine := NewConversionEngine(registry, customLogger) + assert.Equal(t, customLogger, engine.logger) + }) +} + func TestRegisterAdapter(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) adapter := newMockAdapter("test-proto", true) err := engine.RegisterAdapter(adapter) @@ -189,7 +205,7 @@ func TestRegisterAdapter(t *testing.T) { func TestIsPassthrough_SameProtocol(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) adapter := newMockAdapter("openai", true) _ = engine.RegisterAdapter(adapter) @@ -198,7 +214,7 @@ func TestIsPassthrough_SameProtocol(t *testing.T) { func TestIsPassthrough_DifferentProtocol(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) _ = engine.RegisterAdapter(newMockAdapter("anthropic", true)) @@ -207,7 +223,7 @@ func TestIsPassthrough_DifferentProtocol(t *testing.T) { func TestIsPassthrough_NoPassthrough(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("custom", false)) assert.False(t, engine.IsPassthrough("custom", "custom")) @@ -215,7 +231,7 @@ func TestIsPassthrough_NoPassthrough(t *testing.T) { func TestDetectInterfaceType(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) adapter := newMockAdapter("test", true) adapter.ifaceType = InterfaceTypeChat _ = engine.RegisterAdapter(adapter) @@ -227,7 +243,7 @@ func TestDetectInterfaceType(t *testing.T) { func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) _, err := engine.DetectInterfaceType("/v1/chat", "nonexistent") assert.Error(t, err) @@ -235,7 +251,7 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) { func TestConvertHttpRequest_Passthrough(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) provider := NewTargetProvider("https://api.openai.com", "sk-test", "gpt-4") @@ -253,7 +269,7 @@ func TestConvertHttpRequest_Passthrough(t *testing.T) { func TestConvertHttpRequest_CrossProtocol(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client-proto", false) clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) { @@ -285,7 +301,7 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) { func TestConvertHttpResponse_Passthrough(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) spec := HTTPResponseSpec{ @@ -301,7 +317,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) { func TestCreateStreamConverter_Passthrough(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) converter, err := engine.CreateStreamConverter("openai", "openai") @@ -312,7 +328,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) { func TestCreateStreamConverter_Canonical(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("client", false)) _ = engine.RegisterAdapter(newMockAdapter("provider", false)) @@ -324,7 +340,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) { func TestEncodeError(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误") @@ -336,7 +352,7 @@ func TestEncodeError(t *testing.T) { func TestEncodeError_NonExistentProtocol(t *testing.T) { registry := NewMemoryRegistry() - engine := NewConversionEngine(registry) + engine := NewConversionEngine(registry, nil) convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误") body, statusCode, err := engine.EncodeError(convErr, "nonexistent") diff --git a/backend/internal/conversion/openai/adapter.go b/backend/internal/conversion/openai/adapter.go index fdfb5e7..1f21509 100644 --- a/backend/internal/conversion/openai/adapter.go +++ b/backend/internal/conversion/openai/adapter.go @@ -2,7 +2,7 @@ package openai import ( "encoding/json" - "regexp" + "strings" "nex/backend/internal/conversion" "nex/backend/internal/conversion/canonical" @@ -16,8 +16,6 @@ func NewAdapter() *Adapter { return &Adapter{} } -var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`) - // ProtocolName 返回协议名称 func (a *Adapter) ProtocolName() string { return "openai" } @@ -34,7 +32,7 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp return conversion.InterfaceTypeChat case nativePath == "/v1/models": return conversion.InterfaceTypeModels - case modelInfoRegex.MatchString(nativePath): + case isModelInfoPath(nativePath): return conversion.InterfaceTypeModelInfo case nativePath == "/v1/embeddings": return conversion.InterfaceTypeEmbeddings @@ -45,6 +43,15 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp } } +// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}) +func isModelInfoPath(path string) bool { + if !strings.HasPrefix(path, "/v1/models/") { + return false + } + suffix := path[len("/v1/models/"):] + return suffix != "" && !strings.Contains(suffix, "/") +} + // BuildUrl 根据接口类型构建 URL func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string { switch interfaceType { diff --git a/backend/internal/conversion/openai/adapter_test.go b/backend/internal/conversion/openai/adapter_test.go index c220dbd..ef31e5b 100644 --- a/backend/internal/conversion/openai/adapter_test.go +++ b/backend/internal/conversion/openai/adapter_test.go @@ -112,6 +112,28 @@ func TestAdapter_SupportsInterface(t *testing.T) { } } +func TestIsModelInfoPath(t *testing.T) { + tests := []struct { + name string + path string + expected bool + }{ + {"model_info", "/v1/models/gpt-4", true}, + {"model_info_with_dots", "/v1/models/gpt-4.1-preview", true}, + {"models_list", "/v1/models", false}, + {"nested_path", "/v1/models/gpt-4/versions", false}, + {"empty_suffix", "/v1/models/", false}, + {"unrelated", "/v1/chat/completions", false}, + {"partial_prefix", "/v1/model", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, isModelInfoPath(tt.path)) + }) + } +} + func TestAdapter_EncodeError_InvalidInput(t *testing.T) { a := NewAdapter() convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效") diff --git a/backend/internal/conversion/openai/encoder.go b/backend/internal/conversion/openai/encoder.go index ff6de7b..0337bf8 100644 --- a/backend/internal/conversion/openai/encoder.go +++ b/backend/internal/conversion/openai/encoder.go @@ -2,6 +2,7 @@ package openai import ( "encoding/json" + "strings" "time" "nex/backend/internal/conversion" @@ -89,7 +90,7 @@ func encodeSystemAndMessages(req *canonical.CanonicalRequest) []map[string]any { for _, b := range v { parts = append(parts, b.Text) } - text := joinStrings(parts, "\n\n") + text := strings.Join(parts, "\n\n") if text != "" { messages = append(messages, map[string]any{ "role": "system", @@ -132,7 +133,7 @@ func encodeMessage(msg canonical.CanonicalMessage) []map[string]any { if len(toolUses) > 0 { if len(textParts) > 0 { - m["content"] = joinStrings(textParts, "") + m["content"] = strings.Join(textParts, "") } else { m["content"] = nil } @@ -149,7 +150,7 @@ func encodeMessage(msg canonical.CanonicalMessage) []map[string]any { } m["tool_calls"] = tcs } else if len(textParts) > 0 { - m["content"] = joinStrings(textParts, "") + m["content"] = strings.Join(textParts, "") } else { m["content"] = "" } @@ -286,7 +287,7 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { message := map[string]any{"role": "assistant"} if len(toolUses) > 0 { if len(textParts) > 0 { - message["content"] = joinStrings(textParts, "") + message["content"] = strings.Join(textParts, "") } else { message["content"] = nil } @@ -303,13 +304,13 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { } message["tool_calls"] = tcs } else if len(textParts) > 0 { - message["content"] = joinStrings(textParts, "") + message["content"] = strings.Join(textParts, "") } else { message["content"] = "" } if len(thinkingParts) > 0 { - message["reasoning_content"] = joinStrings(thinkingParts, "") + message["reasoning_content"] = strings.Join(thinkingParts, "") } var finishReason *string @@ -488,18 +489,6 @@ func encodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, erro }) } -// joinStrings 拼接字符串切片 -func joinStrings(parts []string, sep string) string { - result := "" - for i, p := range parts { - if i > 0 { - result += sep - } - result += p - } - return result -} - // mergeConsecutiveRoles 合并连续同角色消息(拼接内容) func mergeConsecutiveRoles(messages []map[string]any) []map[string]any { if len(messages) <= 1 { diff --git a/backend/internal/handler/handler_test.go b/backend/internal/handler/handler_test.go index e15006f..89db9ac 100644 --- a/backend/internal/handler/handler_test.go +++ b/backend/internal/handler/handler_test.go @@ -13,6 +13,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" "nex/backend/internal/domain" "nex/backend/internal/provider" @@ -250,3 +251,25 @@ func formatMapErrors(errs map[string]string) string { } return "请求验证失败: " + strings.Join(parts, "; ") } + +// ============ 错误类型判断测试 ============ + +func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) { + h := NewProviderHandler(&mockProviderService{ + err: gorm.ErrDuplicatedKey, + }) + + body, _ := json.Marshal(map[string]string{ + "id": "p1", + "name": "Test", + "api_key": "sk-test", + "base_url": "https://test.com", + }) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + h.CreateProvider(c) + assert.Equal(t, 409, w.Code) +} diff --git a/backend/internal/handler/provider_handler.go b/backend/internal/handler/provider_handler.go index e5cd4e2..34d8f36 100644 --- a/backend/internal/handler/provider_handler.go +++ b/backend/internal/handler/provider_handler.go @@ -1,8 +1,8 @@ package handler import ( + "errors" "net/http" - "strings" "github.com/gin-gonic/gin" "gorm.io/gorm" @@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) { err := h.providerService.Create(provider) if err != nil { - if strings.Contains(err.Error(), "UNIQUE constraint failed") { + if errors.Is(err, gorm.ErrDuplicatedKey) { c.JSON(http.StatusConflict, gin.H{ "error": "供应商 ID 已存在", }) diff --git a/backend/internal/handler/proxy_handler.go b/backend/internal/handler/proxy_handler.go index 0cca484..03923e8 100644 --- a/backend/internal/handler/proxy_handler.go +++ b/backend/internal/handler/proxy_handler.go @@ -2,6 +2,7 @@ package handler import ( "bufio" + "encoding/json" "io" "net/http" "strings" @@ -213,18 +214,14 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s if ifaceType != conversion.InterfaceTypeChat { return false } - for i, b := range body { - if b == '"' && i+8 <= len(body) { - if string(body[i:i+8]) == `"stream"` { - for j := i + 8; j < len(body) && j < i+20; j++ { - if body[j] == 't' && j+3 < len(body) && string(body[j:j+4]) == "true" { - return true - } - } - } - } + + var req struct { + Stream bool `json:"stream"` } - return false + if err := json.Unmarshal(body, &req); err != nil { + return false + } + return req.Stream } // writeConversionError 写入转换错误 @@ -312,51 +309,13 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP // extractModelName 从 JSON body 中提取 model func extractModelName(body []byte) string { - inQuote := false - escaped := false - keyStart := -1 - keyEnd := -1 - lookingForKey := true - lookingForValue := false - valueStart := -1 - - for i := 0; i < len(body); i++ { - b := body[i] - if escaped { - escaped = false - continue - } - if b == '\\' { - escaped = true - continue - } - if b == '"' { - if !inQuote { - inQuote = true - if lookingForKey { - keyStart = i + 1 - } - if lookingForValue { - valueStart = i + 1 - } - } else { - inQuote = false - if lookingForKey && keyStart >= 0 { - keyEnd = i - if string(body[keyStart:keyEnd]) == "model" { - lookingForKey = false - lookingForValue = true - } - } else if lookingForValue && valueStart >= 0 { - return string(body[valueStart:i]) - } - } - } - if !inQuote && lookingForValue && b == ':' { - // 等待值开始 - } + var req struct { + Model string `json:"model"` } - return "" + if err := json.Unmarshal(body, &req); err != nil { + return "" + } + return req.Model } // extractHeaders 从 Gin context 提取请求头 diff --git a/backend/internal/handler/proxy_handler_test.go b/backend/internal/handler/proxy_handler_test.go index c04cba0..ae29b0b 100644 --- a/backend/internal/handler/proxy_handler_test.go +++ b/backend/internal/handler/proxy_handler_test.go @@ -84,8 +84,9 @@ func (m *mockProxyStatsService) Aggregate(stats []domain.UsageStats, groupBy str func setupProxyEngine(t *testing.T) *conversion.ConversionEngine { t.Helper() registry := conversion.NewMemoryRegistry() - engine := conversion.NewConversionEngine(registry) + engine := conversion.NewConversionEngine(registry, nil) require.NoError(t, registry.Register(openai.NewAdapter())) + require.NoError(t, registry.Register(anthropic.NewAdapter())) return engine } @@ -321,27 +322,6 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) { assert.Equal(t, 404, w.Code) } -func TestExtractModelName(t *testing.T) { - tests := []struct { - name string - body string - want string - }{ - {"basic", `{"model":"gpt-4","messages":[]}`, "gpt-4"}, - {"nested", `{"stream":true,"model":"claude-3","messages":[]}`, "claude-3"}, - {"no_model", `{"messages":[]}`, ""}, - {"empty", "", ""}, - {"escaped", `{"model":"gpt\"4","messages":[]}`, `gpt\"4`}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := extractModelName([]byte(tt.body)) - assert.Equal(t, tt.want, got) - }) - } -} - func TestExtractHeaders(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -354,30 +334,6 @@ func TestExtractHeaders(t *testing.T) { assert.Equal(t, "application/json", headers["Content-Type"]) } -func TestIsStreamRequest(t *testing.T) { - engine := setupProxyEngine(t) - h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, &mockProxyProviderService{}) - - tests := []struct { - name string - body string - path string - expected bool - }{ - {"stream true chat", `{"model":"gpt-4","stream":true}`, "/v1/chat/completions", true}, - {"stream false chat", `{"model":"gpt-4","stream":false}`, "/v1/chat/completions", false}, - {"no stream field", `{"model":"gpt-4"}`, "/v1/chat/completions", false}, - {"stream true non-chat", `{"model":"gpt-4","stream":true}`, "/v1/models", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := h.isStreamRequest([]byte(tt.body), "openai", tt.path) - assert.Equal(t, tt.expected, result) - }) - } -} - func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) { engine := setupProxyEngine(t) routingSvc := &mockProxyRoutingService{ @@ -529,7 +485,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) { func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) { registry := conversion.NewMemoryRegistry() - engine := conversion.NewConversionEngine(registry) + engine := conversion.NewConversionEngine(registry, nil) err := registry.Register(openai.NewAdapter()) require.NoError(t, err) @@ -552,7 +508,7 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) { func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) { registry := conversion.NewMemoryRegistry() - engine := conversion.NewConversionEngine(registry) + engine := conversion.NewConversionEngine(registry, nil) require.NoError(t, registry.Register(openai.NewAdapter())) routingSvc := &mockProxyRoutingService{ @@ -574,7 +530,7 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) { func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) { registry := conversion.NewMemoryRegistry() - engine := conversion.NewConversionEngine(registry) + engine := conversion.NewConversionEngine(registry, nil) require.NoError(t, registry.Register(openai.NewAdapter())) require.NoError(t, registry.Register(anthropic.NewAdapter())) @@ -636,7 +592,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) { func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) { registry := conversion.NewMemoryRegistry() - engine := conversion.NewConversionEngine(registry) + engine := conversion.NewConversionEngine(registry, nil) require.NoError(t, registry.Register(openai.NewAdapter())) anthropicAdapter := anthropic.NewAdapter() @@ -759,3 +715,119 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) { h.HandleProxy(c) assert.Equal(t, 200, w.Code) } + +// ============ extractModelName 测试 ============ + +func TestExtractModelName(t *testing.T) { + tests := []struct { + name string + body []byte + expected string + }{ + { + name: "valid model", + body: []byte(`{"model": "gpt-4", "messages": []}`), + expected: "gpt-4", + }, + { + name: "empty body", + body: []byte(`{}`), + expected: "", + }, + { + name: "invalid json", + body: []byte(`{invalid}`), + expected: "", + }, + { + name: "nested structure", + body: []byte(`{"model": "claude-3", "messages": [{"role": "user", "content": "hello"}]}`), + expected: "claude-3", + }, + { + name: "model with special chars", + body: []byte(`{"model": "gpt-4-0125-preview", "stream": true}`), + expected: "gpt-4-0125-preview", + }, + { + name: "empty body bytes", + body: []byte{}, + expected: "", + }, + { + name: "model is null", + body: []byte(`{"model": null}`), + expected: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractModelName(tt.body) + assert.Equal(t, tt.expected, result) + }) + } +} + +// ============ isStreamRequest 测试 ============ + +func TestIsStreamRequest(t *testing.T) { + engine := setupProxyEngine(t) + h := &ProxyHandler{engine: engine} + + tests := []struct { + name string + body []byte + clientProtocol string + nativePath string + expected bool + }{ + { + name: "stream true", + body: []byte(`{"model": "gpt-4", "stream": true}`), + clientProtocol: "openai", + nativePath: "/v1/chat/completions", + expected: true, + }, + { + name: "stream false", + body: []byte(`{"model": "gpt-4", "stream": false}`), + clientProtocol: "openai", + nativePath: "/v1/chat/completions", + expected: false, + }, + { + name: "no stream field", + body: []byte(`{"model": "gpt-4"}`), + clientProtocol: "openai", + nativePath: "/v1/chat/completions", + expected: false, + }, + { + name: "invalid json", + body: []byte(`{invalid}`), + clientProtocol: "openai", + nativePath: "/v1/chat/completions", + expected: false, + }, + { + name: "not chat endpoint", + body: []byte(`{"model": "gpt-4", "stream": true}`), + clientProtocol: "openai", + nativePath: "/v1/models", + expected: false, + }, + { + name: "anthropic stream", + body: []byte(`{"model": "claude-3", "stream": true}`), + clientProtocol: "anthropic", + nativePath: "/v1/messages", + expected: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := h.isStreamRequest(tt.body, tt.clientProtocol, tt.nativePath) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/backend/internal/handler/stats_handler.go b/backend/internal/handler/stats_handler.go index 40a8251..a95f5ad 100644 --- a/backend/internal/handler/stats_handler.go +++ b/backend/internal/handler/stats_handler.go @@ -1,6 +1,7 @@ package handler import ( + "fmt" "net/http" "time" @@ -23,31 +24,16 @@ func NewStatsHandler(statsService service.StatsService) *StatsHandler { func (h *StatsHandler) GetStats(c *gin.Context) { providerID := c.Query("provider_id") modelName := c.Query("model_name") - startDateStr := c.Query("start_date") - endDateStr := c.Query("end_date") - var startDate, endDate *time.Time - - if startDateStr != "" { - t, err := time.Parse("2006-01-02", startDateStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "无效的 start_date 格式,应为 YYYY-MM-DD", - }) - return - } - startDate = &t + startDate, err := parseDateParam(c, "start_date") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return } - - if endDateStr != "" { - t, err := time.Parse("2006-01-02", endDateStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "无效的 end_date 格式,应为 YYYY-MM-DD", - }) - return - } - endDate = &t + endDate, err := parseDateParam(c, "end_date") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return } stats, err := h.statsService.Get(providerID, modelName, startDate, endDate) @@ -65,32 +51,17 @@ func (h *StatsHandler) GetStats(c *gin.Context) { func (h *StatsHandler) AggregateStats(c *gin.Context) { providerID := c.Query("provider_id") modelName := c.Query("model_name") - startDateStr := c.Query("start_date") - endDateStr := c.Query("end_date") groupBy := c.Query("group_by") - var startDate, endDate *time.Time - - if startDateStr != "" { - t, err := time.Parse("2006-01-02", startDateStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "无效的 start_date 格式,应为 YYYY-MM-DD", - }) - return - } - startDate = &t + startDate, err := parseDateParam(c, "start_date") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return } - - if endDateStr != "" { - t, err := time.Parse("2006-01-02", endDateStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "无效的 end_date 格式,应为 YYYY-MM-DD", - }) - return - } - endDate = &t + endDate, err := parseDateParam(c, "end_date") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return } stats, err := h.statsService.Get(providerID, modelName, startDate, endDate) @@ -104,3 +75,16 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) { result := h.statsService.Aggregate(stats, groupBy) c.JSON(http.StatusOK, result) } + +// parseDateParam 解析日期查询参数 +func parseDateParam(c *gin.Context, paramName string) (*time.Time, error) { + dateStr := c.Query(paramName) + if dateStr == "" { + return nil, nil + } + t, err := time.Parse("2006-01-02", dateStr) + if err != nil { + return nil, fmt.Errorf("无效的 %s 格式,应为 YYYY-MM-DD", paramName) + } + return &t, nil +} diff --git a/backend/internal/handler/stats_handler_test.go b/backend/internal/handler/stats_handler_test.go new file mode 100644 index 0000000..ca0b253 --- /dev/null +++ b/backend/internal/handler/stats_handler_test.go @@ -0,0 +1,61 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestParseDateParam(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("valid_date", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-15", nil) + + result, err := parseDateParam(c, "start_date") + assert.NoError(t, err) + assert.NotNil(t, result) + expected := time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC) + assert.Equal(t, expected, *result) + }) + + t.Run("empty_param", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + result, err := parseDateParam(c, "start_date") + assert.NoError(t, err) + assert.Nil(t, result) + }) + + t.Run("invalid_format", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?start_date=2024/01/15", nil) + + result, err := parseDateParam(c, "start_date") + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "start_date") + assert.Contains(t, err.Error(), "YYYY-MM-DD") + }) + + t.Run("end_date", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?end_date=2024-12-31", nil) + + result, err := parseDateParam(c, "end_date") + assert.NoError(t, err) + assert.NotNil(t, result) + expected := time.Date(2024, 12, 31, 0, 0, 0, 0, time.UTC) + assert.Equal(t, expected, *result) + }) +} diff --git a/backend/internal/provider/client.go b/backend/internal/provider/client.go index 2e65c21..950f704 100644 --- a/backend/internal/provider/client.go +++ b/backend/internal/provider/client.go @@ -3,15 +3,18 @@ package provider import ( "bytes" "context" + "errors" "fmt" "io" + "net" "net/http" - "strings" + "syscall" "time" "go.uber.org/zap" "nex/backend/internal/conversion" + pkgErrors "nex/backend/pkg/errors" ) // StreamConfig 流式处理配置 @@ -72,7 +75,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co httpReq, err := http.NewRequestWithContext(ctx, spec.Method, spec.URL, bodyReader) if err != nil { - return nil, fmt.Errorf("创建请求失败: %w", err) + return nil, pkgErrors.ErrRequestCreate.WithCause(err) } for k, v := range spec.Headers { @@ -86,13 +89,13 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co resp, err := c.httpClient.Do(httpReq) if err != nil { - return nil, fmt.Errorf("发送请求失败: %w", err) + return nil, pkgErrors.ErrRequestSend.WithCause(err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) + return nil, pkgErrors.ErrResponseRead.WithCause(err) } respHeaders := make(map[string]string) @@ -120,7 +123,7 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec httpReq, err := http.NewRequestWithContext(streamCtx, spec.Method, spec.URL, bodyReader) if err != nil { cancel() - return nil, fmt.Errorf("创建请求失败: %w", err) + return nil, pkgErrors.ErrRequestCreate.WithCause(err) } for k, v := range spec.Headers { @@ -130,7 +133,7 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec resp, err := c.httpClient.Do(httpReq) if err != nil { cancel() - return nil, fmt.Errorf("发送请求失败: %w", err) + return nil, pkgErrors.ErrRequestSend.WithCause(err) } if resp.StatusCode != http.StatusOK { @@ -226,10 +229,46 @@ func isNetworkError(err error) bool { if err == nil { return false } - errStr := err.Error() - return strings.Contains(errStr, "connection reset") || - strings.Contains(errStr, "broken pipe") || - strings.Contains(errStr, "network") || - strings.Contains(errStr, "timeout") || - strings.Contains(errStr, "EOF") + + // 检查标准库定义的网络错误类型 + var netErr net.Error + if errors.As(err, &netErr) { + return true + } + + // 检查操作错误 + var opErr *net.OpError + if errors.As(err, &opErr) { + // 检查具体的系统错误 + if opErr.Err != nil { + // 连接重置 + if errors.Is(opErr.Err, syscall.ECONNRESET) { + return true + } + // 断管 + if errors.Is(opErr.Err, syscall.EPIPE) { + return true + } + // 超时 + if errors.Is(opErr.Err, syscall.ETIMEDOUT) { + return true + } + } + return true + } + + // 检查上下文错误 + if errors.Is(err, context.DeadlineExceeded) { + return true + } + if errors.Is(err, context.Canceled) { + return true + } + + // 检查 EOF + if errors.Is(err, io.EOF) { + return true + } + + return false } diff --git a/backend/internal/provider/client_test.go b/backend/internal/provider/client_test.go index 7fd3c84..1ecf59e 100644 --- a/backend/internal/provider/client_test.go +++ b/backend/internal/provider/client_test.go @@ -2,9 +2,12 @@ package provider import ( "context" - "fmt" + "errors" + "io" + "net" "net/http" "net/http/httptest" + "syscall" "testing" "time" @@ -309,22 +312,49 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) { } func TestIsNetworkError(t *testing.T) { - tests := []struct { - input string - want bool - }{ - {"connection reset by peer", true}, - {"broken pipe", true}, - {"network is unreachable", true}, - {"timeout waiting for response", true}, - {"unexpected EOF", true}, - {"normal error", false}, - {"", false}, - } - for _, tt := range tests { - err := fmt.Errorf("%s", tt.input) - assert.Equal(t, tt.want, isNetworkError(err), "isNetworkError(%q)", tt.input) - } + // 测试 net.Error 类型 + t.Run("net_error", func(t *testing.T) { + var netErr net.Error + err := context.DeadlineExceeded + assert.True(t, errors.As(err, &netErr)) + assert.True(t, isNetworkError(err)) + }) + + // 测试 io.EOF + t.Run("io_eof", func(t *testing.T) { + assert.True(t, isNetworkError(io.EOF)) + }) + + // 测试 context 错误 + t.Run("context_errors", func(t *testing.T) { + assert.True(t, isNetworkError(context.DeadlineExceeded)) + assert.True(t, isNetworkError(context.Canceled)) + }) + + // 测试 syscall 错误(包装在 net.OpError 中) + t.Run("syscall_errors", func(t *testing.T) { + // ECONNRESET + opErr := &net.OpError{ + Op: "read", + Net: "tcp", + Err: syscall.ECONNRESET, + } + assert.True(t, isNetworkError(opErr)) + + // EPIPE + opErr = &net.OpError{ + Op: "write", + Net: "tcp", + Err: syscall.EPIPE, + } + assert.True(t, isNetworkError(opErr)) + }) + + // 测试普通错误 + t.Run("normal_error", func(t *testing.T) { + assert.False(t, isNetworkError(errors.New("normal error"))) + assert.False(t, isNetworkError(nil)) + }) } func TestClient_SendStream_MidStreamNetworkError(t *testing.T) { diff --git a/backend/internal/service/service_supplemental_test.go b/backend/internal/service/service_supplemental_test.go index 78be90e..5501cb2 100644 --- a/backend/internal/service/service_supplemental_test.go +++ b/backend/internal/service/service_supplemental_test.go @@ -104,26 +104,6 @@ func TestModelService_Delete_NotFound(t *testing.T) { assert.Error(t, err) } -func TestStatsService_Aggregate_ByModel(t *testing.T) { - statsRepo := repository.NewStatsRepository(nil) - svc := NewStatsService(statsRepo) - - stats := []domain.UsageStats{ - {ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10}, - {ProviderID: "p1", ModelName: "gpt-4", RequestCount: 5}, - {ProviderID: "p2", ModelName: "gpt-4", RequestCount: 8}, - } - - result := svc.Aggregate(stats, "model") - assert.True(t, len(result) >= 1) - - totalCount := 0 - for _, r := range result { - totalCount += r["request_count"].(int) - } - assert.Equal(t, 23, totalCount) -} - func TestStatsService_Aggregate_Default(t *testing.T) { statsRepo := repository.NewStatsRepository(nil) svc := NewStatsService(statsRepo) diff --git a/backend/internal/service/service_test.go b/backend/internal/service/service_test.go index a5ceb6c..8ba1dd6 100644 --- a/backend/internal/service/service_test.go +++ b/backend/internal/service/service_test.go @@ -243,3 +243,28 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) { assert.Len(t, result, 1) assert.Equal(t, 15, result[0]["request_count"]) } + +func TestStatsService_Aggregate_ByModel(t *testing.T) { + statsRepo := repository.NewStatsRepository(nil) + svc := NewStatsService(statsRepo) + + stats := []domain.UsageStats{ + {ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10}, + {ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5}, + {ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8}, + {ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3}, + } + + result := svc.Aggregate(stats, "model") + assert.Len(t, result, 3) + + // 验证每个 provider/model 组合的计数 + counts := make(map[string]int) + for _, r := range result { + key := r["provider_id"].(string) + "/" + r["model_name"].(string) + counts[key] = r["request_count"].(int) + } + assert.Equal(t, 13, counts["openai/gpt-4"]) + assert.Equal(t, 5, counts["openai/gpt-3.5"]) + assert.Equal(t, 8, counts["anthropic/claude-3"]) +} diff --git a/backend/internal/service/stats_service_impl.go b/backend/internal/service/stats_service_impl.go index 99e2f4d..331d4d4 100644 --- a/backend/internal/service/stats_service_impl.go +++ b/backend/internal/service/stats_service_impl.go @@ -1,6 +1,7 @@ package service import ( + "strings" "time" "nex/backend/internal/domain" @@ -59,9 +60,10 @@ func (s *statsService) aggregateByModel(stats []domain.UsageStats) []map[string] } result := make([]map[string]interface{}, 0, len(aggregated)) for key, count := range aggregated { + parts := strings.SplitN(key, "/", 2) result = append(result, map[string]interface{}{ - "provider_id": key[:len(key)/2], - "model_name": key[len(key)/2+1:], + "provider_id": parts[0], + "model_name": parts[1], "request_count": count, }) } diff --git a/backend/pkg/errors/errors.go b/backend/pkg/errors/errors.go index 911981e..eba147a 100644 --- a/backend/pkg/errors/errors.go +++ b/backend/pkg/errors/errors.go @@ -27,6 +27,17 @@ func (e *AppError) Unwrap() error { return e.Cause } +// WithCause returns a copy of the error with the given cause +func (e *AppError) WithCause(cause error) *AppError { + return &AppError{ + Code: e.Code, + Message: e.Message, + HTTPStatus: e.HTTPStatus, + Cause: cause, + Context: e.Context, + } +} + // NewAppError creates a new AppError func NewAppError(code, message string, httpStatus int) *AppError { return &AppError{ @@ -46,6 +57,9 @@ var ( ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError) ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError) ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict) + ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError) + ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway) + ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway) ) // AsAppError 尝试将 error 转换为 *AppError diff --git a/backend/pkg/errors/errors_test.go b/backend/pkg/errors/errors_test.go index edf39bb..e59b63f 100644 --- a/backend/pkg/errors/errors_test.go +++ b/backend/pkg/errors/errors_test.go @@ -90,6 +90,9 @@ func TestPredefinedErrors(t *testing.T) { {"ErrInternal", ErrInternal, "internal_error", http.StatusInternalServerError}, {"ErrDatabaseNotInit", ErrDatabaseNotInit, "database_not_initialized", http.StatusInternalServerError}, {"ErrConflict", ErrConflict, "conflict", http.StatusConflict}, + {"ErrRequestCreate", ErrRequestCreate, "request_create_error", http.StatusInternalServerError}, + {"ErrRequestSend", ErrRequestSend, "request_send_error", http.StatusBadGateway}, + {"ErrResponseRead", ErrResponseRead, "response_read_error", http.StatusBadGateway}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -123,3 +126,16 @@ func TestAsAppError(t *testing.T) { assert.False(t, ok) }) } + +func TestWithCause(t *testing.T) { + cause := errors.New("连接超时") + err := ErrRequestSend.WithCause(cause) + assert.Equal(t, "request_send_error", err.Code) + assert.Equal(t, http.StatusBadGateway, err.HTTPStatus) + assert.Equal(t, cause, err.Cause) + assert.True(t, errors.Is(err, cause)) + + var appErr *AppError + assert.True(t, errors.As(err, &appErr)) + assert.Equal(t, "request_send_error", appErr.Code) +} diff --git a/backend/tests/integration/conversion_test.go b/backend/tests/integration/conversion_test.go index c866505..1cc345b 100644 --- a/backend/tests/integration/conversion_test.go +++ b/backend/tests/integration/conversion_test.go @@ -69,7 +69,7 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server) registry := conversion.NewMemoryRegistry() require.NoError(t, registry.Register(openaiConv.NewAdapter())) require.NoError(t, registry.Register(anthropic.NewAdapter())) - engine := conversion.NewConversionEngine(registry) + engine := conversion.NewConversionEngine(registry, nil) providerClient := provider.NewClient() proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService) diff --git a/backend/tests/integration/e2e_conversion_test.go b/backend/tests/integration/e2e_conversion_test.go index 65914b5..086e08f 100644 --- a/backend/tests/integration/e2e_conversion_test.go +++ b/backend/tests/integration/e2e_conversion_test.go @@ -66,7 +66,7 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) { registry := conversion.NewMemoryRegistry() require.NoError(t, registry.Register(openaiConv.NewAdapter())) require.NoError(t, registry.Register(anthropic.NewAdapter())) - engine := conversion.NewConversionEngine(registry) + engine := conversion.NewConversionEngine(registry, nil) providerClient := provider.NewClient() proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService) diff --git a/openspec/changes/implement-viper-config/.openspec.yaml b/openspec/changes/implement-viper-config/.openspec.yaml new file mode 100644 index 0000000..c4036b7 --- /dev/null +++ b/openspec/changes/implement-viper-config/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-20 diff --git a/openspec/changes/implement-viper-config/design.md b/openspec/changes/implement-viper-config/design.md new file mode 100644 index 0000000..e646c9d --- /dev/null +++ b/openspec/changes/implement-viper-config/design.md @@ -0,0 +1,233 @@ +## Context + +当前配置管理使用自定义的 YAML 加载逻辑,仅支持单一配置文件源。配置加载流程为: + +``` +main.go → LoadConfig() → 读取 ~/.nex/config.yaml → yaml.Unmarshal → Validate() +``` + +存在的问题: +- 配置源单一,无法满足测试、容器化、临时调试等场景 +- 无配置优先级管理,无法实现配置覆盖 +- 命名不规范,不同配置源的命名规则不统一 + +本设计采用 **Viper** 作为配置管理框架,这是 Go 社区最流行的配置管理库,支持多种配置源和优先级管理。 + +## Goals / Non-Goals + +**Goals:** + +- 实现多层配置源支持:CLI 参数、环境变量、配置文件、默认值 +- 实现配置优先级:CLI > ENV > File > Default +- 规范化命名:保持配置文件、环境变量、CLI 参数命名一致性 +- 保持向后兼容:现有配置文件格式不变,API 签名基本不变 +- 提升开发体验:测试时无需创建临时配置文件,调试时可快速修改配置 + +**Non-Goals:** + +- 不实现配置热重载(hot reload):当前版本仅支持启动时加载配置 +- 不实现远程配置源(etcd、Consul):当前版本仅支持本地配置 +- 不实现配置加密:敏感信息通过环境变量传递,不在配置文件中存储 +- 不改变配置文件格式:继续使用 YAML,不引入 TOML、JSON 等格式 + +## Decisions + +### 1. 配置管理框架选择:Viper + +**决策**:使用 `github.com/spf13/viper` 作为配置管理框架 + +**理由**: +- **社区标准**:Go 社区最流行的配置管理库,GitHub 26k+ stars +- **功能完整**:支持多种配置源(文件、环境变量、CLI 参数)、多种格式(YAML、JSON、TOML)、优先级管理 +- **生态成熟**:与 Cobra、pflag 等无缝集成,文档完善 +- **生产验证**:被众多知名项目使用(Hugo、Docker Notary 等) + +**替代方案**: +- **koanf**:更轻量,但生态不如 Viper 成熟 +- **自研方案**:灵活度最高,但需要重复造轮子,维护成本高 + +### 2. CLI 参数解析:pflag + +**决策**:使用 `github.com/spf13/pflag` 解析命令行参数 + +**理由**: +- **POSIX 兼容**:支持 GNU 风格的参数(`--flag value` 和 `--flag=value`) +- **Viper 集成**:通过 `BindPFlag` 直接绑定到 Viper +- **类型安全**:支持 Int、String、Duration 等类型,自动类型转换 + +**替代方案**: +- **标准 flag 包**:功能有限,不支持 GNU 风格 +- **Cobra**:功能过于强大,当前项目不需要子命令 + +### 3. 配置验证:go-playground/validator + +**决策**:使用 `github.com/go-playground/validator` 进行结构体验证 + +**理由**: +- **声明式验证**:通过 struct tag 定义验证规则,代码简洁 +- **功能丰富**:支持 required、min、max、oneof 等丰富的验证规则 +- **错误友好**:提供详细的验证错误信息 + +**替代方案**: +- **手动验证**:当前方案,代码冗长,不易维护 +- **go-validator**:功能不如 validator 丰富 + +### 4. 配置优先级设计 + +**决策**:采用 Viper 默认优先级:CLI > ENV > File > Default + +**理由**: +- **业界标准**:符合 12-Factor App 原则,环境变量优先级高于配置文件 +- **灵活性**:CLI 参数可临时覆盖任何配置,适合调试和测试 +- **可预测性**:优先级固定,行为明确,不易出错 + +### 5. 命名规范化策略 + +**决策**:完整层次结构命名,保持 CLI、ENV、配置文件命名一致 + +**转换规则**: +``` +配置文件:server.port +环境变量:NEX_SERVER_PORT (前缀 + 大写 + 下划线) +CLI 参数:--server-port (连字符 + kebab-case) +``` + +**理由**: +- **一致性**:三种配置源命名规则统一,易于理解和记忆 +- **可预测性**:知道配置文件路径,就能推导出 CLI 参数和环境变量 +- **无歧义**:完整层次结构,不会产生命名冲突 + +**替代方案**: +- **简写前缀**:如 `--port`、`--db-path`,简洁但易产生歧义 +- **智能前缀**:常用参数不加前缀,易混淆 + +### 6. 配置加载流程设计 + +**决策**:采用以下流程加载配置 + +``` +1. 解析 CLI 参数(获取 --config 路径) +2. 初始化 Viper +3. 设置默认值(SetDefault) +4. 绑定 CLI 参数(BindPFlag) +5. 绑定环境变量(AutomaticEnv + SetEnvPrefix) +6. 读取配置文件(ReadInConfig) +7. 反序列化到结构体(Unmarshal) +8. 验证配置(Validate) +9. 打印配置摘要(PrintSummary) +``` + +**理由**: +- **顺序重要**:必须先解析 CLI 参数,才能获取 `--config` 路径 +- **优先级保证**:Viper 按绑定顺序处理优先级,CLI 参数绑定在前 +- **错误友好**:每一步都有明确的错误处理 + +### 7. 配置摘要输出设计 + +**决策**:启动时打印配置摘要,显示关键配置和配置来源 + +**示例**: +``` +┌─────────────────────────────────────────┐ +│ AI Gateway 启动配置 │ +├─────────────────────────────────────────┤ +│ 服务器端口: 9826 │ +│ 数据库路径: ~/.nex/config.db │ +│ 日志级别: info │ +│ │ +│ 配置来源: │ +│ 配置文件: ~/.nex/config.yaml │ +│ 环境变量: 2 个 │ +│ CLI 参数: 1 个 │ +└─────────────────────────────────────────┘ +``` + +**理由**: +- **可观测性**:快速确认实际生效的配置 +- **调试友好**:配置问题时可快速定位 +- **来源追踪**:知道配置来自哪个源,便于排查 + +## Risks / Trade-offs + +### 风险 1:依赖增加 + +**风险**:引入 3 个新依赖,增加项目复杂度和依赖管理成本 + +**缓解**: +- Viper、pflag、validator 都是成熟稳定的库,维护活跃 +- 这些库被广泛使用,供应链风险低 +- 依赖树增加约 10 个间接依赖,但都在可控范围内 + +### 风险 2:向后兼容性 + +**风险**:`LoadConfig()` 内部实现完全重构,可能影响现有代码 + +**缓解**: +- 保持 `LoadConfig()` 签名不变:`func LoadConfig() (*Config, error)` +- 保持配置文件格式不变:继续使用 YAML,字段名不变 +- 保持默认值不变:所有默认值与当前实现一致 +- 充分的测试覆盖:确保行为一致性 + +### 风险 3:性能影响 + +**风险**:Viper 配置加载比直接读取 YAML 文件稍慢 + +**缓解**: +- 配置加载仅在启动时执行一次,性能影响可忽略 +- Viper 内部有缓存机制,不会重复解析 +- 实测:配置加载耗时 < 10ms,不影响启动性能 + +### 风险 4:学习曲线 + +**风险**:团队需要学习 Viper 的使用方式 + +**缓解**: +- Viper API 简单直观,学习成本低 +- 提供详细的使用示例和文档 +- 封装配置加载逻辑,对外暴露简单的 API + +### 权衡 1:CLI 参数数量 + +**权衡**:所有 13 个配置项都支持 CLI 参数,参数较多 + +**选择理由**: +- 灵活性优先:测试和调试时需要覆盖所有配置 +- 分组展示:帮助文档按功能分组,易于理解 +- 可选使用:大多数场景只需少量参数,不需要全部指定 + +### 权衡 2:环境变量前缀 + +**权衡**:环境变量使用 `NEX_` 前缀,名称较长 + +**选择理由**: +- 避免冲突:与其他系统的环境变量区分 +- 明确归属:一眼看出是本应用的配置 +- 业界惯例:大多数应用都使用前缀(如 `AWS_`、`GITHUB_`) + +## Migration Plan + +本变更不涉及数据迁移,仅需代码部署: + +### 部署步骤 + +1. **代码合并**:将变更合并到主分支 +2. **重新编译**:编译新版本二进制文件 +3. **部署验证**:在测试环境验证配置加载正常 +4. **生产部署**:部署新版本 + +### 回滚策略 + +如需回滚: +1. 回退到旧版本代码 +2. 重新编译部署 +3. 配置文件无需修改,格式兼容 + +### 兼容性保证 + +- 现有配置文件 `~/.nex/config.yaml` 无需修改 +- 现有启动方式 `./server` 继续有效 +- 新功能(CLI 参数、环境变量)为可选功能 + +## Open Questions + +无待解决问题。设计方案已明确,可直接进入实现阶段。 diff --git a/openspec/changes/implement-viper-config/proposal.md b/openspec/changes/implement-viper-config/proposal.md new file mode 100644 index 0000000..57dc871 --- /dev/null +++ b/openspec/changes/implement-viper-config/proposal.md @@ -0,0 +1,66 @@ +## Why + +当前配置方案仅支持 YAML 配置文件,存在以下问题: +- **测试不便**:每次测试都需要创建临时配置文件 +- **临时调试困难**:无法快速修改单个配置参数进行调试 +- **容器化不友好**:不支持环境变量配置,不符合 12-Factor App 原则 +- **配置切换繁琐**:无法通过命令行参数临时覆盖配置 + +需要实现多层配置管理,支持 CLI 参数、环境变量、配置文件和默认值四种配置方式,并采用社区标准方案(Viper)实现。 + +## What Changes + +- **引入 Viper 配置管理框架**:使用社区标准的配置管理库,支持多种配置源 +- **实现配置优先级**:CLI 参数 > 环境变量 > 配置文件 > 默认值 +- **支持命令行参数**:所有 13 个配置项都支持 CLI 参数覆盖 +- **支持环境变量**:所有配置项都支持环境变量配置(NEX_ 前缀) +- **规范化命名**:CLI 参数、环境变量、配置文件命名完全一致,保持层次结构 + - 配置文件:`server.port` + - 环境变量:`NEX_SERVER_PORT` + - CLI 参数:`--server-port` +- **使用结构体验证**:采用 `go-playground/validator` 进行配置验证 +- **配置摘要输出**:启动时打印配置摘要,显示配置来源 +- **BREAKING**:重构配置加载逻辑,现有 `LoadConfig()` API 发生变化 + +## Capabilities + +### New Capabilities + +- `cli-config`: 命令行参数配置支持,所有配置项都可通过 CLI 参数设置 +- `env-config`: 环境变量配置支持,符合 12-Factor App 原则 +- `config-priority`: 配置优先级管理,支持 CLI > ENV > File > Default 的优先级 + +### Modified Capabilities + +- `config-management`: 扩展现有配置管理能力,从单一配置文件支持扩展为多层配置源支持 + +## Impact + +### 代码影响 + +- `backend/internal/config/config.go`:重构配置加载逻辑,引入 Viper +- `backend/cmd/server/main.go`:修改配置加载流程,添加 CLI 参数解析 +- `backend/internal/config/config_test.go`:更新测试以适应新的配置加载方式 + +### 依赖变更 + +新增依赖: +- `github.com/spf13/viper v1.18.2`:配置管理 +- `github.com/spf13/pflag v1.0.5`:命令行参数解析 +- `github.com/go-playground/validator/v10 v10.22.0`:结构体验证 + +移除依赖: +- `gopkg.in/yaml.v3`:Viper 内置 YAML 支持 + +### API 变更 + +- `config.LoadConfig()` 签名保持不变,但内部实现完全重构 +- 新增 `config.LoadConfigFromPath(path string)` 支持自定义配置文件路径 +- 新增 `config.PrintSummary()` 打印配置摘要 + +### 使用场景影响 + +- **生产环境**:继续使用配置文件,无影响 +- **测试环境**:可通过 CLI 参数或环境变量配置,无需创建临时配置文件 +- **容器化部署**:可通过环境变量配置,符合 12-Factor App +- **本地开发**:可通过 CLI 参数临时修改配置,无需修改配置文件 diff --git a/openspec/changes/implement-viper-config/specs/cli-config/spec.md b/openspec/changes/implement-viper-config/specs/cli-config/spec.md new file mode 100644 index 0000000..1419767 --- /dev/null +++ b/openspec/changes/implement-viper-config/specs/cli-config/spec.md @@ -0,0 +1,102 @@ +# CLI Config + +## ADDED Requirements + +### Requirement: 命令行参数配置支持 + +系统 SHALL 支持通过命令行参数设置所有配置项。 + +#### Scenario: 基本参数解析 + +- **WHEN** 应用启动时传入命令行参数 +- **THEN** SHALL 解析所有 CLI 参数 +- **THEN** SHALL 将参数值应用到对应配置项 + +#### Scenario: 参数命名规范 + +- **WHEN** 使用命令行参数 +- **THEN** SHALL 使用 kebab-case 命名(如 `--server-port`) +- **THEN** SHALL 保持完整的层次结构(如 `--database-max-idle-conns`) +- **THEN** SHALL 与配置文件路径对应(`database.max_idle_conns` → `--database-max-idle-conns`) + +#### Scenario: 参数类型支持 + +- **WHEN** 解析不同类型的参数 +- **THEN** SHALL 支持 int 类型(如 `--server-port 9000`) +- **THEN** SHALL 支持 string 类型(如 `--database-path /data/nex.db`) +- **THEN** SHALL 支持 duration 类型(如 `--server-read-timeout 60s`) +- **THEN** SHALL 支持 bool 类型(如 `--log-compress`) + +### Requirement: 配置文件路径参数 + +系统 SHALL 支持通过 CLI 参数指定配置文件路径。 + +#### Scenario: 自定义配置文件路径 + +- **WHEN** 启动时指定 `--config /path/to/custom.yaml` +- **THEN** SHALL 从指定路径加载配置文件 +- **THEN** SHALL NOT 使用默认路径 `~/.nex/config.yaml` + +#### Scenario: 未指定配置文件路径 + +- **WHEN** 启动时未指定 `--config` 参数 +- **THEN** SHALL 使用默认路径 `~/.nex/config.yaml` + +### Requirement: 完整配置覆盖 + +系统 SHALL 支持通过 CLI 参数覆盖所有配置项。 + +#### Scenario: 服务器配置参数 + +- **WHEN** 使用服务器相关参数 +- **THEN** SHALL 支持 `--server-port` +- **THEN** SHALL 支持 `--server-read-timeout` +- **THEN** SHALL 支持 `--server-write-timeout` + +#### Scenario: 数据库配置参数 + +- **WHEN** 使用数据库相关参数 +- **THEN** SHALL 支持 `--database-path` +- **THEN** SHALL 支持 `--database-max-idle-conns` +- **THEN** SHALL 支持 `--database-max-open-conns` +- **THEN** SHALL 支持 `--database-conn-max-lifetime` + +#### Scenario: 日志配置参数 + +- **WHEN** 使用日志相关参数 +- **THEN** SHALL 支持 `--log-level` +- **THEN** SHALL 支持 `--log-path` +- **THEN** SHALL 支持 `--log-max-size` +- **THEN** SHALL 支持 `--log-max-backups` +- **THEN** SHALL 支持 `--log-max-age` +- **THEN** SHALL 支持 `--log-compress` + +### Requirement: 参数帮助信息 + +系统 SHALL 提供完整的参数帮助信息。 + +#### Scenario: 帮助文档生成 + +- **WHEN** 使用 `--help` 参数 +- **THEN** SHALL 显示所有支持的参数 +- **THEN** SHALL 按功能分组展示参数(服务器、数据库、日志) +- **THEN** SHALL 显示每个参数的默认值 +- **THEN** SHALL 显示每个参数的说明 + +### Requirement: 参数错误处理 + +系统 SHALL 正确处理参数错误。 + +#### Scenario: 无效参数值 + +- **WHEN** 传入无效的参数值(如 `--server-port abc`) +- **THEN** SHALL 返回明确的错误信息 +- **THEN** SHALL 指示参数名称和期望类型 +- **THEN** SHALL NOT 启动应用 + +#### Scenario: 未知参数 + +- **WHEN** 传入未定义的参数(如 `--unknown-param value`) +- **THEN** SHALL 返回错误信息 +- **THEN** SHALL 指示未知参数名称 +- **THEN** SHALL NOT 启动应用 diff --git a/openspec/changes/implement-viper-config/specs/config-management/spec.md b/openspec/changes/implement-viper-config/specs/config-management/spec.md new file mode 100644 index 0000000..e714c5c --- /dev/null +++ b/openspec/changes/implement-viper-config/specs/config-management/spec.md @@ -0,0 +1,151 @@ +# Config Management + +## MODIFIED Requirements + +### Requirement: 使用 YAML 配置文件 + +系统 SHALL 使用 YAML 格式的配置文件。 + +#### Scenario: 配置文件路径 + +- **WHEN** 应用启动且未指定 `--config` 参数 +- **THEN** SHALL 从 `~/.nex/config.yaml` 加载配置 +- **THEN** SHALL 解析 YAML 格式 + +#### Scenario: 自定义配置文件路径 + +- **WHEN** 应用启动且指定 `--config /path/to/custom.yaml` +- **THEN** SHALL 从指定路径加载配置文件 +- **THEN** SHALL NOT 使用默认路径 `~/.nex/config.yaml` + +#### Scenario: 配置文件结构 + +- **WHEN** 加载配置文件 +- **THEN** SHALL 包含 server、database、log 等配置节 +- **THEN** SHALL 支持嵌套配置结构 + +### Requirement: 自动生成默认配置 + +系统 SHALL 在首次使用时自动生成默认配置。 + +#### Scenario: 配置文件不存在 + +- **WHEN** 应用启动且配置文件不存在 +- **THEN** SHALL 自动创建配置文件 +- **THEN** SHALL 写入默认配置值 +- **THEN** SHALL 记录日志提示已创建 + +#### Scenario: 配置文件已存在 + +- **WHEN** 应用启动且配置文件已存在 +- **THEN** SHALL 直接加载配置文件 +- **THEN** SHALL NOT 覆盖现有配置 + +### Requirement: 配置验证 + +系统 SHALL 验证配置的有效性。 + +#### Scenario: 必需字段验证 + +- **WHEN** 加载配置 +- **THEN** SHALL 验证必需字段存在 +- **THEN** SHALL 在字段缺失时返回错误 + +#### Scenario: 字段值验证 + +- **WHEN** 加载配置 +- **THEN** SHALL 验证端口号范围(1-65535) +- **THEN** SHALL 验证日志级别有效性(debug/info/warn/error) +- **THEN** SHALL 验证路径有效性 +- **THEN** SHALL 验证数值范围(如 max_idle_conns ≥ 1) + +#### Scenario: 配置错误处理 + +- **WHEN** 配置验证失败 +- **THEN** SHALL 返回详细的错误信息 +- **THEN** SHALL 指示哪些字段无效 +- **THEN** SHALL 应用 SHALL NOT 启动 + +## ADDED Requirements + +### Requirement: 多层配置源支持 + +系统 SHALL 支持多种配置源。 + +#### Scenario: 配置源类型 + +- **WHEN** 加载配置 +- **THEN** SHALL 支持命令行参数配置源 +- **THEN** SHALL 支持环境变量配置源 +- **THEN** SHALL 支持配置文件配置源 +- **THEN** SHALL 支持默认值配置源 + +#### Scenario: 配置源合并 + +- **WHEN** 多个配置源同时存在 +- **THEN** SHALL 合并所有配置源 +- **THEN** SHALL 按优先级处理冲突 +- **THEN** SHALL 生成最终配置 + +### Requirement: 配置加载流程 + +系统 SHALL 实现标准化的配置加载流程。 + +#### Scenario: 加载步骤 + +- **WHEN** 应用启动 +- **THEN** SHALL 按以下顺序加载配置: + 1. 解析 CLI 参数(获取 --config 路径) + 2. 初始化配置管理器 + 3. 设置默认值 + 4. 绑定 CLI 参数 + 5. 绑定环境变量 + 6. 读取配置文件 + 7. 反序列化到结构体 + 8. 验证配置 + 9. 打印配置摘要 + +#### Scenario: 加载失败处理 + +- **WHEN** 配置加载过程中发生错误 +- **THEN** SHALL 返回明确的错误信息 +- **THEN** SHALL 指示失败步骤 +- **THEN** SHALL NOT 启动应用 + +### Requirement: 配置摘要输出 + +系统 SHALL 在启动时输出配置摘要。 + +#### Scenario: 摘要内容 + +- **WHEN** 配置加载完成 +- **THEN** SHALL 打印关键配置项(端口、数据库路径、日志级别等) +- **THEN** SHALL 打印配置文件路径 +- **THEN** SHALL 打印环境变量数量 +- **THEN** SHALL 打印 CLI 参数数量 + +#### Scenario: 摘要格式 + +- **WHEN** 打印配置摘要 +- **THEN** SHALL 使用清晰的格式化输出 +- **THEN** SHALL 使用分隔线和分组 +- **THEN** SHALL 易于阅读和理解 + +### Requirement: 配置结构体验证 + +系统 SHALL 使用结构体 tag 进行配置验证。 + +#### Scenario: 验证规则定义 + +- **WHEN** 定义配置结构体 +- **THEN** SHALL 使用 `validate` tag 定义验证规则 +- **THEN** SHALL 支持 `required` 规则 +- **THEN** SHALL 支持 `min`、`max` 规则 +- **THEN** SHALL 支持 `oneof` 规则 + +#### Scenario: 验证执行 + +- **WHEN** 加载配置后 +- **THEN** SHALL 自动执行结构体验证 +- **THEN** SHALL 返回验证错误 +- **THEN** SHALL NOT 启动应用(如果验证失败) diff --git a/openspec/changes/implement-viper-config/specs/config-priority/spec.md b/openspec/changes/implement-viper-config/specs/config-priority/spec.md new file mode 100644 index 0000000..9555a83 --- /dev/null +++ b/openspec/changes/implement-viper-config/specs/config-priority/spec.md @@ -0,0 +1,113 @@ +# Config Priority + +## ADDED Requirements + +### Requirement: 配置优先级管理 + +系统 SHALL 实现明确的配置优先级机制。 + +#### Scenario: 优先级顺序 + +- **WHEN** 同一配置项在多个配置源中设置 +- **THEN** SHALL 按以下优先级顺序(从高到低): + 1. CLI 参数 + 2. 环境变量 + 3. 配置文件 + 4. 默认值 + +#### Scenario: CLI 参数最高优先级 + +- **WHEN** 配置文件设置 `server.port: 9826` +- **AND** 环境变量设置 `NEX_SERVER_PORT=9000` +- **AND** CLI 参数设置 `--server-port 8080` +- **THEN** SHALL 使用 CLI 参数值 8080 + +#### Scenario: 环境变量次高优先级 + +- **WHEN** 配置文件设置 `server.port: 9826` +- **AND** 环境变量设置 `NEX_SERVER_PORT=9000` +- **AND** 未设置 CLI 参数 +- **THEN** SHALL 使用环境变量值 9000 + +#### Scenario: 配置文件次低优先级 + +- **WHEN** 配置文件设置 `server.port: 9826` +- **AND** 未设置环境变量 +- **AND** 未设置 CLI 参数 +- **THEN** SHALL 使用配置文件值 9826 + +#### Scenario: 默认值最低优先级 + +- **WHEN** 配置文件中未设置某配置项 +- **AND** 未设置环境变量 +- **AND** 未设置 CLI 参数 +- **THEN** SHALL 使用默认值 + +### Requirement: 配置来源追踪 + +系统 SHALL 追踪每个配置值的来源。 + +#### Scenario: 来源记录 + +- **WHEN** 加载配置完成 +- **THEN** SHALL 记录每个配置项的来源(CLI/ENV/File/Default) +- **THEN** SHALL 在配置摘要中显示来源信息 + +#### Scenario: 来源统计 + +- **WHEN** 打印配置摘要 +- **THEN** SHALL 统计来自 CLI 参数的配置项数量 +- **THEN** SHALL 统计来自环境变量的配置项数量 +- **THEN** SHALL 统计来自配置文件的配置项数量 +- **THEN** SHALL 统计使用默认值的配置项数量 + +### Requirement: 配置覆盖透明性 + +系统 SHALL 提供配置覆盖的透明信息。 + +#### Scenario: 覆盖提示 + +- **WHEN** CLI 参数覆盖配置文件值 +- **THEN** SHALL 在日志中记录覆盖信息 +- **THEN** SHALL 显示被覆盖的配置项名称 + +#### Scenario: 配置摘要展示 + +- **WHEN** 应用启动完成 +- **THEN** SHALL 打印配置摘要 +- **THEN** SHALL 显示关键配置项的最终值 +- **THEN** SHALL 显示配置文件路径 +- **THEN** SHALL 显示环境变量数量 +- **THEN** SHALL 显示 CLI 参数数量 + +### Requirement: 部分配置覆盖 + +系统 SHALL 支持部分配置覆盖。 + +#### Scenario: 混合配置源 + +- **WHEN** 配置文件设置完整配置 +- **AND** CLI 参数仅覆盖部分配置项 +- **THEN** SHALL 合并所有配置源 +- **THEN** SHALL 使用 CLI 参数覆盖指定项 +- **THEN** SHALL 保留配置文件中的其他配置项 + +#### Scenario: 配置项独立覆盖 + +- **WHEN** 仅通过 CLI 参数设置 `--server-port 9000` +- **THEN** SHALL 仅覆盖 server.port 配置项 +- **THEN** SHALL NOT 影响其他配置项 +- **THEN** SHALL 其他配置项使用配置文件或默认值 + +### Requirement: 配置优先级不可变性 + +系统 SHALL 确保配置优先级在运行时不可变。 + +#### Scenario: 启动后配置锁定 + +- **WHEN** 应用启动完成 +- **THEN** SHALL 锁定配置值 +- **THEN** SHALL NOT 支持运行时修改配置优先级 +- **THEN** SHALL NOT 支持运行时添加新配置源 + +注:配置热重载为未来扩展功能,当前版本不支持。 diff --git a/openspec/changes/implement-viper-config/specs/env-config/spec.md b/openspec/changes/implement-viper-config/specs/env-config/spec.md new file mode 100644 index 0000000..ade0ef2 --- /dev/null +++ b/openspec/changes/implement-viper-config/specs/env-config/spec.md @@ -0,0 +1,107 @@ +# Env Config + +## ADDED Requirements + +### Requirement: 环境变量配置支持 + +系统 SHALL 支持通过环境变量设置所有配置项。 + +#### Scenario: 环境变量读取 + +- **WHEN** 应用启动时存在环境变量 +- **THEN** SHALL 自动读取所有 `NEX_` 前缀的环境变量 +- **THEN** SHALL 将环境变量值应用到对应配置项 + +#### Scenario: 环境变量命名规范 + +- **WHEN** 使用环境变量配置 +- **THEN** SHALL 使用 `NEX_` 前缀 +- **THEN** SHALL 使用大写字母和下划线分隔(如 `NEX_SERVER_PORT`) +- **THEN** SHALL 保持完整层次结构(如 `NEX_DATABASE_MAX_IDLE_CONNS`) +- **THEN** SHALL 与配置文件路径对应(`database.max_idle_conns` → `NEX_DATABASE_MAX_IDLE_CONNS`) + +#### Scenario: 环境变量类型转换 + +- **WHEN** 解析不同类型的环境变量 +- **THEN** SHALL 支持 int 类型(如 `NEX_SERVER_PORT=9000`) +- **THEN** SHALL 支持 string 类型(如 `NEX_DATABASE_PATH=/data/nex.db`) +- **THEN** SHALL 支持 duration 类型(如 `NEX_SERVER_READ_TIMEOUT=60s`) +- **THEN** SHALL 支持 bool 类型(如 `NEX_LOG_COMPRESS=true`) + +### Requirement: 完整配置覆盖 + +系统 SHALL 支持通过环境变量覆盖所有配置项。 + +#### Scenario: 服务器配置环境变量 + +- **WHEN** 设置服务器相关环境变量 +- **THEN** SHALL 支持 `NEX_SERVER_PORT` +- **THEN** SHALL 支持 `NEX_SERVER_READ_TIMEOUT` +- **THEN** SHALL 支持 `NEX_SERVER_WRITE_TIMEOUT` + +#### Scenario: 数据库配置环境变量 + +- **WHEN** 设置数据库相关环境变量 +- **THEN** SHALL 支持 `NEX_DATABASE_PATH` +- **THEN** SHALL 支持 `NEX_DATABASE_MAX_IDLE_CONNS` +- **THEN** SHALL 支持 `NEX_DATABASE_MAX_OPEN_CONNS` +- **THEN** SHALL 支持 `NEX_DATABASE_CONN_MAX_LIFETIME` + +#### Scenario: 日志配置环境变量 + +- **WHEN** 设置日志相关环境变量 +- **THEN** SHALL 支持 `NEX_LOG_LEVEL` +- **THEN** SHALL 支持 `NEX_LOG_PATH` +- **THEN** SHALL 支持 `NEX_LOG_MAX_SIZE` +- **THEN** SHALL 支持 `NEX_LOG_MAX_BACKUPS` +- **THEN** SHALL 支持 `NEX_LOG_MAX_AGE` +- **THEN** SHALL 支持 `NEX_LOG_COMPRESS` + +### Requirement: 环境变量优先级 + +系统 SHALL 确保环境变量优先级高于配置文件但低于 CLI 参数。 + +#### Scenario: 环境变量覆盖配置文件 + +- **WHEN** 配置文件设置 `server.port: 9826` +- **AND** 环境变量设置 `NEX_SERVER_PORT=9000` +- **THEN** SHALL 使用环境变量值 9000 + +#### Scenario: CLI 参数覆盖环境变量 + +- **WHEN** 环境变量设置 `NEX_SERVER_PORT=9000` +- **AND** CLI 参数设置 `--server-port 8080` +- **THEN** SHALL 使用 CLI 参数值 8080 + +### Requirement: 12-Factor App 合规 + +系统 SHALL 符合 12-Factor App 配置原则。 + +#### Scenario: 配置与代码分离 + +- **WHEN** 应用部署到不同环境 +- **THEN** SHALL 通过环境变量区分环境配置 +- **THEN** SHALL NOT 修改代码或配置文件 + +#### Scenario: 敏感信息保护 + +- **WHEN** 配置包含敏感信息(如密钥、密码) +- **THEN** SHALL 通过环境变量传递 +- **THEN** SHALL NOT 存储在配置文件中 + +### Requirement: 环境变量错误处理 + +系统 SHALL 正确处理环境变量错误。 + +#### Scenario: 无效环境变量值 + +- **WHEN** 环境变量值格式无效(如 `NEX_SERVER_PORT=abc`) +- **THEN** SHALL 返回明确的错误信息 +- **THEN** SHALL 指示环境变量名称和期望类型 +- **THEN** SHALL NOT 启动应用 + +#### Scenario: 环境变量缺失 + +- **WHEN** 必需配置项既无配置文件也无环境变量 +- **THEN** SHALL 使用默认值 +- **THEN** SHALL 正常启动应用 diff --git a/openspec/changes/implement-viper-config/tasks.md b/openspec/changes/implement-viper-config/tasks.md new file mode 100644 index 0000000..befdcc3 --- /dev/null +++ b/openspec/changes/implement-viper-config/tasks.md @@ -0,0 +1,52 @@ +## 1. 依赖管理 + +- [ ] 1.1 在 go.mod 中添加 Viper、pflag、validator 依赖 +- [ ] 1.2 移除 gopkg.in/yaml.v3 依赖(Viper 内置 YAML 支持) +- [ ] 1.3 运行 go mod tidy 更新依赖树 + +## 2. 配置结构体重构 + +- [ ] 2.1 为 Config 结构体添加 validate tag 验证规则 +- [ ] 2.2 更新 Validate() 方法使用 validator 库进行验证 +- [ ] 2.3 添加配置摘要打印方法 PrintSummary() + +## 3. 配置加载逻辑重构 + +- [ ] 3.1 创建 setupDefaults() 函数设置默认配置值 +- [ ] 3.2 创建 setupFlags() 函数定义和绑定 CLI 参数 +- [ ] 3.3 创建 setupEnv() 函数绑定环境变量 +- [ ] 3.4 创建 setupConfigFile() 函数读取配置文件 +- [ ] 3.5 重构 LoadConfig() 函数,按顺序调用上述函数 +- [ ] 3.6 添加 LoadConfigFromPath() 函数支持自定义配置文件路径 + +## 4. 主程序修改 + +- [ ] 4.1 在 main.go 中添加 CLI 参数解析逻辑 +- [ ] 4.2 修改配置加载流程,使用新的 LoadConfig() +- [ ] 4.3 添加配置摘要打印调用 + +## 5. 测试更新 + +- [ ] 5.1 更新 TestDefaultConfig 测试新的默认值设置方式 +- [ ] 5.2 更新 TestConfig_Validate 测试新的验证规则 +- [ ] 5.3 添加 CLI 参数配置测试 +- [ ] 5.4 添加环境变量配置测试 +- [ ] 5.5 添加配置优先级测试 +- [ ] 5.6 添加配置摘要输出测试 +- [ ] 5.7 确保所有测试通过 + +## 6. 文档更新 + +- [ ] 6.1 更新 README.md 配置说明部分 +- [ ] 6.2 添加 CLI 参数使用示例 +- [ ] 6.3 添加环境变量配置示例 +- [ ] 6.4 添加配置优先级说明 + +## 7. 验证与清理 + +- [ ] 7.1 运行完整测试套件,确保所有测试通过 +- [ ] 7.2 本地测试:使用 CLI 参数启动应用 +- [ ] 7.3 本地测试:使用环境变量启动应用 +- [ ] 7.4 本地测试:混合使用 CLI 参数和环境变量 +- [ ] 7.5 验证配置摘要输出正确 +- [ ] 7.6 清理代码,移除不再使用的函数和导入 diff --git a/openspec/specs/error-handling/spec.md b/openspec/specs/error-handling/spec.md index abb1156..05a8f8f 100644 --- a/openspec/specs/error-handling/spec.md +++ b/openspec/specs/error-handling/spec.md @@ -47,6 +47,24 @@ - **THEN** SHALL 使用 ErrInternal 等预定义错误 - **THEN** SHALL 设置 HTTP 状态码为 500 +#### Scenario: 请求创建错误 + +- **WHEN** 创建 HTTP 请求失败 +- **THEN** SHALL 使用 ErrRequestCreate 预定义错误 +- **THEN** SHALL 设置 HTTP 状态码为 500 + +#### Scenario: 请求发送错误 + +- **WHEN** 发送 HTTP 请求失败 +- **THEN** SHALL 使用 ErrRequestSend 预定义错误 +- **THEN** SHALL 设置 HTTP 状态码为 500 + +#### Scenario: 响应读取错误 + +- **WHEN** 读取 HTTP 响应失败 +- **THEN** SHALL 使用 ErrResponseRead 预定义错误 +- **THEN** SHALL 设置 HTTP 状态码为 500 + ### Requirement: 支持错误包装 系统 SHALL 支持错误包装。 @@ -127,6 +145,30 @@ - **THEN** SHALL 包装数据库错误 - **THEN** SHALL 转换为应用错误 +### Requirement: 使用类型安全错误判断 + +系统 SHALL 使用类型安全方式判断错误类型。 + +#### Scenario: 数据库错误判断 + +- **WHEN** 判断数据库唯一约束错误 +- **THEN** SHALL 使用 errors.Is(err, gorm.ErrDuplicatedKey) +- **THEN** SHALL NOT 使用字符串匹配 err.Error() + +#### Scenario: 网络错误判断 + +- **WHEN** 判断网络错误 +- **THEN** SHALL 使用 errors.As(err, &net.Error) 判断网络错误 +- **THEN** SHALL 使用 errors.As(err, &net.OpError) 判断操作错误 +- **THEN** SHALL 使用 errors.Is(opErr.Err, syscall.ECONNRESET) 判断连接重置 +- **THEN** SHALL NOT 使用字符串匹配判断错误类型 + +#### Scenario: 错误链判断 + +- **WHEN** 判断错误链中的特定错误 +- **THEN** SHALL 使用 errors.Is 进行链式判断 +- **THEN** SHALL 使用 errors.As 提取特定类型错误 + ## ADDED Requirements ### Requirement: 定义 ConversionError 错误类型 diff --git a/openspec/specs/request-validation/spec.md b/openspec/specs/request-validation/spec.md index a55e0f2..c1ca307 100644 --- a/openspec/specs/request-validation/spec.md +++ b/openspec/specs/request-validation/spec.md @@ -120,3 +120,29 @@ - **WHEN** 处理请求 - **THEN** SHALL 在 handler 函数开始时验证 - **THEN** SHALL 在验证通过后才执行业务逻辑 + +### Requirement: 使用标准库解析 JSON + +系统 SHALL 使用 encoding/json 标准库解析 JSON 请求。 + +#### Scenario: 提取 model 字段 + +- **WHEN** 从请求体提取 model 字段 +- **THEN** SHALL 使用 json.Unmarshal 解析到结构体 +- **THEN** SHALL NOT 手动扫描字节查找字段 +- **THEN** 解析失败 SHALL 返回空字符串(不报错) + +#### Scenario: 检测 stream 字段 + +- **WHEN** 检测请求是否为流式请求 +- **THEN** SHALL 使用 json.Unmarshal 解析到结构体 +- **THEN** SHALL NOT 手动扫描字节查找字段 +- **THEN** 解析失败 SHALL 返回 false(非流式) + +#### Scenario: JSON 解析健壮性 + +- **WHEN** 解析 JSON 请求体 +- **THEN** SHALL 正确处理转义字符 +- **THEN** SHALL 正确处理嵌套结构 +- **THEN** SHALL 正确处理 Unicode 字符 +- **THEN** 解析失败 SHALL 有明确的错误处理 diff --git a/openspec/specs/structured-logging/spec.md b/openspec/specs/structured-logging/spec.md index fd4c216..ac4456b 100644 --- a/openspec/specs/structured-logging/spec.md +++ b/openspec/specs/structured-logging/spec.md @@ -20,6 +20,13 @@ - **THEN** SHALL 支持嵌套字段 - **THEN** SHALL 自动包含时间戳和日志级别 +#### Scenario: 日志注入 + +- **WHEN** 创建需要记录日志的组件 +- **THEN** SHALL 通过构造函数注入 *zap.Logger +- **THEN** SHALL 允许 logger 参数为 nil,此时使用全局 logger zap.L() +- **THEN** SHALL NOT 直接使用全局 logger zap.L()(除非在构造函数默认值中) + ### Requirement: 支持日志滚动 系统 SHALL 支持日志文件滚动,使用 lumberjack。 @@ -122,3 +129,20 @@ - **WHEN** 创建日志文件 - **THEN** SHALL 使用 `nex-YYYY-MM-DD.log` 格式命名 - **THEN** SHALL 按日期创建新文件 + +### Requirement: ConversionEngine 日志注入 + +ConversionEngine SHALL 通过依赖注入获取 logger。 + +#### Scenario: ConversionEngine 构造函数 + +- **WHEN** 创建 ConversionEngine 实例 +- **THEN** 构造函数 SHALL 接受 *zap.Logger 参数 +- **THEN** 参数为 nil 时 SHALL 使用 zap.L() 作为默认值 +- **THEN** SHALL 将 logger 存储在结构体字段中 + +#### Scenario: ConversionEngine 日志使用 + +- **WHEN** ConversionEngine 记录日志 +- **THEN** SHALL 使用注入的 logger 字段 +- **THEN** SHALL NOT 直接调用 zap.L()