package conversion import ( "encoding/json" "testing" "nex/backend/internal/conversion/canonical" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" ) // mockProtocolAdapter 模拟协议适配器 type mockProtocolAdapter struct { protocolName string passthrough bool ifaceType InterfaceType supportsIface map[InterfaceType]bool decodeReqFn func([]byte) (*canonical.CanonicalRequest, error) encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error) decodeRespFn func([]byte) (*canonical.CanonicalResponse, error) encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error) streamDecoderFn func() StreamDecoder streamEncoderFn func() StreamEncoder } func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter { return &mockProtocolAdapter{ protocolName: name, passthrough: passthrough, ifaceType: InterfaceTypeChat, supportsIface: map[InterfaceType]bool{}, } } func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName } func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" } func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough } func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType { return m.ifaceType } func (m *mockProtocolAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string { return nativePath } func (m *mockProtocolAdapter) BuildHeaders(provider *TargetProvider) map[string]string { return map[string]string{"Authorization": "Bearer " + provider.APIKey} } func (m *mockProtocolAdapter) SupportsInterface(interfaceType InterfaceType) bool { if v, ok := m.supportsIface[interfaceType]; ok { return v } return interfaceType == InterfaceTypeChat } func (m *mockProtocolAdapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) { if m.decodeReqFn != nil { return m.decodeReqFn(raw) } req := &canonical.CanonicalRequest{} _ = json.Unmarshal(raw, req) return req, nil } func (m *mockProtocolAdapter) EncodeRequest(req *canonical.CanonicalRequest, provider *TargetProvider) ([]byte, error) { if m.encodeReqFn != nil { return m.encodeReqFn(req, provider) } return json.Marshal(req) } func (m *mockProtocolAdapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) { if m.decodeRespFn != nil { return m.decodeRespFn(raw) } resp := &canonical.CanonicalResponse{} _ = json.Unmarshal(raw, resp) return resp, nil } func (m *mockProtocolAdapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { if m.encodeRespFn != nil { return m.encodeRespFn(resp) } return json.Marshal(resp) } func (m *mockProtocolAdapter) CreateStreamDecoder() StreamDecoder { if m.streamDecoderFn != nil { return m.streamDecoderFn() } return &noopStreamDecoder{} } func (m *mockProtocolAdapter) CreateStreamEncoder() StreamEncoder { if m.streamEncoderFn != nil { return m.streamEncoderFn() } return &noopStreamEncoder{} } func (m *mockProtocolAdapter) EncodeError(err *ConversionError) ([]byte, int) { return []byte(`{"error":"mock"}`), 400 } func (m *mockProtocolAdapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) { return &canonical.CanonicalModelList{}, nil } func (m *mockProtocolAdapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) { return json.Marshal(list) } func (m *mockProtocolAdapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) { return &canonical.CanonicalModelInfo{}, nil } func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) { return json.Marshal(info) } func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) { return &canonical.CanonicalEmbeddingRequest{}, nil } func (m *mockProtocolAdapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *TargetProvider) ([]byte, error) { return json.Marshal(req) } func (m *mockProtocolAdapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) { return &canonical.CanonicalEmbeddingResponse{}, nil } func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) { return json.Marshal(resp) } func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) { return &canonical.CanonicalRerankRequest{}, nil } func (m *mockProtocolAdapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error) { return json.Marshal(req) } func (m *mockProtocolAdapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) { return &canonical.CanonicalRerankResponse{}, nil } func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { return json.Marshal(resp) } // noopStreamDecoder 空流式解码器 type noopStreamDecoder struct{} func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil } func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil } // noopStreamEncoder 空流式编码器 type noopStreamEncoder struct{} func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil } func (e *noopStreamEncoder) Flush() [][]byte { return nil } // ============ 测试用例 ============ func TestNewConversionEngine(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) assert.NotNil(t, engine) assert.Equal(t, registry, engine.GetRegistry()) } func TestNewConversionEngine_LoggerInjection(t *testing.T) { t.Run("nil_logger_uses_global", func(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) assert.NotNil(t, engine.logger) }) t.Run("custom_logger", func(t *testing.T) { registry := NewMemoryRegistry() customLogger := zap.NewNop() engine := NewConversionEngine(registry, customLogger) assert.Equal(t, customLogger, engine.logger) }) } func TestRegisterAdapter(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) adapter := newMockAdapter("test-proto", true) err := engine.RegisterAdapter(adapter) require.NoError(t, err) protocols := registry.ListProtocols() assert.Contains(t, protocols, "test-proto") } func TestIsPassthrough_SameProtocol(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) adapter := newMockAdapter("openai", true) _ = engine.RegisterAdapter(adapter) assert.True(t, engine.IsPassthrough("openai", "openai")) } func TestIsPassthrough_DifferentProtocol(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) _ = engine.RegisterAdapter(newMockAdapter("anthropic", true)) assert.False(t, engine.IsPassthrough("openai", "anthropic")) } func TestIsPassthrough_NoPassthrough(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("custom", false)) assert.False(t, engine.IsPassthrough("custom", "custom")) } func TestDetectInterfaceType(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) adapter := newMockAdapter("test", true) adapter.ifaceType = InterfaceTypeChat _ = engine.RegisterAdapter(adapter) ifaceType, err := engine.DetectInterfaceType("/v1/chat/completions", "test") require.NoError(t, err) assert.Equal(t, InterfaceTypeChat, ifaceType) } func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) _, err := engine.DetectInterfaceType("/v1/chat", "nonexistent") assert.Error(t, err) } func TestConvertHttpRequest_Passthrough(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) provider := NewTargetProvider("https://api.openai.com", "sk-test", "gpt-4") spec := HTTPRequestSpec{ URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`), } result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider) require.NoError(t, err) assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL) assert.Equal(t, spec.Body, result.Body) } func TestConvertHttpRequest_CrossProtocol(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) clientAdapter := newMockAdapter("client-proto", false) clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) { return &canonical.CanonicalRequest{ Model: "test-model", Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}}, }, nil } _ = engine.RegisterAdapter(clientAdapter) providerAdapter := newMockAdapter("provider-proto", false) providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) { return json.Marshal(map[string]any{"model": p.ModelName}) } _ = engine.RegisterAdapter(providerAdapter) provider := NewTargetProvider("https://example.com", "key", "my-model") spec := HTTPRequestSpec{ URL: "/v1/chat", Method: "POST", Body: []byte(`{"model":"test"}`), } result, err := engine.ConvertHttpRequest(spec, "client-proto", "provider-proto", provider) require.NoError(t, err) assert.Contains(t, result.URL, "https://example.com") assert.NotNil(t, result.Body) } func TestConvertHttpResponse_Passthrough(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) spec := HTTPResponseSpec{ StatusCode: 200, Body: []byte(`{"id":"123"}`), } result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat) require.NoError(t, err) assert.Equal(t, 200, result.StatusCode) assert.Equal(t, spec.Body, result.Body) } func TestCreateStreamConverter_Passthrough(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) converter, err := engine.CreateStreamConverter("openai", "openai") require.NoError(t, err) _, ok := converter.(*PassthroughStreamConverter) assert.True(t, ok) } func TestCreateStreamConverter_Canonical(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("client", false)) _ = engine.RegisterAdapter(newMockAdapter("provider", false)) converter, err := engine.CreateStreamConverter("client", "provider") require.NoError(t, err) _, ok := converter.(*CanonicalStreamConverter) assert.True(t, ok) } func TestEncodeError(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误") body, statusCode, err := engine.EncodeError(convErr, "openai") require.NoError(t, err) assert.Equal(t, 400, statusCode) assert.NotNil(t, body) } func TestEncodeError_NonExistentProtocol(t *testing.T) { registry := NewMemoryRegistry() engine := NewConversionEngine(registry, nil) convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误") body, statusCode, err := engine.EncodeError(convErr, "nonexistent") require.NoError(t, err) assert.Equal(t, 500, statusCode) assert.Contains(t, string(body), "测试错误") } func TestRegistry_DuplicateRegistration(t *testing.T) { registry := NewMemoryRegistry() adapter := newMockAdapter("openai", true) err := registry.Register(adapter) require.NoError(t, err) err = registry.Register(adapter) assert.Error(t, err) assert.Contains(t, err.Error(), "适配器已注册") } func TestRegistry_GetNonExistent(t *testing.T) { registry := NewMemoryRegistry() _, err := registry.Get("nonexistent") assert.Error(t, err) assert.Contains(t, err.Error(), "未找到适配器") }