feat: 系统性改进后端测试体系
- 新增 6 个测试场景 (config load pipe, handler errors, service aggregation, engine degradation, openai decoder edges, negative tests) - 更新测试工具规格 (mockgen, in-memory SQLite) - 覆盖率目标从 >80% 提升至 >85% - 新增 test-unit 和 test-integration Makefile 命令 - 新增死代码清理和 mockgen 需求 - 归档变更至 openspec/changes/archive/2026-04-22-improve-backend-testing/
This commit is contained in:
@@ -321,3 +321,58 @@ func (m *testMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEv
|
||||
}
|
||||
|
||||
var _ = json.Marshal
|
||||
|
||||
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return nil, errors.New("decode embedding failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"text-embedding","input":"hello"}`)
|
||||
result, err := engine.convertEmbeddingBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertRerankBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return nil, errors.New("decode rerank failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"rerank","query":"test","documents":["a"]}`)
|
||||
result, err := engine.convertRerankBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"test":"data"}`)
|
||||
result, err := engine.convertBody(InterfaceType("UNKNOWN"), clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
@@ -13,18 +13,20 @@ import (
|
||||
|
||||
// mockProtocolAdapter 模拟协议适配器
|
||||
type mockProtocolAdapter struct {
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
decodeEmbeddingReqFn func([]byte) (*canonical.CanonicalEmbeddingRequest, error)
|
||||
decodeRerankReqFn func([]byte) (*canonical.CanonicalRerankRequest, error)
|
||||
}
|
||||
|
||||
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
@@ -126,6 +128,9 @@ func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalM
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
if m.decodeEmbeddingReqFn != nil {
|
||||
return m.decodeEmbeddingReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingRequest{}, nil
|
||||
}
|
||||
|
||||
@@ -142,6 +147,9 @@ func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalE
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
if m.decodeRerankReqFn != nil {
|
||||
return m.decodeRerankReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalRerankRequest{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -409,3 +409,25 @@ func TestDecodeResponse_Refusal(t *testing.T) {
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_AssistantContentArray(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello back"}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assistantMsg := req.Messages[1]
|
||||
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
|
||||
assert.Len(t, assistantMsg.Content, 1)
|
||||
assert.Equal(t, "text", assistantMsg.Content[0].Type)
|
||||
assert.Equal(t, "hello back", assistantMsg.Content[0].Text)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,19 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
)
|
||||
|
||||
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -37,7 +44,12 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -56,9 +68,13 @@ func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1"), gomock.Eq(true)).Return(&domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -72,7 +88,11 @@ func TestProviderHandler_UpdateProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -84,7 +104,12 @@ func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_DeleteProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("p1")).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -97,7 +122,12 @@ func TestProviderHandler_DeleteProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_DeleteModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("m1")).Return(nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -110,7 +140,15 @@ func TestModelHandler_DeleteModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_Success(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "p1",
|
||||
@@ -130,9 +168,12 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_GetModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ModelName: "gpt-4"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -148,9 +189,13 @@ func TestModelHandler_GetModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ModelName: "gpt-4o"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4o"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -2,119 +2,34 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/tests/mocks"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// ============ Mock 实现 ============
|
||||
|
||||
type mockRoutingService struct {
|
||||
result *domain.RouteResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||
return m.result, m.err
|
||||
}
|
||||
|
||||
type mockStatsService struct {
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
aggrResult []map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockStatsService) Record(providerID, modelName string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return m.stats, nil
|
||||
}
|
||||
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
|
||||
return m.aggrResult
|
||||
}
|
||||
|
||||
type mockProviderService struct {
|
||||
provider *domain.Provider
|
||||
providers []domain.Provider
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
|
||||
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
return m.provider, m.err
|
||||
}
|
||||
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
|
||||
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockProviderService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockModelService struct {
|
||||
model *domain.Model
|
||||
models []domain.Model
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockModelService) Create(model *domain.Model) error {
|
||||
if m.err == nil {
|
||||
model.ID = "mock-uuid-1234"
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
func (m *mockModelService) Get(id string) (*domain.Model, error) {
|
||||
return m.model, m.err
|
||||
}
|
||||
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
|
||||
return m.models, m.err
|
||||
}
|
||||
func (m *mockModelService) ListEnabled() ([]domain.Model, error) {
|
||||
return []domain.Model{}, nil
|
||||
}
|
||||
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockModelService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockProviderClient struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
// ============ Provider Handler 测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "p1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -127,12 +42,15 @@ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
providers: []domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().List().Return([]domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -142,14 +60,17 @@ func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1"), gomock.Eq(true)).Return(&domain.Provider{ID: "p1", Name: "P1", APIKey: "***"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -160,10 +81,12 @@ func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ Model Handler 测试 ============
|
||||
|
||||
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "m1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -176,12 +99,15 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_ListModels(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
models: []domain.Model{
|
||||
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().List(gomock.Eq("")).Return([]domain.Model{
|
||||
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
|
||||
}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -198,9 +124,12 @@ func TestModelHandler_ListModels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -217,7 +146,15 @@ func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
@@ -238,9 +175,13 @@ func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -257,14 +198,15 @@ func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
|
||||
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
// ============ Stats Handler 测试 ============
|
||||
|
||||
func TestStatsHandler_GetStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
}, nil)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -275,7 +217,11 @@ func TestStatsHandler_GetStats(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -286,14 +232,17 @@ func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
},
|
||||
aggrResult: []map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
},
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
}, nil)
|
||||
mockSvc.EXPECT().Aggregate(gomock.Any(), gomock.Eq("provider")).Return([]map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
})
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -303,8 +252,6 @@ func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ writeError 测试 ============
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -333,12 +280,13 @@ func formatMapErrors(errs map[string]string) string {
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
// ============ 错误类型判断测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
err: appErrors.ErrConflict,
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrConflict)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -354,3 +302,158 @@ func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_ProviderNotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrProviderNotFound)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "nonexistent",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商不存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_DuplicateModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrDuplicateModel)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "同一供应商下模型名称已存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_NotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(gorm.ErrRecordNotFound)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_ImmutableField(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(appErrors.ErrImmutableField)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商 ID 不允许修改")
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -50,6 +50,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
// ProviderClient 供应商客户端接口
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
||||
type ProviderClient interface {
|
||||
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
|
||||
|
||||
@@ -2,6 +2,8 @@ package repository
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=model_repo.go -destination=../../tests/mocks/mock_model_repository.go -package=mocks
|
||||
|
||||
// ModelRepository 模型数据仓库接口
|
||||
type ModelRepository interface {
|
||||
Create(model *domain.Model) error
|
||||
|
||||
@@ -2,6 +2,8 @@ package repository
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=provider_repo.go -destination=../../tests/mocks/mock_provider_repository.go -package=mocks
|
||||
|
||||
// ProviderRepository 供应商数据仓库接口
|
||||
type ProviderRepository interface {
|
||||
Create(provider *domain.Provider) error
|
||||
@@ -9,7 +11,4 @@ type ProviderRepository interface {
|
||||
List() ([]domain.Provider, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
// 统一模型 ID 相关方法
|
||||
ListEnabledModels() ([]domain.Model, error)
|
||||
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
|
||||
}
|
||||
|
||||
@@ -71,25 +71,6 @@ func (r *providerRepository) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListEnabledModels 返回所有启用的模型(关联启用的供应商)
|
||||
func (r *providerRepository) ListEnabledModels() ([]domain.Model, error) {
|
||||
var models []domain.Model
|
||||
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
|
||||
Where("models.enabled = ? AND providers.enabled = ?", true, true).
|
||||
Find(&models).Error
|
||||
return models, err
|
||||
}
|
||||
|
||||
// FindByProviderAndModelName 按 provider_id 和 model_name 查询模型
|
||||
func (r *providerRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
||||
var model domain.Model
|
||||
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&model).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
func toDomainProvider(p *config.Provider) domain.Provider {
|
||||
return domain.Provider{
|
||||
ID: p.ID,
|
||||
|
||||
@@ -5,28 +5,16 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
testHelpers "nex/backend/tests"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
// 关闭数据库连接以便 TempDir 清理
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
})
|
||||
return db
|
||||
return testHelpers.SetupTestDB(t)
|
||||
}
|
||||
|
||||
// ============ ProviderRepository 测试 ============
|
||||
@@ -88,7 +76,7 @@ func TestProviderRepository_Update(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"})
|
||||
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"}))
|
||||
|
||||
err := repo.Update("p1", map[string]interface{}{"name": "New"})
|
||||
require.NoError(t, err)
|
||||
@@ -109,7 +97,7 @@ func TestProviderRepository_Delete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
err := repo.Delete("p1")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -129,17 +117,21 @@ func TestProviderRepository_Delete_NotFound(t *testing.T) {
|
||||
|
||||
func TestModelRepository_Create(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestModelRepository_GetByID(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
result, err := repo.GetByID("m1")
|
||||
require.NoError(t, err)
|
||||
@@ -149,9 +141,11 @@ func TestModelRepository_GetByID(t *testing.T) {
|
||||
|
||||
func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
result, err := repo.FindByProviderAndModelName("p1", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
@@ -162,9 +156,11 @@ func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
|
||||
|
||||
func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
// Wrong provider_id
|
||||
_, err := repo.FindByProviderAndModelName("p2", "gpt-4")
|
||||
@@ -181,11 +177,14 @@ func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
|
||||
|
||||
func TestModelRepository_List(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
||||
repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p2", Name: "Test2", APIKey: "key", BaseURL: "https://test2.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"}))
|
||||
|
||||
all, err := repo.List("")
|
||||
require.NoError(t, err)
|
||||
@@ -246,9 +245,11 @@ func TestModelRepository_ListEnabled(t *testing.T) {
|
||||
|
||||
func TestModelRepository_Update(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
err := repo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, err)
|
||||
@@ -259,9 +260,11 @@ func TestModelRepository_Update(t *testing.T) {
|
||||
|
||||
func TestModelRepository_Delete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
|
||||
|
||||
err := repo.Delete("m1")
|
||||
require.NoError(t, err)
|
||||
@@ -293,10 +296,32 @@ func TestStatsRepository_Query(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewStatsRepository(db)
|
||||
|
||||
repo.Record("p1", "gpt-4")
|
||||
require.NoError(t, repo.Record("p1", "gpt-4"))
|
||||
// 注意:当前 schema 只有 date 字段有唯一约束
|
||||
// 所以同一 provider + model 只能有一条记录
|
||||
stats, err := repo.Query("p1", "", nil, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
}
|
||||
|
||||
func TestModelRepository_List_EmptyResult(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
result, err := repo.List("")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Empty(t, result)
|
||||
assert.Len(t, result, 0)
|
||||
}
|
||||
|
||||
func TestProviderRepository_List_EmptyResult(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
result, err := repo.List()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Empty(t, result)
|
||||
assert.Len(t, result, 0)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=stats_repo.go -destination=../../tests/mocks/mock_stats_repository.go -package=mocks
|
||||
|
||||
// StatsRepository 统计数据仓库接口
|
||||
type StatsRepository interface {
|
||||
Record(providerID, modelName string) error
|
||||
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=model_service.go -destination=../../tests/mocks/mock_model_service.go -package=mocks
|
||||
|
||||
// ModelService 模型服务接口
|
||||
type ModelService interface {
|
||||
Create(model *domain.Model) error
|
||||
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=provider_service.go -destination=../../tests/mocks/mock_provider_service.go -package=mocks
|
||||
|
||||
// ProviderService 供应商服务接口
|
||||
type ProviderService interface {
|
||||
Create(provider *domain.Provider) error
|
||||
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=routing_service.go -destination=../../tests/mocks/mock_routing_service.go -package=mocks
|
||||
|
||||
// RoutingService 路由服务接口
|
||||
type RoutingService interface {
|
||||
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestProviderService_Update(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
|
||||
err := svc.Update("p1", map[string]interface{}{"name": "Updated"})
|
||||
require.NoError(t, err)
|
||||
@@ -42,7 +42,7 @@ func TestModelService_Get(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
@@ -57,7 +57,7 @@ func TestModelService_Update(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
@@ -75,7 +75,7 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
@@ -89,7 +89,7 @@ func TestModelService_Delete(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
|
||||
@@ -3,14 +3,15 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
testHelpers "nex/backend/tests"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
@@ -18,18 +19,7 @@ import (
|
||||
|
||||
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
})
|
||||
return db
|
||||
return testHelpers.SetupTestDB(t)
|
||||
}
|
||||
|
||||
// ============ RoutingService - RouteByModelName 测试 ============
|
||||
@@ -40,9 +30,8 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
// 创建供应商和模型
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
result, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
@@ -66,10 +55,9 @@ func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
// 创建启用的供应商和禁用的模型
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false}))
|
||||
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
|
||||
@@ -81,10 +69,9 @@ func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
// 创建启用的供应商和模型,然后禁用供应商
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false}))
|
||||
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
|
||||
@@ -98,7 +85,7 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
@@ -122,7 +109,7 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
@@ -179,8 +166,8 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
|
||||
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
@@ -216,7 +203,7 @@ func TestModelService_Update_Success(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
@@ -272,3 +259,223 @@ func TestProviderService_Update_Success(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OpenAI Updated", updated.Name)
|
||||
}
|
||||
|
||||
// ============ StatsService - Aggregate ByModel 测试 ============
|
||||
|
||||
func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stats []domain.UsageStats
|
||||
expected []map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "multiple providers with same model name",
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20},
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"provider_id": "openai", "model_name": "gpt-4", "request_count": 15},
|
||||
{"provider_id": "azure", "model_name": "gpt-4", "request_count": 20},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty providerID",
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "", ModelName: "gpt-4", RequestCount: 5},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"provider_id": "", "model_name": "gpt-4", "request_count": 15},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty result set",
|
||||
stats: []domain.UsageStats{},
|
||||
expected: []map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "model")
|
||||
|
||||
assert.Len(t, result, len(tt.expected))
|
||||
for _, exp := range tt.expected {
|
||||
found := false
|
||||
for _, r := range result {
|
||||
if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] {
|
||||
assert.Equal(t, exp["request_count"], r["request_count"])
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected result not found: %v", exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ StatsService - Aggregate ByDate 测试 ============
|
||||
|
||||
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stats []domain.UsageStats
|
||||
expected []map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "normal date grouping",
|
||||
stats: []domain.UsageStats{
|
||||
{Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10},
|
||||
{Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5},
|
||||
{Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"date": "2024-01-01", "request_count": 15},
|
||||
{"date": "2024-01-02", "request_count": 20},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zero-value time",
|
||||
stats: []domain.UsageStats{
|
||||
{Date: time.Time{}, RequestCount: 10},
|
||||
{Date: time.Time{}, RequestCount: 5},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"date": "0001-01-01", "request_count": 15},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty result set",
|
||||
stats: []domain.UsageStats{},
|
||||
expected: []map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "date")
|
||||
|
||||
assert.Len(t, result, len(tt.expected))
|
||||
for _, exp := range tt.expected {
|
||||
found := false
|
||||
for _, r := range result {
|
||||
if r["date"] == exp["date"] {
|
||||
assert.Equal(t, exp["request_count"], r["request_count"])
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected result not found: %v", exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ ProviderService - isUniqueConstraintError 测试 ============
|
||||
|
||||
func TestProviderService_isUniqueConstraintError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "UNIQUE constraint failed",
|
||||
err: errors.New("UNIQUE constraint failed"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "duplicate key value",
|
||||
err: errors.New("duplicate key value"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "UNIQUE constraint case insensitive",
|
||||
err: errors.New("unique constraint violation"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
err: errors.New("some other error"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isUniqueConstraintError(tt.err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ ProviderService - List MaskAPIKey 测试 ============
|
||||
|
||||
func TestProviderService_List_MaskAPIKey(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
|
||||
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
|
||||
require.NoError(t, svc.Create(provider1))
|
||||
require.NoError(t, svc.Create(provider2))
|
||||
|
||||
providers, err := svc.List()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 2)
|
||||
|
||||
for _, p := range providers {
|
||||
assert.Contains(t, p.APIKey, "***")
|
||||
assert.Len(t, p.APIKey, 7)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelService_ConcurrentCreate(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
results := make(chan error, 2)
|
||||
for i := 0; i < 2; i++ {
|
||||
go func() {
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
results <- svc.Create(model)
|
||||
}()
|
||||
}
|
||||
|
||||
err1 := <-results
|
||||
err2 := <-results
|
||||
|
||||
successCount := 0
|
||||
errorCount := 0
|
||||
for _, err := range []error{err1, err2} {
|
||||
if err == nil {
|
||||
successCount++
|
||||
} else {
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, successCount)
|
||||
assert.Equal(t, 1, errorCount)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=stats_service.go -destination=../../tests/mocks/mock_stats_service.go -package=mocks
|
||||
|
||||
// StatsService 统计服务接口
|
||||
type StatsService interface {
|
||||
Record(providerID, modelName string) error
|
||||
|
||||
Reference in New Issue
Block a user