1
0

feat: 系统性改进后端测试体系

- 新增 6 个测试场景 (config load pipe, handler errors, service aggregation, engine degradation, openai decoder edges, negative tests)
- 更新测试工具规格 (mockgen, in-memory SQLite)
- 覆盖率目标从 >80% 提升至 >85%
- 新增 test-unit 和 test-integration Makefile 命令
- 新增死代码清理和 mockgen 需求
- 归档变更至 openspec/changes/archive/2026-04-22-improve-backend-testing/
This commit is contained in:
2026-04-22 13:18:51 +08:00
parent 59179094ed
commit 4e86adffb7
32 changed files with 3374 additions and 729 deletions

View File

@@ -1,4 +1,4 @@
.PHONY: build run test test-coverage clean migrate-up migrate-down migrate-status migrate-create lint .PHONY: build run test test-unit test-integration test-coverage clean migrate-up migrate-down migrate-status migrate-create lint generate deps
# 构建 # 构建
build: build:
@@ -12,6 +12,14 @@ run:
test: test:
go test ./... -v go test ./... -v
# 单元测试
test-unit:
go test ./internal/... ./pkg/... -v
# 集成测试
test-integration:
go test ./tests/... -v
# 测试覆盖率 # 测试覆盖率
test-coverage: test-coverage:
go test ./... -coverprofile=coverage.out go test ./... -coverprofile=coverage.out
@@ -38,8 +46,12 @@ migrate-create:
# 代码检查 # 代码检查
lint: lint:
golangci-lint run ./... go tool golangci-lint run ./...
# 安装依赖 # 安装依赖
deps: deps:
go mod tidy go mod tidy
# 生成代码mock 等)
generate:
go generate ./...

View File

@@ -2,6 +2,11 @@ module nex/backend
go 1.26.2 go 1.26.2
tool (
github.com/golangci/golangci-lint/cmd/golangci-lint
go.uber.org/mock/mockgen
)
require ( require (
github.com/gin-gonic/gin v1.12.0 github.com/gin-gonic/gin v1.12.0
github.com/go-playground/validator/v10 v10.30.2 github.com/go-playground/validator/v10 v10.30.2
@@ -11,6 +16,7 @@ require (
github.com/spf13/pflag v1.0.10 github.com/spf13/pflag v1.0.10
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
go.uber.org/mock v0.6.0
go.uber.org/zap v1.27.1 go.uber.org/zap v1.27.1
gopkg.in/lumberjack.v2 v2.0.0 gopkg.in/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
@@ -19,52 +25,211 @@ require (
) )
require ( require (
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
4d63.com/gochecknoglobals v0.2.2 // indirect
github.com/4meepo/tagalign v1.4.2 // indirect
github.com/Abirdcfly/dupword v0.1.3 // indirect
github.com/Antonboom/errname v1.0.0 // indirect
github.com/Antonboom/nilnil v1.0.1 // indirect
github.com/Antonboom/testifylint v1.5.2 // indirect
github.com/BurntSushi/toml v1.6.0 // indirect github.com/BurntSushi/toml v1.6.0 // indirect
github.com/Crocmagnon/fatcontext v0.7.1 // indirect
github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24 // indirect
github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
github.com/OpenPeeDeeP/depguard/v2 v2.2.1 // indirect
github.com/alecthomas/go-check-sumtype v0.3.1 // indirect
github.com/alexkohler/nakedret/v2 v2.0.5 // indirect
github.com/alexkohler/prealloc v1.0.0 // indirect
github.com/alingse/asasalint v0.0.11 // indirect
github.com/alingse/nilnesserr v0.1.2 // indirect
github.com/ashanbrown/forbidigo v1.6.0 // indirect
github.com/ashanbrown/makezero v1.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bkielbasa/cyclop v1.2.3 // indirect
github.com/blizzy78/varnamelen v0.8.0 // indirect
github.com/bombsimon/wsl/v4 v4.5.0 // indirect
github.com/breml/bidichk v0.3.2 // indirect
github.com/breml/errchkjson v0.4.0 // indirect
github.com/butuzov/ireturn v0.3.1 // indirect
github.com/butuzov/mirror v1.3.0 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/catenacyber/perfsprint v0.8.2 // indirect
github.com/ccojocar/zxcvbn-go v1.0.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charithe/durationcheck v0.0.10 // indirect
github.com/chavacava/garif v0.1.0 // indirect
github.com/ckaznocha/intrange v0.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/base64x v0.1.6 // indirect
github.com/curioswitch/go-reassign v0.3.0 // indirect
github.com/daixiang0/gci v0.13.5 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/denis-tingaikin/go-header v0.5.0 // indirect
github.com/ettle/strcase v0.2.0 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/fatih/structtag v1.2.0 // indirect
github.com/firefart/nonamedreturns v1.0.5 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/fzipp/gocyclo v0.6.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.13 // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect
github.com/ghostiam/protogetter v0.3.9 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-critic/go-critic v0.12.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-toolsmith/astcast v1.1.0 // indirect
github.com/go-toolsmith/astcopy v1.1.0 // indirect
github.com/go-toolsmith/astequal v1.2.0 // indirect
github.com/go-toolsmith/astfmt v1.1.0 // indirect
github.com/go-toolsmith/astp v1.1.0 // indirect
github.com/go-toolsmith/strparse v1.1.0 // indirect
github.com/go-toolsmith/typep v1.1.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/go-xmlfmt/xmlfmt v1.1.3 // indirect
github.com/gobwas/glob v0.2.3 // indirect
github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect github.com/goccy/go-yaml v1.19.2 // indirect
github.com/gofrs/flock v0.12.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32 // indirect
github.com/golangci/go-printf-func-name v0.1.0 // indirect
github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d // indirect
github.com/golangci/golangci-lint v1.64.8 // indirect
github.com/golangci/misspell v0.6.0 // indirect
github.com/golangci/plugin-module-register v0.1.1 // indirect
github.com/golangci/revgrep v0.8.0 // indirect
github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/gordonklaus/ineffassign v0.1.0 // indirect
github.com/gostaticanalysis/analysisutil v0.7.1 // indirect
github.com/gostaticanalysis/comment v1.5.0 // indirect
github.com/gostaticanalysis/forcetypeassert v0.2.0 // indirect
github.com/gostaticanalysis/nilerr v0.1.1 // indirect
github.com/hashicorp/go-immutable-radix/v2 v2.1.0 // indirect
github.com/hashicorp/go-version v1.7.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jgautheron/goconst v1.7.1 // indirect
github.com/jingyugao/rowserrcheck v1.1.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/jjti/go-spancheck v0.6.4 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/julz/importas v0.2.0 // indirect
github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect
github.com/kisielk/errcheck v1.9.0 // indirect
github.com/kkHAIKE/contextcheck v1.1.6 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/kulti/thelper v0.6.3 // indirect
github.com/kunwardeep/paralleltest v1.0.10 // indirect
github.com/lasiar/canonicalheader v1.1.2 // indirect
github.com/ldez/exptostd v0.4.2 // indirect
github.com/ldez/gomoddirectives v0.6.1 // indirect
github.com/ldez/grignotin v0.9.0 // indirect
github.com/ldez/tagliatelle v0.7.1 // indirect
github.com/ldez/usetesting v0.4.2 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/leonklingele/grouper v1.1.2 // indirect
github.com/macabu/inamedparam v0.1.3 // indirect
github.com/maratori/testableexamples v1.0.0 // indirect
github.com/maratori/testpackage v1.1.1 // indirect
github.com/matoous/godox v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect github.com/mfridman/interpolate v0.0.2 // indirect
github.com/mgechev/revive v1.7.0 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/moricho/tparallel v0.3.2 // indirect
github.com/nakabonne/nestif v0.3.1 // indirect
github.com/nishanths/exhaustive v0.12.0 // indirect
github.com/nishanths/predeclared v0.2.2 // indirect
github.com/nunnatsa/ginkgolinter v0.19.1 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/polyfloyd/go-errorlint v1.7.1 // indirect
github.com/prometheus/client_golang v1.12.1 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1 // indirect
github.com/quasilyte/go-ruleguard/dsl v0.3.22 // indirect
github.com/quasilyte/gogrep v0.5.0 // indirect
github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect
github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect github.com/quic-go/quic-go v0.59.0 // indirect
github.com/raeperd/recvcheck v0.2.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/ryancurrah/gomodguard v1.3.5 // indirect
github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sanposhiho/wastedassign/v2 v2.1.0 // indirect
github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect
github.com/sashamelentyev/interfacebloat v1.1.0 // indirect
github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect
github.com/securego/gosec/v2 v2.22.2 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sivchari/containedctx v1.0.3 // indirect
github.com/sivchari/tenv v1.12.1 // indirect
github.com/sonatard/noctx v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/sourcegraph/go-diff v0.7.0 // indirect
github.com/spf13/afero v1.15.0 // indirect github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/cobra v1.9.1 // indirect
github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect
github.com/stbenjam/no-sprintf-host-port v0.2.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/tdakkota/asciicheck v0.4.1 // indirect
github.com/tetafro/godot v1.5.0 // indirect
github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 // indirect
github.com/timonwong/loggercheck v0.10.1 // indirect
github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect
github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect
github.com/ultraware/funlen v0.2.0 // indirect
github.com/ultraware/whitespace v0.2.0 // indirect
github.com/uudashr/gocognit v1.2.0 // indirect
github.com/uudashr/iface v1.3.1 // indirect
github.com/xen0n/gosmopolitan v1.2.2 // indirect
github.com/yagipy/maintidx v1.0.0 // indirect
github.com/yeya24/promlinter v0.3.0 // indirect
github.com/ykadowak/zerologlint v0.1.5 // indirect
gitlab.com/bosi/decorder v0.4.2 // indirect
go-simpler.org/musttag v0.13.0 // indirect
go-simpler.org/sloglint v0.9.0 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.49.0 // indirect golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
golang.org/x/mod v0.33.0 // indirect
golang.org/x/net v0.51.0 // indirect golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.20.0 // indirect golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.42.0 // indirect
golang.org/x/tools/go/expect v0.1.1-deprecated // indirect
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
google.golang.org/protobuf v1.36.11 // indirect google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
honnef.co/go/tools v0.6.1 // indirect
mvdan.cc/gofumpt v0.7.0 // indirect
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect
) )

File diff suppressed because it is too large Load Diff

View File

@@ -321,3 +321,58 @@ func (m *testMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEv
} }
var _ = json.Marshal var _ = json.Marshal
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
return nil, errors.New("decode embedding failed")
}
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"model":"text-embedding","input":"hello"}`)
result, err := engine.convertEmbeddingBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestConvertRerankBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
return nil, errors.New("decode rerank failed")
}
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"model":"rerank","query":"test","documents":["a"]}`)
result, err := engine.convertRerankBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
providerAdapter := newMockAdapter("provider", false)
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"test":"data"}`)
result, err := engine.convertBody(InterfaceType("UNKNOWN"), clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
require.NoError(t, err)
assert.Equal(t, body, result)
}

View File

@@ -25,6 +25,8 @@ type mockProtocolAdapter struct {
streamEncoderFn func() StreamEncoder streamEncoderFn func() StreamEncoder
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error) rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error) rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
decodeEmbeddingReqFn func([]byte) (*canonical.CanonicalEmbeddingRequest, error)
decodeRerankReqFn func([]byte) (*canonical.CanonicalRerankRequest, error)
} }
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter { func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
@@ -126,6 +128,9 @@ func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalM
} }
func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) { func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
if m.decodeEmbeddingReqFn != nil {
return m.decodeEmbeddingReqFn(raw)
}
return &canonical.CanonicalEmbeddingRequest{}, nil return &canonical.CanonicalEmbeddingRequest{}, nil
} }
@@ -142,6 +147,9 @@ func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalE
} }
func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) { func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
if m.decodeRerankReqFn != nil {
return m.decodeRerankReqFn(raw)
}
return &canonical.CanonicalRerankRequest{}, nil return &canonical.CanonicalRerankRequest{}, nil
} }

View File

@@ -409,3 +409,25 @@ func TestDecodeResponse_Refusal(t *testing.T) {
} }
assert.True(t, found) assert.True(t, found)
} }
func TestDecodeRequest_AssistantContentArray(t *testing.T) {
body := []byte(`{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "hello"},
{
"role": "assistant",
"content": [{"type": "text", "text": "hello back"}]
}
]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Len(t, req.Messages, 2)
assistantMsg := req.Messages[1]
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
assert.Len(t, assistantMsg.Content, 1)
assert.Equal(t, "text", assistantMsg.Content[0].Type)
assert.Equal(t, "hello back", assistantMsg.Content[0].Text)
}

View File

@@ -9,12 +9,19 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/tests/mocks"
) )
func TestProviderHandler_CreateProvider_Success(t *testing.T) { func TestProviderHandler_CreateProvider_Success(t *testing.T) {
h := NewProviderHandler(&mockProviderService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{
"id": "p1", "id": "p1",
@@ -37,7 +44,12 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
} }
func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) { func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
h := NewProviderHandler(&mockProviderService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{
"id": "p1", "id": "p1",
@@ -56,9 +68,13 @@ func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
} }
func TestProviderHandler_UpdateProvider(t *testing.T) { func TestProviderHandler_UpdateProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{ ctrl := gomock.NewController(t)
provider: &domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"}, defer ctrl.Finish()
})
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(nil)
mockSvc.EXPECT().Get(gomock.Eq("p1"), gomock.Eq(true)).Return(&domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"}, nil)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"name": "Updated"}) body, _ := json.Marshal(map[string]string{"name": "Updated"})
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -72,7 +88,11 @@ func TestProviderHandler_UpdateProvider(t *testing.T) {
} }
func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) { func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
h := NewProviderHandler(&mockProviderService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -84,7 +104,12 @@ func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
} }
func TestProviderHandler_DeleteProvider(t *testing.T) { func TestProviderHandler_DeleteProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Delete(gomock.Eq("p1")).Return(nil)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -97,7 +122,12 @@ func TestProviderHandler_DeleteProvider(t *testing.T) {
} }
func TestModelHandler_DeleteModel(t *testing.T) { func TestModelHandler_DeleteModel(t *testing.T) {
h := NewModelHandler(&mockModelService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Delete(gomock.Eq("m1")).Return(nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -110,7 +140,15 @@ func TestModelHandler_DeleteModel(t *testing.T) {
} }
func TestModelHandler_CreateModel_Success(t *testing.T) { func TestModelHandler_CreateModel_Success(t *testing.T) {
h := NewModelHandler(&mockModelService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
model.ID = "mock-uuid-1234"
return nil
})
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{
"provider_id": "p1", "provider_id": "p1",
@@ -130,9 +168,12 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
} }
func TestModelHandler_GetModel(t *testing.T) { func TestModelHandler_GetModel(t *testing.T) {
h := NewModelHandler(&mockModelService{ ctrl := gomock.NewController(t)
model: &domain.Model{ID: "m1", ModelName: "gpt-4"}, defer ctrl.Finish()
})
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4"}, nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -148,9 +189,13 @@ func TestModelHandler_GetModel(t *testing.T) {
} }
func TestModelHandler_UpdateModel(t *testing.T) { func TestModelHandler_UpdateModel(t *testing.T) {
h := NewModelHandler(&mockModelService{ ctrl := gomock.NewController(t)
model: &domain.Model{ID: "m1", ModelName: "gpt-4o"}, defer ctrl.Finish()
})
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4o"}, nil)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"}) body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"})
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -2,119 +2,34 @@ package handler
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"gorm.io/gorm"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
// ============ Mock 实现 ============
type mockRoutingService struct {
result *domain.RouteResult
err error
}
func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
type mockStatsService struct {
err error
stats []domain.UsageStats
aggrResult []map[string]interface{}
}
func (m *mockStatsService) Record(providerID, modelName string) error {
return m.err
}
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return m.stats, nil
}
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
return m.aggrResult
}
type mockProviderService struct {
provider *domain.Provider
providers []domain.Provider
err error
}
func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
return m.provider, m.err
}
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockProviderService) Delete(id string) error { return m.err }
type mockModelService struct {
model *domain.Model
models []domain.Model
err error
}
func (m *mockModelService) Create(model *domain.Model) error {
if m.err == nil {
model.ID = "mock-uuid-1234"
}
return m.err
}
func (m *mockModelService) Get(id string) (*domain.Model, error) {
return m.model, m.err
}
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
return m.models, m.err
}
func (m *mockModelService) ListEnabled() ([]domain.Model, error) {
return []domain.Model{}, nil
}
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockModelService) Delete(id string) error { return m.err }
type mockProviderClient struct {
err error
}
func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) {
return nil, m.err
}
func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) {
return nil, m.err
}
// ============ Provider Handler 测试 ============
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) { func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
h := NewProviderHandler(&mockProviderService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"id": "p1"}) body, _ := json.Marshal(map[string]string{"id": "p1"})
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -127,12 +42,15 @@ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
} }
func TestProviderHandler_ListProviders(t *testing.T) { func TestProviderHandler_ListProviders(t *testing.T) {
h := NewProviderHandler(&mockProviderService{ ctrl := gomock.NewController(t)
providers: []domain.Provider{ defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "P1"}, {ID: "p1", Name: "P1"},
{ID: "p2", Name: "P2"}, {ID: "p2", Name: "P2"},
}, }, nil)
}) h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -142,14 +60,17 @@ func TestProviderHandler_ListProviders(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var result []domain.Provider var result []domain.Provider
json.Unmarshal(w.Body.Bytes(), &result) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Len(t, result, 2) assert.Len(t, result, 2)
} }
func TestProviderHandler_GetProvider(t *testing.T) { func TestProviderHandler_GetProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{ ctrl := gomock.NewController(t)
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"}, defer ctrl.Finish()
})
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Get(gomock.Eq("p1"), gomock.Eq(true)).Return(&domain.Provider{ID: "p1", Name: "P1", APIKey: "***"}, nil)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -160,10 +81,12 @@ func TestProviderHandler_GetProvider(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
} }
// ============ Model Handler 测试 ============
func TestModelHandler_CreateModel_MissingFields(t *testing.T) { func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
h := NewModelHandler(&mockModelService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"id": "m1"}) body, _ := json.Marshal(map[string]string{"id": "m1"})
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -176,12 +99,15 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
} }
func TestModelHandler_ListModels(t *testing.T) { func TestModelHandler_ListModels(t *testing.T) {
h := NewModelHandler(&mockModelService{ ctrl := gomock.NewController(t)
models: []domain.Model{ defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().List(gomock.Eq("")).Return([]domain.Model{
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, {ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"}, {ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
}, }, nil)
}) h := NewModelHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -198,9 +124,12 @@ func TestModelHandler_ListModels(t *testing.T) {
} }
func TestModelHandler_GetModel_UnifiedID(t *testing.T) { func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{ ctrl := gomock.NewController(t)
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, defer ctrl.Finish()
})
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -217,7 +146,15 @@ func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
} }
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) { func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
model.ID = "mock-uuid-1234"
return nil
})
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{
"provider_id": "openai", "provider_id": "openai",
@@ -238,9 +175,13 @@ func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
} }
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) { func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{ ctrl := gomock.NewController(t)
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, defer ctrl.Finish()
})
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, nil)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"enabled": false}) body, _ := json.Marshal(map[string]interface{}{"enabled": false})
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -257,14 +198,15 @@ func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID) assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
} }
// ============ Stats Handler 测试 ============
func TestStatsHandler_GetStats(t *testing.T) { func TestStatsHandler_GetStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{ ctrl := gomock.NewController(t)
stats: []domain.UsageStats{ defer ctrl.Finish()
mockSvc := mocks.NewMockStatsService(ctrl)
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10}, {ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
}, }, nil)
}) h := NewStatsHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -275,7 +217,11 @@ func TestStatsHandler_GetStats(t *testing.T) {
} }
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) { func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
h := NewStatsHandler(&mockStatsService{}) ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockStatsService(ctrl)
h := NewStatsHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -286,14 +232,17 @@ func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
} }
func TestStatsHandler_AggregateStats(t *testing.T) { func TestStatsHandler_AggregateStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{ ctrl := gomock.NewController(t)
stats: []domain.UsageStats{ defer ctrl.Finish()
mockSvc := mocks.NewMockStatsService(ctrl)
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
{ProviderID: "p1", RequestCount: 10}, {ProviderID: "p1", RequestCount: 10},
}, }, nil)
aggrResult: []map[string]interface{}{ mockSvc.EXPECT().Aggregate(gomock.Any(), gomock.Eq("provider")).Return([]map[string]interface{}{
{"provider_id": "p1", "request_count": 10}, {"provider_id": "p1", "request_count": 10},
},
}) })
h := NewStatsHandler(mockSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -303,8 +252,6 @@ func TestStatsHandler_AggregateStats(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
} }
// ============ writeError 测试 ============
func TestWriteError(t *testing.T) { func TestWriteError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -333,12 +280,13 @@ func formatMapErrors(errs map[string]string) string {
return "请求验证失败: " + strings.Join(parts, "; ") return "请求验证失败: " + strings.Join(parts, "; ")
} }
// ============ 错误类型判断测试 ============
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) { func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h := NewProviderHandler(&mockProviderService{ ctrl := gomock.NewController(t)
err: appErrors.ErrConflict, defer ctrl.Finish()
})
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrConflict)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{
"id": "p1", "id": "p1",
@@ -354,3 +302,158 @@ func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h.CreateProvider(c) h.CreateProvider(c)
assert.Equal(t, 409, w.Code) assert.Equal(t, 409, w.Code)
} }
func TestModelHandler_CreateModel_ProviderNotFound(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrProviderNotFound)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "nonexistent",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 400, w.Code)
assert.Contains(t, w.Body.String(), "供应商不存在")
}
func TestModelHandler_CreateModel_DuplicateModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrDuplicateModel)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 409, w.Code)
assert.Contains(t, w.Body.String(), "同一供应商下模型名称已存在")
}
func TestModelHandler_CreateModel_InternalError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(errors.New("database error"))
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 500, w.Code)
}
func TestProviderHandler_UpdateProvider_NotFound(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(gorm.ErrRecordNotFound)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 404, w.Code)
}
func TestProviderHandler_UpdateProvider_ImmutableField(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(appErrors.ErrImmutableField)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 400, w.Code)
assert.Contains(t, w.Body.String(), "供应商 ID 不允许修改")
}
func TestProviderHandler_UpdateProvider_InternalError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(errors.New("database error"))
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 500, w.Code)
}
func TestModelHandler_CreateModel_InvalidJSON(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader([]byte("{invalid json")))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 400, w.Code)
}
func TestProviderHandler_CreateProvider_InvalidJSON(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader([]byte("{invalid json")))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateProvider(c)
assert.Equal(t, 400, w.Code)
}

File diff suppressed because it is too large Load Diff

View File

@@ -50,6 +50,7 @@ type Client struct {
} }
// ProviderClient 供应商客户端接口 // ProviderClient 供应商客户端接口
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface { type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)

View File

@@ -2,6 +2,8 @@ package repository
import "nex/backend/internal/domain" import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=model_repo.go -destination=../../tests/mocks/mock_model_repository.go -package=mocks
// ModelRepository 模型数据仓库接口 // ModelRepository 模型数据仓库接口
type ModelRepository interface { type ModelRepository interface {
Create(model *domain.Model) error Create(model *domain.Model) error

View File

@@ -2,6 +2,8 @@ package repository
import "nex/backend/internal/domain" import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=provider_repo.go -destination=../../tests/mocks/mock_provider_repository.go -package=mocks
// ProviderRepository 供应商数据仓库接口 // ProviderRepository 供应商数据仓库接口
type ProviderRepository interface { type ProviderRepository interface {
Create(provider *domain.Provider) error Create(provider *domain.Provider) error
@@ -9,7 +11,4 @@ type ProviderRepository interface {
List() ([]domain.Provider, error) List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error Update(id string, updates map[string]interface{}) error
Delete(id string) error Delete(id string) error
// 统一模型 ID 相关方法
ListEnabledModels() ([]domain.Model, error)
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
} }

View File

@@ -71,25 +71,6 @@ func (r *providerRepository) Delete(id string) error {
return nil return nil
} }
// ListEnabledModels 返回所有启用的模型(关联启用的供应商)
func (r *providerRepository) ListEnabledModels() ([]domain.Model, error) {
var models []domain.Model
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
Where("models.enabled = ? AND providers.enabled = ?", true, true).
Find(&models).Error
return models, err
}
// FindByProviderAndModelName 按 provider_id 和 model_name 查询模型
func (r *providerRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
var model domain.Model
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}
func toDomainProvider(p *config.Provider) domain.Provider { func toDomainProvider(p *config.Provider) domain.Provider {
return domain.Provider{ return domain.Provider{
ID: p.ID, ID: p.ID,

View File

@@ -5,28 +5,16 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/config" testHelpers "nex/backend/tests"
"nex/backend/internal/domain" "nex/backend/internal/domain"
) )
func setupTestDB(t *testing.T) *gorm.DB { func setupTestDB(t *testing.T) *gorm.DB {
t.Helper() t.Helper()
dir := t.TempDir() return testHelpers.SetupTestDB(t)
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
// 关闭数据库连接以便 TempDir 清理
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
return db
} }
// ============ ProviderRepository 测试 ============ // ============ ProviderRepository 测试 ============
@@ -88,7 +76,7 @@ func TestProviderRepository_Update(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
repo := NewProviderRepository(db) repo := NewProviderRepository(db)
repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"}) require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"}))
err := repo.Update("p1", map[string]interface{}{"name": "New"}) err := repo.Update("p1", map[string]interface{}{"name": "New"})
require.NoError(t, err) require.NoError(t, err)
@@ -109,7 +97,7 @@ func TestProviderRepository_Delete(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
repo := NewProviderRepository(db) repo := NewProviderRepository(db)
repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}) require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
err := repo.Delete("p1") err := repo.Delete("p1")
require.NoError(t, err) require.NoError(t, err)
@@ -129,17 +117,21 @@ func TestProviderRepository_Delete_NotFound(t *testing.T) {
func TestModelRepository_Create(t *testing.T) { func TestModelRepository_Create(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db) repo := NewModelRepository(db)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, err) require.NoError(t, err)
} }
func TestModelRepository_GetByID(t *testing.T) { func TestModelRepository_GetByID(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db) repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
result, err := repo.GetByID("m1") result, err := repo.GetByID("m1")
require.NoError(t, err) require.NoError(t, err)
@@ -149,9 +141,11 @@ func TestModelRepository_GetByID(t *testing.T) {
func TestModelRepository_FindByProviderAndModelName(t *testing.T) { func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db) repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
result, err := repo.FindByProviderAndModelName("p1", "gpt-4") result, err := repo.FindByProviderAndModelName("p1", "gpt-4")
require.NoError(t, err) require.NoError(t, err)
@@ -162,9 +156,11 @@ func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) { func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db) repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
// Wrong provider_id // Wrong provider_id
_, err := repo.FindByProviderAndModelName("p2", "gpt-4") _, err := repo.FindByProviderAndModelName("p2", "gpt-4")
@@ -181,11 +177,14 @@ func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
func TestModelRepository_List(t *testing.T) { func TestModelRepository_List(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db) repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p2", Name: "Test2", APIKey: "key", BaseURL: "https://test2.com"}))
repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"}) require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"}))
all, err := repo.List("") all, err := repo.List("")
require.NoError(t, err) require.NoError(t, err)
@@ -246,9 +245,11 @@ func TestModelRepository_ListEnabled(t *testing.T) {
func TestModelRepository_Update(t *testing.T) { func TestModelRepository_Update(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db) repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
err := repo.Update("m1", map[string]interface{}{"enabled": false}) err := repo.Update("m1", map[string]interface{}{"enabled": false})
require.NoError(t, err) require.NoError(t, err)
@@ -259,9 +260,11 @@ func TestModelRepository_Update(t *testing.T) {
func TestModelRepository_Delete(t *testing.T) { func TestModelRepository_Delete(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db) repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
err := repo.Delete("m1") err := repo.Delete("m1")
require.NoError(t, err) require.NoError(t, err)
@@ -293,10 +296,32 @@ func TestStatsRepository_Query(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
repo := NewStatsRepository(db) repo := NewStatsRepository(db)
repo.Record("p1", "gpt-4") require.NoError(t, repo.Record("p1", "gpt-4"))
// 注意:当前 schema 只有 date 字段有唯一约束 // 注意:当前 schema 只有 date 字段有唯一约束
// 所以同一 provider + model 只能有一条记录 // 所以同一 provider + model 只能有一条记录
stats, err := repo.Query("p1", "", nil, nil) stats, err := repo.Query("p1", "", nil, nil)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, stats, 1) assert.Len(t, stats, 1)
} }
func TestModelRepository_List_EmptyResult(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
result, err := repo.List("")
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result)
assert.Len(t, result, 0)
}
func TestProviderRepository_List_EmptyResult(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
result, err := repo.List()
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result)
assert.Len(t, result, 0)
}

View File

@@ -6,6 +6,8 @@ import (
"nex/backend/internal/domain" "nex/backend/internal/domain"
) )
//go:generate go run go.uber.org/mock/mockgen -source=stats_repo.go -destination=../../tests/mocks/mock_stats_repository.go -package=mocks
// StatsRepository 统计数据仓库接口 // StatsRepository 统计数据仓库接口
type StatsRepository interface { type StatsRepository interface {
Record(providerID, modelName string) error Record(providerID, modelName string) error

View File

@@ -2,6 +2,8 @@ package service
import "nex/backend/internal/domain" import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=model_service.go -destination=../../tests/mocks/mock_model_service.go -package=mocks
// ModelService 模型服务接口 // ModelService 模型服务接口
type ModelService interface { type ModelService interface {
Create(model *domain.Model) error Create(model *domain.Model) error

View File

@@ -2,6 +2,8 @@ package service
import "nex/backend/internal/domain" import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=provider_service.go -destination=../../tests/mocks/mock_provider_service.go -package=mocks
// ProviderService 供应商服务接口 // ProviderService 供应商服务接口
type ProviderService interface { type ProviderService interface {
Create(provider *domain.Provider) error Create(provider *domain.Provider) error

View File

@@ -2,6 +2,8 @@ package service
import "nex/backend/internal/domain" import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=routing_service.go -destination=../../tests/mocks/mock_routing_service.go -package=mocks
// RoutingService 路由服务接口 // RoutingService 路由服务接口
type RoutingService interface { type RoutingService interface {
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)

View File

@@ -16,7 +16,7 @@ func TestProviderService_Update(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo) svc := NewProviderService(repo, modelRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}) require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}))
err := svc.Update("p1", map[string]interface{}{"name": "Updated"}) err := svc.Update("p1", map[string]interface{}{"name": "Updated"})
require.NoError(t, err) require.NoError(t, err)
@@ -42,7 +42,7 @@ func TestModelService_Get(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model)) require.NoError(t, svc.Create(model))
@@ -57,7 +57,7 @@ func TestModelService_Update(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model)) require.NoError(t, svc.Create(model))
@@ -75,7 +75,7 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model)) require.NoError(t, svc.Create(model))
@@ -89,7 +89,7 @@ func TestModelService_Delete(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model)) require.NoError(t, svc.Create(model))

View File

@@ -3,14 +3,15 @@ package service
import ( import (
"errors" "errors"
"testing" "testing"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/config" testHelpers "nex/backend/tests"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
@@ -18,18 +19,7 @@ import (
func setupServiceTestDB(t *testing.T) *gorm.DB { func setupServiceTestDB(t *testing.T) *gorm.DB {
t.Helper() t.Helper()
dir := t.TempDir() return testHelpers.SetupTestDB(t)
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
return db
} }
// ============ RoutingService - RouteByModelName 测试 ============ // ============ RoutingService - RouteByModelName 测试 ============
@@ -40,9 +30,8 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo) svc := NewRoutingService(modelRepo, providerRepo)
// 创建供应商和模型 require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}) require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
result, err := svc.RouteByModelName("openai", "gpt-4") result, err := svc.RouteByModelName("openai", "gpt-4")
require.NoError(t, err) require.NoError(t, err)
@@ -66,10 +55,9 @@ func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo) svc := NewRoutingService(modelRepo, providerRepo)
// 创建启用的供应商和禁用的模型 require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}) require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}) require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false}))
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
_, err := svc.RouteByModelName("openai", "gpt-4") _, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled)) assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
@@ -81,10 +69,9 @@ func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo) svc := NewRoutingService(modelRepo, providerRepo)
// 创建启用的供应商和模型,然后禁用供应商 require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}) require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}) require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false}))
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
_, err := svc.RouteByModelName("openai", "gpt-4") _, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled)) assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
@@ -98,7 +85,7 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model) err := svc.Create(model)
@@ -122,7 +109,7 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1) err := svc.Create(model1)
@@ -179,8 +166,8 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1) err := svc.Create(model1)
@@ -216,7 +203,7 @@ func TestModelService_Update_Success(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model) err := svc.Create(model)
@@ -272,3 +259,223 @@ func TestProviderService_Update_Success(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "OpenAI Updated", updated.Name) assert.Equal(t, "OpenAI Updated", updated.Name)
} }
// ============ StatsService - Aggregate ByModel 测试 ============
func TestStatsService_Aggregate_ByModel(t *testing.T) {
tests := []struct {
name string
stats []domain.UsageStats
expected []map[string]interface{}
}{
{
name: "multiple providers with same model name",
stats: []domain.UsageStats{
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20},
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5},
},
expected: []map[string]interface{}{
{"provider_id": "openai", "model_name": "gpt-4", "request_count": 15},
{"provider_id": "azure", "model_name": "gpt-4", "request_count": 20},
},
},
{
name: "empty providerID",
stats: []domain.UsageStats{
{ProviderID: "", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "", ModelName: "gpt-4", RequestCount: 5},
},
expected: []map[string]interface{}{
{"provider_id": "", "model_name": "gpt-4", "request_count": 15},
},
},
{
name: "empty result set",
stats: []domain.UsageStats{},
expected: []map[string]interface{}{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
result := svc.Aggregate(tt.stats, "model")
assert.Len(t, result, len(tt.expected))
for _, exp := range tt.expected {
found := false
for _, r := range result {
if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] {
assert.Equal(t, exp["request_count"], r["request_count"])
found = true
break
}
}
assert.True(t, found, "expected result not found: %v", exp)
}
})
}
}
// ============ StatsService - Aggregate ByDate 测试 ============
func TestStatsService_Aggregate_ByDate(t *testing.T) {
tests := []struct {
name string
stats []domain.UsageStats
expected []map[string]interface{}
}{
{
name: "normal date grouping",
stats: []domain.UsageStats{
{Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10},
{Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5},
{Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20},
},
expected: []map[string]interface{}{
{"date": "2024-01-01", "request_count": 15},
{"date": "2024-01-02", "request_count": 20},
},
},
{
name: "zero-value time",
stats: []domain.UsageStats{
{Date: time.Time{}, RequestCount: 10},
{Date: time.Time{}, RequestCount: 5},
},
expected: []map[string]interface{}{
{"date": "0001-01-01", "request_count": 15},
},
},
{
name: "empty result set",
stats: []domain.UsageStats{},
expected: []map[string]interface{}{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
result := svc.Aggregate(tt.stats, "date")
assert.Len(t, result, len(tt.expected))
for _, exp := range tt.expected {
found := false
for _, r := range result {
if r["date"] == exp["date"] {
assert.Equal(t, exp["request_count"], r["request_count"])
found = true
break
}
}
assert.True(t, found, "expected result not found: %v", exp)
}
})
}
}
// ============ ProviderService - isUniqueConstraintError 测试 ============
func TestProviderService_isUniqueConstraintError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "UNIQUE constraint failed",
err: errors.New("UNIQUE constraint failed"),
expected: true,
},
{
name: "duplicate key value",
err: errors.New("duplicate key value"),
expected: true,
},
{
name: "UNIQUE constraint case insensitive",
err: errors.New("unique constraint violation"),
expected: true,
},
{
name: "other error",
err: errors.New("some other error"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isUniqueConstraintError(tt.err)
assert.Equal(t, tt.expected, result)
})
}
}
// ============ ProviderService - List MaskAPIKey 测试 ============
func TestProviderService_List_MaskAPIKey(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
require.NoError(t, svc.Create(provider1))
require.NoError(t, svc.Create(provider2))
providers, err := svc.List()
require.NoError(t, err)
require.Len(t, providers, 2)
for _, p := range providers {
assert.Contains(t, p.APIKey, "***")
assert.Len(t, p.APIKey, 7)
}
}
func TestModelService_ConcurrentCreate(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
results := make(chan error, 2)
for i := 0; i < 2; i++ {
go func() {
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
results <- svc.Create(model)
}()
}
err1 := <-results
err2 := <-results
successCount := 0
errorCount := 0
for _, err := range []error{err1, err2} {
if err == nil {
successCount++
} else {
errorCount++
}
}
assert.Equal(t, 1, successCount)
assert.Equal(t, 1, errorCount)
}

View File

@@ -6,6 +6,8 @@ import (
"nex/backend/internal/domain" "nex/backend/internal/domain"
) )
//go:generate go run go.uber.org/mock/mockgen -source=stats_service.go -destination=../../tests/mocks/mock_stats_service.go -package=mocks
// StatsService 统计服务接口 // StatsService 统计服务接口
type StatsService interface { type StatsService interface {
Record(providerID, modelName string) error Record(providerID, modelName string) error

View File

@@ -0,0 +1,193 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/config"
)
func TestLoadConfig_DefaultValues(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
assert.Equal(t, 9826, cfg.Server.Port)
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
assert.Equal(t, "info", cfg.Log.Level)
assert.Equal(t, 100, cfg.Log.MaxSize)
assert.Equal(t, 10, cfg.Log.MaxBackups)
assert.Equal(t, 30, cfg.Log.MaxAge)
assert.True(t, cfg.Log.Compress)
}
func TestLoadConfig_EnvOverride(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
t.Setenv("NEX_SERVER_PORT", "9000")
t.Setenv("NEX_LOG_LEVEL", "debug")
t.Setenv("NEX_DATABASE_MAX_IDLE_CONNS", "20")
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
assert.Equal(t, 9000, cfg.Server.Port)
assert.Equal(t, "debug", cfg.Log.Level)
assert.Equal(t, 20, cfg.Database.MaxIdleConns)
}
func TestLoadConfig_YAMLFile(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
yamlContent := `
server:
port: 8080
read_timeout: 60s
write_timeout: 60s
database:
path: /custom/path.db
max_idle_conns: 5
max_open_conns: 50
conn_max_lifetime: 2h
log:
level: warn
path: /custom/log
max_size: 200
max_backups: 5
max_age: 7
compress: false
`
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
assert.Equal(t, 8080, cfg.Server.Port)
assert.Equal(t, 60*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 60*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, "/custom/path.db", cfg.Database.Path)
assert.Equal(t, 5, cfg.Database.MaxIdleConns)
assert.Equal(t, 50, cfg.Database.MaxOpenConns)
assert.Equal(t, 2*time.Hour, cfg.Database.ConnMaxLifetime)
assert.Equal(t, "warn", cfg.Log.Level)
assert.Equal(t, "/custom/log", cfg.Log.Path)
assert.Equal(t, 200, cfg.Log.MaxSize)
assert.Equal(t, 5, cfg.Log.MaxBackups)
assert.Equal(t, 7, cfg.Log.MaxAge)
assert.False(t, cfg.Log.Compress)
}
func TestLoadConfig_PriorityChain(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
yamlContent := `
server:
port: 8080
log:
level: warn
`
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
t.Setenv("NEX_SERVER_PORT", "9000")
originalArgs := os.Args
defer func() { os.Args = originalArgs }()
os.Args = []string{"test", "--server-port", "9999"}
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
assert.Equal(t, 9999, cfg.Server.Port, "CLI should override ENV and YAML")
assert.Equal(t, "warn", cfg.Log.Level, "YAML value should be used when no CLI/ENV override")
}
func TestLoadConfig_AutoCreate(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
_, err := os.Stat(configPath)
assert.True(t, os.IsNotExist(err), "config file should not exist before load")
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
require.NotNil(t, cfg)
assert.Equal(t, 9826, cfg.Server.Port, "should load with default values")
}
func TestSaveAndLoadConfig(t *testing.T) {
tmpDir := t.TempDir()
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
nexDir := filepath.Join(homeDir, ".nex")
configPath := filepath.Join(nexDir, "config.yaml")
originalConfig, err := os.ReadFile(configPath)
if err != nil && !os.IsNotExist(err) {
require.NoError(t, err)
}
defer func() {
if originalConfig != nil {
_ = os.WriteFile(configPath, originalConfig, 0644)
}
}()
cfg := &config.Config{
Server: config.ServerConfig{
Port: 7777,
ReadTimeout: 45 * time.Second,
WriteTimeout: 45 * time.Second,
},
Database: config.DatabaseConfig{
Path: filepath.Join(tmpDir, "test.db"),
MaxIdleConns: 15,
MaxOpenConns: 150,
ConnMaxLifetime: 2 * time.Hour,
},
Log: config.LogConfig{
Level: "debug",
Path: filepath.Join(tmpDir, "log"),
MaxSize: 50,
MaxBackups: 3,
MaxAge: 14,
Compress: false,
},
}
err = config.SaveConfig(cfg)
require.NoError(t, err)
loaded, err := config.LoadConfig()
require.NoError(t, err)
assert.Equal(t, cfg.Server.Port, loaded.Server.Port)
assert.Equal(t, cfg.Server.ReadTimeout, loaded.Server.ReadTimeout)
assert.Equal(t, cfg.Server.WriteTimeout, loaded.Server.WriteTimeout)
assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns)
assert.Equal(t, cfg.Database.MaxOpenConns, loaded.Database.MaxOpenConns)
assert.Equal(t, cfg.Database.ConnMaxLifetime, loaded.Database.ConnMaxLifetime)
assert.Equal(t, cfg.Log.Level, loaded.Log.Level)
assert.Equal(t, cfg.Log.MaxSize, loaded.Log.MaxSize)
assert.Equal(t, cfg.Log.MaxBackups, loaded.Log.MaxBackups)
assert.Equal(t, cfg.Log.MaxAge, loaded.Log.MaxAge)
assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress)
}

View File

@@ -7,49 +7,40 @@ import (
"nex/backend/internal/config" "nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
) )
// SetupTestDB initializes an in-memory SQLite database with auto-migration.
// Uses :memory: mode with MaxOpenConns(1) to ensure all operations share the
// same connection, avoiding "database is closed" errors from connection pool.
// Enables foreign key constraints for SQLite.
func SetupTestDB(t *testing.T) *gorm.DB { func SetupTestDB(t *testing.T) *gorm.DB {
t.Helper() t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:?_foreign_keys=on"), &gorm.Config{}) db, err := gorm.Open(sqlite.Open(":memory:?_foreign_keys=on"), &gorm.Config{})
assert.NoError(t, err, "failed to open test database") require.NoError(t, err, "failed to open test database")
// 限制为单连接,确保 :memory: 数据库不被连接池丢弃
sqlDB, err := db.DB() sqlDB, err := db.DB()
assert.NoError(t, err, "failed to get underlying sql.DB") require.NoError(t, err, "failed to get underlying sql.DB")
sqlDB.SetMaxOpenConns(1) sqlDB.SetMaxOpenConns(1)
sqlDB.SetConnMaxLifetime(0) sqlDB.SetConnMaxLifetime(0)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
assert.NoError(t, err, "failed to auto-migrate test database") require.NoError(t, err, "failed to auto-migrate test database")
return db return db
} }
// CleanupTestDB closes the database after a brief delay to allow async
// goroutines (e.g. stats recording) to finish.
func CleanupTestDB(t *testing.T, db *gorm.DB) { func CleanupTestDB(t *testing.T, db *gorm.DB) {
t.Helper() t.Helper()
// 等待异步 goroutine如 statsService.Record完成
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
sqlDB, err := db.DB() sqlDB, err := db.DB()
assert.NoError(t, err, "failed to get underlying sql.DB") require.NoError(t, err, "failed to get underlying sql.DB")
err = sqlDB.Close() err = sqlDB.Close()
assert.NoError(t, err, "failed to close test database") require.NoError(t, err, "failed to close test database")
} }
// CreateTestProvider creates a test provider and returns it.
func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider { func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider {
t.Helper() t.Helper()
@@ -62,13 +53,11 @@ func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider {
} }
err := db.Create(&provider).Error err := db.Create(&provider).Error
assert.NoError(t, err, "failed to create test provider") require.NoError(t, err, "failed to create test provider")
return provider return provider
} }
// CreateTestModel creates a test model and returns it.
// Does NOT assert on error - returns the model and error for caller to verify.
func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) (config.Model, error) { func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) (config.Model, error) {
t.Helper() t.Helper()

View File

@@ -0,0 +1,143 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: model_repo.go
//
// Generated by this command:
//
// mockgen -source=model_repo.go -destination=../../tests/mocks/mock_model_repository.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockModelRepository is a mock of ModelRepository interface.
type MockModelRepository struct {
ctrl *gomock.Controller
recorder *MockModelRepositoryMockRecorder
isgomock struct{}
}
// MockModelRepositoryMockRecorder is the mock recorder for MockModelRepository.
type MockModelRepositoryMockRecorder struct {
mock *MockModelRepository
}
// NewMockModelRepository creates a new mock instance.
func NewMockModelRepository(ctrl *gomock.Controller) *MockModelRepository {
mock := &MockModelRepository{ctrl: ctrl}
mock.recorder = &MockModelRepositoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockModelRepository) EXPECT() *MockModelRepositoryMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockModelRepository) Create(model *domain.Model) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", model)
ret0, _ := ret[0].(error)
return ret0
}
// Create indicates an expected call of Create.
func (mr *MockModelRepositoryMockRecorder) Create(model any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockModelRepository)(nil).Create), model)
}
// Delete mocks base method.
func (m *MockModelRepository) Delete(id string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockModelRepositoryMockRecorder) Delete(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockModelRepository)(nil).Delete), id)
}
// FindByProviderAndModelName mocks base method.
func (m *MockModelRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindByProviderAndModelName", providerID, modelName)
ret0, _ := ret[0].(*domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindByProviderAndModelName indicates an expected call of FindByProviderAndModelName.
func (mr *MockModelRepositoryMockRecorder) FindByProviderAndModelName(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByProviderAndModelName", reflect.TypeOf((*MockModelRepository)(nil).FindByProviderAndModelName), providerID, modelName)
}
// GetByID mocks base method.
func (m *MockModelRepository) GetByID(id string) (*domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByID", id)
ret0, _ := ret[0].(*domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetByID indicates an expected call of GetByID.
func (mr *MockModelRepositoryMockRecorder) GetByID(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockModelRepository)(nil).GetByID), id)
}
// List mocks base method.
func (m *MockModelRepository) List(providerID string) ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List", providerID)
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockModelRepositoryMockRecorder) List(providerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockModelRepository)(nil).List), providerID)
}
// ListEnabled mocks base method.
func (m *MockModelRepository) ListEnabled() ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListEnabled")
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListEnabled indicates an expected call of ListEnabled.
func (mr *MockModelRepositoryMockRecorder) ListEnabled() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEnabled", reflect.TypeOf((*MockModelRepository)(nil).ListEnabled))
}
// Update mocks base method.
func (m *MockModelRepository) Update(id string, updates map[string]any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", id, updates)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockModelRepositoryMockRecorder) Update(id, updates any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockModelRepository)(nil).Update), id, updates)
}

View File

@@ -0,0 +1,128 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: model_service.go
//
// Generated by this command:
//
// mockgen -source=model_service.go -destination=../../tests/mocks/mock_model_service.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockModelService is a mock of ModelService interface.
type MockModelService struct {
ctrl *gomock.Controller
recorder *MockModelServiceMockRecorder
isgomock struct{}
}
// MockModelServiceMockRecorder is the mock recorder for MockModelService.
type MockModelServiceMockRecorder struct {
mock *MockModelService
}
// NewMockModelService creates a new mock instance.
func NewMockModelService(ctrl *gomock.Controller) *MockModelService {
mock := &MockModelService{ctrl: ctrl}
mock.recorder = &MockModelServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockModelService) EXPECT() *MockModelServiceMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockModelService) Create(model *domain.Model) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", model)
ret0, _ := ret[0].(error)
return ret0
}
// Create indicates an expected call of Create.
func (mr *MockModelServiceMockRecorder) Create(model any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockModelService)(nil).Create), model)
}
// Delete mocks base method.
func (m *MockModelService) Delete(id string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockModelServiceMockRecorder) Delete(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockModelService)(nil).Delete), id)
}
// Get mocks base method.
func (m *MockModelService) Get(id string) (*domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", id)
ret0, _ := ret[0].(*domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockModelServiceMockRecorder) Get(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockModelService)(nil).Get), id)
}
// List mocks base method.
func (m *MockModelService) List(providerID string) ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List", providerID)
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockModelServiceMockRecorder) List(providerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockModelService)(nil).List), providerID)
}
// ListEnabled mocks base method.
func (m *MockModelService) ListEnabled() ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListEnabled")
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListEnabled indicates an expected call of ListEnabled.
func (mr *MockModelServiceMockRecorder) ListEnabled() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEnabled", reflect.TypeOf((*MockModelService)(nil).ListEnabled))
}
// Update mocks base method.
func (m *MockModelService) Update(id string, updates map[string]any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", id, updates)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockModelServiceMockRecorder) Update(id, updates any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockModelService)(nil).Update), id, updates)
}

View File

@@ -0,0 +1,73 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: client.go
//
// Generated by this command:
//
// mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
conversion "nex/backend/internal/conversion"
provider "nex/backend/internal/provider"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockProviderClient is a mock of ProviderClient interface.
type MockProviderClient struct {
ctrl *gomock.Controller
recorder *MockProviderClientMockRecorder
isgomock struct{}
}
// MockProviderClientMockRecorder is the mock recorder for MockProviderClient.
type MockProviderClientMockRecorder struct {
mock *MockProviderClient
}
// NewMockProviderClient creates a new mock instance.
func NewMockProviderClient(ctrl *gomock.Controller) *MockProviderClient {
mock := &MockProviderClient{ctrl: ctrl}
mock.recorder = &MockProviderClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProviderClient) EXPECT() *MockProviderClientMockRecorder {
return m.recorder
}
// Send mocks base method.
func (m *MockProviderClient) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Send", ctx, spec)
ret0, _ := ret[0].(*conversion.HTTPResponseSpec)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Send indicates an expected call of Send.
func (mr *MockProviderClientMockRecorder) Send(ctx, spec any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockProviderClient)(nil).Send), ctx, spec)
}
// SendStream mocks base method.
func (m *MockProviderClient) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendStream", ctx, spec)
ret0, _ := ret[0].(<-chan provider.StreamEvent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SendStream indicates an expected call of SendStream.
func (mr *MockProviderClientMockRecorder) SendStream(ctx, spec any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendStream", reflect.TypeOf((*MockProviderClient)(nil).SendStream), ctx, spec)
}

View File

@@ -0,0 +1,113 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: provider_repo.go
//
// Generated by this command:
//
// mockgen -source=provider_repo.go -destination=../../tests/mocks/mock_provider_repository.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockProviderRepository is a mock of ProviderRepository interface.
type MockProviderRepository struct {
ctrl *gomock.Controller
recorder *MockProviderRepositoryMockRecorder
isgomock struct{}
}
// MockProviderRepositoryMockRecorder is the mock recorder for MockProviderRepository.
type MockProviderRepositoryMockRecorder struct {
mock *MockProviderRepository
}
// NewMockProviderRepository creates a new mock instance.
func NewMockProviderRepository(ctrl *gomock.Controller) *MockProviderRepository {
mock := &MockProviderRepository{ctrl: ctrl}
mock.recorder = &MockProviderRepositoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProviderRepository) EXPECT() *MockProviderRepositoryMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockProviderRepository) Create(provider *domain.Provider) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", provider)
ret0, _ := ret[0].(error)
return ret0
}
// Create indicates an expected call of Create.
func (mr *MockProviderRepositoryMockRecorder) Create(provider any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockProviderRepository)(nil).Create), provider)
}
// Delete mocks base method.
func (m *MockProviderRepository) Delete(id string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockProviderRepositoryMockRecorder) Delete(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockProviderRepository)(nil).Delete), id)
}
// GetByID mocks base method.
func (m *MockProviderRepository) GetByID(id string) (*domain.Provider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByID", id)
ret0, _ := ret[0].(*domain.Provider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetByID indicates an expected call of GetByID.
func (mr *MockProviderRepositoryMockRecorder) GetByID(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockProviderRepository)(nil).GetByID), id)
}
// List mocks base method.
func (m *MockProviderRepository) List() ([]domain.Provider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List")
ret0, _ := ret[0].([]domain.Provider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockProviderRepositoryMockRecorder) List() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockProviderRepository)(nil).List))
}
// Update mocks base method.
func (m *MockProviderRepository) Update(id string, updates map[string]any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", id, updates)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockProviderRepositoryMockRecorder) Update(id, updates any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockProviderRepository)(nil).Update), id, updates)
}

View File

@@ -0,0 +1,143 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: provider_service.go
//
// Generated by this command:
//
// mockgen -source=provider_service.go -destination=../../tests/mocks/mock_provider_service.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockProviderService is a mock of ProviderService interface.
type MockProviderService struct {
ctrl *gomock.Controller
recorder *MockProviderServiceMockRecorder
isgomock struct{}
}
// MockProviderServiceMockRecorder is the mock recorder for MockProviderService.
type MockProviderServiceMockRecorder struct {
mock *MockProviderService
}
// NewMockProviderService creates a new mock instance.
func NewMockProviderService(ctrl *gomock.Controller) *MockProviderService {
mock := &MockProviderService{ctrl: ctrl}
mock.recorder = &MockProviderServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProviderService) EXPECT() *MockProviderServiceMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockProviderService) Create(provider *domain.Provider) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", provider)
ret0, _ := ret[0].(error)
return ret0
}
// Create indicates an expected call of Create.
func (mr *MockProviderServiceMockRecorder) Create(provider any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockProviderService)(nil).Create), provider)
}
// Delete mocks base method.
func (m *MockProviderService) Delete(id string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockProviderServiceMockRecorder) Delete(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockProviderService)(nil).Delete), id)
}
// Get mocks base method.
func (m *MockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", id, maskKey)
ret0, _ := ret[0].(*domain.Provider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockProviderServiceMockRecorder) Get(id, maskKey any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockProviderService)(nil).Get), id, maskKey)
}
// GetModelByProviderAndName mocks base method.
func (m *MockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetModelByProviderAndName", providerID, modelName)
ret0, _ := ret[0].(*domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetModelByProviderAndName indicates an expected call of GetModelByProviderAndName.
func (mr *MockProviderServiceMockRecorder) GetModelByProviderAndName(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetModelByProviderAndName", reflect.TypeOf((*MockProviderService)(nil).GetModelByProviderAndName), providerID, modelName)
}
// List mocks base method.
func (m *MockProviderService) List() ([]domain.Provider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List")
ret0, _ := ret[0].([]domain.Provider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockProviderServiceMockRecorder) List() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockProviderService)(nil).List))
}
// ListEnabledModels mocks base method.
func (m *MockProviderService) ListEnabledModels() ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListEnabledModels")
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListEnabledModels indicates an expected call of ListEnabledModels.
func (mr *MockProviderServiceMockRecorder) ListEnabledModels() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEnabledModels", reflect.TypeOf((*MockProviderService)(nil).ListEnabledModels))
}
// Update mocks base method.
func (m *MockProviderService) Update(id string, updates map[string]any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", id, updates)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockProviderServiceMockRecorder) Update(id, updates any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockProviderService)(nil).Update), id, updates)
}

View File

@@ -0,0 +1,56 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: routing_service.go
//
// Generated by this command:
//
// mockgen -source=routing_service.go -destination=../../tests/mocks/mock_routing_service.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockRoutingService is a mock of RoutingService interface.
type MockRoutingService struct {
ctrl *gomock.Controller
recorder *MockRoutingServiceMockRecorder
isgomock struct{}
}
// MockRoutingServiceMockRecorder is the mock recorder for MockRoutingService.
type MockRoutingServiceMockRecorder struct {
mock *MockRoutingService
}
// NewMockRoutingService creates a new mock instance.
func NewMockRoutingService(ctrl *gomock.Controller) *MockRoutingService {
mock := &MockRoutingService{ctrl: ctrl}
mock.recorder = &MockRoutingServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRoutingService) EXPECT() *MockRoutingServiceMockRecorder {
return m.recorder
}
// RouteByModelName mocks base method.
func (m *MockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteByModelName", providerID, modelName)
ret0, _ := ret[0].(*domain.RouteResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RouteByModelName indicates an expected call of RouteByModelName.
func (mr *MockRoutingServiceMockRecorder) RouteByModelName(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteByModelName", reflect.TypeOf((*MockRoutingService)(nil).RouteByModelName), providerID, modelName)
}

View File

@@ -0,0 +1,71 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: stats_repo.go
//
// Generated by this command:
//
// mockgen -source=stats_repo.go -destination=../../tests/mocks/mock_stats_repository.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
time "time"
gomock "go.uber.org/mock/gomock"
)
// MockStatsRepository is a mock of StatsRepository interface.
type MockStatsRepository struct {
ctrl *gomock.Controller
recorder *MockStatsRepositoryMockRecorder
isgomock struct{}
}
// MockStatsRepositoryMockRecorder is the mock recorder for MockStatsRepository.
type MockStatsRepositoryMockRecorder struct {
mock *MockStatsRepository
}
// NewMockStatsRepository creates a new mock instance.
func NewMockStatsRepository(ctrl *gomock.Controller) *MockStatsRepository {
mock := &MockStatsRepository{ctrl: ctrl}
mock.recorder = &MockStatsRepositoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStatsRepository) EXPECT() *MockStatsRepositoryMockRecorder {
return m.recorder
}
// Query mocks base method.
func (m *MockStatsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Query", providerID, modelName, startDate, endDate)
ret0, _ := ret[0].([]domain.UsageStats)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockStatsRepositoryMockRecorder) Query(providerID, modelName, startDate, endDate any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStatsRepository)(nil).Query), providerID, modelName, startDate, endDate)
}
// Record mocks base method.
func (m *MockStatsRepository) Record(providerID, modelName string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Record", providerID, modelName)
ret0, _ := ret[0].(error)
return ret0
}
// Record indicates an expected call of Record.
func (mr *MockStatsRepositoryMockRecorder) Record(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Record", reflect.TypeOf((*MockStatsRepository)(nil).Record), providerID, modelName)
}

View File

@@ -0,0 +1,85 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: stats_service.go
//
// Generated by this command:
//
// mockgen -source=stats_service.go -destination=../../tests/mocks/mock_stats_service.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
time "time"
gomock "go.uber.org/mock/gomock"
)
// MockStatsService is a mock of StatsService interface.
type MockStatsService struct {
ctrl *gomock.Controller
recorder *MockStatsServiceMockRecorder
isgomock struct{}
}
// MockStatsServiceMockRecorder is the mock recorder for MockStatsService.
type MockStatsServiceMockRecorder struct {
mock *MockStatsService
}
// NewMockStatsService creates a new mock instance.
func NewMockStatsService(ctrl *gomock.Controller) *MockStatsService {
mock := &MockStatsService{ctrl: ctrl}
mock.recorder = &MockStatsServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStatsService) EXPECT() *MockStatsServiceMockRecorder {
return m.recorder
}
// Aggregate mocks base method.
func (m *MockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]any {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Aggregate", stats, groupBy)
ret0, _ := ret[0].([]map[string]any)
return ret0
}
// Aggregate indicates an expected call of Aggregate.
func (mr *MockStatsServiceMockRecorder) Aggregate(stats, groupBy any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Aggregate", reflect.TypeOf((*MockStatsService)(nil).Aggregate), stats, groupBy)
}
// Get mocks base method.
func (m *MockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", providerID, modelName, startDate, endDate)
ret0, _ := ret[0].([]domain.UsageStats)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockStatsServiceMockRecorder) Get(providerID, modelName, startDate, endDate any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStatsService)(nil).Get), providerID, modelName, startDate, endDate)
}
// Record mocks base method.
func (m *MockStatsService) Record(providerID, modelName string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Record", providerID, modelName)
ret0, _ := ret[0].(error)
return ret0
}
// Record indicates an expected call of Record.
func (mr *MockStatsServiceMockRecorder) Record(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Record", reflect.TypeOf((*MockStatsService)(nil).Record), providerID, modelName)
}

View File

@@ -31,6 +31,57 @@
- **THEN** SHALL 测试请求转换、响应转换、流式转换 - **THEN** SHALL 测试请求转换、响应转换、流式转换
- **THEN** SHALL 验证转换的准确性和完整性 - **THEN** SHALL 验证转换的准确性和完整性
#### Scenario: config 加载管道集成测试
- **WHEN** 运行 config 加载管道的集成测试
- **THEN** SHALL 验证 LoadConfigFromPath 正确加载默认值
- **THEN** SHALL 验证环境变量(`NEX_` 前缀)覆盖默认值
- **THEN** SHALL 验证 YAML 配置文件正确读取
- **THEN** SHALL 验证优先级链CLI 参数 > 环境变量 > YAML 文件 > 默认值
- **THEN** SHALL 验证首次启动自动创建配置文件
- **THEN** SHALL 验证 SaveConfig 后重新 LoadConfig 数据一致
#### Scenario: handler 错误分支测试
- **WHEN** 运行 handler 层的单元测试
- **THEN** SHALL 覆盖 ModelHandler.CreateModel 的所有错误分支ErrProviderNotFound(400)、ErrDuplicateModel(409)、通用错误(500)
- **THEN** SHALL 覆盖 ProviderHandler.UpdateProvider 的所有错误分支ErrRecordNotFound(404)、ErrImmutableField(400)、通用错误(500)
- **THEN** SHALL 验证每个错误分支返回正确的 HTTP 状态码和错误响应格式
#### Scenario: service 聚合逻辑测试
- **WHEN** 运行 service 层的单元测试
- **THEN** SHALL 覆盖 StatsService.Aggregate 的所有分组模式byProvider、byModel、byDate
- **THEN** SHALL 验证 aggregateByModel 正确拼接和拆分 providerID/modelName key
- **THEN** SHALL 验证 aggregateByDate 正确格式化日期并聚合
- **THEN** SHALL 覆盖空结果集、同名模型不同 provider 等边界场景
#### Scenario: provider service 工具方法测试
- **WHEN** 运行 provider service 的单元测试
- **THEN** SHALL 验证 isUniqueConstraintError 正确识别 SQLite 唯一约束冲突错误消息
- **THEN** SHALL 验证 List 方法对每个 provider 调用 MaskAPIKey
#### Scenario: engine 降级路径测试
- **WHEN** 运行 conversion engine 的单元测试
- **THEN** SHALL 验证 convertEmbeddingBody 在 decode 失败时返回原始 bodypassthrough
- **THEN** SHALL 验证 convertRerankBody 在 decode 失败时返回原始 bodypassthrough
- **THEN** SHALL 验证降级过程不 panic、不返回空 body
#### Scenario: openai decoder 边界场景测试
- **WHEN** 运行 openai decoder 的单元测试
- **THEN** SHALL 覆盖 assistant message content 为 JSON 数组格式的解析text/refusal 类型)
- **THEN** SHALL 验证 decodeContentParts 正确提取文本内容和拒绝消息
#### Scenario: 业务逻辑负面测试
- **WHEN** 运行业务逻辑负面测试
- **THEN** SHALL 覆盖 JSON 格式错误请求体的处理
- **THEN** SHALL 覆盖并发创建相同 provider + model 的重复检测
- **THEN** SHALL 覆盖空结果集查询的正确处理
### Requirement: 建立集成测试体系 ### Requirement: 建立集成测试体系
系统 SHALL 建立集成测试体系,覆盖 API 端到端流程。 系统 SHALL 建立集成测试体系,覆盖 API 端到端流程。
@@ -63,26 +114,28 @@
#### Scenario: 测试数据库初始化 #### Scenario: 测试数据库初始化
- **WHEN** 编写需要数据库的测试 - **WHEN** 编写需要数据库的测试
- **THEN** SHALL 提供测试数据库初始化函数 - **THEN** SHALL 提供统一的测试数据库初始化函数 `SetupTestDB`
- **THEN** SHALL 使用临时数据库文件 - **THEN** SHALL 统一使用 SQLite `:memory:` + `MaxOpenConns(1)` 策略
- **THEN** SHALL 在测试结束后自动清理 - **THEN** SHALL 在测试结束后自动清理
- **THEN** 所有测试包 SHALL 通过 `tests.SetupTestDB()` 获取测试数据库,不允许各自独立实现
#### Scenario: Mock 工具 #### Scenario: Mock 工具
- **WHEN** 编写需要 Mock 的测试 - **WHEN** 编写需要 Mock 的测试
- **THEN** SHALL 提供 Mock 接口实现 - **THEN** SHALL 使用 mockgen 自动生成 mock 实现
- **THEN** SHALL 支持常见 Mock 场景 - **THEN** SHALL 在接口定义文件中使用 `//go:generate` 注解标注生成命令
- **THEN** SHALL 易于使用和扩展 - **THEN** 生成的 mock SHALL 放置在 `tests/mocks/` 目录下
- **THEN** SHALL 覆盖 service 和 repository 接口的 mock 生成
### Requirement: 达到测试覆盖率目标 ### Requirement: 达到测试覆盖率目标
系统 SHALL 达到 > 80% 的测试覆盖率。 系统 SHALL 达到 > 85% 的测试覆盖率。
#### Scenario: 总体覆盖率 #### Scenario: 总体覆盖率
- **WHEN** 运行所有测试并生成覆盖率报告 - **WHEN** 运行所有测试并生成覆盖率报告
- **THEN** 总体覆盖率 SHALL 大于 80% - **THEN** 总体覆盖率 SHALL 大于 85%
- **THEN** 核心包覆盖率 SHALL 大于 85% - **THEN** 核心包config、service、handler、conversion、repository覆盖率 SHALL 大于 85%
#### Scenario: 覆盖率报告生成 #### Scenario: 覆盖率报告生成
@@ -102,6 +155,14 @@
- **THEN** SHALL 显示测试结果 - **THEN** SHALL 显示测试结果
- **THEN** SHALL 在测试失败时返回非零退出码 - **THEN** SHALL 在测试失败时返回非零退出码
#### Scenario: 分类测试命令
- **WHEN** 执行 `make test-unit` 命令
- **THEN** SHALL 仅运行 `./internal/...``./pkg/...` 下的单元测试
- **WHEN** 执行 `make test-integration` 命令
- **THEN** SHALL 仅运行 `./tests/...` 下的集成测试
#### Scenario: 覆盖率检查命令 #### Scenario: 覆盖率检查命令
- **WHEN** 执行 `make test-coverage` 命令 - **WHEN** 执行 `make test-coverage` 命令
@@ -170,3 +231,33 @@
- **WHEN** 在 frontend/ 目录执行 E2E 测试命令 - **WHEN** 在 frontend/ 目录执行 E2E 测试命令
- **THEN** SHALL 启动 Playwright 运行 E2E 测试 - **THEN** SHALL 启动 Playwright 运行 E2E 测试
- **THEN** SHALL 在测试失败时返回非零退出码 - **THEN** SHALL 在测试失败时返回非零退出码
### Requirement: 清理 ProviderRepository 死代码
系统 SHALL 移除 ProviderRepository 中未被调用的重复方法。
#### Scenario: 移除死代码方法
- **WHEN** 审查 ProviderRepository 接口
- **THEN** SHALL 移除 `ListEnabledModels()` 方法声明和实现
- **THEN** SHALL 移除 `FindByProviderAndModelName()` 方法声明和实现
- **THEN** SHALL 确保所有现有调用者通过 ModelRepository 访问等效功能
- **THEN** SHALL 不影响任何运行时行为
### Requirement: 使用 mockgen 生成 mock
系统 SHALL 使用 mockgen 为接口自动生成 mock 实现,替代手写 mock。
#### Scenario: mock 生成配置
- **WHEN** 在接口定义文件中添加 `//go:generate mockgen` 注解
- **THEN** SHALL 为 `ProviderService``ModelService``RoutingService``StatsService` 接口生成 mock
- **THEN** SHALL 为 `ModelRepository``ProviderRepository``StatsRepository` 接口生成 mock
- **THEN** SHALL 为 `ProviderClient` 接口生成 mock
- **THEN** 生成的 mock SHALL 输出到 `tests/mocks/` 目录
#### Scenario: 替换手写 mock
- **WHEN** mockgen 生成的 mock 就绪
- **THEN** handler 测试中的手写 mock SHALL 被替换为生成的 mock
- **THEN** 所有测试 SHALL 继续通过,行为不变