1
0

Compare commits

...

79 Commits

Author SHA1 Message Date
b00fa4dcee chore: 调整开源许可证为 Apache 2.0
新增根目录 Apache 2.0 许可证文件,并同步更新仓库文档与前端包元数据声明。
2026-04-27 11:24:33 +08:00
92525b39c3 chore: 将 assets 图标文件迁移到 Git LFS 2026-04-27 10:34:38 +08:00
38a2555c7b fix: Anthropic 流式编码器补全 message_start/message_delta 必填字段
跨协议流式转换时,Anthropic 客户端 Zod 校验因 SSE 事件缺少必填字段报错。
由 Anthropic encoder 层(而非 OpenAI decoder 层)负责补全协议默认值,保持权责分离。

- encodeMessageStart 补全 type/content/stop_reason/stop_sequence,usage nil 时输出零值
- encodeMessageDelta usage nil 时输出零值
- 更新相关测试覆盖新增行为
2026-04-26 23:27:34 +08:00
9622d44aac fix: 完善转换代理行为 2026-04-26 21:48:17 +08:00
155244433f fix: 完善桌面应用图标打包
统一 macOS 图标命名为 icon.icns。\n补充 Linux hicolor 图标资源。\n修复 Windows make 构建兼容性并为 exe 嵌入图标资源。\n清理旧版图标说明与不再使用的 SVG 源文件。
2026-04-26 12:24:03 +08:00
2c043c6cf7 fix: 修正 conversion 代理路径和错误边界 2026-04-25 23:12:54 +08:00
f5c82b6980 chore: 合并 dev-test-config 到 master 2026-04-24 23:23:12 +08:00
9105a36097 feat: 将"关于"从系统托盘原生对话框迁移到前端页面
移除系统托盘右键菜单中的"关于"选项及各平台原生对话框实现,
在前端新增 /about 路由和关于页面展示品牌信息,侧边栏增加关于导航入口
2026-04-24 23:17:22 +08:00
f1ee646ca4 fix: 修复 TestSaveAndLoadConfig 测试隔离问题,使用临时目录替代真实用户配置 2026-04-24 22:59:26 +08:00
b9b487c591 chore: 统一项目编辑器配置,移至仓库根目录 2026-04-24 22:28:01 +08:00
4c62c071fb fix: 修复 macOS 桌面应用打包与元数据
将 macOS 桌面应用改为通用二进制并动态写入最低系统版本,避免 Intel Mac 无法启动。统一桌面应用名称与托盘展示,并补充测试确保相关行为稳定。
2026-04-24 19:35:51 +08:00
b2e9dd8b7f refactor: 合并 macOS 桌面打包流程
将 macOS .app 打包直接并入 desktop-build-mac,减少重复的桌面构建入口。\n\n同时移除未实现或已废弃的 desktop-package-* 命令和独立打包脚本,降低维护成本。
2026-04-24 19:02:21 +08:00
d143c5f3df fix: 补齐前端生成物忽略并消除构建告警
统一 Git、ESLint、Prettier 对测试和构建生成物的忽略规则,避免本地产物导致 frontend-build 失败。

补齐表单 effect 依赖,移除无关告警,让前端构建链路恢复稳定。
2026-04-24 18:53:53 +08:00
4eebdfb8db chore: 合并 dev-code-format-frontend 到 master 2026-04-24 18:21:27 +08:00
b517946585 chore: 合并 dev-code-backend-format 到 master 2026-04-24 18:20:12 +08:00
4ddae6be74 refactor: 优化提示词文档,增强智能合并与规范审查能力 2026-04-24 18:12:23 +08:00
195762ff97 fix: 修复后端配置加载测试失败
- 修复 viper SafeWriteConfig 与 SetConfigFile 不兼容问题
  - 将 SafeWriteConfig() 替换为 SafeWriteConfigAs(configPath)
  - 绕过 viper 的 configPaths 检查
- 调整 Makefile 测试命令分类
  - backend-test: 仅运行后端核心测试
  - backend-test-all: 运行全部后端测试(含 desktop)
  - desktop-test: 单独运行桌面应用测试
- 同步 config-management 和 test-coverage 规范
2026-04-24 14:06:03 +08:00
bcf5ca89e5 refactor: Makefile 前端命令自动安装依赖 2026-04-24 13:50:51 +08:00
365943e4c4 feat: 前端集成 Prettier 代码格式化 2026-04-24 13:40:53 +08:00
4c6b49099d feat: 配置 golangci-lint 静态分析并修复存量违规
- 新增 backend/.golangci.yml 配置 12 个 linter(forbidigo、errorlint、errcheck、staticcheck、revive、gocritic、gosec、bodyclose、noctx、nilerr、goimports、gocyclo)
- 新增 lefthook.yml 配置 pre-commit hook 自动运行 lint
- 修复存量代码违规:errors.Is/As 替换、zap.Error 替换、import 排序、errcheck 修复
- 更新 README 补充编码规范说明
- 归档 backend-code-lint 变更
2026-04-24 13:01:48 +08:00
4c78ab6cc8 chore: 新增 backend-code-lint 变更计划,更新 .gitignore 2026-04-24 00:19:56 +08:00
52007c9461 feat: 前端 ESLint 规则增强,自动检测 LLM 编码违规
- 启用 TanStack Query flat/recommended(7 条规则)
- 新增 no-console(允许 warn/error)、consistent-type-imports(inline 风格)、no-non-null-assertion 规则
- 新增自定义规则 no-hardcoded-color-in-style,检测 JSX style 中硬编码颜色值
- 将 ESLint 检查集成到 build 命令(tsc -b && eslint . && vite build)
- 修复现有代码中的 lint 违规(import 顺序、type import 风格、unused vars)
- 使用 @typescript-eslint/rule-tester 编写自定义规则集成测试
2026-04-23 22:47:32 +08:00
086dd1fed7 refactor: 重构智能合并提示词,优化交互体验并增加语义审查能力
- 三阶段模型:计划审批(1次) → 自动合并(仅异常中断) → 汇总收尾(1次),减少冗余介入
- 新增语义审查:合并后自动分析代码冗余/过时模式/未集成基础设施/风格不一致
- 完善安全锚点体系:全局锚点+分支级锚点,分级回退机制
- 修复多个逻辑缺陷:stash含未跟踪文件、分支匹配过宽、远端分支拉取重名、语义修复交叉污染等
- 明确提问工具选项和交互方式,消除歧义引导
2026-04-23 22:35:24 +08:00
1d7e839b49 Merge branch 'dev-frontend-encrypt' into master 2026-04-23 18:42:42 +08:00
fa7babf13b chore: 归档 fix-windows-desktop-packaging 变更 2026-04-23 18:40:23 +08:00
280099b89c refactor: 后端日志系统重构
- 新增模块化日志器(pkg/logger/module.go)
- 新增 GORM 日志适配器
- 统一日志入口,移除所有 zap.L() 全局 logger 调用
- 字段标准化
- 启动阶段使用结构化日志
- 更新所有相关测试
2026-04-23 18:37:51 +08:00
0a92a25451 feat: 前端生产构建添加代码混淆
- 集成 vite-plugin-javascript-obfuscator 插件
- 配置中等偏高强度混淆策略(变量名、字符串、对象键、数字)
- 仅生产构建时启用,不影响开发体验
- 仅混淆业务代码,排除第三方库
- 不生成 Source Map
- 新增 frontend-obfuscation 规范
2026-04-23 18:23:07 +08:00
8c075194e5 fix: 修复合并后代码质量问题
- 修正 Makefile 迁移目录路径(sqlite3 → sqlite)
- 统一 database.go 日志风格(log.Printf → zapLogger)
- 修复 config.go validator 标签大小写
- 修复 database_test.go 测试使用 nil logger
- 移除未使用的 log 导入
2026-04-23 16:58:01 +08:00
53e477d383 Merge branch 'dev-mysql-support' into master
- 新增 MySQL 数据库驱动支持,支持跨设备数据同步
- 新增 MySQL 专项测试能力(并发、约束、迁移)
- 重构迁移目录结构:migrations/sqlite 和 migrations/mysql
- 修复 statsRepo 并发竞态条件,使用 upsert 保证原子性
- Makefile 合并:保留完整命令体系 + 新增 MySQL 测试命令
2026-04-23 16:31:29 +08:00
1522c87c74 fix: 修复 statsRepo 并发竞态条件,使用 upsert 保证原子性
- 使用 GORM clause.OnConflict 替代事务包装
- Record 和 BatchUpdate 方法改用 upsert 模式
- 修复 UsageStats 的 GORM struct tag,确保 AutoMigrate 创建正确的 UNIQUE 约束
- 更新 usage-statistics spec 以反映 upsert 操作

MySQL 并发测试验证:10 并发调用 → request_count = 10
2026-04-23 15:54:56 +08:00
e0d05c9869 refactor: Makefile 命名规范化,新增顶层便捷命令
统一命名规范为 <namespace>-<action>[-<variant>] 格式:
- 重命名 desktop-mac/win/linux → desktop-build-mac/win/linux
- 重命名 backend-migrate-* → backend-db-*
- 重命名 frontend-build-desktop → desktop-prepare-frontend
- 重命名 embedfs-prepare → desktop-prepare-embedfs
- 重命名 package-macos → desktop-package-mac

新增顶层便捷命令:
- dev: 并行启动开发环境
- build: 构建所有产物
- test: 运行所有测试
- lint: 检查所有代码
- clean: 清理所有构建产物
2026-04-23 12:30:02 +08:00
5b401e29cb feat: 新增 MySQL 专项测试能力
- 新增 backend/tests/mysql/ 目录,包含 Docker Compose 配置和测试文件
- 新增 Makefile 命令: test-mysql, test-mysql-up, test-mysql-down, test-mysql-quick
- 使用 build tag 控制测试启用,默认不运行
- 测试覆盖: 迁移正确性、外键约束、UNIQUE 约束、并发写入
- 发现 statsRepo.Record 存在并发 bug(检查-然后-操作竞态条件)
2026-04-23 12:25:55 +08:00
65ac9f740a refactor: 桌面应用对话框代码拆分为平台专用文件
- 新增 dialog_windows.go、dialog_darwin.go、dialog_linux.go
- 使用 Go 构建标签实现条件编译
- 修复跨平台编译错误(syscall.NewLazyDLL 在 macOS/Linux 未定义)
- 实现 Linux 多工具降级策略(zenity → kdialog → notify-send → xmessage → stderr)
- 实现 macOS AppleScript 字符转义
- 更新 messagebox_test.go 构建标签
- 更新 desktop-app spec 新增 Linux 降级策略和 macOS 字符转义规范
2026-04-23 11:47:48 +08:00
58ebcaa299 refactor(scripts): 开发分支初始化脚本从 bash 重构为 Python,增强跨平台支持和用户体验 2026-04-23 10:43:59 +08:00
5b765c8b5e feat: 新增 MySQL 数据库驱动支持,支持跨设备数据同步 2026-04-23 00:43:23 +08:00
b3258e76df perf: 前端打包产物优化——路由级懒加载和 vendor 分包
- 使用 React.lazy() + Suspense 实现路由级代码分割
- 配置 manualChunks 将 react/tdesign/recharts 拆分为独立 vendor chunk
- 页面组件改为 export default 以支持动态导入
- 新增 bundle-optimization 规范,更新 frontend 导航规范
2026-04-23 00:26:54 +08:00
64dc66afa6 fix: Windows 桌面应用打包问题修复
- 删除通用 desktop target,重命名 platform targets 为简短形式 (desktop-mac/win/linux)
- 构建产物文件名统一为 nex-{os}-{arch}[.exe] 格式
- Windows 托盘图标使用 .ico 格式(运行时按平台选择)
- Windows 原生对话框使用 user32.MessageBoxW 替代 msg * 命令
- 更新 README.md 和 package-macos.sh 中的引用
- 添加单元测试覆盖 MessageBoxW 封装和图标选择逻辑
- 同步更新 desktop-app spec 规范文档
2026-04-22 23:20:39 +08:00
15f08ee2ca fix: 桌面应用跨平台编译和单实例锁
- 使用 gofrs/flock 替代 syscall.Flock 以支持 Windows
- 引入 SingletonLock 结构体,支持锁路径参数化(测试与生产隔离)
- 对齐服务初始化流程与 cmd/server(RoutingCache、StatsBuffer)
- 添加 gofrs/flock 依赖
- 重写单例测试,覆盖加锁/解锁/重复加锁场景
- 更新 desktop-app 规范,补充跨平台锁细节
- 新增 cross-platform-singleton 规范
2026-04-22 22:32:55 +08:00
380586afa6 Merge branch 'dev-update-readme' 2026-04-22 19:43:01 +08:00
ebb70809bf Merge branch 'dev-app' 2026-04-22 19:38:29 +08:00
7399afbc5c Merge branch 'dev-frontend-style-optimization' 2026-04-22 19:37:30 +08:00
c0669e4b07 Merge branch 'dev-database-write-optimization' 2026-04-22 19:36:38 +08:00
05c04091b3 Merge branch 'dev-add-review-prompt' 2026-04-22 19:35:47 +08:00
0b05e08705 feat: 新增桌面应用支持
- 新增 desktop 应用入口,将后端与前端打包为单一可执行文件
- 集成系统托盘功能(getlantern/systray)
- 支持单实例锁和端口冲突检测
- 启动时自动打开浏览器显示管理界面
- 新增 embedfs 模块嵌入静态资源
- 新增跨平台构建脚本(macOS/Windows/Linux)
- 新增 macOS .app 打包脚本
- 统一 Makefile,移除 backend/Makefile
- 更新 README 添加桌面应用使用说明
2026-04-22 19:27:27 +08:00
df253559a5 feat(cache): 实现 RoutingCache 和 StatsBuffer 优化数据库写入
- 新增 RoutingCache 组件,使用 sync.Map 缓存 Provider 和 Model
- 新增 StatsBuffer 组件,使用 sync.Map + atomic.Int64 缓冲统计数据
- 扩展 StatsRepository.BatchUpdate 支持批量增量更新
- 改造 RoutingService/StatsService/ProviderService/ModelService 集成缓存
- 更新 usage-statistics spec,新增 routing-cache 和 stats-buffer spec
- 新增单元测试覆盖缓存命中/失效/并发场景
2026-04-22 19:24:36 +08:00
669cbb8c51 feat(prompts): 添加 proposal-review 和 apply-review 审查提示词 2026-04-22 19:12:03 +08:00
5ae9d85272 style: 优化前端样式,提升现代化设计感
- ConfigProvider 注入全局配置(动画、表格尺寸)
- CSS Variables 主题微调(页面背景、圆角、字体栈)
- AppLayout Menu 支持 logo/operations/collapsed
- Statistic 组件增加 color/prefix/suffix/animation
- Card 组件启用 hoverShadow/headerBordered
- Table 组件启用 stripe 斑马纹
- Tag 组件使用 variant="light" + shape="round"
- Dialog 居中显示并设置固定宽度
- 布局样式硬编码颜色替换为 TDesign Token
- UsageChart 改用 AreaChart + 渐变填充
- 更新 frontend spec 同步样式体系要求
2026-04-22 18:09:22 +08:00
72aebef625 docs: 更新三份 README 文档以反映实际项目情况 2026-04-22 16:24:38 +08:00
f5e45d032e refactor: 重命名提示词文件为英文 prompt-xxx 格式,优化智能合并提示词 2026-04-22 15:34:41 +08:00
b03e5f809f Merge branch 'dev-initial-scritps' into master 2026-04-22 15:26:14 +08:00
ec563aaa16 docs: 优化 prompts 提示词,面向 AI 精简 token 并新增书写原则 2026-04-22 15:21:43 +08:00
873f09d3bf refactor(scripts): 拆分脚本为 init/ 和 detect/ 子目录,优化 init-llm.sh 2026-04-22 14:57:14 +08:00
5e7267db07 fix(e2e): 修复 10 个被 skip 的 E2E 测试
- 将 playwright.config.ts 的 mkdtemp 替换为固定路径,解决主进程/worker 临时目录不一致问题
- 交换后端 WAL 与迁移执行顺序,确保 sql.js 能读取到完整 schema
- 修复 models.spec.ts 断言使用 exact:true 避免统一模型 ID 列干扰
- 移除全部 10 个 test.skip,26 个 E2E 测试全部通过
2026-04-22 14:32:12 +08:00
7b28cee7a1 Merge branch 'dev-frontend-optimization' 2026-04-22 13:31:37 +08:00
934c8dea77 Merge branch 'dev-testcase-analysis' 2026-04-22 13:23:32 +08:00
7d91fe345e Merge branch 'dev-conversion-docs' 2026-04-22 13:23:15 +08:00
4e86adffb7 feat: 系统性改进后端测试体系
- 新增 6 个测试场景 (config load pipe, handler errors, service aggregation, engine degradation, openai decoder edges, negative tests)
- 更新测试工具规格 (mockgen, in-memory SQLite)
- 覆盖率目标从 >80% 提升至 >85%
- 新增 test-unit 和 test-integration Makefile 命令
- 新增死代码清理和 mockgen 需求
- 归档变更至 openspec/changes/archive/2026-04-22-improve-backend-testing/
2026-04-22 13:18:51 +08:00
5d58acf5a6 fix: 修复供应商管理弹窗交互问题并去掉 API Key 脱敏
- Dialog 设置 lazy={false} 修复首次打开编辑弹窗表单为空
- API Key 改为普通字段(前端去掉 password 类型,后端去掉掩码逻辑)
- 删除模型编辑弹窗中的统一模型 ID 字段
- 简化 ProviderService.Get 签名(去掉 maskKey 参数)
- 删除 domain 和 config 层的 MaskAPIKey() 方法
- 更新前后端测试(107 单元测试 + 16 E2E 全部通过)
- 同步 delta spec 到主 spec
2026-04-22 13:13:25 +08:00
81dcecb723 docs: 补充 bun 作为前端唯一包管理器的说明 2026-04-22 11:37:05 +08:00
141f5f886f fix: 修复供应商管理弹窗交互问题
- 导入 TDesign react-19-adapter 修复 MessagePlugin 在 React 19 下的渲染错误
- Dialog 禁用蒙版点击和 ESC 键关闭,防止误操作丢失表单数据
- 重构弹窗关闭逻辑,使用 mutateAsync 替代 useEffect 监听 isSuccess
- 成功后自动关闭弹窗,失败后保持弹窗打开并显示错误提示
2026-04-22 11:36:16 +08:00
7fa5af483b docs: 更新 conversion 设计文档以匹配代码实现
主要更新内容:
- 新增三车道数据流模型(透传/智能透传/完整转换)
- 补充 ProtocolAdapter 智能透传相关 4 个方法
- 更新 InterfaceType 枚举(移除 AUDIO/IMAGES,增加 PASSTHROUGH)
- 新增 HTTPRequestSpec/HTTPResponseSpec 类型定义
- 更新引擎方法签名(增加 modelOverride 和 interfaceType 参数)
- 明确接口类型分发策略和中间件应用范围
- 新增三种流式转换器变体
- 重写错误处理策略为分层宽容策略
- 标记多模态和扩展点为 Deferred
- 更新附录 B 接口速查和附录 D 协议适配清单
2026-04-22 10:54:30 +08:00
f488b9cc15 fix(e2e): 修复对话框关闭问题,完善 E2E 测试
- 修复 TDesign Dialog onConfirm 不自动关闭的问题
- 使用 useEffect 监听 mutation 状态自动关闭对话框
- 测试使用 waitForResponse 等待 API 响应
- 添加 clearDatabase 函数确保测试隔离
- 归档 e2e-real-backend 变更到 archive/2026-04-22
- 同步 e2e-testing spec 到主 specs
2026-04-22 10:32:57 +08:00
59179094ed feat: E2E 测试集成真实后端
- Playwright 双 webServer 模式自动启动 Go 后端 + Vite 前端
- 后端使用临时 SQLite 数据库隔离,固定端口 19026
- vite.config.ts proxy target 动态读取环境变量
- 新增 sql.js 依赖用于 SQLite 统计数据 seed
- 新增 e2e/fixtures.ts 共享工具模块(API seed + SQLite seed)
- 拆分测试文件 5→7(providers/models/stats/navigation/validation)
- 删除旧文件 crud.spec.ts/sidebar.spec.ts/stats-cards.spec.ts
- E2E 测试尚有部分用例需调试修复
2026-04-22 00:31:35 +08:00
4fc5fb4764 Merge branch 'dev-openai-path-parse' 2026-04-21 20:50:52 +08:00
feff97acbd feat: 前端适配后端新接口
适配后端统一模型 ID、协议字段、UUID 自动生成和结构化错误响应:

- 类型定义:Provider 新增 protocol 字段,Model 新增 unifiedId,CreateModelInput 移除 id
- API 客户端:提取结构化错误响应中的错误码
- 供应商管理:添加协议选择下拉框和表格列
- 模型管理:移除 ID 输入,显示统一模型 ID(只读)
- Hooks:错误码映射为友好中文消息
- 测试:所有组件测试通过,mock 数据适配新字段
- 文档:更新 README 说明协议字段和统一模型 ID
2026-04-21 20:49:37 +08:00
b7e205f4b6 refactor: 优化 URL 路径拼接,修复 /v1 重复问题
## 主要变更

**核心修改**:
- 路由定义:/:protocol/v1/*path → /:protocol/*path
- proxy_handler:nativePath 直接使用 path 参数,不添加 /v1 前缀
- OpenAI 适配器:DetectInterfaceType 和 BuildUrl 去掉 /v1 前缀
- Anthropic 适配器:保持 /v1 前缀(Claude Code 兼容)

**URL 格式变化**:
- OpenAI: /openai/v1/chat/completions → /openai/chat/completions
- Anthropic: /anthropic/v1/messages (保持不变)

**base_url 配置**:
- OpenAI: 配置到版本路径,如 https://api.openai.com/v1
- Anthropic: 不配置版本路径,如 https://api.anthropic.com

## 测试验证

- 所有单元测试通过
- 所有集成测试通过
- 真实 API 测试验证成功
- 跨协议转换正常工作

## 文档更新

- 更新 backend/README.md URL 格式说明
- 同步 OpenSpec 规范文件
2026-04-21 20:21:17 +08:00
24f03595a7 Merge branch 'scripts-test'
合并 API 兼容性检测脚本改进:
- 完善流式测试覆盖并精简用例
- 添加缺失的 parse_sse_events 函数到 core.py
- 补充 OpenAI 枚举参数和边界越界测试
- 完善 API 兼容性测试用例
- 优化兼容性检测脚本
2026-04-21 18:16:50 +08:00
395887667d feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
2026-04-21 18:14:10 +08:00
44d6af026a feat: 完善流式测试覆盖并精简用例
- 提取共享定义(tool_weather, image_url, json_schema_math)到功能块前
- 流式用例精简为代表子集:核心 6-8 个 + 扩展各 1-2 个 + 高级参数代表
- OpenAI: 15 个流式用例(核心 8 + vision/tools/logprobs/json_schema + 高级参数)
- Anthropic: 11 个流式用例(核心 6 + vision/tools/thinking + 高级参数)
- 更新 README:新增流式测试覆盖原则、parse_sse_events 函数说明
2026-04-21 17:18:35 +08:00
6e11ada42c fix: 添加缺失的 parse_sse_events 函数到 core.py 2026-04-21 15:22:08 +08:00
da790db75b feat: 补充 OpenAI 枚举参数和边界越界测试
- service_tier: 补充 flex, priority 测试
- reasoning_effort: 补充 none, minimal 测试
- verbosity: 补充 medium, high 测试
- 边界越界测试: frequency_penalty, presence_penalty, top_p, n
- core.py: http_stream_request 支持 method 参数
- Anthropic: 补充 content_block_start 事件验证
2026-04-21 14:15:27 +08:00
e1af978c56 feat: 完善 API 兼容性测试用例
- 修复 Anthropic Count Tokens 响应验证器,检查嵌套结构
- 补充 OpenAI service_tier: default 测试
- 补充 Anthropic output_config 带 effort 字段测试
- 补充 OpenAI reasoning_effort: low/high 测试
- 补充 Anthropic service_tier: standard_only 测试
- 修复流式响应 choices 数量验证逻辑,跳过空数组
2026-04-21 14:00:39 +08:00
980875ecf3 feat: 优化兼容性检测脚本
- 重命名脚本为 detect_xxx.py 格式
- 移除所有装饰线,精简输出格式
- 请求/响应输出增加 URL/Headers/入参/响应 标题标记
- 为所有正面用例添加响应验证器
- 补充 OpenAI 版缺失的负面测试(max_tokens 负数/0、temperature 越界)
- 移除未使用的 format_validation_errors 导入
- 新增 scripts/README.md 文档
2026-04-21 12:50:49 +08:00
7f0f831226 feat: 抽取 scripts/core.py 公共模块,重构检测脚本
将 anthropic_detect.py 和 openai_detect.py 中的公共功能抽取到
core.py 模块,包括:
- HTTP 请求(普通/流式)及重试逻辑
- SSL 上下文管理
- 测试用例/结果数据结构 (TestCase, TestResult)
- 错误分类 (ErrorType)
- 响应验证辅助函数 (validate_response_structure 等)
- 测试执行框架 (run_test, run_test_suite)

两个检测脚本重构后更聚焦于各自 API 的测试用例定义。
2026-04-21 11:45:21 +08:00
f3a207fa16 docs: 添加 openspec 变更记录,归档已完成变更并添加 unified-model-id 提案 2026-04-21 00:45:39 +08:00
56ecc73d1b docs: 整合 openspec 规范,合并配置和前端相关独立 spec
将 cli-config、config-priority、env-config 合并入 config-management;
将 tdesign-integration、recharts-integration、frontend-config-ui、
frontend-testing、stats-dashboard 合并入新的 frontend/spec.md;
清理其余 spec 中的冗余标记,补充缺失场景。
2026-04-20 19:55:56 +08:00
1ae9336cbe docs: 更新后端文档细节和内容 2026-04-20 19:47:41 +08:00
3fa5827de3 feat: 添加 Anthropic 兼容性检测脚本,OpenAI 脚本增加 --all 参数
- 新增 scripts/anthropic_detect.py,覆盖 Messages/Models/Count Tokens 等 API 的正面与负面测试用例
- OpenAI 脚本新增 --all 快捷 flag 一键开启所有扩展测试
- 更新 .gitignore 补充 Python 常见忽略项
2026-04-20 19:35:47 +08:00
cfb0edf802 feat: 引入 Viper 实现多层配置管理
引入 Viper 配置管理框架,支持 CLI 参数、环境变量、配置文件和默认值四种配置方式。

主要变更:
- 引入 Viper、pflag、validator、mapstructure 依赖
- 实现配置优先级:CLI > ENV > File > Default
- 所有 13 个配置项支持 CLI 参数和环境变量
- 规范化命名:server.port → NEX_SERVER_PORT → --server-port
- 使用结构体验证器进行配置验证
- 添加配置摘要输出功能

新增能力:
- cli-config: 命令行参数配置支持
- env-config: 环境变量配置支持(符合 12-Factor App)
- config-priority: 配置优先级管理

修改能力:
- config-management: 扩展为多层配置源支持

使用示例:
  ./server --server-port 9000 --log-level debug
  export NEX_SERVER_PORT=9000 && ./server
  ./server --config /path/to/custom.yaml
2026-04-20 18:04:42 +08:00
286 changed files with 26544 additions and 6696 deletions

7
.editorconfig Normal file
View File

@@ -0,0 +1,7 @@
root = true
[*]
end_of_line = lf
trim_trailing_whitespace = true
insert_final_newline = true
charset = utf-8

8
.gitattributes vendored Normal file
View File

@@ -0,0 +1,8 @@
* text=auto eol=lf
assets/*.png filter=lfs diff=lfs merge=lfs -text
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
assets/*.icns filter=lfs diff=lfs merge=lfs -text
assets/**/*.icns filter=lfs diff=lfs merge=lfs -text
assets/*.ico filter=lfs diff=lfs merge=lfs -text
assets/**/*.ico filter=lfs diff=lfs merge=lfs -text

89
.gitignore vendored
View File

@@ -317,10 +317,99 @@ Network Trash Folder
Temporary Items
.apdisk
### Python.gitignore ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Environments
.venv/
venv/
ENV/
env/
.python-version
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# Pyre
.pyre/
# pytype
.pytype/
# Cython debug symbols
cython_debug/
# Custom
.claude
.opencode
.codex
openspec/changes/archive
temp
.agents
skills-lock.json
.worktrees
!scripts/build/
# Embedfs generated
embedfs/assets/
embedfs/frontend-dist/
backend/cmd/desktop/rsrc_windows_*.syso

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"files.eol": "\n"
}

184
LICENSE Normal file
View File

@@ -0,0 +1,184 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.
"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form,
made available under the License, as indicated by a copyright notice that is
included in or attached to the work (an example is provided in the Appendix
below).
"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this
License, each Contributor hereby grants to You a perpetual, worldwide,
non-exclusive, no-charge, royalty-free, irrevocable copyright license to
reproduce, prepare Derivative Works of, publicly display, publicly perform,
sublicense, and distribute the Work and such Derivative Works in Source or
Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License,
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section) patent
license to make, have made, use, offer to sell, sell, import, and otherwise
transfer the Work, where such license applies only to those patent claims
licensable by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s) with the Work
to which such Contribution(s) was submitted. If You institute patent litigation
against any entity (including a cross-claim or counterclaim in a lawsuit)
alleging that the Work or a Contribution incorporated within the Work
constitutes direct or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate as of the date
such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or
Derivative Works thereof in any medium, with or without modifications, and in
Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of
this License; and
(b) You must cause any modified files to carry prominent notices stating that
You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You
distribute, all copyright, patent, trademark, and attribution notices from the
Source form of the Work, excluding those notices that do not pertain to any part
of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution, then
any Derivative Works that You distribute must include a readable copy of the
attribution notices contained within such NOTICE file, excluding those notices
that do not pertain to any part of the Derivative Works, in at least one of the
following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.
You may add Your own copyright statement to Your modifications and may provide
additional or different license terms and conditions for use, reproduction, or
distribution of Your modifications, or for any such Derivative Works as a whole,
provided Your use, reproduction, and distribution of the Work otherwise complies
with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any
Contribution intentionally submitted for inclusion in the Work by You to the
Licensor shall be under the terms and conditions of this License, without any
additional terms or conditions. Notwithstanding the above, nothing herein shall
supersede or modify the terms of any separate license agreement you may have
executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names,
trademarks, service marks, or product names of the Licensor, except as required
for reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
writing, Licensor provides the Work (and each Contributor provides its
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied, including, without limitation, any warranties
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any risks
associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in
tort (including negligence), contract, or otherwise, unless required by
applicable law (such as deliberate and grossly negligent acts) or agreed to in
writing, shall any Contributor be liable to You for damages, including any
direct, indirect, special, incidental, or consequential damages of any character
arising as a result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill, work stoppage,
computer failure or malfunction, or any and all other commercial damages or
losses), even if such Contributor has been advised of the possibility of such
damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or
Derivative Works thereof, You may choose to offer, and charge a fee for,
acceptance of support, warranty, indemnity, or other liability obligations
and/or rights consistent with this License. However, in accepting such
obligations, You may act only on Your own behalf and on Your sole
responsibility, not on behalf of any other Contributor, and only if You agree to
indemnify, defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason of your
accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier identification within
third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

284
Makefile Normal file
View File

@@ -0,0 +1,284 @@
.PHONY: all dev build test lint clean \
backend-build backend-run backend-dev backend-test backend-test-all backend-test-unit backend-test-integration backend-test-coverage \
backend-lint backend-clean backend-deps backend-generate \
backend-db-up backend-db-down backend-db-status backend-db-create \
test-mysql-up test-mysql-down test-mysql test-mysql-quick \
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \
desktop-build desktop-build-mac desktop-build-win desktop-build-linux \
desktop-dev desktop-test desktop-clean \
desktop-prepare-frontend desktop-prepare-embedfs desktop-prepare-windows-resource
# ============================================
# 顶层便捷命令
# ============================================
dev:
@echo "🚀 Starting development environment..."
@$(MAKE) -j2 backend-dev frontend-dev
build: backend-build frontend-build
@echo "✅ Build complete"
test: backend-test desktop-test frontend-test
@echo "✅ All tests passed"
lint: backend-lint frontend-lint
@echo "✅ Lint complete"
all: build test lint
# ============================================
# 后端
# ============================================
backend-build:
cd backend && go build -o bin/server ./cmd/server
backend-run:
cd backend && go run ./cmd/server
backend-dev:
cd backend && go run ./cmd/server
backend-test:
cd backend && go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
backend-test-all:
cd backend && go test ./... -v
backend-test-unit:
cd backend && go test ./internal/... ./pkg/... -v
backend-test-integration:
cd backend && go test ./tests/... -v
backend-test-coverage:
cd backend && go test ./... -coverprofile=coverage.out
cd backend && go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: backend/coverage.html"
backend-lint:
cd backend && go tool golangci-lint run ./...
backend-clean:
rm -rf backend/bin/ backend/coverage.out backend/coverage.html
backend-deps:
cd backend && go mod tidy
backend-generate:
cd backend && go generate ./...
backend-db-up:
@echo "Running database migration up..."
cd backend && goose -dir migrations/sqlite sqlite3 "$(DB_PATH)" up
backend-db-down:
@echo "Running database migration down..."
cd backend && goose -dir migrations/sqlite sqlite3 "$(DB_PATH)" down
backend-db-status:
@echo "Checking database migration status..."
cd backend && goose -dir migrations/sqlite sqlite3 "$(DB_PATH)" status
backend-db-create:
@read -p "Migration name: " name; \
cd backend && goose -dir migrations/sqlite create $$name sql; \
cd backend && goose -dir migrations/mysql create $$name sql
# ============================================
# MySQL 专项测试
# ============================================
test-mysql-up:
@echo "Starting MySQL test container..."
cd backend/tests/mysql && docker-compose up -d
@echo "Waiting for MySQL to be ready..."
@for i in $$(seq 1 30); do \
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
echo "MySQL is ready!"; \
exit 0; \
fi; \
echo "Waiting... ($$i/30)"; \
sleep 1; \
done; \
echo "MySQL failed to start"; \
exit 1
test-mysql-down:
@echo "Stopping MySQL test container..."
cd backend/tests/mysql && docker-compose down -v
test-mysql: test-mysql-up
@echo "Running MySQL tests..."
cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
$(MAKE) test-mysql-down
test-mysql-quick:
@echo "Running MySQL tests (without container management)..."
cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
# ============================================
# 前端
# ============================================
frontend-install:
cd frontend && bun install
frontend-build: frontend-install
cd frontend && bun run build
frontend-dev: frontend-install
cd frontend && bun dev
frontend-test: frontend-install
cd frontend && bun run test
frontend-test-watch: frontend-install
cd frontend && bun run test:watch
frontend-test-coverage: frontend-install
cd frontend && bun run test:coverage
frontend-test-e2e: frontend-install
cd frontend && bun run test:e2e
frontend-lint: frontend-install
cd frontend && bun run lint
frontend-clean:
rm -rf frontend/dist frontend/.next frontend/node_modules frontend/coverage frontend/playwright-report frontend/test-results frontend/tsconfig.tsbuildinfo
# ============================================
# 桌面应用
# ============================================
desktop-build: desktop-build-mac desktop-build-win desktop-build-linux
@echo "✅ Desktop builds complete for all platforms"
desktop-prepare-frontend:
@echo "📦 Preparing frontend for desktop..."
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Copy-Item -LiteralPath 'frontend/.env.desktop' -Destination 'frontend/.env.production.local' -Force"
cd frontend && bun install && bun run build
powershell -NoProfile -Command "Remove-Item -LiteralPath 'frontend/.env.production.local' -Force -ErrorAction SilentlyContinue"
else
cd frontend && cp .env.desktop .env.production.local
cd frontend && bun install && bun run build
rm -f frontend/.env.production.local
endif
desktop-prepare-embedfs:
@echo "📦 Preparing embedded filesystem..."
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Remove-Item -LiteralPath 'embedfs/assets' -Recurse -Force -ErrorAction SilentlyContinue; Remove-Item -LiteralPath 'embedfs/frontend-dist' -Recurse -Force -ErrorAction SilentlyContinue; Copy-Item -LiteralPath 'assets' -Destination 'embedfs/assets' -Recurse; Copy-Item -LiteralPath 'frontend/dist' -Destination 'embedfs/frontend-dist' -Recurse"
else
rm -rf embedfs/assets embedfs/frontend-dist
cp -r assets embedfs/assets
cp -r frontend/dist embedfs/frontend-dist
endif
desktop-prepare-windows-resource:
@echo "📦 Preparing Windows executable icon..."
ifeq ($(OS),Windows_NT)
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso
else
@if command -v x86_64-w64-mingw32-windres >/dev/null 2>&1; then \
cd backend/cmd/desktop && x86_64-w64-mingw32-windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
elif command -v windres >/dev/null 2>&1; then \
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
else \
echo "❌ 未找到 windres无法生成 Windows exe 图标资源"; \
exit 1; \
fi
endif
desktop-build-mac: desktop-prepare-frontend desktop-prepare-embedfs
@echo "🍎 Building macOS..."
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -o ../build/nex-mac-arm64 ./cmd/desktop
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o ../build/nex-mac-amd64 ./cmd/desktop
lipo -create build/nex-mac-arm64 build/nex-mac-amd64 -output build/nex-mac-universal
@echo "📦 Packaging macOS .app..."
mkdir -p build/Nex.app/Contents/MacOS build/Nex.app/Contents/Resources
cp build/nex-mac-universal build/Nex.app/Contents/MacOS/nex
@if [ -f assets/icon.icns ]; then \
cp assets/icon.icns build/Nex.app/Contents/Resources/; \
else \
echo "⚠️ 未找到 assets/icon.icns"; \
fi
@MIN_MACOS_VERSION=$$(vtool -show-build build/nex-mac-universal | awk '/minos / {print $$2; exit}'); \
if [ -z "$$MIN_MACOS_VERSION" ]; then \
echo "❌ 无法读取 macOS 最低系统版本"; \
exit 1; \
fi; \
{ \
printf '%s\n' '<?xml version="1.0" encoding="UTF-8"?>' \
'<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">' \
'<plist version="1.0">' \
'<dict>' \
' <key>CFBundleDevelopmentRegion</key>' \
' <string>zh-Hans</string>' \
' <key>CFBundleExecutable</key>' \
' <string>nex</string>' \
' <key>CFBundleIconFile</key>' \
' <string>icon</string>' \
' <key>CFBundleIdentifier</key>' \
' <string>com.lanyuanxiaoyao.nex</string>' \
' <key>CFBundleInfoDictionaryVersion</key>' \
' <string>6.0</string>' \
' <key>LSApplicationCategoryType</key>' \
' <string>public.app-category.developer-tools</string>' \
' <key>CFBundleName</key>' \
' <string>Nex</string>' \
' <key>CFBundleDisplayName</key>' \
' <string>Nex</string>' \
' <key>CFBundlePackageType</key>' \
' <string>APPL</string>' \
' <key>CFBundleShortVersionString</key>' \
' <string>1.0.0</string>' \
' <key>CFBundleVersion</key>' \
' <string>1.0.0</string>' \
' <key>NSHumanReadableCopyright</key>' \
' <string>Copyright © 2026 Nex</string>' \
' <key>LSMinimumSystemVersion</key>' \
" <string>$$MIN_MACOS_VERSION</string>" \
' <key>LSUIElement</key>' \
' <true/>' \
' <key>NSHighResolutionCapable</key>' \
' <true/>' \
'</dict>' \
'</plist>'; \
} > build/Nex.app/Contents/Info.plist
chmod +x build/Nex.app/Contents/MacOS/nex
@echo "✅ macOS app packaged: build/Nex.app"
desktop-build-win: desktop-prepare-frontend desktop-prepare-embedfs desktop-prepare-windows-resource
@echo "🪟 Building Windows..."
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "New-Item -ItemType Directory -Path 'build' -Force | Out-Null"
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "-H=windowsgui" -o ../build/nex-win-amd64.exe ./cmd/desktop
else
mkdir -p build
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "-H=windowsgui" -o ../build/nex-win-amd64.exe ./cmd/desktop
endif
desktop-build-linux: desktop-prepare-frontend desktop-prepare-embedfs
@echo "🐧 Building Linux..."
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o ../build/nex-linux-amd64 ./cmd/desktop
desktop-dev: desktop-prepare-frontend desktop-prepare-embedfs
@echo "🖥️ Starting desktop app in dev mode..."
cd backend && go run ./cmd/desktop
desktop-test:
cd backend && go test ./cmd/desktop/... -v
desktop-clean:
rm -rf build/ embedfs/assets embedfs/frontend-dist
# ============================================
# 清理
# ============================================
clean: backend-clean frontend-clean desktop-clean
@echo "✅ Clean complete"

221
README.md
View File

@@ -7,13 +7,15 @@
```
nex/
├── backend/ # Go 后端服务(分层架构)
│ ├── cmd/server/ # 主程序入口
│ ├── cmd/
│ │ ├── server/ # CLI 主程序入口
│ │ └── desktop/ # 桌面应用入口
│ ├── internal/
│ │ ├── handler/ # HTTP 处理器 + 中间件
│ │ ├── service/ # 业务逻辑层
│ │ ├── repository/ # 数据访问层
│ │ ├── domain/ # 领域模型
│ │ ├── protocol/ # 协议适配器OpenAI/Anthropic
│ │ ├── conversion/ # 协议转换引擎OpenAI/Anthropic 适配器
│ │ ├── provider/ # 供应商客户端
│ │ └── config/ # 配置管理
│ ├── pkg/ # 公共包logger/errors/validator
@@ -32,16 +34,25 @@ nex/
│ ├── e2e/ # Playwright E2E 测试
│ └── package.json
├── assets/ # 应用资源
│ ├── icon.png # 托盘图标
│ ├── icon.icns # macOS 应用图标
│ └── icon.ico # Windows 应用图标
└── README.md # 本文件
```
## 功能特性
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
- **透明代理**:对 OpenAI 兼容供应商透传请求
- **流式响应**:完整支持 SSE 流式传输
- **跨协议转换**Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
- **统一模型 ID**`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`
- **Smart Passthrough**:同协议请求跳过 Canonical 全量转换,仅在 JSON 层改写 model 字段
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
- **Function Calling**支持工具调用Tools
- **多供应商管理**:配置和管理多个供应商
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
- **扩展接口**:支持 Embeddings 和 Rerank 接口
- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
- **用量统计**:按供应商、模型、日期统计请求数量
- **Web 配置界面**:提供供应商和模型配置管理
@@ -51,12 +62,26 @@ nex/
- **语言**: Go 1.26+
- **HTTP 框架**: Gin
- **ORM**: GORM
- **数据库**: SQLite
- **日志**: zap + lumberjack结构化日志 + 日志轮转)
- **配置**: gopkg.in/yaml.v3
- **数据库**: SQLite / MySQL
- **日志**: zap + lumberjack结构化日志 + 日志轮转 + 模块标识
- **配置**: Viper + pflag多层配置CLI > 环境变量 > 配置文件 > 默认值)
- **验证**: go-playground/validator/v10
- **迁移**: goose
#### 日志模块标识规范
每个模块通过依赖注入获取带模块标识的 logger日志输出格式为 `[module.name]`
```
Console: INFO [handler.proxy] 处理请求 method=POST path=/v1/chat
JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":"POST"}
```
模块命名规范:
- 单一职责包:`database``config`
- 多实体包:`handler.proxy``service.provider`
- 子包:`handler.middleware`
### 前端
- **运行时**: Bun
- **构建工具**: Vite
@@ -71,7 +96,45 @@ nex/
## 快速开始
### 后端
### 桌面应用(推荐)
**构建桌面应用**
```bash
# macOS (arm64 + amd64并打包为 .app)
make desktop-build-mac
# Windows
make desktop-build-win
# Linux
make desktop-build-linux
# 构建所有平台
make desktop-build
```
**使用桌面应用**
- 双击启动应用macOS: Nex.appWindows: nex-win-amd64.exeLinux: nex-linux-amd64
- 系统托盘图标出现,浏览器自动打开管理界面
- 点击托盘图标显示菜单,可打开管理界面或退出
- 关闭浏览器后服务继续运行,可通过托盘重新打开
**注意事项**
- 桌面应用需要 CGO 支持
- macOS: 自带 Xcode Command Line Tools
- Linux: 自带 gcc部分桌面环境需要 `libappindicator3-dev`
- Windows: 需要 MinGW-w64 或在 Windows 环境构建
**Linux 桌面环境兼容性**
- GNOME: 需要 AppIndicator 扩展
- KDE Plasma: 原生支持
- Xfce: 需要 libappindicator
- 其他支持 StatusNotifierItem 规范的环境
### CLI 模式
#### 后端
```bash
cd backend
@@ -99,23 +162,42 @@ bun dev
### 代理接口(对外部应用)
- `POST /v1/chat/completions` - OpenAI Chat Completions API
- `POST /v1/messages` - Anthropic Messages API
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写并保持未改写字段的 JSON 内容和类型不变;跨协议请求走完整 decode/encode 转换。
**OpenAI 协议**`protocol=openai`
- `POST /openai/v1/chat/completions` - 对话补全
- `GET /openai/v1/models` - 模型列表(本地数据库聚合)
- `GET /openai/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
- `POST /openai/v1/embeddings` - 嵌入
- `POST /openai/v1/rerank` - 重排序
**Anthropic 协议**`protocol=anthropic`
- `POST /anthropic/v1/messages` - 消息对话
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
路径边界:网关只剥离第一段协议前缀,剩余路径保持协议原生形态交给 adapter。OpenAI adapter 接收 `/v1/chat/completions``/v1/models``/v1/embeddings``/v1/rerank`,并在构建上游 URL 时去掉 `/v1`Anthropic adapter 接收 `/v1/messages``/v1/models`。因此 OpenAI 供应商 `base_url` 配置到版本路径一级(如 `https://api.openai.com/v1`Anthropic 供应商 `base_url` 配置到域名级(如 `https://api.anthropic.com`)。
代理错误边界:网关层错误统一返回 `{"error":"...","code":"..."}`,例如 `INVALID_JSON``MODEL_NOT_FOUND``CONVERSION_FAILED``UPSTREAM_UNAVAILABLE`。只要上游已经返回 HTTP 响应,非 2xx 的 status、过滤 hop-by-hop header 后的 headers 和 body 会直接透传,不包装为应用错误或协议错误。
模型路由边界:只有 adapter 明确适配的接口会解析请求体中的 `model` 并使用统一模型 ID 路由;未知接口即使包含顶层 `model` 也按无 model 透传处理。
流式边界:同协议无响应 model 改写时原样透传 SSE frame 和 `[DONE]`;同协议需要响应 model 改写时只解析 SSE frame 的 `data` JSON 并改写 `model`;跨协议流式仍走 provider decoder → Canonical stream event → client encoder。
### 管理接口(对前端)
#### 供应商管理
- `GET /api/providers` - 列出所有供应商
- `POST /api/providers` - 创建供应商
- `POST /api/providers` - 创建供应商`id` 仅限字母、数字、下划线,长度 1-64
- `GET /api/providers/:id` - 获取供应商
- `PUT /api/providers/:id` - 更新供应商
- `PUT /api/providers/:id` - 更新供应商`id` 不可修改)
- `DELETE /api/providers/:id` - 删除供应商
#### 模型管理
- `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤)
- `POST /api/models` - 创建模型
- `GET /api/models/:id` - 获取模型
- `PUT /api/models/:id` - 更新模型
- `POST /api/models` - 创建模型`id` 由系统自动生成 UUID`provider_id` + `model_name` 联合唯一)
- `GET /api/models/:id` - 获取模型(响应含 `unified_id` 字段,格式 `provider_id/model_name`
- `PUT /api/models/:id` - 更新模型(不可修改 `id`
- `DELETE /api/models/:id` - 删除模型
#### 统计查询
@@ -126,6 +208,10 @@ bun dev
## 配置
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
### 配置文件
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成:
```yaml
@@ -135,7 +221,14 @@ server:
write_timeout: 30s
database:
path: ~/.nex/config.db
driver: sqlite # sqlite 或 mysql
path: ~/.nex/config.db # SQLite 数据库文件路径
# --- MySQL 配置driver=mysql 时生效)---
# host: localhost
# port: 3306
# user: nex
# password: ""
# dbname: nex
max_idle_conns: 10
max_open_conns: 100
conn_max_lifetime: 1h
@@ -149,48 +242,88 @@ log:
compress: true
```
数据文件:
### 环境变量
所有配置项支持环境变量,使用 `NEX_` 前缀:
```bash
export NEX_SERVER_PORT=9000
export NEX_DATABASE_PATH=/data/nex.db
export NEX_LOG_LEVEL=debug
# MySQL 模式
export NEX_DATABASE_DRIVER=mysql
export NEX_DATABASE_HOST=db.example.com
export NEX_DATABASE_PORT=3306
export NEX_DATABASE_USER=nex
export NEX_DATABASE_PASSWORD=secret
export NEX_DATABASE_DBNAME=nex
```
命名规则:配置路径转大写 + 下划线(如 `server.port``NEX_SERVER_PORT`)。
### CLI 参数
```bash
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
```
命名规则:配置路径转 kebab-case`server.port``--server-port`)。
### 数据文件
- `~/.nex/config.yaml` - 配置文件
- `~/.nex/config.db` - SQLite 数据库
- `~/.nex/config.db` - SQLite 数据库MySQL 模式下不使用本地数据库文件)
- `~/.nex/log/` - 日志目录
## 测试
### 后端测试
```bash
cd backend
# 顶层便捷命令
make test # 运行所有测试
make test-coverage # 生成覆盖率报告
```
### 端测试
# 端测试
make backend-test # 后端测试
make backend-test-coverage # 后端覆盖率
make backend-test-unit # 后端单元测试
make backend-test-integration # 后端集成测试
```bash
cd frontend
bun run test # 单元测试 + 组件测试
bun run test:watch # 监听模式
bun run test:coverage # 生成覆盖率报告
bun run test:e2e # E2E 测试
# 前端测试
make frontend-test # 前端测试
make frontend-test-e2e # 前端 E2E 测试
make frontend-test-coverage # 前端覆盖率
```
## 开发
### 后端开发
```bash
cd backend
make build # 构建
make lint # 代码检查
make migrate-up # 数据库迁移
```
# 首次克隆后安装 Git hooks
lefthook install
### 前端开发
# 顶层便捷命令
make dev # 启动开发环境(并行启动后端和前端)
make build # 构建所有产物
make lint # 检查所有代码
make clean # 清理所有构建产物
```bash
cd frontend
bun run build # 构建生产版本
bun run lint # 代码检查
# 后端开发
make backend-build # 构建后端
make backend-run # 运行后端
make backend-dev # 后端开发模式
make backend-lint # 后端代码检查
make backend-clean # 清理后端构建产物
# 数据库操作
make backend-db-up # 数据库迁移
make backend-db-down # 数据库回滚
make backend-db-status # 数据库迁移状态
make backend-db-create # 创建新迁移
# 前端开发
make frontend-build # 构建前端
make frontend-dev # 前端开发模式
make frontend-lint # 前端代码检查
make frontend-clean # 清理前端构建产物
```
## 开发规范
@@ -201,4 +334,4 @@ bun run lint # 代码检查
## 许可证
MIT
Apache License 2.0

BIN
assets/icon.icns LFS Normal file

Binary file not shown.

BIN
assets/icon.ico LFS Normal file

Binary file not shown.

BIN
assets/icon.png LFS Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

91
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,91 @@
run:
timeout: 5m
tests: true
linters:
disable-all: true
enable:
- forbidigo
- errorlint
- errcheck
- staticcheck
- revive
- gocritic
- gosec
- bodyclose
- noctx
- nilerr
- goimports
- gocyclo
linters-settings:
errcheck:
check-blank: true
check-type-assertions: true
exclude-functions:
- fmt.Fprintf
forbidigo:
analyze-types: true
forbid:
- p: '^fmt\.Print.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^fmt\.Fprint.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
msg: 使用 zap logger不要使用标准库 log
- p: '^zap\.L$'
msg: 通过依赖注入传递 *zap.Logger不要使用全局 logger
- p: '^zap\.S$'
msg: 不使用 Sugar logger
revive:
rules:
- name: exported
- name: var-naming
- name: indent-error-flow
- name: error-strings
- name: error-return
- name: blank-imports
- name: context-as-argument
- name: unexported-return
goimports:
local-prefixes: nex/backend
gocyclo:
min-complexity: 10
issues:
exclude-dirs:
- tests/mocks
exclude-generated: true
exclude-rules:
- path: '(_test\.go|tests/)'
linters:
- forbidigo
- path: '(_test\.go|tests/)'
linters:
- errcheck
source: '(^\s*_\s*=|,\s*_)'
- path: 'tests/integration/e2e_conversion_test\.go'
linters:
- errcheck
- path: '(_test\.go|tests/)'
linters:
- revive
text: '^exported:'
- path: '(_test\.go|tests/)'
linters:
- gosec
text: 'G(101|401|501)'
- path: '(_test\.go|tests/)'
linters:
- gocyclo
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
- linters:
- revive
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
- path: 'internal/conversion/.*\.go'
linters:
- gocyclo
- gocritic
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
linters:
- gocyclo

View File

@@ -1,45 +0,0 @@
.PHONY: build run test test-coverage clean migrate-up migrate-down migrate-status migrate-create lint
# 构建
build:
go build -o bin/server ./cmd/server
# 运行
run:
go run ./cmd/server
# 测试
test:
go test ./... -v
# 测试覆盖率
test-coverage:
go test ./... -coverprofile=coverage.out
go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: coverage.html"
# 清理
clean:
rm -rf bin/ coverage.out coverage.html
# 数据库迁移
migrate-up:
goose -dir migrations sqlite3 $(DB_PATH) up
migrate-down:
goose -dir migrations sqlite3 $(DB_PATH) down
migrate-status:
goose -dir migrations sqlite3 $(DB_PATH) status
migrate-create:
@read -p "Migration name: " name; \
goose -dir migrations create $$name sql
# 代码检查
lint:
golangci-lint run ./...
# 安装依赖
deps:
go mod tidy

View File

@@ -4,29 +4,75 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
## 功能特性
- 支持 OpenAI 协议(`/openai/v1/...`
- 支持 OpenAI 协议(`/openai/v1/...`,例如 `/openai/v1/chat/completions`
- 支持 Anthropic 协议(`/anthropic/v1/...`
- 支持 Hub-and-Spoke 跨协议双向转换OpenAI ↔ Anthropic
- 同协议透传(零语义损失、零序列化开销
- 同协议透传(跳过 Canonical 全量转换,保持协议语义
- 支持流式响应SSE
- 支持 Function Calling / Tools
- 支持 Thinking / Reasoning
- 支持扩展层接口Models、Embeddings、Rerank
- 多供应商配置和路由
- 用量统计
- 结构化日志zap + lumberjack
- 结构化日志zap + lumberjack + 模块标识
- YAML 配置管理
- 请求验证
- 中间件支持(请求 ID、日志、恢复、CORS
## 日志规范
### 模块标识
每个模块通过依赖注入获取带模块标识的 logger
```go
func NewProxyHandler(..., logger *zap.Logger) *ProxyHandler {
return &ProxyHandler{
logger: pkglogger.WithModule(logger, "handler.proxy"),
}
}
```
输出格式:
- Console: `INFO [handler.proxy] 处理请求 method=POST path=/v1/chat`
- JSON: `{"level":"info","logger":"handler.proxy","msg":"处理请求"}`
### 模块命名规范
| 模块 | 命名 |
|------|------|
| ProxyHandler | `handler.proxy` |
| ProviderHandler | `handler.provider` |
| Provider Client | `provider.client` |
| ConversionEngine | `conversion.engine` |
| RoutingCache | `service.routing_cache` |
| StatsBuffer | `service.stats_buffer` |
| Database | `database` |
### 标准字段
使用 `pkg/logger/field.go` 中定义的字段构造函数:
```go
logger.Info("请求开始",
pkglogger.Method("POST"),
pkglogger.Path("/v1/chat"),
pkglogger.RequestID("xxx"),
)
```
### GORM 日志
GORM 日志自动桥接到 zapSQL 查询映射到 Debug 级别。
## 技术栈
- **语言**: Go 1.26+
- **HTTP 框架**: Gin
- **ORM**: GORM
- **数据库**: SQLite
- **数据库**: SQLite / MySQL
- **日志**: zap + lumberjack
- **配置**: gopkg.in/yaml.v3
- **配置**: Viper + pflag多层配置CLI > 环境变量 > 配置文件 > 默认值)
- **验证**: go-playground/validator/v10
- **迁移**: goose
@@ -39,7 +85,7 @@ backend/
│ └── main.go # 主程序入口(依赖注入)
├── internal/
│ ├── config/ # 配置管理
│ │ ├── config.go # 配置加载/保存/验证
│ │ ├── config.go # Viper 多层配置加载/验证
│ │ └── models.go # GORM 数据模型
│ ├── domain/ # 领域模型
│ │ ├── provider.go
@@ -105,19 +151,26 @@ backend/
│ │ ├── errors.go
│ │ └── wrap.go
│ ├── logger/ # 日志系统
│ │ ├── logger.go
│ │ ├── rotate.go
│ │ ── context.go
│ │ ├── logger.go # 核心初始化
│ │ ├── field.go # 标准字段定义
│ │ ── module.go # 模块日志器
│ │ ├── context.go # Context 辅助函数
│ │ ├── gorm.go # GORM 适配器
│ │ ├── minimal.go # 最小化 logger
│ │ └── rotate.go # 日志轮转
│ ├── modelid/ # 统一模型 ID 工具包
│ │ ├── model_id.go
│ │ └── model_id_test.go
│ └── validator/ # 验证器
│ └── validator.go
├── migrations/ # 数据库迁移
── 001_initial_schema.sql
│ └── 002_add_indexes.sql
├── tests/ # 测试
│ ├── helpers.go
│ ├── integration/
├── unit/
│ └── testdata/
── 20260421000001_initial_schema.sql
├── tests/ # 集成测试
│ ├── helpers.go # 测试辅助函数
│ ├── config/ # 测试配置
│ ├── integration/ # 集成测试
│ └── e2e_conversion_test.go # E2E 协议转换测试
│ └── mocks/ # Mock 实现
├── Makefile
├── go.mod
└── README.md
@@ -133,6 +186,136 @@ handlerHTTP 请求处理)
→ repository数据访问
```
代理请求通过 ConversionEngine 进行协议转换:
```
Client Request (clientProtocol)
→ ProxyHandler 路由到上游 provider
→ ConversionEngine 请求转换 (clientProtocol → providerProtocol)
→ ProviderClient 发送请求
→ ConversionEngine 响应转换 (providerProtocol → clientProtocol)
→ Client Response
```
同协议时自动透传,跳过序列化开销。
## 协议转换架构
### Canonical Model 中间表示
所有协议转换都经过 Canonical Model 中间表示层,实现 Hub-and-Spoke 架构:
```
OpenAI Request → Canonical Request → Anthropic Request
(中间表示)
OpenAI Response ← Canonical Response ← Anthropic Response
```
**CanonicalRequest 核心字段**
- `Model` - 统一模型 ID
- `Messages` - 消息列表(支持 text、tool_use、tool_result、thinking 类型)
- `Tools` - 工具定义
- `Thinking` - 推理配置(`budget_tokens``effort`
- `Parameters` - 通用参数(`max_tokens``temperature``top_p` 等)
### Smart Passthrough 机制
同协议请求走 Smart Passthrough 路径,不进入 Canonical 全量转换:
```
1. 检测 clientProtocol == providerProtocol
2. 仅改写请求体中的 model 字段unified_id → upstream_model_name
3. 直接转发请求到上游
4. 响应中仅改写 model 字段upstream_model_name → unified_id
```
Smart Passthrough 保持未改写 JSON 字段的内容和类型不变,不承诺保留原始字节顺序、空白或对象字段顺序。
### 流式转换器层次
```
StreamConverter (接口)
├── PassthroughStreamConverter # 直接透传,无任何处理
├── SmartPassthroughStreamConverter # 同协议 + 按 SSE frame 改写 data JSON model
└── CanonicalStreamConverter # 跨协议完整转换decode → encode
```
### InterfaceType 枚举
| 类型 | 说明 |
|------|------|
| `CHAT` | 对话补全chat/completions、messages |
| `MODELS` | 模型列表 |
| `MODEL_INFO` | 模型详情 |
| `EMBEDDINGS` | 嵌入接口 |
| `RERANK` | 重排序接口 |
| `PASSTHROUGH` | 未知接口,直接透传 |
## 协议适配器特性
### OpenAI 适配器
**特有字段支持**
- `reasoning_effort` - 映射到 Canonical Thinking 配置(`none` → 禁用,其他 → `effort`
- `reasoning_content` - 非标准字段,映射到 Canonical thinking 块
- `max_completion_tokens` - 新字段,优先于 `max_tokens`
- `refusal` - 非标准字段,作为 text 块处理
**废弃字段兼容**
- `functions` / `function_call` - 自动转换为 `tools` / `tool_choice`
**消息处理**
- 合并连续同角色消息Anthropic 不支持连续同角色)
- 工具选择映射:`any``required`
### Anthropic 适配器
**特有字段支持**
- `thinking` - 推理配置(`type: enabled``budget_tokens``effort`
- `output_config` - 结构化输出配置
- `disable_parallel_tool_use` - 禁用并行工具调用
- `container` - 工具容器字段
**不支持的功能**
- Embeddings 接口(返回 `INTERFACE_NOT_SUPPORTED` 错误)
### 跨协议转换注意事项
| 源协议 | 目标协议 | 转换说明 |
|--------|----------|----------|
| OpenAI | Anthropic | `reasoning_effort``thinking`,消息角色合并 |
| Anthropic | OpenAI | `thinking` 块 → `reasoning_content`,工具选择转换 |
## 错误码
### ConversionError 错误码
| 错误码 | 说明 |
|--------|------|
| `INVALID_INPUT` | 输入数据无效 |
| `MISSING_REQUIRED_FIELD` | 缺少必填字段 |
| `INCOMPATIBLE_FEATURE` | 功能不兼容(如跨协议不支持某特性) |
| `FIELD_MAPPING_FAILURE` | 字段映射失败 |
| `TOOL_CALL_PARSE_ERROR` | 工具调用解析错误 |
| `JSON_PARSE_ERROR` | JSON 解析错误 |
| `STREAM_STATE_ERROR` | 流式状态错误 |
| `UTF8_DECODE_ERROR` | UTF-8 解码错误(流式 chunk 截断) |
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
| `ENCODING_FAILURE` | 编码失败 |
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings |
| `UNSUPPORTED_MULTIMODAL` | 跨协议暂不支持多模态内容 |
### AppError 预定义错误
| 错误 | HTTP 状态码 | 说明 |
|------|-------------|------|
| `ErrModelNotFound` | 404 | 模型未找到 |
| `ErrModelDisabled` | 404 | 模型已禁用 |
| `ErrProviderNotFound` | 404 | 供应商未找到 |
| `ErrInvalidProviderID` | 400 | 供应商 ID 格式无效 |
| `ErrDuplicateModel` | 409 | 同一供应商下模型名称重复 |
| `ErrImmutableField` | 400 | 不可修改字段(如供应商 ID |
## 运行方式
### 安装依赖
@@ -151,6 +334,10 @@ go run cmd/server/main.go
## 配置
配置支持多种方式:配置文件、环境变量、命令行参数,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
### 配置文件
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成。
```yaml
@@ -160,7 +347,14 @@ server:
write_timeout: 30s
database:
path: ~/.nex/config.db
driver: sqlite # sqlite 或 mysql
path: ~/.nex/config.db # SQLite 数据库文件路径
# --- MySQL 配置driver=mysql 时生效)---
# host: localhost
# port: 3306
# user: nex
# password: ""
# dbname: nex
max_idle_conns: 10
max_open_conns: 100
conn_max_lifetime: 1h
@@ -174,11 +368,72 @@ log:
compress: true
```
### 环境变量
所有配置项都支持环境变量,使用 `NEX_` 前缀:
```bash
export NEX_SERVER_PORT=9000
export NEX_DATABASE_PATH=/data/nex.db
export NEX_LOG_LEVEL=debug
# MySQL 模式
export NEX_DATABASE_DRIVER=mysql
export NEX_DATABASE_HOST=db.example.com
export NEX_DATABASE_PORT=3306
export NEX_DATABASE_USER=nex
export NEX_DATABASE_PASSWORD=secret
export NEX_DATABASE_DBNAME=nex
```
命名规则:配置路径转大写 + 下划线 + `NEX_` 前缀(如 `server.port``NEX_SERVER_PORT`)。
### 命令行参数
```bash
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
```
命名规则:配置路径转 kebab-case + `--` 前缀(如 `server.port``--server-port`)。
完整参数列表:
```
服务器: --server-port, --server-read-timeout, --server-write-timeout
数据库: --database-driver, --database-path, --database-host, --database-port, --database-user, --database-password, --database-dbname, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime
日志: --log-level, --log-path, --log-max-size, --log-max-backups, --log-max-age, --log-compress
通用: --config (指定配置文件路径)
```
### 使用示例
```bash
# 默认配置
./server
# 临时修改端口
./server --server-port 9000
# 测试场景
./server --database-path /tmp/test.db --log-level debug
# Docker 部署
docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
# MySQL 模式
./server --database-driver mysql --database-host db.example.com --database-user nex --database-password secret --database-dbname nex
# 自定义配置文件
./server --config /path/to/custom.yaml
```
数据文件:
- `~/.nex/config.yaml` - 配置文件
- `~/.nex/config.db` - SQLite 数据库
- `~/.nex/config.db` - SQLite 数据库MySQL 模式下不使用本地数据库文件)
- `~/.nex/log/` - 日志目录
**MySQL 连接说明**MySQL 连接使用 DSN 格式: `user:password@tcp(host:port)/dbname?charset=utf8mb4&parseTime=true&loc=Local`,最低支持 MySQL 8.0+。
## 测试
```bash
@@ -197,6 +452,9 @@ make migrate-up DB_PATH=~/.nex/config.db
make migrate-down DB_PATH=~/.nex/config.db
make migrate-status DB_PATH=~/.nex/config.db
# 创建新迁移
make migrate-create
# 或直接使用 goose
goose -dir migrations sqlite3 ~/.nex/config.db up
```
@@ -205,9 +463,9 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
### 代理接口
使用 `/{protocol}/v1/{path}` URL 前缀路由
使用 `/{protocol}/*path` URL 前缀路由。网关只剥离第一段协议前缀,不在 Handler 中统一添加或移除 `/v1`;剩余 path 是协议原生 nativePath由对应 adapter 识别和组合上游 URL。
#### OpenAI 协议代理
#### OpenAI 协议
```
POST /openai/v1/chat/completions
@@ -216,38 +474,26 @@ POST /openai/v1/embeddings
POST /openai/v1/rerank
```
请求示例:
```json
{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": false
}
```
#### Anthropic 协议代理
#### Anthropic 协议
```
POST /anthropic/v1/messages
GET /anthropic/v1/models
```
请求示例:
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传或 Smart Passthrough跳过 Canonical 全量转换。
```json
{
"model": "claude-3-opus",
"max_tokens": 1024,
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hello"}]}
]
}
```
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
**base_url 约定**
- OpenAI 供应商配置到版本路径一级,例如 `https://api.openai.com/v1`;当客户端请求 `/openai/v1/chat/completions`OpenAI adapter 会把 nativePath `/v1/chat/completions` 映射为上游 path `/chat/completions`,最终 URL 为 `https://api.openai.com/v1/chat/completions`
- Anthropic 供应商配置到域名级,例如 `https://api.anthropic.com`
**模型提取边界**:只有 adapter 明确适配的 Chat、Embeddings、Rerank 等接口会提取 `model` 并尝试统一模型 ID 路由。未知接口不做顶层 `model` 猜测,直接按无 model 透传。
**流式透传边界**:同协议无响应 model 改写时 raw passthrough保留 SSE frame 边界和 `[DONE]`;同协议需要改写时按 SSE frame 解析 `data` JSON仅改写 `model`;跨协议继续使用 StreamDecoder → CanonicalStreamConverter → StreamEncoder。
**错误边界**:网关层代理错误返回 `{"error":"...","code":"..."}`。已收到上游 HTTP 响应时,非 2xx status、过滤 hop-by-hop header 后的 headers 和 body 直接透传;没有收到上游响应的连接/DNS/TLS/超时错误返回 `UPSTREAM_UNAVAILABLE`
### 管理接口
@@ -259,8 +505,6 @@ GET /anthropic/v1/models
- `PUT /api/providers/:id` - 更新供应商
- `DELETE /api/providers/:id` - 删除供应商
创建供应商示例:
```json
{
"id": "openai",
@@ -271,15 +515,15 @@ GET /anthropic/v1/models
}
```
**Protocol 字段说明:**
- `protocol` 标识上游供应商使用的协议类型,可选值:`"openai"`(默认)、`"anthropic"`
- 同协议透传时,请求体和响应体原样转发,零序列化开销
**Protocol 字段**:标识上游供应商使用的协议类型,可选值 `"openai"`(默认)、`"anthropic"`
**重要说明**
- `base_url` 应配置到 API 版本路径,不包含具体端点
- OpenAI: `https://api.openai.com/v1`
- GLM: `https://open.bigmodel.cn/api/paas/v4`
- 其他 OpenAI 兼容供应商根据其文档配置版本路径
**base_url 说明**
- OpenAI 协议:配置到 API 版本路径,如 `https://api.openai.com/v1``https://open.bigmodel.cn/api/paas/v4`
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
**对外 URL 格式**
- OpenAI 协议:`/{protocol}/v1/{endpoint}`,如 `/openai/v1/chat/completions``/openai/v1/models``/openai/v1/embeddings`
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages``/anthropic/v1/models`
#### 模型管理
@@ -289,76 +533,72 @@ GET /anthropic/v1/models
- `PUT /api/models/:id` - 更新模型
- `DELETE /api/models/:id` - 删除模型
创建模型示例
**创建请求**id 由系统自动生成 UUID
```json
{
"id": "gpt-4",
"provider_id": "openai",
"model_name": "gpt-4"
}
```
**响应示例**
```json
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"provider_id": "openai",
"model_name": "gpt-4",
"unified_id": "openai/gpt-4",
"enabled": true,
"created_at": "2026-04-21T00:00:00Z"
}
```
**统一模型 ID**`unified_id` 字段为 `provider_id/model_name` 格式,用于代理请求的 `model` 参数。
#### 统计查询
- `GET /api/stats` - 查询统计
- `GET /api/stats/aggregate` - 聚合统计
查询参数:
查询参数:`provider_id``model_name``start_date`YYYY-MM-DD`end_date``group_by`provider/model/date
- `provider_id` - 供应商 ID
- `model_name` - 模型名称
- `start_date` - 开始日期YYYY-MM-DD
- `end_date` - 结束日期YYYY-MM-DD
- `group_by` - 聚合维度provider/model/date
#### 健康检查
- `GET /health` - 返回 `{"status": "ok"}`
## 开发
### 构建
```bash
make build
make build # 构建
make lint # 代码检查
make deps # 整理依赖
```
### 代码检查
```bash
make lint
```
### 环境要求
- Go 1.26 或更高版本
环境要求Go 1.26 或更高版本
## 公共库使用指南
### pkg/errors — 结构化错误
使用预定义的错误类型,配合 `errors.Is` / `errors.As` 判断错误:
```go
import (
"errors"
pkgErrors "nex/backend/pkg/errors"
)
// 使用预定义错误
return pkgErrors.ErrRequestSend.WithCause(err)
// 判断错误类型
var appErr *pkgErrors.AppError
if errors.As(err, &appErr) {
// appErr.Code, appErr.HTTPStatus, appErr.Message
}
```
可用函数:`NewAppError``Wrap``WithContext``WithMessage``AsAppError`
预定义错误:`ErrModelNotFound``ErrProviderNotFound``ErrInvalidRequest``ErrRequestCreate``ErrRequestSend``ErrResponseRead`
### pkg/logger — 日志系统
使用依赖注入模式,构造函数接受 `*zap.Logger` 参数nil 时回退到 `zap.L()`
构造函数接受 `*zap.Logger` 参数nil 时回退到 `zap.L()`
```go
func NewMyService(repo Repository, logger *zap.Logger) *MyService {
@@ -369,8 +609,6 @@ func NewMyService(repo Repository, logger *zap.Logger) *MyService {
}
```
禁止直接在业务代码中使用 `zap.L()` 全局 logger应通过构造函数注入。
### pkg/validator — 请求验证
```go
@@ -382,8 +620,9 @@ err := v.Validate(myStruct)
## 编码规范
- **JSON 解析**:使用 `encoding/json` 标准库`json.Unmarshal` / `json.Marshal`,不手动扫描字节
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(`strings.Contains(err.Error(), ...)`
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(lint 强约束errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
- **字符串分割**:使用 `strings.SplitN(key, "/", 2)` 等精确分割,不使用索引切片
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片

View File

@@ -0,0 +1,25 @@
//go:build darwin
package main
import (
"fmt"
"os/exec"
"strings"
"go.uber.org/zap"
)
func showError(title, message string) {
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
escapeAppleScript(message), escapeAppleScript(title))
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
}
}
func escapeAppleScript(s string) string {
s = strings.ReplaceAll(s, "\\", "\\\\")
s = strings.ReplaceAll(s, "\"", "\\\"")
return s
}

View File

@@ -0,0 +1,67 @@
//go:build linux
package main
import (
"fmt"
"os/exec"
"sync"
)
type dialogToolType int
const (
toolNone dialogToolType = iota
toolZenity
toolKdialog
toolNotifySend
toolXmessage
)
var (
dialogTool dialogToolType
dialogToolOnce sync.Once
)
func init() {
dialogToolOnce.Do(detectDialogTool)
}
func detectDialogTool() {
tools := []struct {
name string
typ dialogToolType
}{
{"zenity", toolZenity},
{"kdialog", toolKdialog},
{"notify-send", toolNotifySend},
{"xmessage", toolXmessage},
}
for _, tool := range tools {
if _, err := exec.LookPath(tool.name); err == nil {
dialogTool = tool.typ
return
}
}
dialogTool = toolNone
}
func showError(title, message string) {
switch dialogTool {
case toolZenity:
exec.Command("zenity", "--error",
fmt.Sprintf("--title=%s", title),
fmt.Sprintf("--text=%s", message)).Run()
case toolKdialog:
exec.Command("kdialog", "--error", message, "--title", title).Run()
case toolNotifySend:
exec.Command("notify-send", "-u", "critical", title, message).Run()
case toolXmessage:
exec.Command("xmessage", "-center",
fmt.Sprintf("%s: %s", title, message)).Run()
default:
dialogLogger().Error("无法显示错误对话框")
}
}

View File

@@ -0,0 +1,15 @@
package main
import (
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
func dialogLogger() *zap.Logger {
if zapLogger != nil {
return zapLogger
}
return pkgLogger.NewMinimal()
}

View File

@@ -0,0 +1,62 @@
//go:build windows
package main
import (
"errors"
"fmt"
"syscall"
"unsafe"
"go.uber.org/zap"
)
const (
mbIconError = 0x10
mbIconInformation = 0x40
)
var (
user32 = syscall.NewLazyDLL("user32.dll")
procMessageBoxW = user32.NewProc("MessageBoxW")
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
return ret, err
}
)
func showError(title, message string) {
if err := messageBox(title, message, mbIconError); err != nil {
if zapLogger != nil {
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
}
}
}
func messageBox(title, message string, flags uint) error {
titlePtr, err := syscall.UTF16PtrFromString(title)
if err != nil {
return err
}
messagePtr, err := syscall.UTF16PtrFromString(message)
if err != nil {
return err
}
ret, callErr := callMessageBoxW(
0,
uintptr(unsafe.Pointer(messagePtr)),
uintptr(unsafe.Pointer(titlePtr)),
uintptr(flags),
)
if ret != 0 {
return nil
}
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
return callErr
}
return fmt.Errorf("MessageBoxW 调用失败")
}

View File

@@ -0,0 +1,33 @@
package main
import (
"runtime"
"testing"
"nex/embedfs"
)
func TestIconSelection_Windows(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("图标格式选择测试仅在 Windows 上运行")
}
if err := testIconLoad("assets/icon.ico"); err != nil {
t.Fatalf("Windows 应加载 .ico 文件: %v", err)
}
}
func TestIconSelection_NonWindows(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("图标格式选择测试在非 Windows 平台运行")
}
if err := testIconLoad("assets/icon.png"); err != nil {
t.Fatalf("非 Windows 平台应加载 .png 文件: %v", err)
}
}
func testIconLoad(path string) error {
_, err := embedfs.Assets.ReadFile(path)
return err
}

View File

@@ -0,0 +1 @@
1 ICON "../../../assets/icon.ico"

402
backend/cmd/desktop/main.go Normal file
View File

@@ -0,0 +1,402 @@
package main
import (
"context"
"fmt"
"io/fs"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"nex/embedfs"
"nex/backend/internal/config"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/database"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
"github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
var (
server *http.Server
zapLogger *zap.Logger
shutdownCtx context.Context
shutdownCancel context.CancelFunc
)
func main() {
port := 9826
minimalLogger := pkgLogger.NewMinimal()
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
if err := singleLock.Lock(); err != nil {
minimalLogger.Error("已有 Nex 实例运行")
showError(appName, "已有 Nex 实例运行")
os.Exit(1)
}
defer func() {
if err := singleLock.Unlock(); err != nil {
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
}
}()
if err := checkPortAvailable(port); err != nil {
minimalLogger.Error("端口不可用", zap.Error(err))
showError(appName, err.Error())
return
}
cfg, err := config.LoadConfig()
if err != nil {
minimalLogger.Fatal("加载配置失败", zap.Error(err))
}
zapLogger, err = pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
Level: cfg.Log.Level,
Path: cfg.Log.Path,
MaxSize: cfg.Log.MaxSize,
MaxBackups: cfg.Log.MaxBackups,
MaxAge: cfg.Log.MaxAge,
Compress: cfg.Log.Compress,
})
if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
}
defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger)
db, err := database.Init(&cfg.Database, zapLogger)
if err != nil {
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
}
defer database.Close(db)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
if err := routingCache.Preload(); err != nil {
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
}
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
service.WithFlushInterval(5*time.Second),
service.WithFlushThreshold(100))
statsBuffer.Start()
defer statsBuffer.Stop()
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
}
if err := registry.Register(anthropic.NewAdapter()); err != nil {
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
}
engine := conversion.NewConversionEngine(registry, zapLogger)
providerClient := provider.NewClient(zapLogger)
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(middleware.RequestID())
r.Use(middleware.Recovery(zapLogger))
r.Use(middleware.Logging(zapLogger))
r.Use(middleware.CORS())
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
setupStaticFiles(r)
server = &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: r,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
}
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.Error(err))
}
}()
go func() {
time.Sleep(500 * time.Millisecond)
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.Error(err))
}
}()
setupSystray(port)
}
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
providers := r.Group("/api/providers")
{
providers.GET("", providerHandler.ListProviders)
providers.POST("", providerHandler.CreateProvider)
providers.GET("/:id", providerHandler.GetProvider)
providers.PUT("/:id", providerHandler.UpdateProvider)
providers.DELETE("/:id", providerHandler.DeleteProvider)
}
models := r.Group("/api/models")
{
models.GET("", modelHandler.ListModels)
models.POST("", modelHandler.CreateModel)
models.GET("/:id", modelHandler.GetModel)
models.PUT("/:id", modelHandler.UpdateModel)
models.DELETE("/:id", modelHandler.DeleteModel)
}
stats := r.Group("/api/stats")
{
stats.GET("", statsHandler.GetStats)
stats.GET("/aggregate", statsHandler.AggregateStats)
}
r.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
})
}
func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
next(c)
}
}
func setupStaticFiles(r *gin.Engine) {
distFS, err := frontendDistFS()
if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
}
setupStaticFilesWithFS(r, distFS)
}
func frontendDistFS() (fs.FS, error) {
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
}
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") {
return "application/javascript"
}
if strings.HasSuffix(path, ".css") {
return "text/css"
}
if strings.HasSuffix(path, ".svg") {
return "image/svg+xml"
}
if strings.HasSuffix(path, ".png") {
return "image/png"
}
if strings.HasSuffix(path, ".ico") {
return "image/x-icon"
}
if strings.HasSuffix(path, ".woff") || strings.HasSuffix(path, ".woff2") {
return "font/woff2"
}
return "application/octet-stream"
}
r.GET("/assets/*filepath", func(c *gin.Context) {
filepath := c.Param("filepath")
data, err := fs.ReadFile(distFS, "assets"+filepath)
if err != nil {
c.Status(404)
return
}
c.Data(200, getContentType(filepath), data)
})
r.GET("/favicon.svg", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg")
if err != nil {
c.Status(404)
return
}
c.Data(200, "image/svg+xml", data)
})
r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/openai/") ||
strings.HasPrefix(path, "/anthropic/") ||
path == "/openai" ||
path == "/anthropic" ||
strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"})
return
}
data, err := fs.ReadFile(distFS, "index.html")
if err != nil {
c.Status(500)
return
}
c.Data(200, "text/html; charset=utf-8", data)
})
}
func setupSystray(port int) {
systray.Run(func() {
var icon []byte
var err error
if runtime.GOOS == "windows" {
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
} else {
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
}
if err != nil {
zapLogger.Error("无法加载托盘图标", zap.Error(err))
}
systray.SetIcon(icon)
systray.SetTooltip(appTooltip)
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
systray.AddSeparator()
mStatus := systray.AddMenuItem("状态: 运行中", "")
mStatus.Disable()
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
mPort.Disable()
systray.AddSeparator()
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
go func() {
for {
select {
case <-mOpen.ClickedCh:
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("打开浏览器失败", zap.Error(err))
}
case <-mQuit.ClickedCh:
doShutdown()
systray.Quit()
return
}
}
}()
}, nil)
}
func doShutdown() {
if zapLogger != nil {
zapLogger.Info("正在关闭服务器...")
}
if server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
zapLogger.Warn("关闭服务器失败", zap.Error(err))
}
}
if shutdownCancel != nil {
shutdownCancel()
}
}
func checkPortAvailable(port int) error {
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return fmt.Errorf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
}
ln.Close()
return nil
}
type SingletonLock struct {
flock *flock.Flock
}
func NewSingletonLock(lockPath string) *SingletonLock {
return &SingletonLock{
flock: flock.New(lockPath),
}
}
func (s *SingletonLock) Lock() error {
locked, err := s.flock.TryLock()
if err != nil {
return err
}
if !locked {
return fmt.Errorf("已有实例运行")
}
return nil
}
func (s *SingletonLock) Unlock() error {
return s.flock.Unlock()
}
func openBrowser(url string) error {
var cmd *exec.Cmd
switch runtime.GOOS {
case "darwin":
cmd = exec.Command("open", url)
case "windows":
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
case "linux":
browsers := []string{"xdg-open", "google-chrome", "firefox"}
for _, browser := range browsers {
if _, err := exec.LookPath(browser); err == nil {
cmd = exec.Command(browser, url)
break
}
}
}
if cmd == nil {
return fmt.Errorf("无法打开浏览器")
}
return cmd.Start()
}

View File

@@ -0,0 +1,61 @@
//go:build windows
package main
import (
"errors"
"syscall"
"testing"
)
func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
t.Helper()
old := callMessageBoxW
callMessageBoxW = fn
t.Cleanup(func() {
callMessageBoxW = old
})
}
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
if err == nil {
t.Fatal("包含 NUL 字符时应该返回错误")
}
}
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 1, syscall.Errno(123)
})
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
}
}
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
err := messageBox("测试标题", "测试消息", mbIconInformation)
if !errors.Is(err, syscall.Errno(5)) {
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
}
}
func TestShowError_WindowsBranch(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
defer func() {
if recovered := recover(); recovered != nil {
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
}
}()
showError("测试错误", "这是一条测试错误消息")
}

View File

@@ -0,0 +1,9 @@
package main
const (
appName = "Nex"
appTooltip = appName
appDescription = "AI Gateway - 统一的大模型 API 网关"
// #nosec G101 -- 项目官网地址不是凭据
appWebsite = "https://github.com/nex/gateway"
)

View File

@@ -0,0 +1,13 @@
package main
import "testing"
func TestDesktopMetadata(t *testing.T) {
if appName != "Nex" {
t.Fatalf("appName = %q, want %q", appName, "Nex")
}
if appTooltip != appName {
t.Fatalf("appTooltip = %q, want %q", appTooltip, appName)
}
}

View File

@@ -0,0 +1,69 @@
package main
import (
"errors"
"net"
"net/http"
"testing"
"time"
)
func TestCheckPortAvailable(t *testing.T) {
port := 19826
err := checkPortAvailable(port)
if err != nil {
t.Fatalf("端口 %d 应该可用: %v", port, err)
}
t.Log("端口可用测试通过")
}
func TestCheckPortOccupied(t *testing.T) {
port := 19827
listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
defer listener.Close()
time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port)
if err == nil {
t.Fatal("端口被占用时应该返回错误")
}
t.Log("端口占用检测测试通过")
}
func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828
listener, err := net.Listen("tcp", "127.0.0.1:19828")
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
server := &http.Server{ReadHeaderTimeout: time.Second}
defer server.Close()
go func() {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
t.Errorf("serve failed: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
listener.Close()
time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port)
if err != nil {
t.Fatalf("端口关闭后应该可用: %v", err)
}
t.Log("端口关闭后可用测试通过")
}

View File

@@ -0,0 +1,74 @@
package main
import (
"os"
"path/filepath"
"testing"
)
func TestSingletonLock_FirstLockSuccess(t *testing.T) {
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-first.lock")
defer os.Remove(lockPath)
lock := NewSingletonLock(lockPath)
if err := lock.Lock(); err != nil {
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
}
defer func() {
if err := lock.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
}
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-dup.lock")
defer os.Remove(lockPath)
lock1 := NewSingletonLock(lockPath)
if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err)
}
defer func() {
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
lock2 := NewSingletonLock(lockPath)
err := lock2.Lock()
if err == nil {
if unlockErr := lock2.Unlock(); unlockErr != nil {
t.Fatalf("解锁失败: %v", unlockErr)
}
t.Fatal("重复加锁应失败,但返回 nil")
}
}
func TestSingletonLock_UnlockThenRelock(t *testing.T) {
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-relock.lock")
defer os.Remove(lockPath)
lock1 := NewSingletonLock(lockPath)
if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err)
}
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
lock2 := NewSingletonLock(lockPath)
if err := lock2.Lock(); err != nil {
t.Fatalf("释放后重新加锁应成功: %v", err)
}
if err := lock2.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
if err := lock.Unlock(); err != nil {
t.Fatalf("未加锁时解锁失败: %v", err)
}
}

View File

@@ -0,0 +1,213 @@
package main
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestSetupStaticFiles(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
r := gin.New()
setupStaticFilesWithFS(r, distFS)
t.Run("API 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 404 {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
})
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/anthropic/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("SPA fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
})
t.Run("MIME type for JS", func(t *testing.T) {
req := httptest.NewRequest("GET", "/assets/test.js", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == 200 {
expected := "application/javascript"
if w.Header().Get("Content-Type") != expected {
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
}
} else {
t.Log("文件不存在,跳过 MIME 类型验证")
}
})
t.Run("MIME type for CSS", func(t *testing.T) {
req := httptest.NewRequest("GET", "/assets/test.css", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == 200 {
expected := "text/css"
if w.Header().Get("Content-Type") != expected {
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
}
} else {
t.Log("文件不存在,跳过 MIME 类型验证")
}
})
t.Log("静态文件服务测试通过")
}
func TestWithProtocolAndStaticRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
r := gin.New()
var gotProtocol string
var gotPath string
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
setupStaticFilesWithFS(r, distFS)
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "openai" {
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
}
if gotPath != "/v1/chat/completions" {
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
}
})
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "anthropic" {
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
}
if gotPath != "/v1/messages" {
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
}
})
t.Run("Static assets are not hijacked", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("GET", "/assets/test.js", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if gotProtocol != "" || gotPath != "" {
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
if w.Code == http.StatusOK {
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
}
return
}
if w.Code != http.StatusNotFound {
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
}
})
t.Run("SPA path keeps fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
t.Errorf("期望返回 HTML实际 %s", w.Header().Get("Content-Type"))
}
})
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/unknown", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
}
if gotProtocol != "openai" || gotPath != "/unknown" {
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
})
}

View File

@@ -3,26 +3,20 @@ package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/pressly/goose/v3"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"nex/backend/internal/config"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/database"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
@@ -32,17 +26,14 @@ import (
)
func main() {
// 1. 加载配置
minimalLogger := pkgLogger.NewMinimal()
cfg, err := config.LoadConfig()
if err != nil {
log.Fatalf("加载配置失败: %v", err)
}
if err := cfg.Validate(); err != nil {
log.Fatalf("配置验证失败: %v", err)
minimalLogger.Fatal("加载配置失败", zap.Error(err))
}
// 2. 初始化日志
zapLogger, err := pkgLogger.New(pkgLogger.Config{
zapLogger, err := pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
Level: cfg.Log.Level,
Path: cfg.Log.Path,
MaxSize: cfg.Log.MaxSize,
@@ -51,48 +42,57 @@ func main() {
Compress: cfg.Log.Compress,
})
if err != nil {
log.Fatalf("初始化日志失败: %v", err)
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
}
defer zapLogger.Sync()
defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
// 3. 初始化数据库
db, err := initDatabase(cfg)
cfg.PrintSummary(zapLogger)
db, err := database.Init(&cfg.Database, zapLogger)
if err != nil {
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error()))
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
}
defer closeDB(db)
defer database.Close(db)
// 4. 初始化 repository 层
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
// 5. 初始化 service 层
providerService := service.NewProviderService(providerRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
if err := routingCache.Preload(); err != nil {
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
}
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
service.WithFlushInterval(5*time.Second),
service.WithFlushThreshold(100))
statsBuffer.Start()
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
// 6. 创建 ConversionEngine
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
}
if err := registry.Register(anthropic.NewAdapter()); err != nil {
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
}
engine := conversion.NewConversionEngine(registry, zapLogger)
// 7. 初始化 provider client
providerClient := provider.NewClient()
providerClient := provider.NewClient(zapLogger)
// 8. 初始化 handler 层
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
// 9. 创建 Gin 引擎
gin.SetMode(gin.ReleaseMode)
r := gin.New()
@@ -103,9 +103,8 @@ func main() {
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
// 10. 启动服务器
srv := &http.Server{
Addr: formatAddr(cfg.Server.Port),
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
Handler: r,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
@@ -114,7 +113,7 @@ func main() {
go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr))
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
zapLogger.Fatal("服务器启动失败", zap.Error(err))
}
}()
@@ -128,89 +127,17 @@ func main() {
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
zapLogger.Fatal("服务器强制关闭", zap.Error(err))
}
statsBuffer.Stop()
zapLogger.Info("服务器已关闭")
}
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return nil, err
}
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
if err := runMigrations(db); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime)
return db, nil
}
func runMigrations(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
migrationsDir := getMigrationsDir()
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
goose.SetDialect("sqlite3")
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}
return nil
}
func getMigrationsDir() string {
_, filename, _, ok := runtime.Caller(0)
if ok {
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
if abs, err := filepath.Abs(dir); err == nil {
return abs
}
}
return "./migrations"
}
func closeDB(db *gorm.DB) {
sqlDB, err := db.DB()
if err != nil {
return
}
sqlDB.Close()
}
func formatAddr(port int) string {
return fmt.Sprintf(":%d", port)
}
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
// 统一代理入口: /{protocol}/v1/{path}
r.Any("/:protocol/v1/*path", proxyHandler.HandleProxy)
r.Any("/:protocol/*path", proxyHandler.HandleProxy)
// 供应商管理 API
providers := r.Group("/api/providers")
{
providers.GET("", providerHandler.ListProviders)
@@ -220,7 +147,6 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
providers.DELETE("/:id", providerHandler.DeleteProvider)
}
// 模型管理 API
models := r.Group("/api/models")
{
models.GET("", modelHandler.ListModels)
@@ -230,14 +156,12 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
models.DELETE("/:id", modelHandler.DeleteModel)
}
// 统计查询 API
stats := r.Group("/api/stats")
{
stats.GET("", statsHandler.GetStats)
stats.GET("/aggregate", statsHandler.AggregateStats)
}
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
})

View File

@@ -2,58 +2,249 @@ module nex/backend
go 1.26.2
require (
github.com/gin-gonic/gin v1.12.0
github.com/google/uuid v1.6.0
github.com/pressly/goose/v3 v3.27.0
github.com/stretchr/testify v1.11.1
go.uber.org/zap v1.27.1
gopkg.in/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1
replace nex/embedfs => ../embedfs
tool (
github.com/golangci/golangci-lint/cmd/golangci-lint
go.uber.org/mock/mockgen
)
require (
github.com/getlantern/systray v1.2.2
github.com/gin-gonic/gin v1.12.0
github.com/go-playground/validator/v10 v10.30.2
github.com/gofrs/flock v0.13.0
github.com/google/uuid v1.6.0
github.com/mitchellh/mapstructure v1.5.0
github.com/pressly/goose/v3 v3.27.0
github.com/spf13/pflag v1.0.10
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
go.uber.org/mock v0.6.0
go.uber.org/zap v1.27.1
gopkg.in/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1
nex/embedfs v0.0.0-00010101000000-000000000000
)
require (
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
4d63.com/gochecknoglobals v0.2.2 // indirect
filippo.io/edwards25519 v1.2.0 // indirect
github.com/4meepo/tagalign v1.4.2 // indirect
github.com/Abirdcfly/dupword v0.1.3 // indirect
github.com/Antonboom/errname v1.0.0 // indirect
github.com/Antonboom/nilnil v1.0.1 // indirect
github.com/Antonboom/testifylint v1.5.2 // indirect
github.com/BurntSushi/toml v1.6.0 // indirect
github.com/Crocmagnon/fatcontext v0.7.1 // indirect
github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24 // indirect
github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
github.com/OpenPeeDeeP/depguard/v2 v2.2.1 // indirect
github.com/alecthomas/go-check-sumtype v0.3.1 // indirect
github.com/alexkohler/nakedret/v2 v2.0.5 // indirect
github.com/alexkohler/prealloc v1.0.0 // indirect
github.com/alingse/asasalint v0.0.11 // indirect
github.com/alingse/nilnesserr v0.1.2 // indirect
github.com/ashanbrown/forbidigo v1.6.0 // indirect
github.com/ashanbrown/makezero v1.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bkielbasa/cyclop v1.2.3 // indirect
github.com/blizzy78/varnamelen v0.8.0 // indirect
github.com/bombsimon/wsl/v4 v4.5.0 // indirect
github.com/breml/bidichk v0.3.2 // indirect
github.com/breml/errchkjson v0.4.0 // indirect
github.com/butuzov/ireturn v0.3.1 // indirect
github.com/butuzov/mirror v1.3.0 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/catenacyber/perfsprint v0.8.2 // indirect
github.com/ccojocar/zxcvbn-go v1.0.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charithe/durationcheck v0.0.10 // indirect
github.com/chavacava/garif v0.1.0 // indirect
github.com/ckaznocha/intrange v0.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/curioswitch/go-reassign v0.3.0 // indirect
github.com/daixiang0/gci v0.13.5 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/denis-tingaikin/go-header v0.5.0 // indirect
github.com/ettle/strcase v0.2.0 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/fatih/structtag v1.2.0 // indirect
github.com/firefart/nonamedreturns v1.0.5 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/fzipp/gocyclo v0.6.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/ghostiam/protogetter v0.3.9 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-critic/go-critic v0.12.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.30.2 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/go-toolsmith/astcast v1.1.0 // indirect
github.com/go-toolsmith/astcopy v1.1.0 // indirect
github.com/go-toolsmith/astequal v1.2.0 // indirect
github.com/go-toolsmith/astfmt v1.1.0 // indirect
github.com/go-toolsmith/astp v1.1.0 // indirect
github.com/go-toolsmith/strparse v1.1.0 // indirect
github.com/go-toolsmith/typep v1.1.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/go-xmlfmt/xmlfmt v1.1.3 // indirect
github.com/gobwas/glob v0.2.3 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32 // indirect
github.com/golangci/go-printf-func-name v0.1.0 // indirect
github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d // indirect
github.com/golangci/golangci-lint v1.64.8 // indirect
github.com/golangci/misspell v0.6.0 // indirect
github.com/golangci/plugin-module-register v0.1.1 // indirect
github.com/golangci/revgrep v0.8.0 // indirect
github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/gordonklaus/ineffassign v0.1.0 // indirect
github.com/gostaticanalysis/analysisutil v0.7.1 // indirect
github.com/gostaticanalysis/comment v1.5.0 // indirect
github.com/gostaticanalysis/forcetypeassert v0.2.0 // indirect
github.com/gostaticanalysis/nilerr v0.1.1 // indirect
github.com/hashicorp/go-immutable-radix/v2 v2.1.0 // indirect
github.com/hashicorp/go-version v1.7.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jgautheron/goconst v1.7.1 // indirect
github.com/jingyugao/rowserrcheck v1.1.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jjti/go-spancheck v0.6.4 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/julz/importas v0.2.0 // indirect
github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect
github.com/kisielk/errcheck v1.9.0 // indirect
github.com/kkHAIKE/contextcheck v1.1.6 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/kulti/thelper v0.6.3 // indirect
github.com/kunwardeep/paralleltest v1.0.10 // indirect
github.com/lasiar/canonicalheader v1.1.2 // indirect
github.com/ldez/exptostd v0.4.2 // indirect
github.com/ldez/gomoddirectives v0.6.1 // indirect
github.com/ldez/grignotin v0.9.0 // indirect
github.com/ldez/tagliatelle v0.7.1 // indirect
github.com/ldez/usetesting v0.4.2 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/leonklingele/grouper v1.1.2 // indirect
github.com/macabu/inamedparam v0.1.3 // indirect
github.com/maratori/testableexamples v1.0.0 // indirect
github.com/maratori/testpackage v1.1.1 // indirect
github.com/matoous/godox v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/mgechev/revive v1.7.0 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/moricho/tparallel v0.3.2 // indirect
github.com/nakabonne/nestif v0.3.1 // indirect
github.com/nishanths/exhaustive v0.12.0 // indirect
github.com/nishanths/predeclared v0.2.2 // indirect
github.com/nunnatsa/ginkgolinter v0.19.1 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/polyfloyd/go-errorlint v1.7.1 // indirect
github.com/prometheus/client_golang v1.12.1 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1 // indirect
github.com/quasilyte/go-ruleguard/dsl v0.3.22 // indirect
github.com/quasilyte/gogrep v0.5.0 // indirect
github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect
github.com/raeperd/recvcheck v0.2.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/ryancurrah/gomodguard v1.3.5 // indirect
github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sanposhiho/wastedassign/v2 v2.1.0 // indirect
github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect
github.com/sashamelentyev/interfacebloat v1.1.0 // indirect
github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect
github.com/securego/gosec/v2 v2.22.2 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sivchari/containedctx v1.0.3 // indirect
github.com/sivchari/tenv v1.12.1 // indirect
github.com/sonatard/noctx v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/sourcegraph/go-diff v0.7.0 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/cobra v1.9.1 // indirect
github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect
github.com/stbenjam/no-sprintf-host-port v0.2.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tdakkota/asciicheck v0.4.1 // indirect
github.com/tetafro/godot v1.5.0 // indirect
github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 // indirect
github.com/timonwong/loggercheck v0.10.1 // indirect
github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect
github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect
github.com/ultraware/funlen v0.2.0 // indirect
github.com/ultraware/whitespace v0.2.0 // indirect
github.com/uudashr/gocognit v1.2.0 // indirect
github.com/uudashr/iface v1.3.1 // indirect
github.com/xen0n/gosmopolitan v1.2.2 // indirect
github.com/yagipy/maintidx v1.0.0 // indirect
github.com/yeya24/promlinter v0.3.0 // indirect
github.com/ykadowak/zerologlint v0.1.5 // indirect
gitlab.com/bosi/decorder v0.4.2 // indirect
go-simpler.org/musttag v0.13.0 // indirect
go-simpler.org/sloglint v0.9.0 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
golang.org/x/mod v0.33.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.42.0 // indirect
golang.org/x/tools/go/expect v0.1.1-deprecated // indirect
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
honnef.co/go/tools v0.6.1 // indirect
mvdan.cc/gofumpt v0.7.0 // indirect
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,18 @@
package config
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/go-playground/validator/v10"
"github.com/mitchellh/mapstructure"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
appErrors "nex/backend/pkg/errors"
@@ -13,40 +20,49 @@ import (
// Config 应用配置
type Config struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Log LogConfig `yaml:"log"`
Server ServerConfig `yaml:"server" mapstructure:"server" validate:"required"`
Database DatabaseConfig `yaml:"database" mapstructure:"database" validate:"required"`
Log LogConfig `yaml:"log" mapstructure:"log" validate:"required"`
}
// ServerConfig 服务器配置
type ServerConfig struct {
Port int `yaml:"port"`
ReadTimeout time.Duration `yaml:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout"`
Port int `yaml:"port" mapstructure:"port" validate:"required,min=1,max=65535"`
ReadTimeout time.Duration `yaml:"read_timeout" mapstructure:"read_timeout" validate:"required"`
WriteTimeout time.Duration `yaml:"write_timeout" mapstructure:"write_timeout" validate:"required"`
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Path string `yaml:"path"`
MaxIdleConns int `yaml:"max_idle_conns"`
MaxOpenConns int `yaml:"max_open_conns"`
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime"`
Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"`
Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"`
Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"`
Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,min=1,max=65535"`
User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"`
Password string `yaml:"password" mapstructure:"password"`
DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"`
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"`
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"`
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"`
}
// LogConfig 日志配置
type LogConfig struct {
Level string `yaml:"level"`
Path string `yaml:"path"`
MaxSize int `yaml:"max_size"`
MaxBackups int `yaml:"max_backups"`
MaxAge int `yaml:"max_age"`
Compress bool `yaml:"compress"`
Level string `yaml:"level" mapstructure:"level" validate:"required,oneof=debug info warn error"`
Path string `yaml:"path" mapstructure:"path" validate:"required"`
MaxSize int `yaml:"max_size" mapstructure:"max_size" validate:"required,min=1"`
MaxBackups int `yaml:"max_backups" mapstructure:"max_backups" validate:"required,min=0"`
MaxAge int `yaml:"max_age" mapstructure:"max_age" validate:"required,min=0"`
Compress bool `yaml:"compress" mapstructure:"compress"`
}
// DefaultConfig returns default config values
func DefaultConfig() *Config {
// Use home dir for default paths
homeDir, _ := os.UserHomeDir()
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex")
return &Config{
@@ -56,7 +72,13 @@ func DefaultConfig() *Config {
WriteTimeout: 30 * time.Second,
},
Database: DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(nexDir, "config.db"),
Host: "",
Port: 3306,
User: "",
Password: "",
DBName: "nex",
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 1 * time.Hour,
@@ -79,7 +101,7 @@ func GetConfigDir() (string, error) {
return "", err
}
configDir := filepath.Join(homeDir, ".nex")
if err := os.MkdirAll(configDir, 0755); err != nil {
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
@@ -103,29 +125,179 @@ func GetConfigPath() (string, error) {
return filepath.Join(configDir, "config.yaml"), nil
}
// setupDefaults 设置默认配置值
func setupDefaults(v *viper.Viper) {
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex")
v.SetDefault("server.port", 9826)
v.SetDefault("server.read_timeout", "30s")
v.SetDefault("server.write_timeout", "30s")
v.SetDefault("database.driver", "sqlite")
v.SetDefault("database.path", filepath.Join(nexDir, "config.db"))
v.SetDefault("database.host", "")
v.SetDefault("database.port", 3306)
v.SetDefault("database.user", "")
v.SetDefault("database.password", "")
v.SetDefault("database.dbname", "nex")
v.SetDefault("database.max_idle_conns", 10)
v.SetDefault("database.max_open_conns", 100)
v.SetDefault("database.conn_max_lifetime", "1h")
v.SetDefault("log.level", "info")
v.SetDefault("log.path", filepath.Join(nexDir, "log"))
v.SetDefault("log.max_size", 100)
v.SetDefault("log.max_backups", 10)
v.SetDefault("log.max_age", 30)
v.SetDefault("log.compress", true)
}
// setupFlags 定义和绑定 CLI 参数
func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
// 定义所有配置项的 CLI 参数
// 注意:这里不设置默认值,让 viper 的默认值生效
flagSet.Int("server-port", 0, "服务器端口")
flagSet.Duration("server-read-timeout", 0, "读超时")
flagSet.Duration("server-write-timeout", 0, "写超时")
flagSet.String("database-driver", "", "数据库驱动sqlite/mysql")
flagSet.String("database-path", "", "数据库文件路径")
flagSet.String("database-host", "", "MySQL 主机地址")
flagSet.Int("database-port", 0, "MySQL 端口")
flagSet.String("database-user", "", "MySQL 用户名")
flagSet.String("database-password", "", "MySQL 密码")
flagSet.String("database-dbname", "", "MySQL 数据库名")
flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数")
flagSet.Int("database-max-open-conns", 0, "最大打开连接数")
flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期")
flagSet.String("log-level", "", "日志级别debug/info/warn/error")
flagSet.String("log-path", "", "日志文件目录")
flagSet.Int("log-max-size", 0, "单个日志文件最大大小 MB")
flagSet.Int("log-max-backups", 0, "保留的旧日志文件最大数量")
flagSet.Int("log-max-age", 0, "保留旧日志文件的最大天数")
flagSet.Bool("log-compress", false, "是否压缩旧日志文件")
// 绑定所有 flag 到 viper
// 注意:必须在设置默认值之后绑定
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
}
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
if err := v.BindPFlag(key, flag); err != nil {
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
}
}
// setupEnv 绑定环境变量
func setupEnv(v *viper.Viper) {
v.SetEnvPrefix("NEX")
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
}
// setupConfigFile 读取配置文件
func setupConfigFile(v *viper.Viper, configPath string) error {
v.SetConfigFile(configPath)
v.SetConfigType("yaml")
// 尝试读取配置文件,如果不存在则忽略
if err := v.ReadInConfig(); err != nil {
if !os.IsNotExist(err) {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
// 配置文件不存在,创建默认配置文件
writeErr := v.SafeWriteConfigAs(configPath)
if writeErr == nil {
return nil
}
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
if errors.As(writeErr, &alreadyExistsErr) {
return nil
}
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
}
return nil
}
// LoadConfig loads config from YAML file, creates default if not exists
func LoadConfig() (*Config, error) {
configPath, err := GetConfigPath()
if err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
return LoadConfigFromPath(configPath)
}
cfg := DefaultConfig()
// LoadConfigFromPath 从指定路径加载配置
func LoadConfigFromPath(configPath string) (*Config, error) {
// 1. 创建 Viper 实例
v := viper.New()
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
// Create default config file
if saveErr := SaveConfig(cfg); saveErr != nil {
return nil, appErrors.WithMessage(appErrors.ErrInternal, "创建默认配置失败")
// 2. 定义 CLI 参数
flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError)
flagSet.String("config", configPath, "配置文件路径")
setupFlags(v, flagSet)
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
if err := flagSet.Parse(os.Args[1:]); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
}
return cfg, nil
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
configPath = configPathFlag
}
// 5. 设置默认值
setupDefaults(v)
// 6. 绑定环境变量
setupEnv(v)
// 7. 读取配置文件
if err := setupConfigFile(v, configPath); err != nil {
return nil, err
}
// 8. 反序列化到结构体
cfg := &Config{}
if err := v.Unmarshal(cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
))); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
// 9. 验证配置
if err := cfg.Validate(); err != nil {
return nil, err
}
return cfg, nil
@@ -145,27 +317,41 @@ func SaveConfig(cfg *Config) error {
// Ensure directory exists
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
if err := os.MkdirAll(dir, 0o755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
return os.WriteFile(configPath, data, 0644)
return os.WriteFile(configPath, data, 0o600)
}
// Validate validates the config
func (c *Config) Validate() error {
if c.Server.Port < 1 || c.Server.Port > 65535 {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的端口号: %d", c.Server.Port))
validate := validator.New()
if err := validate.Struct(c); err != nil {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("配置验证失败: %v", err))
}
validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true}
if !validLevels[c.Log.Level] {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的日志级别: %s", c.Log.Level))
}
if c.Database.Path == "" {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, "数据库路径不能为空")
}
return nil
}
// PrintSummary 打印配置摘要
func (c *Config) PrintSummary(logger *zap.Logger) {
logger.Info("AI Gateway 启动配置",
zap.Int("server_port", c.Server.Port),
zap.String("database_driver", c.Database.Driver),
zap.String("log_level", c.Log.Level),
)
if c.Database.Driver == "mysql" {
logger.Info("数据库配置",
zap.String("driver", "mysql"),
zap.String("host", c.Database.Host),
zap.Int("port", c.Database.Port),
zap.String("database", c.Database.DBName),
)
} else {
logger.Info("数据库配置",
zap.String("driver", "sqlite"),
zap.String("path", c.Database.Path),
)
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
@@ -19,6 +20,12 @@ func TestDefaultConfig(t *testing.T) {
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, "sqlite", cfg.Database.Driver)
assert.Equal(t, "", cfg.Database.Host)
assert.Equal(t, 3306, cfg.Database.Port)
assert.Equal(t, "", cfg.Database.User)
assert.Equal(t, "", cfg.Database.Password)
assert.Equal(t, "nex", cfg.Database.DBName)
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
@@ -46,13 +53,13 @@ func TestConfig_Validate(t *testing.T) {
name: "端口号为0无效",
modify: func(c *Config) { c.Server.Port = 0 },
wantErr: true,
errMsg: "无效的端口号",
errMsg: "配置验证失败",
},
{
name: "端口号超出范围无效",
modify: func(c *Config) { c.Server.Port = 70000 },
wantErr: true,
errMsg: "无效的端口号",
errMsg: "配置验证失败",
},
{
name: "端口号为1有效",
@@ -68,7 +75,7 @@ func TestConfig_Validate(t *testing.T) {
name: "无效日志级别",
modify: func(c *Config) { c.Log.Level = "invalid" },
wantErr: true,
errMsg: "无效的日志级别",
errMsg: "配置验证失败",
},
{
name: "debug级别有效",
@@ -86,10 +93,75 @@ func TestConfig_Validate(t *testing.T) {
wantErr: false,
},
{
name: "数据库路径为空无效",
name: "SQLite模式路径为空无效",
modify: func(c *Config) { c.Database.Path = "" },
wantErr: true,
errMsg: "数据库路径不能为空",
errMsg: "配置验证失败",
},
{
name: "driver值不合法",
modify: func(c *Config) { c.Database.Driver = "postgres" },
wantErr: true,
errMsg: "配置验证失败",
},
{
name: "MySQL配置有效",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = "localhost"
c.Database.Port = 3306
c.Database.User = "root"
c.Database.DBName = "nex"
c.Database.Path = ""
},
wantErr: false,
},
{
name: "MySQL模式host为空无效",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = ""
c.Database.User = "root"
c.Database.DBName = "nex"
c.Database.Path = ""
},
wantErr: true,
errMsg: "配置验证失败",
},
{
name: "MySQL模式user为空无效",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = "localhost"
c.Database.User = ""
c.Database.DBName = "nex"
c.Database.Path = ""
},
wantErr: true,
errMsg: "配置验证失败",
},
{
name: "MySQL模式dbname为空无效",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = "localhost"
c.Database.User = "root"
c.Database.DBName = ""
c.Database.Path = ""
},
wantErr: true,
errMsg: "配置验证失败",
},
{
name: "MySQL模式忽略path字段",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = "localhost"
c.Database.User = "root"
c.Database.DBName = "nex"
c.Database.Path = ""
},
wantErr: false,
},
}
@@ -100,7 +172,9 @@ func TestConfig_Validate(t *testing.T) {
err := cfg.Validate()
if tt.wantErr {
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
@@ -140,7 +214,10 @@ func TestSaveAndLoadConfig(t *testing.T) {
WriteTimeout: 20 * time.Second,
},
Database: DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(dir, "test.db"),
Port: 3306,
DBName: "nex",
MaxIdleConns: 5,
MaxOpenConns: 50,
ConnMaxLifetime: 30 * time.Minute,
@@ -159,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg)
require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644)
err = os.WriteFile(configPath, data, 0o600)
require.NoError(t, err)
// 加载配置
@@ -174,3 +251,72 @@ func TestSaveAndLoadConfig(t *testing.T) {
assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns)
assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress)
}
func TestCLIConfig(t *testing.T) {
// 测试 CLI 参数配置(简化版本)
// 注意:由于 flag.Parse 只能调用一次,这里只测试配置加载流程
t.Run("配置加载流程", func(t *testing.T) {
// 使用默认配置路径测试
cfg := DefaultConfig()
require.NotNil(t, cfg)
// 验证默认值正确
assert.Equal(t, 9826, cfg.Server.Port)
assert.Equal(t, "info", cfg.Log.Level)
})
}
func TestEnvConfig(t *testing.T) {
// 测试环境变量配置(简化版本)
t.Run("环境变量前缀", func(t *testing.T) {
// 验证环境变量前缀设置正确
// 实际的环境变量测试需要独立的进程,这里只验证配置结构
cfg := DefaultConfig()
require.NotNil(t, cfg)
assert.Equal(t, 9826, cfg.Server.Port)
})
}
func TestConfigPriority(t *testing.T) {
// 测试配置优先级(简化版本)
t.Run("默认值设置", func(t *testing.T) {
cfg := DefaultConfig()
require.NotNil(t, cfg)
// 验证所有默认值
assert.Equal(t, 9826, cfg.Server.Port)
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, "sqlite", cfg.Database.Driver)
assert.Equal(t, 3306, cfg.Database.Port)
assert.Equal(t, "nex", cfg.Database.DBName)
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
assert.Equal(t, "info", cfg.Log.Level)
assert.Equal(t, 100, cfg.Log.MaxSize)
assert.Equal(t, 10, cfg.Log.MaxBackups)
assert.Equal(t, 30, cfg.Log.MaxAge)
assert.Equal(t, true, cfg.Log.Compress)
})
}
func TestPrintSummary(t *testing.T) {
t.Run("SQLite模式摘要", func(t *testing.T) {
cfg := DefaultConfig()
assert.NotPanics(t, func() {
cfg.PrintSummary(zap.NewNop())
})
})
t.Run("MySQL模式摘要", func(t *testing.T) {
cfg := DefaultConfig()
cfg.Database.Driver = "mysql"
cfg.Database.Host = "db.example.com"
cfg.Database.Port = 3306
cfg.Database.User = "nex"
cfg.Database.DBName = "nex"
assert.NotPanics(t, func() {
cfg.PrintSummary(zap.NewNop())
})
})
}

View File

@@ -17,11 +17,11 @@ type Provider struct {
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
}
// Model 模型配置
// Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name)
type Model struct {
ID string `gorm:"primaryKey" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"`
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"provider_id"`
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"model_name"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
}
@@ -29,8 +29,8 @@ type Model struct {
// UsageStats 用量统计
type UsageStats struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"`
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"`
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"`
RequestCount int `gorm:"default:0" json:"request_count"`
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
}
@@ -47,12 +47,3 @@ func (Model) TableName() string {
func (UsageStats) TableName() string {
return "usage_stats"
}
// MaskAPIKey 掩码 API Key仅显示最后 4 个字符)
func (p *Provider) MaskAPIKey() {
if len(p.APIKey) > 4 {
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
} else {
p.APIKey = "***"
}
}

View File

@@ -40,6 +40,12 @@ type ProtocolAdapter interface {
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
// 统一模型 ID 相关方法
ExtractUnifiedModelID(nativePath string) (string, error)
ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)
RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
}
// AdapterRegistry 适配器注册表接口

View File

@@ -2,6 +2,7 @@ package anthropic
import (
"encoding/json"
"fmt"
"strings"
"nex/backend/internal/conversion"
@@ -39,13 +40,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
}
}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/v1/models/") {
return false
}
suffix := path[len("/v1/models/"):]
return suffix != "" && !strings.Contains(suffix, "/")
return suffix != ""
}
// BuildUrl 根据接口类型构建 URL
@@ -140,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Message: err.Message,
},
}
body, _ := json.Marshal(errMsg)
body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
}
return body, statusCode
}
@@ -203,3 +207,82 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
}
// ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
}
suffix := nativePath[len("/v1/models/"):]
if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID")
}
return suffix, nil
}
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return "", nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat:
raw, exists := m["model"]
if !exists {
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
}
var current string
if err := json.Unmarshal(raw, &current); err != nil {
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
}
return current, rewriteFunc, nil
default:
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
}
}
// ExtractModelName 从请求体中提取 model 值
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
model, _, err := locateModelFieldInRequest(body, ifaceType)
return model, err
}
// RewriteRequestModelName 最小化改写请求体中的 model 字段
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
if err != nil {
return nil, err
}
return rewriteFunc(newModel)
}
// RewriteResponseModelName 最小化改写响应体中的 model 字段
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
default:
return body, nil
}
}

View File

@@ -2,6 +2,7 @@ package anthropic
import (
"encoding/json"
"errors"
"testing"
"nex/backend/internal/conversion"
@@ -48,6 +49,28 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
}
}
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
a := NewAdapter()
// docs/api_reference/anthropic defines messages and models under /v1.
tests := []struct {
path string
expected conversion.InterfaceType
}{
{"/v1/messages", conversion.InterfaceTypeChat},
{"/v1/models", conversion.InterfaceTypeModels},
{"/v1/models/claude-sonnet-4-5", conversion.InterfaceTypeModelInfo},
{"/messages", conversion.InterfaceTypePassthrough},
{"/models", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
})
}
}
func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter()
@@ -141,8 +164,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
t.Run("解码嵌入请求", func(t *testing.T) {
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
@@ -150,24 +173,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.True(t, errors.As(err, &convErr))
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("解码嵌入响应", func(t *testing.T) {
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("编码嵌入响应", func(t *testing.T) {
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
}
@@ -178,8 +201,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
t.Run("解码重排序请求", func(t *testing.T) {
_, err := a.DecodeRerankRequest([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
@@ -187,24 +210,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("解码重排序响应", func(t *testing.T) {
_, err := a.DecodeRerankResponse([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("编码重排序响应", func(t *testing.T) {
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
}

View File

@@ -0,0 +1,263 @@
package anthropic
import (
"encoding/json"
"testing"
"nex/backend/internal/conversion"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// ExtractUnifiedModelID
// ---------------------------------------------------------------------------
func TestExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("standard_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3")
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3", id)
})
t.Run("multi_segment_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model")
require.NoError(t, err)
assert.Equal(t, "some/deep/nested/model", id)
})
t.Run("single_segment", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/claude-3")
require.NoError(t, err)
assert.Equal(t, "claude-3", id)
})
t.Run("non_model_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/messages")
require.Error(t, err)
})
t.Run("empty_suffix", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models/")
require.Error(t, err)
})
t.Run("models_list_no_slash", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
t.Run("unrelated_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/other")
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName (Chat only for Anthropic)
// ---------------------------------------------------------------------------
func TestExtractModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3", model)
})
t.Run("chat_with_max_tokens", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3-opus", model)
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type_embedding", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3"}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
require.Error(t, err)
})
t.Run("unsupported_interface_type_rerank", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3"}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteRequestModelName (Chat only for Anthropic)
// ---------------------------------------------------------------------------
func TestRewriteRequestModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "claude-3", m["model"])
msgs, ok := m["messages"]
require.True(t, ok)
msgsArr, ok := msgs.([]interface{})
require.True(t, ok)
assert.Len(t, msgsArr, 0)
})
t.Run("preserves_unknown_fields", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`)
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "claude-3", m["model"])
assert.Equal(t, 0.7, m["temperature"])
// max_tokens is encoded as float in JSON numbers
maxTokens, ok := m["max_tokens"]
require.True(t, ok)
assert.Equal(t, float64(1024), maxTokens)
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3"}`)
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteResponseModelName (Chat only for Anthropic)
// ---------------------------------------------------------------------------
func TestRewriteResponseModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat_existing_model", func(t *testing.T) {
body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`)
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "anthropic/claude-3", m["model"])
// other fields preserved
_, hasContent := m["content"]
assert.True(t, hasContent)
assert.Equal(t, "end_turn", m["stop_reason"])
})
t.Run("chat_without_model_field_adds_it", func(t *testing.T) {
body := []byte(`{"content":[],"stop_reason":"end_turn"}`)
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "anthropic/claude-3", m["model"])
})
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
body := []byte(`{"model":"claude-3"}`)
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough)
require.NoError(t, err)
assert.Equal(t, string(body), string(rewritten))
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName and RewriteRequest consistency
// ---------------------------------------------------------------------------
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
a := NewAdapter()
t.Run("chat_round_trip", func(t *testing.T) {
original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`)
// Extract the unified model ID from the body
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3", extracted)
// Rewrite to the native model name
rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
// Extract again from the rewritten body to verify the same location was targeted
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "claude-3", afterRewrite)
// Verify other fields are preserved
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, float64(1024), m["max_tokens"])
})
}
// ---------------------------------------------------------------------------
// isModelInfoPath (additional unified model ID cases)
// ---------------------------------------------------------------------------
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
tests := []struct {
name string
path string
expected bool
}{
{"simple_model_id", "/v1/models/claude-3", true},
{"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true},
{"models_list", "/v1/models", false},
{"models_list_trailing_slash", "/v1/models/", false},
{"messages_path", "/v1/messages", false},
{"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
})
}
}

View File

@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
var canonicalMsgs []canonical.CanonicalMessage
for _, msg := range req.Messages {
decoded := decodeMessage(msg)
decoded, err := decodeMessage(msg)
if err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
}
canonicalMsgs = append(canonicalMsgs, decoded...)
}
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
}
// decodeMessage 解码 Anthropic 消息
func decodeMessage(msg Message) []canonical.CanonicalMessage {
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
switch msg.Role {
case "user":
blocks := decodeContentBlocks(msg.Content)
blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
var toolResults []canonical.ContentBlock
var others []canonical.ContentBlock
for _, b := range blocks {
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
if len(result) == 0 {
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
}
return result
return result, nil
case "assistant":
blocks := decodeContentBlocks(msg.Content)
blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock(""))
}
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
}
return nil
return nil, nil
}
// decodeContentBlocks 解码内容块列表
func decodeContentBlocks(content any) []canonical.ContentBlock {
func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
switch v := content.(type) {
case string:
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
case []any:
var blocks []canonical.ContentBlock
for _, item := range v {
if m, ok := item.(map[string]any); ok {
block := decodeSingleContentBlock(m)
block, err := decodeSingleContentBlock(m)
if err != nil {
return nil, err
}
if block != nil {
blocks = append(blocks, *block)
}
}
}
if len(blocks) > 0 {
return blocks
return blocks, nil
}
return []canonical.ContentBlock{canonical.NewTextBlock("")}
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
case nil:
return []canonical.ContentBlock{canonical.NewTextBlock("")}
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
default:
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
}
}
// decodeSingleContentBlock 解码单个内容块
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
t, _ := m["type"].(string)
func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
t, ok := m["type"].(string)
if !ok {
return nil, nil
}
switch t {
case "text":
text, _ := m["text"].(string)
return &canonical.ContentBlock{Type: "text", Text: text}
text, ok := m["text"].(string)
if !ok {
text = ""
}
return &canonical.ContentBlock{Type: "text", Text: text}, nil
case "tool_use":
id, _ := m["id"].(string)
name, _ := m["name"].(string)
input, _ := json.Marshal(m["input"])
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
id, ok := m["id"].(string)
if !ok {
id = ""
}
name, ok := m["name"].(string)
if !ok {
name = ""
}
input, err := json.Marshal(m["input"])
if err != nil {
return nil, err
}
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
case "tool_result":
toolUseID, _ := m["tool_use_id"].(string)
toolUseID, ok := m["tool_use_id"].(string)
if !ok {
toolUseID = ""
}
isErr := false
if ie, ok := m["is_error"].(bool); ok {
isErr = ie
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
case string:
content = json.RawMessage(fmt.Sprintf("%q", cv))
default:
content, _ = json.Marshal(cv)
encoded, err := json.Marshal(cv)
if err != nil {
return nil, err
}
content = encoded
}
} else {
content = json.RawMessage(`""`)
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
ToolUseID: toolUseID,
Content: content,
IsError: &isErr,
}
}, nil
case "thinking":
thinking, _ := m["thinking"].(string)
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
thinking, ok := m["thinking"].(string)
if !ok {
thinking = ""
}
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
case "redacted_thinking":
// 丢弃
return nil
return nil, nil
}
return nil
return nil, nil
}
// decodeTools 解码工具定义
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny()
}
case map[string]any:
t, _ := v["type"].(string)
t, ok := v["type"].(string)
if !ok {
return nil
}
switch t {
case "auto":
return canonical.NewToolChoiceAuto()
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
case "any":
return canonical.NewToolChoiceAny()
case "tool":
name, _ := v["name"].(string)
name, ok := v["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name)
}
}

View File

@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
assert.Equal(t, true, result["stream"])
assert.Equal(t, float64(1024), result["max_tokens"])
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1)
}
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
// tool 消息应被合并到相邻 user 消息
foundToolResult := false
for _, m := range msgs {
msgMap := m.(map[string]any)
msgMap, ok := m.(map[string]any)
require.True(t, ok)
if msgMap["role"] == "user" {
content, ok := msgMap["content"].([]any)
if ok {
for _, c := range content {
block := c.(map[string]any)
block, ok := c.(map[string]any)
require.True(t, ok)
if block["type"] == "tool_result" {
foundToolResult = true
}
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
firstMsg := msgs[0].(map[string]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", firstMsg["role"])
}
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "assistant", result["role"])
assert.Equal(t, "end_turn", result["stop_reason"])
content := result["content"].([]any)
content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1)
block := content[0].(map[string]any)
block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", block["type"])
assert.Equal(t, "你好", block["text"])
}
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
data := result["data"].([]any)
data, ok := result["data"].([]any)
require.True(t, ok)
assert.Len(t, data, 1)
model := data[0].(map[string]any)
model, ok := data[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-3-opus", model["id"])
// created 应为 RFC3339 格式
createdAt, ok := model["created_at"].(string)
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1)
userMsg := msgs[0].(map[string]any)
userMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", userMsg["role"])
content := userMsg["content"].([]any)
content, ok := userMsg["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 2)
}
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
usage, ok := result["usage"].(map[string]any)
require.True(t, ok)
_, hasReasoning := usage["reasoning_tokens"]
assert.False(t, hasReasoning)
}
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
content := result["content"].([]any)
content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1)
block := content[0].(map[string]any)
block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", block["type"])
assert.Equal(t, "tool_1", block["id"])
assert.Equal(t, "search", block["name"])

View File

@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
data := rawChunk
if len(d.utf8Remainder) > 0 {
data = append(d.utf8Remainder, rawChunk...)
data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
d.utf8Remainder = nil
}
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
for _, line := range strings.Split(text, "\n") {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "event: ") {
switch {
case strings.HasPrefix(line, "event: "):
eventType = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") {
case strings.HasPrefix(line, "data: "):
eventData = strings.TrimPrefix(line, "data: ")
if eventType != "" && eventData != "" {
chunkEvents := d.processEvent(eventType, []byte(eventData))
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
}
eventType = ""
eventData = ""
} else if line == "" {
// SSE 事件分隔符
case line == "":
continue
}
}

View File

@@ -51,15 +51,23 @@ func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent)
if event.Message != nil {
msg := map[string]any{
"id": event.Message.ID,
"model": event.Message.Model,
"type": "message",
"role": "assistant",
"content": []any{},
"model": event.Message.Model,
"stop_reason": nil,
"stop_sequence": nil,
}
if event.Message.Usage != nil {
usage := map[string]any{
msg["usage"] = map[string]any{
"input_tokens": event.Message.Usage.InputTokens,
"output_tokens": event.Message.Usage.OutputTokens,
}
msg["usage"] = usage
} else {
msg["usage"] = map[string]any{
"input_tokens": 0,
"output_tokens": 0,
}
}
payload["message"] = msg
}
@@ -147,6 +155,10 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
payload["usage"] = map[string]any{
"output_tokens": event.Usage.OutputTokens,
}
} else {
payload["usage"] = map[string]any{
"output_tokens": 0,
}
}
return e.marshalEvent("message_delta", payload)
}

View File

@@ -21,8 +21,55 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
assert.Contains(t, s, "data: ")
assert.Contains(t, s, "msg_1")
assert.Contains(t, s, "claude-3")
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "msg_1", msg["id"])
assert.Equal(t, "message", msg["type"])
assert.Equal(t, "assistant", msg["role"])
assert.Equal(t, []any{}, msg["content"])
assert.Equal(t, "claude-3", msg["model"])
assert.Nil(t, msg["stop_reason"])
assert.Nil(t, msg["stop_sequence"])
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(0), usage["input_tokens"])
assert.Equal(t, float64(0), usage["output_tokens"])
}
func TestStreamEncoder_MessageStart_WithUsage(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewMessageStartEventWithUsage("msg_2", "gpt-4", &canonical.CanonicalUsage{InputTokens: 100, OutputTokens: 50})
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(50), usage["output_tokens"])
}
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
@@ -80,7 +127,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
break
}
}
cb := payload["content_block"].(map[string]any)
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", cb["type"])
}
@@ -107,7 +155,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
break
}
}
cb := payload["content_block"].(map[string]any)
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", cb["type"])
assert.Equal(t, "toolu_1", cb["id"])
assert.Equal(t, "search", cb["name"])
@@ -131,7 +180,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
break
}
}
cb := payload["content_block"].(map[string]any)
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "thinking", cb["type"])
}
@@ -173,8 +223,13 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
break
}
}
delta := payload["delta"].(map[string]any)
delta, okd := payload["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "end_turn", delta["stop_reason"])
usage, oku := payload["usage"].(map[string]any)
require.True(t, oku, "message_delta SHALL always include usage")
assert.Equal(t, float64(0), usage["output_tokens"])
}
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
@@ -199,7 +254,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
break
}
}
u := payload["usage"].(map[string]any)
u, oku := payload["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(88), u["output_tokens"])
}

View File

@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
}
func TestDecodeContentBlocks_Nil(t *testing.T) {
blocks := decodeContentBlocks(nil)
blocks, err := decodeContentBlocks(nil)
require.NoError(t, err)
assert.Len(t, blocks, 1)
assert.Equal(t, "", blocks[0].Text)
}
func TestDecodeContentBlocks_String(t *testing.T) {
blocks := decodeContentBlocks("hello")
blocks, err := decodeContentBlocks("hello")
require.NoError(t, err)
assert.Len(t, blocks, 1)
assert.Equal(t, "hello", blocks[0].Text)
}
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := encodeToolChoice(tt.choice)
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"])
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"])
r, ok := result.(map[string]any)
require.True(t, ok)
assert.Equal(t, tt.want["type"], r["type"])
assert.Equal(t, tt.want["name"], r["name"])
})
}
}
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
tools := result["tools"].([]any)
tools, okt := result["tools"].([]any)
require.True(t, okt)
assert.Len(t, tools, 1)
tool := tools[0].(map[string]any)
tool, okt2 := tools[0].(map[string]any)
require.True(t, okt2)
assert.Equal(t, "search", tool["name"])
assert.Equal(t, "Search things", tool["description"])
tc := result["tool_choice"].(map[string]any)
tc, oktc := result["tool_choice"].(map[string]any)
require.True(t, oktc)
assert.Equal(t, "auto", tc["type"])
}
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])

View File

@@ -3,10 +3,14 @@ package conversion
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"nex/backend/internal/conversion/canonical"
pkglogger "nex/backend/pkg/logger"
)
// HTTPRequestSpec HTTP 请求规格
@@ -33,13 +37,10 @@ type ConversionEngine struct {
// NewConversionEngine 创建转换引擎
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
if logger == nil {
logger = zap.L()
}
return &ConversionEngine{
registry: registry,
middlewareChain: NewMiddlewareChain(),
logger: logger,
logger: pkglogger.WithModule(logger, "conversion.engine"),
}
}
@@ -72,18 +73,39 @@ func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string
// ConvertHttpRequest 转换 HTTP 请求
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
nativePath := spec.URL
nativePath, rawQuery := splitRequestPath(spec.URL)
if e.IsPassthrough(clientProtocol, providerProtocol) {
providerAdapter, err := e.registry.Get(providerProtocol)
if err != nil {
return nil, err
}
// Smart Passthrough: 同协议时最小化改写 model 字段
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
rewrittenBody := spec.Body
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
if len(spec.Body) > 0 && provider.ModelName != "" {
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
zap.Error(err),
zap.String("interface", string(interfaceType)))
rewrittenBody = spec.Body
}
}
}
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL = appendRawQuery(providerURL, rawQuery)
return &HTTPRequestSpec{
URL: provider.BaseURL + nativePath,
URL: joinBaseURL(provider.BaseURL, providerURL),
Method: spec.Method,
Headers: providerAdapter.BuildHeaders(provider),
Body: spec.Body,
Body: rewrittenBody,
}, nil
}
@@ -97,7 +119,8 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
}
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL = appendRawQuery(providerURL, rawQuery)
providerHeaders := providerAdapter.BuildHeaders(provider)
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
if err != nil {
@@ -105,16 +128,34 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
}
return &HTTPRequestSpec{
URL: provider.BaseURL + providerUrl,
URL: joinBaseURL(provider.BaseURL, providerURL),
Method: spec.Method,
Headers: providerHeaders,
Body: providerBody,
}, nil
}
// ConvertHttpResponse 转换 HTTP 响应
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) {
// ConvertHttpResponse 转换 HTTP 响应modelOverride 用于跨协议场景覆写 model 字段
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 {
adapter, getErr := e.registry.Get(clientProtocol)
if getErr == nil {
rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if rewriteErr != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.Error(rewriteErr),
zap.String("interface", string(interfaceType)))
} else {
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
}
}
return &spec, nil
}
@@ -127,7 +168,7 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
return nil, err
}
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body)
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
if err != nil {
return nil, err
}
@@ -139,9 +180,16 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
}, nil
}
// CreateStreamConverter 创建流式转换器
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) {
// CreateStreamConverter 创建流式转换器modelOverride 用于跨协议场景覆写 model 字段
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" {
adapter, getErr := e.registry.Get(clientProtocol)
if getErr == nil {
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
}
}
return NewPassthroughStreamConverter(), nil
}
@@ -156,7 +204,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
ctx := ConversionContext{
ConversionID: uuid.New().String(),
InterfaceType: InterfaceTypeChat,
InterfaceType: interfaceType,
Timestamp: time.Now(),
}
@@ -167,6 +215,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
ctx,
clientProtocol,
providerProtocol,
modelOverride,
), nil
}
@@ -192,11 +241,11 @@ func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapte
}
}
// convertResponseBody 转换响应体
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
// convertResponseBody 转换响应体modelOverride 非空时在 canonical 层面覆写 Model 字段
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
switch interfaceType {
case InterfaceTypeChat:
return e.convertChatResponseBody(clientAdapter, providerAdapter, body)
return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
case InterfaceTypeModels:
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
return body, nil
@@ -211,12 +260,12 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
return body, nil
}
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body)
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
case InterfaceTypeRerank:
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
return body, nil
}
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body)
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
default:
return body, nil
}
@@ -225,7 +274,7 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
canonicalReq, err := clientAdapter.DecodeRequest(body)
if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
return nil, NewRequestJSONParseError("解码请求失败", err)
}
ctx := NewConversionContext(InterfaceTypeChat)
@@ -233,6 +282,9 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
if err != nil {
return nil, err
}
if containsUnsupportedMultimodal(canonicalReq) {
return nil, NewConversionError(ErrorCodeUnsupportedMultimodal, "跨协议暂不支持多模态内容")
}
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
if err != nil {
@@ -241,10 +293,13 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
return encoded, nil
}
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
canonicalResp, err := providerAdapter.DecodeResponse(body)
if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
return nil, NewResponseJSONParseError("解码响应失败", err)
}
if modelOverride != "" {
canonicalResp.Model = modelOverride
}
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
if err != nil {
@@ -256,12 +311,12 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil {
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
encoded, err := clientAdapter.EncodeModelsResponse(models)
if err != nil {
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
return encoded, nil
@@ -270,12 +325,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil {
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil {
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
return encoded, nil
@@ -284,36 +339,43 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil {
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
return body, nil
}
return providerAdapter.EncodeEmbeddingRequest(req, provider)
}
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeEmbeddingResponse(resp)
}
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil {
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
return body, nil
}
return providerAdapter.EncodeRerankRequest(req, provider)
}
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body)
if err != nil {
return body, nil
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
if decodeErr == nil {
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
}
return body, nil
}
// DetectInterfaceType 检测接口类型
@@ -322,6 +384,7 @@ func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string
if err != nil {
return InterfaceTypePassthrough, err
}
nativePath, _ = splitRequestPath(nativePath)
return adapter.DetectInterfaceType(nativePath), nil
}
@@ -335,9 +398,56 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
"type": "internal_error",
},
}
body, _ := json.Marshal(fallback)
body, marshalErr := json.Marshal(fallback)
if marshalErr == nil {
return body, 500, nil
}
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
}
body, statusCode := adapter.EncodeError(err)
return body, statusCode, nil
}
func splitRequestPath(rawPath string) (string, string) {
path, query, found := strings.Cut(rawPath, "?")
if !found {
return rawPath, ""
}
return path, query
}
func appendRawQuery(path, rawQuery string) string {
if rawQuery == "" {
return path
}
if strings.Contains(path, "?") {
return path + "&" + rawQuery
}
return path + "?" + rawQuery
}
func joinBaseURL(baseURL, path string) string {
if baseURL == "" {
return path
}
if path == "" {
return baseURL
}
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
}
func containsUnsupportedMultimodal(req *canonical.CanonicalRequest) bool {
if req == nil {
return false
}
for _, msg := range req.Messages {
for _, block := range msg.Content {
switch block.Type {
case "image", "audio", "video", "file":
return true
}
}
}
return false
}

View File

@@ -0,0 +1,63 @@
package conversion_test
import (
"testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestConvertHttpRequest_SameProtocolUsesAdapterBuildURL(t *testing.T) {
tests := []struct {
name string
adapter conversion.ProtocolAdapter
clientProtocol string
providerProtocol string
baseURL string
nativePath string
expectedURL string
body []byte
}{
{
name: "openai base url includes version path",
adapter: openai.NewAdapter(),
clientProtocol: "openai",
providerProtocol: "openai",
baseURL: "http://example.com/v1",
nativePath: "/chat/completions",
expectedURL: "http://example.com/v1/chat/completions",
body: []byte(`{"model":"gpt-4","messages":[]}`),
},
{
name: "anthropic native path keeps v1",
adapter: anthropic.NewAdapter(),
clientProtocol: "anthropic",
providerProtocol: "anthropic",
baseURL: "http://example.com",
nativePath: "/v1/messages",
expectedURL: "http://example.com/v1/messages",
body: []byte(`{"model":"claude","messages":[]}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(tt.adapter))
out, err := engine.ConvertHttpRequest(conversion.HTTPRequestSpec{
URL: tt.nativePath,
Method: "POST",
Body: tt.body,
}, tt.clientProtocol, tt.providerProtocol, conversion.NewTargetProvider(tt.baseURL, "key", "upstream-model"))
require.NoError(t, err)
assert.Equal(t, tt.expectedURL, out.URL)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestConversionError_WithProviderProtocol(t *testing.T) {
@@ -39,7 +40,7 @@ func TestConversionError_FullBuilder(t *testing.T) {
func TestEngine_Use(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
called := false
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
called = true
@@ -58,7 +59,7 @@ func TestEngine_Use(t *testing.T) {
_ = engine.RegisterAdapter(providerAdapter)
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
}, "client", "provider", NewTargetProvider("https://example.com", "key", "model"))
require.NoError(t, err)
assert.True(t, called)
@@ -66,7 +67,7 @@ func TestEngine_Use(t *testing.T) {
func TestConvertHttpRequest_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
return nil, errors.New("decode failed")
@@ -75,14 +76,14 @@ func TestConvertHttpRequest_DecodeError(t *testing.T) {
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
}, "client", "provider", NewTargetProvider("", "", ""))
assert.Error(t, err)
}
func TestConvertHttpRequest_EncodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("client", false))
providerAdapter := newMockAdapter("provider", false)
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
@@ -91,14 +92,14 @@ func TestConvertHttpRequest_EncodeError(t *testing.T) {
_ = engine.RegisterAdapter(providerAdapter)
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
}, "client", "provider", NewTargetProvider("", "", ""))
assert.Error(t, err)
}
func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
@@ -113,7 +114,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
}, "client", "provider", InterfaceTypeChat)
}, "client", "provider", InterfaceTypeChat, "")
require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode)
assert.Contains(t, string(result.Body), "resp-1")
@@ -121,7 +122,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
func TestConvertHttpResponse_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
providerAdapter := newMockAdapter("provider", false)
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
return nil, errors.New("decode error")
@@ -129,13 +130,13 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
_ = engine.RegisterAdapter(providerAdapter)
_ = engine.RegisterAdapter(newMockAdapter("client", false))
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat)
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat, "")
assert.Error(t, err)
}
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeEmbeddings
@@ -158,7 +159,7 @@ func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
func TestConvertHttpRequest_RerankInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeRerank
@@ -178,7 +179,7 @@ func TestConvertHttpRequest_RerankInterface(t *testing.T) {
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
@@ -189,14 +190,14 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
}, "client", "provider", InterfaceTypeEmbeddings)
}, "client", "provider", InterfaceTypeEmbeddings, "")
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestConvertHttpResponse_RerankInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
@@ -207,14 +208,14 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
}, "client", "provider", InterfaceTypeRerank)
}, "client", "provider", InterfaceTypeRerank, "")
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeModels
providerAdapter := newMockAdapter("provider", false)
@@ -224,7 +225,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
body := []byte(`{"object":"list","data":[]}`)
result, err := engine.ConvertHttpRequest(HTTPRequestSpec{
URL: "/v1/models", Method: "GET", Body: body,
URL: "/models", Method: "GET", Body: body,
}, "client", "provider", NewTargetProvider("https://example.com", "key", ""))
require.NoError(t, err)
assert.Equal(t, body, result.Body)
@@ -232,7 +233,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
providerAdapter := newMockAdapter("provider", false)
@@ -242,14 +243,14 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
}, "client", "provider", InterfaceTypeModels)
}, "client", "provider", InterfaceTypeModels, "")
require.NoError(t, err)
assert.NotNil(t, result)
}
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
providerAdapter := newMockAdapter("provider", false)
@@ -259,7 +260,7 @@ func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
}, "client", "provider", InterfaceTypeModelInfo)
}, "client", "provider", InterfaceTypeModelInfo, "")
require.NoError(t, err)
assert.NotNil(t, result)
}
@@ -321,3 +322,58 @@ func (m *testMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEv
}
var _ = json.Marshal
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
return nil, errors.New("decode embedding failed")
}
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"model":"text-embedding","input":"hello"}`)
result, err := engine.convertEmbeddingBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestConvertRerankBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
return nil, errors.New("decode rerank failed")
}
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"model":"rerank","query":"test","documents":["a"]}`)
result, err := engine.convertRerankBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
providerAdapter := newMockAdapter("provider", false)
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"test":"data"}`)
result, err := engine.convertBody(InterfaceType("UNKNOWN"), clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
require.NoError(t, err)
assert.Equal(t, body, result)
}

View File

@@ -2,6 +2,7 @@ package conversion
import (
"encoding/json"
"strings"
"testing"
"nex/backend/internal/conversion/canonical"
@@ -23,6 +24,10 @@ type mockProtocolAdapter struct {
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
streamDecoderFn func() StreamDecoder
streamEncoderFn func() StreamEncoder
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
decodeEmbeddingReqFn func([]byte) (*canonical.CanonicalEmbeddingRequest, error)
decodeRerankReqFn func([]byte) (*canonical.CanonicalRerankRequest, error)
}
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
@@ -124,6 +129,9 @@ func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalM
}
func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
if m.decodeEmbeddingReqFn != nil {
return m.decodeEmbeddingReqFn(raw)
}
return &canonical.CanonicalEmbeddingRequest{}, nil
}
@@ -140,6 +148,9 @@ func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalE
}
func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
if m.decodeRerankReqFn != nil {
return m.decodeRerankReqFn(raw)
}
return &canonical.CanonicalRerankRequest{}, nil
}
@@ -155,10 +166,34 @@ func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRera
return json.Marshal(resp)
}
func (m *mockProtocolAdapter) ExtractUnifiedModelID(nativePath string) (string, error) {
return "", nil
}
func (m *mockProtocolAdapter) ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) {
return "", nil
}
func (m *mockProtocolAdapter) RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
if m.rewriteReqFn != nil {
return m.rewriteReqFn(body, newModel, ifaceType)
}
return body, nil
}
func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
if m.rewriteRespFn != nil {
return m.rewriteRespFn(body, newModel, ifaceType)
}
return body, nil
}
// noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{}
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil }
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
return nil
}
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
// noopStreamEncoder 空流式编码器
@@ -171,7 +206,7 @@ func (e *noopStreamEncoder) Flush() [][]byte
func TestNewConversionEngine(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
assert.NotNil(t, engine)
assert.Equal(t, registry, engine.GetRegistry())
}
@@ -179,7 +214,7 @@ func TestNewConversionEngine(t *testing.T) {
func TestNewConversionEngine_LoggerInjection(t *testing.T) {
t.Run("nil_logger_uses_global", func(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
assert.NotNil(t, engine.logger)
})
@@ -187,13 +222,14 @@ func TestNewConversionEngine_LoggerInjection(t *testing.T) {
registry := NewMemoryRegistry()
customLogger := zap.NewNop()
engine := NewConversionEngine(registry, customLogger)
assert.Equal(t, customLogger, engine.logger)
assert.NotNil(t, engine.logger)
assert.Contains(t, engine.logger.Name(), "conversion.engine")
})
}
func TestRegisterAdapter(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
adapter := newMockAdapter("test-proto", true)
err := engine.RegisterAdapter(adapter)
@@ -205,7 +241,7 @@ func TestRegisterAdapter(t *testing.T) {
func TestIsPassthrough_SameProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
adapter := newMockAdapter("openai", true)
_ = engine.RegisterAdapter(adapter)
@@ -214,7 +250,7 @@ func TestIsPassthrough_SameProtocol(t *testing.T) {
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
@@ -223,7 +259,7 @@ func TestIsPassthrough_DifferentProtocol(t *testing.T) {
func TestIsPassthrough_NoPassthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
assert.False(t, engine.IsPassthrough("custom", "custom"))
@@ -231,19 +267,19 @@ func TestIsPassthrough_NoPassthrough(t *testing.T) {
func TestDetectInterfaceType(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
adapter := newMockAdapter("test", true)
adapter.ifaceType = InterfaceTypeChat
_ = engine.RegisterAdapter(adapter)
ifaceType, err := engine.DetectInterfaceType("/v1/chat/completions", "test")
ifaceType, err := engine.DetectInterfaceType("/chat/completions", "test")
require.NoError(t, err)
assert.Equal(t, InterfaceTypeChat, ifaceType)
}
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
assert.Error(t, err)
@@ -251,25 +287,39 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
func TestConvertHttpRequest_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
engine := NewConversionEngine(registry, zap.NewNop())
openaiAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("openai", true),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/chat/completions"
}
return nativePath
},
}
openaiAdapter.ifaceType = InterfaceTypeChat
openaiAdapter.supportsIface[InterfaceTypeChat] = true
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
return []byte(`{"model":"` + newModel + `","messages":[{"role":"user","content":"hi"}]}`), nil
}
_ = engine.RegisterAdapter(openaiAdapter)
provider := NewTargetProvider("https://api.openai.com", "sk-test", "gpt-4")
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
spec := HTTPRequestSpec{
URL: "/v1/chat/completions",
Method: "POST",
Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`),
Body: []byte(`{"model":"openai/gpt-4","messages":[{"role":"user","content":"hi"}]}`),
}
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
assert.Equal(t, spec.Body, result.Body)
assert.JSONEq(t, `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`, string(result.Body))
}
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client-proto", false)
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
@@ -299,9 +349,80 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
assert.NotNil(t, result.Body)
}
func TestConvertHttpRequest_UsesProviderAdapterBuildURL(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
openaiAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("openai", true),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/chat/completions"
}
return nativePath
},
}
openaiAdapter.ifaceType = InterfaceTypeChat
openaiAdapter.supportsIface[InterfaceTypeChat] = true
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
return []byte(`{"model":"` + newModel + `"}`), nil
}
require.NoError(t, registry.Register(openaiAdapter))
anthropicAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("anthropic", false),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/v1/messages"
}
return nativePath
},
}
anthropicAdapter.ifaceType = InterfaceTypeChat
anthropicAdapter.supportsIface[InterfaceTypeChat] = true
require.NoError(t, registry.Register(anthropicAdapter))
t.Run("OpenAI to Anthropic", func(t *testing.T) {
provider := NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
spec := HTTPRequestSpec{
URL: "/v1/chat/completions",
Method: "POST",
Body: []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"max_tokens":16}`),
}
result, err := engine.ConvertHttpRequest(spec, "openai", "anthropic", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.anthropic.com/v1/messages", result.URL)
})
t.Run("Anthropic to OpenAI", func(t *testing.T) {
provider := NewTargetProvider("https://api.openai.com/v1", "key", "gpt-4")
spec := HTTPRequestSpec{
URL: "/v1/messages",
Method: "POST",
Body: []byte(`{"model":"p1/claude-3","max_tokens":16,"messages":[{"role":"user","content":"hi"}]}`),
}
result, err := engine.ConvertHttpRequest(spec, "anthropic", "openai", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
})
}
type buildURLMockAdapter struct {
*mockProtocolAdapter
buildURLFn func(string, InterfaceType) string
}
func (m *buildURLMockAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
if m.buildURLFn != nil {
return m.buildURLFn(nativePath, interfaceType)
}
return m.mockProtocolAdapter.BuildUrl(nativePath, interfaceType)
}
func TestConvertHttpResponse_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
spec := HTTPResponseSpec{
@@ -309,7 +430,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
Body: []byte(`{"id":"123"}`),
}
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat)
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "")
require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode)
assert.Equal(t, spec.Body, result.Body)
@@ -317,10 +438,10 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
func TestCreateStreamConverter_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
converter, err := engine.CreateStreamConverter("openai", "openai")
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
require.NoError(t, err)
_, ok := converter.(*PassthroughStreamConverter)
assert.True(t, ok)
@@ -328,11 +449,11 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
func TestCreateStreamConverter_Canonical(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("client", false))
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
converter, err := engine.CreateStreamConverter("client", "provider")
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
require.NoError(t, err)
_, ok := converter.(*CanonicalStreamConverter)
assert.True(t, ok)
@@ -340,7 +461,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
func TestEncodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
@@ -352,7 +473,7 @@ func TestEncodeError(t *testing.T) {
func TestEncodeError_NonExistentProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
@@ -380,3 +501,233 @@ func TestRegistry_GetNonExistent(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "未找到适配器")
}
// ============ modelOverride 测试 ============
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
return json.Marshal(map[string]any{"model": resp.Model})
}
_ = engine.RegisterAdapter(clientAdapter)
providerAdapter := newMockAdapter("provider", false)
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
return &canonical.CanonicalResponse{ID: "test", Model: "native-model", Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, nil
}
_ = engine.RegisterAdapter(providerAdapter)
spec := HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"model":"native-model"}`),
}
result, err := engine.ConvertHttpResponse(spec, "client", "provider", InterfaceTypeChat, "provider/gpt-4")
require.NoError(t, err)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(result.Body, &resp))
assert.Equal(t, "provider/gpt-4", resp["model"])
}
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
openaiAdapter := newMockAdapter("openai", true)
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
_ = engine.RegisterAdapter(openaiAdapter)
spec := HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"id":"resp-1","model":"gpt-4"}`),
}
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "openai/gpt-4")
require.NoError(t, err)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(result.Body, &resp))
assert.Equal(t, "openai/gpt-4", resp["model"])
assert.Equal(t, "resp-1", resp["id"])
}
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
openaiAdapter := newMockAdapter("openai", true)
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
_ = engine.RegisterAdapter(openaiAdapter)
converter, err := engine.CreateStreamConverter("openai", "openai", "openai/gpt-4", InterfaceTypeChat)
require.NoError(t, err)
_, ok := converter.(*SmartPassthroughStreamConverter)
assert.True(t, ok)
// 验证 SSE frame 中的 data JSON 被改写
chunks := converter.ProcessChunk([]byte(`data: {"model":"gpt-4","choices":[]}` + "\n\n"))
require.Len(t, chunks, 1)
var resp map[string]interface{}
payload := strings.TrimPrefix(strings.TrimSpace(string(chunks[0])), "data: ")
require.NoError(t, json.Unmarshal([]byte(payload), &resp))
assert.Equal(t, "openai/gpt-4", resp["model"])
}
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
// provider adapter 解码出含 model 的流式事件
providerAdapter := newMockAdapter("provider", false)
providerAdapter.streamDecoderFn = func() StreamDecoder {
return &engineTestStreamDecoder{
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
return []canonical.CanonicalStreamEvent{
canonical.NewMessageStartEvent("msg-1", "native-model"),
canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: "hi"}),
canonical.NewMessageStopEvent(),
}
},
}
}
_ = engine.RegisterAdapter(providerAdapter)
// client adapter 编码时输出 model 字段
clientAdapter := newMockAdapter("client", false)
clientAdapter.streamEncoderFn = func() StreamEncoder {
return &engineTestStreamEncoder{
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
if event.Message != nil {
data, _ := json.Marshal(map[string]string{
"type": string(event.Type),
"model": event.Message.Model,
})
return [][]byte{data}
}
data, _ := json.Marshal(map[string]string{"type": string(event.Type)})
return [][]byte{data}
},
}
}
_ = engine.RegisterAdapter(clientAdapter)
converter, err := engine.CreateStreamConverter("client", "provider", "provider/gpt-4", InterfaceTypeChat)
require.NoError(t, err)
// 验证类型是 CanonicalStreamConverter
_, ok := converter.(*CanonicalStreamConverter)
assert.True(t, ok)
// 处理一个 chunk验证 model 被覆写为统一模型 ID
chunks := converter.ProcessChunk([]byte("raw"))
require.Len(t, chunks, 3) // message_start + content_block_start + message_stop
var startEvent map[string]string
require.NoError(t, json.Unmarshal(chunks[0], &startEvent))
assert.Equal(t, "provider/gpt-4", startEvent["model"], "跨协议流式中 modelOverride 应覆写 Message.Model")
}
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
providerAdapter := newMockAdapter("provider", false)
providerAdapter.streamDecoderFn = func() StreamDecoder {
return &engineTestStreamDecoder{
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
return []canonical.CanonicalStreamEvent{
canonical.NewMessageStartEvent("msg-1", "native-model"),
}
},
}
}
_ = engine.RegisterAdapter(providerAdapter)
clientAdapter := newMockAdapter("client", false)
clientAdapter.streamEncoderFn = func() StreamEncoder {
return &engineTestStreamEncoder{
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
if event.Message != nil {
data, _ := json.Marshal(map[string]string{
"model": event.Message.Model,
})
return [][]byte{data}
}
return nil
},
}
}
_ = engine.RegisterAdapter(clientAdapter)
// modelOverride 为空,不应覆写
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
require.NoError(t, err)
chunks := converter.ProcessChunk([]byte("raw"))
require.Len(t, chunks, 1)
var resp map[string]string
require.NoError(t, json.Unmarshal(chunks[0], &resp))
assert.Equal(t, "native-model", resp["model"], "modelOverride 为空时不应覆写")
}
// engineTestStreamDecoder 可控的流式解码器(用于 engine_test
type engineTestStreamDecoder struct {
processFn func([]byte) []canonical.CanonicalStreamEvent
flushFn func() []canonical.CanonicalStreamEvent
}
func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.CanonicalStreamEvent {
if d.processFn != nil {
return d.processFn(raw)
}
return nil
}
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil {
return d.flushFn()
}
return nil
}
// engineTestStreamEncoder 可控的流式编码器(用于 engine_test
type engineTestStreamEncoder struct {
encodeFn func(canonical.CanonicalStreamEvent) [][]byte
flushFn func() [][]byte
}
func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
if e.encodeFn != nil {
return e.encodeFn(event)
}
return nil
}
func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil {
return e.flushFn()
}
return nil
}

View File

@@ -17,6 +17,13 @@ const (
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
ErrorCodeUnsupportedMultimodal ErrorCode = "UNSUPPORTED_MULTIMODAL"
)
const (
ErrorDetailPhase = "phase"
ErrorPhaseRequest = "request"
ErrorPhaseResponse = "response"
)
// ConversionError 协议转换错误
@@ -39,6 +46,20 @@ func NewConversionError(code ErrorCode, message string) *ConversionError {
}
}
// NewRequestJSONParseError 创建请求 JSON 解析错误。
func NewRequestJSONParseError(message string, cause error) *ConversionError {
return NewConversionError(ErrorCodeJSONParseError, message).
WithDetail(ErrorDetailPhase, ErrorPhaseRequest).
WithCause(cause)
}
// NewResponseJSONParseError 创建响应 JSON 解析错误。
func NewResponseJSONParseError(message string, cause error) *ConversionError {
return NewConversionError(ErrorCodeJSONParseError, message).
WithDetail(ErrorDetailPhase, ErrorPhaseResponse).
WithCause(cause)
}
// WithClientProtocol 设置客户端协议
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
e.ClientProtocol = protocol

View File

@@ -2,6 +2,7 @@ package openai
import (
"encoding/json"
"fmt"
"strings"
"nex/backend/internal/conversion"
@@ -43,26 +44,31 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
}
}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/v1/models/") {
return false
}
suffix := path[len("/v1/models/"):]
return suffix != "" && !strings.Contains(suffix, "/")
return suffix != ""
}
// BuildUrl 根据接口类型构建 URL
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
switch interfaceType {
case conversion.InterfaceTypeChat:
return "/v1/chat/completions"
return "/chat/completions"
case conversion.InterfaceTypeModels:
return "/v1/models"
return "/models"
case conversion.InterfaceTypeModelInfo:
if modelID, err := a.ExtractUnifiedModelID(nativePath); err == nil {
return "/models/" + modelID
}
return nativePath
case conversion.InterfaceTypeEmbeddings:
return "/v1/embeddings"
return "/embeddings"
case conversion.InterfaceTypeRerank:
return "/v1/rerank"
return "/rerank"
default:
return nativePath
}
@@ -137,7 +143,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Code: string(err.Code),
},
}
body, _ := json.Marshal(errMsg)
body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
}
return body, statusCode
}
@@ -216,3 +225,92 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
return encodeRerankResponse(resp)
}
// ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
}
suffix := nativePath[len("/v1/models/"):]
if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID")
}
return suffix, nil
}
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return "", nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
raw, exists := m["model"]
if !exists {
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
}
var current string
if err := json.Unmarshal(raw, &current); err != nil {
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
}
return current, rewriteFunc, nil
default:
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
}
}
// ExtractModelName 从请求体中提取 model 值
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
model, _, err := locateModelFieldInRequest(body, ifaceType)
return model, err
}
// RewriteRequestModelName 最小化改写请求体中的 model 字段
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
if err != nil {
return nil, err
}
return rewriteFunc(newModel)
}
// RewriteResponseModelName 最小化改写响应体中的 model 字段
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists {
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
}
return json.Marshal(m)
default:
return body, nil
}
}

View File

@@ -30,10 +30,10 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
}{
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
{"模型详情", "/v1/models/gpt-4", conversion.InterfaceTypeModelInfo},
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo},
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
{"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough},
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
@@ -44,6 +44,27 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
}
}
func TestAdapter_OldPathsBecomePassthrough(t *testing.T) {
a := NewAdapter()
tests := []struct {
path string
expected conversion.InterfaceType
}{
{"/chat/completions", conversion.InterfaceTypePassthrough},
{"/models", conversion.InterfaceTypePassthrough},
{"/models/gpt-4.1", conversion.InterfaceTypePassthrough},
{"/embeddings", conversion.InterfaceTypePassthrough},
{"/rerank", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
})
}
}
func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter()
@@ -53,11 +74,13 @@ func TestAdapter_BuildUrl(t *testing.T) {
interfaceType conversion.InterfaceType
expected string
}{
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/v1/chat/completions"},
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/v1/embeddings"},
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/v1/rerank"},
{"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"},
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/models"},
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo, "/models/openai/gpt-4"},
{"复杂模型详情", "/v1/models/azure/accounts/org/models/gpt-4", conversion.InterfaceTypeModelInfo, "/models/azure/accounts/org/models/gpt-4"},
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/rerank"},
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
}
for _, tt := range tests {
@@ -118,13 +141,13 @@ func TestIsModelInfoPath(t *testing.T) {
path string
expected bool
}{
{"model_info", "/v1/models/gpt-4", true},
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
{"model_info", "/v1/models/openai/gpt-4", true},
{"model_info_with_dots", "/v1/models/openai/gpt-4.1-preview", true},
{"models_list", "/v1/models", false},
{"nested_path", "/v1/models/gpt-4/versions", false},
{"nested_path", "/v1/models/azure/accounts/org-123/models/gpt-4", true},
{"empty_suffix", "/v1/models/", false},
{"unrelated", "/v1/chat/completions", false},
{"partial_prefix", "/v1/model", false},
{"partial_prefix", "/model", false},
}
for _, tt := range tests {
@@ -134,6 +157,27 @@ func TestIsModelInfoPath(t *testing.T) {
}
}
func TestAdapter_ExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("标准路径", func(t *testing.T) {
modelID, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", modelID)
})
t.Run("复杂路径", func(t *testing.T) {
modelID, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", modelID)
})
t.Run("非模型详情路径报错", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
}
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
a := NewAdapter()
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")

View File

@@ -0,0 +1,360 @@
package openai
import (
"encoding/json"
"testing"
"nex/backend/internal/conversion"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// ExtractUnifiedModelID
// ---------------------------------------------------------------------------
func TestExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("standard_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", id)
})
t.Run("multi_segment_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
})
t.Run("single_segment", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "gpt-4", id)
})
t.Run("non_model_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
require.Error(t, err)
})
t.Run("empty_suffix", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models/")
require.Error(t, err)
})
t.Run("models_list_no_slash", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
t.Run("unrelated_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/other")
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName
// ---------------------------------------------------------------------------
func TestExtractModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", model)
})
t.Run("embedding", func(t *testing.T) {
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.Equal(t, "openai/text-embedding", model)
})
t.Run("rerank", func(t *testing.T) {
body := []byte(`{"model":"openai/rerank","query":"test"}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
require.NoError(t, err)
assert.Equal(t, "openai/rerank", model)
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4"}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteRequestModelName
// ---------------------------------------------------------------------------
func TestRewriteRequestModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "gpt-4", m["model"])
// messages field preserved
msgs, ok := m["messages"]
require.True(t, ok)
msgsArr, ok := msgs.([]interface{})
require.True(t, ok)
assert.Len(t, msgsArr, 0)
})
t.Run("preserves_unknown_fields", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "gpt-4", m["model"])
assert.Equal(t, 0.7, m["temperature"])
})
t.Run("embedding", func(t *testing.T) {
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "text-embedding", m["model"])
assert.Equal(t, "hello", m["input"])
})
t.Run("rerank", func(t *testing.T) {
body := []byte(`{"model":"openai/rerank","query":"test"}`)
rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "rerank", m["model"])
assert.Equal(t, "test", m["query"])
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4"}`)
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteResponseModelName
// ---------------------------------------------------------------------------
func TestRewriteResponseModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat_existing_model", func(t *testing.T) {
body := []byte(`{"model":"gpt-4","choices":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/gpt-4", m["model"])
choices, ok := m["choices"]
require.True(t, ok)
choicesArr, ok := choices.([]interface{})
require.True(t, ok)
assert.Len(t, choicesArr, 0)
})
t.Run("chat_without_model_field", func(t *testing.T) {
body := []byte(`{"choices":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/gpt-4", m["model"])
choices, ok := m["choices"]
require.True(t, ok)
choicesArr, ok := choices.([]interface{})
require.True(t, ok)
assert.Len(t, choicesArr, 0)
})
t.Run("rerank_existing_model", func(t *testing.T) {
body := []byte(`{"model":"rerank","results":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/rerank", m["model"])
})
t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) {
body := []byte(`{"results":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
_, hasModel := m["model"]
assert.False(t, hasModel, "rerank response without model field should not have one added")
})
t.Run("embedding_existing_model", func(t *testing.T) {
body := []byte(`{"model":"text-embedding","data":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/text-embedding", m["model"])
})
t.Run("embedding_without_model_field_adds", func(t *testing.T) {
body := []byte(`{"data":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/text-embedding", m["model"])
})
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
body := []byte(`{"model":"gpt-4"}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough)
require.NoError(t, err)
assert.Equal(t, string(body), string(rewritten))
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName and RewriteRequest consistency
// ---------------------------------------------------------------------------
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
a := NewAdapter()
t.Run("chat_round_trip", func(t *testing.T) {
original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`)
// Extract the unified model ID from the body
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", extracted)
// Rewrite to the native model name
rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
// Extract again from the rewritten body to verify the same location was targeted
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "gpt-4", afterRewrite)
// Verify other fields are preserved
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, 0.7, m["temperature"])
})
t.Run("embedding_round_trip", func(t *testing.T) {
original := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.Equal(t, "openai/text-embedding", extracted)
rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.Equal(t, "text-embedding", afterRewrite)
})
t.Run("rerank_round_trip", func(t *testing.T) {
original := []byte(`{"model":"openai/rerank","query":"test"}`)
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank)
require.NoError(t, err)
assert.Equal(t, "openai/rerank", extracted)
rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank)
require.NoError(t, err)
assert.Equal(t, "rerank", afterRewrite)
})
}
// ---------------------------------------------------------------------------
// isModelInfoPath (additional unified model ID cases)
// ---------------------------------------------------------------------------
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
tests := []struct {
name string
path string
expected bool
}{
{"simple_model_id", "/v1/models/gpt-4", true},
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
{"models_list", "/v1/models", false},
{"models_list_trailing_slash", "/v1/models/", false},
{"chat_completions", "/v1/chat/completions", false},
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
})
}
}

View File

@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
var blocks []canonical.ContentBlock
for _, item := range v {
if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string)
t, ok := m["type"].(string)
if !ok {
continue
}
switch t {
case "text":
text, _ := m["text"].(string)
text, ok := m["text"].(string)
if !ok {
text = ""
}
blocks = append(blocks, canonical.NewTextBlock(text))
case "image_url":
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
var result []contentPart
for _, item := range parts {
if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string)
t, ok := m["type"].(string)
if !ok {
continue
}
switch t {
case "text":
text, _ := m["text"].(string)
text, ok := m["text"].(string)
if !ok {
text = ""
}
result = append(result, contentPart{Type: "text", Text: text})
case "refusal":
refusal, _ := m["refusal"].(string)
refusal, ok := m["refusal"].(string)
if !ok {
refusal = ""
}
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
}
}
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny()
}
case map[string]any:
t, _ := v["type"].(string)
t, ok := v["type"].(string)
if !ok {
return nil
}
switch t {
case "function":
if fn, ok := v["function"].(map[string]any); ok {
name, _ := fn["name"].(string)
name, ok := fn["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name)
}
case "custom":
if custom, ok := v["custom"].(map[string]any); ok {
name, _ := custom["name"].(string)
name, ok := custom["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name)
}
case "allowed_tools":
if at, ok := v["allowed_tools"].(map[string]any); ok {
mode, _ := at["mode"].(string)
mode, ok := at["mode"].(string)
if !ok {
mode = ""
}
if mode == "required" {
return canonical.NewToolChoiceAny()
}

View File

@@ -409,3 +409,25 @@ func TestDecodeResponse_Refusal(t *testing.T) {
}
assert.True(t, found)
}
func TestDecodeRequest_AssistantContentArray(t *testing.T) {
body := []byte(`{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "hello"},
{
"role": "assistant",
"content": [{"type": "text", "text": "hello back"}]
}
]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Len(t, req.Messages, 2)
assistantMsg := req.Messages[1]
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
assert.Len(t, assistantMsg.Content, 1)
assert.Equal(t, "text", assistantMsg.Content[0].Type)
assert.Equal(t, "hello back", assistantMsg.Content[0].Text)
}

View File

@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 2)
firstMsg := msgs[0].(map[string]any)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "system", firstMsg["role"])
assert.Equal(t, "你是助手", firstMsg["content"])
}
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
assistantMsg := msgs[0].(map[string]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assistantMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
toolCalls, ok := assistantMsg["tool_calls"].([]any)
require.True(t, ok)
assert.Len(t, toolCalls, 1)
tc := toolCalls[0].(map[string]any)
tc, ok := toolCalls[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "call_1", tc["id"])
}
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "resp-1", result["id"])
assert.Equal(t, "chat.completion", result["object"])
choices := result["choices"].([]any)
choice := choices[0].(map[string]any)
msg := choice["message"].(map[string]any)
choices, ok := result["choices"].([]any)
require.True(t, ok)
choice, ok := choices[0].(map[string]any)
require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "你好", msg["content"])
assert.Equal(t, "stop", choice["finish_reason"])
}
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any)
choices, okc := result["choices"].([]any)
require.True(t, okc)
msgMap, okm := choices[0].(map[string]any)
require.True(t, okm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
tcs, ok := msg["tool_calls"].([]any)
require.True(t, ok)
assert.Len(t, tcs, 1)
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "list", result["object"])
data := result["data"].([]any)
data, okd := result["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2)
}
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any)
choices, okch := result["choices"].([]any)
require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
assert.Equal(t, "回答", msg["content"])
assert.Equal(t, "思考过程", msg["reasoning_content"])
}

View File

@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
data := strings.TrimPrefix(s, "data: ")
data = strings.TrimRight(data, "\n")
require.NoError(t, json.Unmarshal([]byte(data), &payload))
choices := payload["choices"].([]any)
delta := choices[0].(map[string]any)["delta"].(map[string]any)
choices, okch := payload["choices"].([]any)
require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
delta, okd := msgMap["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "assistant", delta["role"])
}

View File

@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "rerank-1", result["model"])
results := result["results"].([]any)
results, okr := result["results"].([]any)
require.True(t, okr)
assert.Len(t, results, 1)
}
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["prompt_tokens"])
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
require.True(t, ok)
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
choice := choices[0].(map[string]any)
choices, okch := result["choices"].([]any)
require.True(t, okch)
choice, okc := choices[0].(map[string]any)
require.True(t, okc)
assert.Equal(t, tt.want, choice["finish_reason"])
})
}

View File

@@ -1,6 +1,11 @@
package conversion
import "nex/backend/internal/conversion/canonical"
import (
"bytes"
"strings"
"nex/backend/internal/conversion/canonical"
)
// StreamDecoder 流式解码器接口
type StreamDecoder interface {
@@ -38,6 +43,65 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
return nil
}
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
// 按 SSE frame 改写 data JSON 中的 model 字段
type SmartPassthroughStreamConverter struct {
adapter ProtocolAdapter
modelOverride string
interfaceType InterfaceType
buffer []byte
}
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter {
return &SmartPassthroughStreamConverter{
adapter: adapter,
modelOverride: modelOverride,
interfaceType: interfaceType,
}
}
// ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
if len(rawChunk) == 0 {
return nil
}
c.buffer = append(c.buffer, rawChunk...)
frames, rest := splitSSEFrames(c.buffer)
c.buffer = rest
result := make([][]byte, 0, len(frames))
for _, frame := range frames {
result = append(result, c.rewriteFrame(frame))
}
return result
}
func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
payload, ok := sseFrameDataPayload(frame)
if !ok || strings.TrimSpace(payload) == "[DONE]" {
return frame
}
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
if err != nil {
return frame
}
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
}
// Flush 输出未形成完整 frame 的剩余数据
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
if len(c.buffer) == 0 {
return nil
}
frame := append([]byte(nil), c.buffer...)
c.buffer = nil
return [][]byte{c.rewriteFrame(frame)}
}
// CanonicalStreamConverter 跨协议规范流式转换器
type CanonicalStreamConverter struct {
decoder StreamDecoder
@@ -46,6 +110,7 @@ type CanonicalStreamConverter struct {
ctx ConversionContext
clientProtocol string
providerProtocol string
modelOverride string
}
// NewCanonicalStreamConverter 创建规范流式转换器
@@ -57,7 +122,7 @@ func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *
}
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter {
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter {
return &CanonicalStreamConverter{
decoder: decoder,
encoder: encoder,
@@ -65,10 +130,11 @@ func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder St
ctx: ctx,
clientProtocol: clientProtocol,
providerProtocol: providerProtocol,
modelOverride: modelOverride,
}
}
// ProcessChunk 解码 → 中间件 → 编码管道
// ProcessChunk 解码 → 中间件 → modelOverride → 编码管道
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
events := c.decoder.ProcessChunk(rawChunk)
var result [][]byte
@@ -80,6 +146,7 @@ func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
}
events[i] = *processed
}
c.applyModelOverride(&events[i])
chunks := c.encoder.EncodeEvent(events[i])
result = append(result, chunks...)
}
@@ -98,6 +165,7 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
}
events[i] = *processed
}
c.applyModelOverride(&events[i])
chunks := c.encoder.EncodeEvent(events[i])
result = append(result, chunks...)
}
@@ -105,3 +173,93 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
result = append(result, encoderChunks...)
return result
}
// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段
func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) {
if c.modelOverride != "" && event.Message != nil {
event.Message.Model = c.modelOverride
}
}
func splitSSEFrames(data []byte) ([][]byte, []byte) {
var frames [][]byte
for len(data) > 0 {
idx, sepLen := findSSEFrameSeparator(data)
if idx < 0 {
break
}
end := idx + sepLen
frames = append(frames, append([]byte(nil), data[:end]...))
data = data[end:]
}
return frames, data
}
func findSSEFrameSeparator(data []byte) (int, int) {
lf := bytes.Index(data, []byte("\n\n"))
crlf := bytes.Index(data, []byte("\r\n\r\n"))
switch {
case lf < 0 && crlf < 0:
return -1, 0
case lf < 0:
return crlf, 4
case crlf < 0:
return lf, 2
case crlf <= lf:
return crlf, 4
default:
return lf, 2
}
}
func sseFrameDataPayload(frame []byte) (string, bool) {
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
var dataLines []string
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
value := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(value, " ") {
value = value[1:]
}
dataLines = append(dataLines, value)
}
}
if len(dataLines) == 0 {
return "", false
}
return strings.Join(dataLines, "\n"), true
}
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
lineEnding, separator := sseLineEnding(frame)
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
out := make([]string, 0, len(lines)+1)
dataWritten := false
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
if !dataWritten {
for _, dataLine := range strings.Split(data, "\n") {
out = append(out, "data: "+dataLine)
}
dataWritten = true
}
continue
}
out = append(out, line)
}
if !dataWritten {
out = append(out, "data: "+data)
}
return []byte(strings.Join(out, lineEnding) + separator)
}
func sseLineEnding(frame []byte) (string, string) {
if bytes.Contains(frame, []byte("\r\n")) {
return "\r\n", "\r\n\r\n"
}
return "\n", "\n\n"
}

View File

@@ -93,7 +93,7 @@ func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) {
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
result := converter.ProcessChunk([]byte("raw"))
assert.Len(t, result, 1)
@@ -143,7 +143,7 @@ func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) {
chain.Use(&errorMiddleware{})
ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
result := converter.ProcessChunk([]byte("raw"))
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
@@ -163,7 +163,7 @@ func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) {
chain.Use(&errorMiddleware{})
ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
result := converter.Flush()
assert.Len(t, result, 1)

View File

@@ -0,0 +1,151 @@
package database
import (
"fmt"
"os"
"path/filepath"
"runtime"
"github.com/pressly/goose/v3"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
pkglogger "nex/backend/pkg/logger"
)
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
moduleLogger := pkglogger.WithModule(zapLogger, "database")
db, err := initDB(cfg, moduleLogger)
if err != nil {
return nil, fmt.Errorf("初始化数据库失败: %w", err)
}
if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
configurePool(db, cfg, moduleLogger)
return db, nil
}
func Close(db *gorm.DB) {
sqlDB, err := db.DB()
if err != nil {
return
}
sqlDB.Close()
}
func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
gormLogger := pkglogger.NewGormLogger(zapLogger)
gormConfig := &gorm.Config{
Logger: gormLogger,
}
switch cfg.Driver {
case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
if zapLogger != nil {
zapLogger.Info("连接 MySQL 数据库",
zap.String("host", cfg.Host),
zap.Int("port", cfg.Port),
zap.String("database", cfg.DBName))
}
return gorm.Open(mysql.Open(dsn), gormConfig)
default:
dbDir := filepath.Dir(cfg.Path)
if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
if zapLogger != nil {
zapLogger.Info("连接 SQLite 数据库", zap.String("path", cfg.Path))
}
return gorm.Open(sqlite.Open(cfg.Path), gormConfig)
}
}
func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
gooseDialect := "sqlite3"
migrationsSubDir := "sqlite"
if driver == "mysql" {
gooseDialect = "mysql"
migrationsSubDir = "mysql"
}
migrationsDir := getMigrationsDir(driver)
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
if zapLogger != nil {
zapLogger.Info("执行数据库迁移",
zap.String("dialect", gooseDialect),
zap.String("dir", migrationsSubDir))
}
if err := goose.SetDialect(gooseDialect); err != nil {
return err
}
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}
return nil
}
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig, zapLogger *zap.Logger) {
if cfg.Driver == "sqlite" {
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
if zapLogger != nil {
zapLogger.Warn("启用 WAL 模式失败", zap.Error(err))
}
}
}
sqlDB, err := db.DB()
if err != nil {
return
}
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
if zapLogger != nil {
zapLogger.Info("数据库连接池配置",
zap.Int("max_idle_conns", cfg.MaxIdleConns),
zap.Int("max_open_conns", cfg.MaxOpenConns),
zap.Duration("conn_max_lifetime", cfg.ConnMaxLifetime))
}
}
func getMigrationsDir(driver string) string {
_, filename, _, ok := runtime.Caller(0)
if ok {
subDir := "sqlite"
if driver == "mysql" {
subDir = "mysql"
}
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations", subDir)
if abs, err := filepath.Abs(dir); err == nil {
return abs
}
}
return "./migrations"
}
func BuildDSN(cfg *config.DatabaseConfig) string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
}

View File

@@ -0,0 +1,78 @@
package database
import (
"path/filepath"
"testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestInit_SQLite(t *testing.T) {
dir := t.TempDir()
cfg := &config.DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(dir, "test.db"),
MaxIdleConns: 5,
MaxOpenConns: 10,
ConnMaxLifetime: 0,
}
zapLogger := zap.NewNop()
db, err := Init(cfg, zapLogger)
require.NoError(t, err)
require.NotNil(t, db)
defer Close(db)
sqlDB, err := db.DB()
require.NoError(t, err)
require.NotNil(t, sqlDB)
}
func TestClose(t *testing.T) {
dir := t.TempDir()
cfg := &config.DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(dir, "test.db"),
MaxIdleConns: 5,
MaxOpenConns: 10,
ConnMaxLifetime: 0,
}
zapLogger := zap.NewNop()
db, err := Init(cfg, zapLogger)
require.NoError(t, err)
require.NotNil(t, db)
Close(db)
}
func TestBuildDSN(t *testing.T) {
cfg := &config.DatabaseConfig{
Driver: "mysql",
Host: "db.example.com",
Port: 3306,
User: "nexuser",
Password: "secretpass",
DBName: "nexdb",
}
dsn := BuildDSN(cfg)
assert.Equal(t, "nexuser:secretpass@tcp(db.example.com:3306)/nexdb?charset=utf8mb4&parseTime=true&loc=Local", dsn)
}
func TestBuildDSN_EmptyPassword(t *testing.T) {
cfg := &config.DatabaseConfig{
Driver: "mysql",
Host: "localhost",
Port: 3306,
User: "root",
DBName: "nex",
}
dsn := BuildDSN(cfg)
assert.Equal(t, "root:@tcp(localhost:3306)/nex?charset=utf8mb4&parseTime=true&loc=Local", dsn)
}

View File

@@ -1,8 +1,12 @@
package domain
import "time"
import (
"time"
// Model 模型领域模型
"nex/backend/pkg/modelid"
)
// Model 模型领域模型id 为 UUID 自动生成)
type Model struct {
ID string `json:"id"`
ProviderID string `json:"provider_id"`
@@ -10,3 +14,8 @@ type Model struct {
Enabled bool `json:"enabled"`
CreatedAt time.Time `json:"created_at"`
}
// UnifiedModelID 返回统一模型 ID格式provider_id/model_name
func (m *Model) UnifiedModelID() string {
return modelid.FormatUnifiedModelID(m.ProviderID, m.ModelName)
}

View File

@@ -13,12 +13,3 @@ type Provider struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// MaskAPIKey 掩码 API Key仅显示最后 4 个字符)
func (p *Provider) MaskAPIKey() {
if len(p.APIKey) > 4 {
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
} else {
p.APIKey = "***"
}
}

View File

@@ -6,15 +6,22 @@ import (
"net/http/httptest"
"testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/domain"
"go.uber.org/mock/gomock"
)
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"id": "p1",
@@ -33,11 +40,16 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
var result domain.Provider
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "p1", result.ID)
assert.Contains(t, result.APIKey, "***")
assert.Equal(t, "sk-test", result.APIKey)
}
func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"id": "p1",
@@ -56,9 +68,13 @@ func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
}
func TestProviderHandler_UpdateProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
provider: &domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(nil)
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "Updated", APIKey: "sk-test"}, nil)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"name": "Updated"})
w := httptest.NewRecorder()
@@ -72,7 +88,11 @@ func TestProviderHandler_UpdateProvider(t *testing.T) {
}
func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -84,7 +104,12 @@ func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
}
func TestProviderHandler_DeleteProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Delete(gomock.Eq("p1")).Return(nil)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -97,7 +122,12 @@ func TestProviderHandler_DeleteProvider(t *testing.T) {
}
func TestModelHandler_DeleteModel(t *testing.T) {
h := NewModelHandler(&mockModelService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Delete(gomock.Eq("m1")).Return(nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -110,10 +140,17 @@ func TestModelHandler_DeleteModel(t *testing.T) {
}
func TestModelHandler_CreateModel_Success(t *testing.T) {
h := NewModelHandler(&mockModelService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
model.ID = "mock-uuid-1234"
return nil
})
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"id": "m1",
"provider_id": "p1",
"model_name": "gpt-4",
})
@@ -127,13 +164,16 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
var result domain.Model
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "m1", result.ID)
assert.NotEmpty(t, result.ID)
}
func TestModelHandler_GetModel(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ModelName: "gpt-4"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4"}, nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -149,9 +189,13 @@ func TestModelHandler_GetModel(t *testing.T) {
}
func TestModelHandler_UpdateModel(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ModelName: "gpt-4o"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4o"}, nil)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"})
w := httptest.NewRecorder()

View File

@@ -2,21 +2,22 @@ package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http/httptest"
"strings"
"testing"
"time"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"gorm.io/gorm"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors"
)
@@ -24,82 +25,12 @@ func init() {
gin.SetMode(gin.TestMode)
}
// ============ Mock 实现 ============
type mockRoutingService struct {
result *domain.RouteResult
err error
}
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
type mockStatsService struct {
err error
stats []domain.UsageStats
aggrResult []map[string]interface{}
}
func (m *mockStatsService) Record(providerID, modelName string) error {
return m.err
}
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return m.stats, nil
}
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
return m.aggrResult
}
type mockProviderService struct {
provider *domain.Provider
providers []domain.Provider
err error
}
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
return m.provider, m.err
}
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockProviderService) Delete(id string) error { return m.err }
type mockModelService struct {
model *domain.Model
models []domain.Model
err error
}
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
func (m *mockModelService) Get(id string) (*domain.Model, error) {
return m.model, m.err
}
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
return m.models, m.err
}
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockModelService) Delete(id string) error { return m.err }
type mockProviderClient struct {
err error
}
func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) {
return nil, m.err
}
func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) {
return nil, m.err
}
// ============ Provider Handler 测试 ============
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"id": "p1"})
w := httptest.NewRecorder()
@@ -112,12 +43,15 @@ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
}
func TestProviderHandler_ListProviders(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
providers: []domain.Provider{
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "P1"},
{ID: "p2", Name: "P2"},
},
})
}, nil)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -127,14 +61,17 @@ func TestProviderHandler_ListProviders(t *testing.T) {
assert.Equal(t, 200, w.Code)
var result []domain.Provider
json.Unmarshal(w.Body.Bytes(), &result)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Len(t, result, 2)
}
func TestProviderHandler_GetProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "P1", APIKey: "sk-test"}, nil)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -145,10 +82,12 @@ func TestProviderHandler_GetProvider(t *testing.T) {
assert.Equal(t, 200, w.Code)
}
// ============ Model Handler 测试 ============
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
h := NewModelHandler(&mockModelService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"id": "m1"})
w := httptest.NewRecorder()
@@ -161,12 +100,15 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
}
func TestModelHandler_ListModels(t *testing.T) {
h := NewModelHandler(&mockModelService{
models: []domain.Model{
{ID: "m1", ModelName: "gpt-4"},
{ID: "m2", ModelName: "gpt-3.5"},
},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().List(gomock.Eq("")).Return([]domain.Model{
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
}, nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -174,16 +116,98 @@ func TestModelHandler_ListModels(t *testing.T) {
h.ListModels(c)
assert.Equal(t, 200, w.Code)
var result []modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
require.Len(t, result, 2)
assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID)
assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID)
}
// ============ Stats Handler 测试 ============
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
h.GetModel(c)
assert.Equal(t, 200, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "m1", result.ID)
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
}
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
model.ID = "mock-uuid-1234"
return nil
})
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 201, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "mock-uuid-1234", result.ID)
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
}
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, nil)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateModel(c)
assert.Equal(t, 200, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
}
func TestStatsHandler_GetStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{
stats: []domain.UsageStats{
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockStatsService(ctrl)
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
},
})
}, nil)
h := NewStatsHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -194,7 +218,11 @@ func TestStatsHandler_GetStats(t *testing.T) {
}
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
h := NewStatsHandler(&mockStatsService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockStatsService(ctrl)
h := NewStatsHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -205,14 +233,17 @@ func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
}
func TestStatsHandler_AggregateStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{
stats: []domain.UsageStats{
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockStatsService(ctrl)
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
},
aggrResult: []map[string]interface{}{
}, nil)
mockSvc.EXPECT().Aggregate(gomock.Any(), gomock.Eq("provider")).Return([]map[string]interface{}{
{"provider_id": "p1", "request_count": 10},
},
})
h := NewStatsHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -222,8 +253,6 @@ func TestStatsHandler_AggregateStats(t *testing.T) {
assert.Equal(t, 200, w.Code)
}
// ============ writeError 测试 ============
func TestWriteError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -252,12 +281,13 @@ func formatMapErrors(errs map[string]string) string {
return "请求验证失败: " + strings.Join(parts, "; ")
}
// ============ 错误类型判断测试 ============
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
err: gorm.ErrDuplicatedKey,
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrConflict)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"id": "p1",
@@ -273,3 +303,158 @@ func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h.CreateProvider(c)
assert.Equal(t, 409, w.Code)
}
func TestModelHandler_CreateModel_ProviderNotFound(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrProviderNotFound)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "nonexistent",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 400, w.Code)
assert.Contains(t, w.Body.String(), "供应商不存在")
}
func TestModelHandler_CreateModel_DuplicateModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrDuplicateModel)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 409, w.Code)
assert.Contains(t, w.Body.String(), "同一供应商下模型名称已存在")
}
func TestModelHandler_CreateModel_InternalError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(errors.New("database error"))
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 500, w.Code)
}
func TestProviderHandler_UpdateProvider_NotFound(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(gorm.ErrRecordNotFound)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 404, w.Code)
}
func TestProviderHandler_UpdateProvider_ImmutableField(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(appErrors.ErrImmutableField)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 400, w.Code)
assert.Contains(t, w.Body.String(), "供应商 ID 不允许修改")
}
func TestProviderHandler_UpdateProvider_InternalError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(errors.New("database error"))
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 500, w.Code)
}
func TestModelHandler_CreateModel_InvalidJSON(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader([]byte("{invalid json")))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 400, w.Code)
}
func TestProviderHandler_CreateProvider_InvalidJSON(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader([]byte("{invalid json")))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateProvider(c)
assert.Equal(t, 400, w.Code)
}

View File

@@ -5,9 +5,10 @@ import (
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
)
// Logging 日志中间件
func Logging(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
@@ -15,12 +16,16 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
query := c.Request.URL.RawQuery
requestID, _ := c.Get(RequestIDKey)
var requestIDStr string
if id, ok := requestID.(string); ok {
requestIDStr = id
}
logger.Info("请求开始",
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("client_ip", c.ClientIP()),
zap.Any("request_id", requestID),
pkglogger.Method(c.Request.Method),
pkglogger.Path(path),
pkglogger.Query(query),
pkglogger.ClientIP(c.ClientIP()),
pkglogger.RequestID(requestIDStr),
)
c.Next()
@@ -29,12 +34,12 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
statusCode := c.Writer.Status()
logger.Info("请求结束",
zap.Int("status", statusCode),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.Duration("latency", latency),
zap.Int("body_size", c.Writer.Size()),
zap.Any("request_id", requestID),
pkglogger.StatusCode(statusCode),
pkglogger.Method(c.Request.Method),
pkglogger.Path(path),
pkglogger.Latency(latency),
pkglogger.BodySize(c.Writer.Size()),
pkglogger.RequestID(requestIDStr),
)
}
}

View File

@@ -1,15 +1,16 @@
package handler
import (
"errors"
"net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ModelHandler 模型管理处理器
@@ -22,40 +23,59 @@ func NewModelHandler(modelService service.ModelService) *ModelHandler {
return &ModelHandler{modelService: modelService}
}
// modelResponse 模型响应 DTO扩展 unified_id 字段
type modelResponse struct {
domain.Model
UnifiedModelID string `json:"unified_id"`
}
// newModelResponse 从 domain.Model 构造响应 DTO
func newModelResponse(m *domain.Model) modelResponse {
return modelResponse{
Model: *m,
UnifiedModelID: m.UnifiedModelID(),
}
}
// CreateModel 创建模型
func (h *ModelHandler) CreateModel(c *gin.Context) {
var req struct {
ID string `json:"id" binding:"required"`
ProviderID string `json:"provider_id" binding:"required"`
ModelName string `json:"model_name" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "缺少必需字段: id, provider_id, model_name",
"error": "缺少必需字段: provider_id, model_name",
})
return
}
model := &domain.Model{
ID: req.ID,
ProviderID: req.ProviderID,
ModelName: req.ModelName,
}
err := h.modelService.Create(model)
if err != nil {
if err == appErrors.ErrProviderNotFound {
if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
writeError(c, err)
return
}
c.JSON(http.StatusCreated, model)
c.JSON(http.StatusCreated, newModelResponse(model))
}
// ListModels 列出模型
@@ -68,7 +88,11 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
return
}
c.JSON(http.StatusOK, models)
resp := make([]modelResponse, len(models))
for i, m := range models {
resp[i] = newModelResponse(&m)
}
c.JSON(http.StatusOK, resp)
}
// GetModel 获取模型
@@ -77,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
model, err := h.modelService.Get(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
@@ -87,7 +111,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
return
}
c.JSON(http.StatusOK, model)
c.JSON(http.StatusOK, newModelResponse(model))
}
// UpdateModel 更新模型
@@ -104,18 +128,25 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
err := h.modelService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, appErrors.ErrModelNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
return
}
if err == appErrors.ErrProviderNotFound {
if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{
"error": appErrors.ErrDuplicateModel.Message,
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
writeError(c, err)
return
}
@@ -126,7 +157,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
return
}
c.JSON(http.StatusOK, model)
c.JSON(http.StatusOK, newModelResponse(model))
}
// DeleteModel 删除模型
@@ -135,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
err := h.modelService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})

View File

@@ -4,13 +4,13 @@ import (
"errors"
"net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ProviderHandler 供应商管理处理器
@@ -55,9 +55,10 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider)
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
c.JSON(http.StatusConflict, gin.H{
"error": "供应商 ID 已存在",
if errors.Is(err, appErrors.ErrInvalidProviderID) {
c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code,
})
return
}
@@ -65,7 +66,6 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
return
}
provider.MaskAPIKey()
c.JSON(http.StatusCreated, provider)
}
@@ -84,9 +84,9 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
func (h *ProviderHandler) GetProvider(c *gin.Context) {
id := c.Param("id")
provider, err := h.providerService.Get(id, true)
provider, err := h.providerService.Get(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
@@ -113,17 +113,24 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
err := h.providerService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
return
}
if errors.Is(err, appErrors.ErrImmutableField) {
c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrImmutableField.Message,
"code": appErrors.ErrImmutableField.Code,
})
return
}
writeError(c, err)
return
}
provider, err := h.providerService.Get(id, true)
provider, err := h.providerService.Get(id)
if err != nil {
writeError(c, err)
return
@@ -138,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
err := h.providerService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})

View File

@@ -3,17 +3,23 @@ package handler
import (
"bufio"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/internal/service"
appErrors "nex/backend/pkg/errors"
"nex/backend/pkg/modelid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
)
// ProxyHandler 统一代理处理器
@@ -27,14 +33,14 @@ type ProxyHandler struct {
}
// NewProxyHandler 创建统一代理处理器
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler {
return &ProxyHandler{
engine: engine,
client: client,
routingService: routingService,
providerService: providerService,
statsService: statsService,
logger: zap.L(),
logger: pkglogger.WithModule(logger, "handler.proxy"),
}
}
@@ -43,47 +49,93 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
clientProtocol := c.Param("protocol")
if clientProtocol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
return
}
// 原始路径: /v1/{path}
// 原始路径: /{path}
path := c.Param("path")
if strings.HasPrefix(path, "/") {
path = path[1:]
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
nativePath := path
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
// 获取 client adapter
registry := h.engine.GetRegistry()
clientAdapter, err := registry.Get(clientProtocol)
if err != nil {
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return
}
// 检测接口类型
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
// 处理 Models 接口:本地聚合
if ifaceType == conversion.InterfaceTypeModels {
h.handleModelsList(c, clientAdapter)
return
}
// 处理 ModelInfo 接口:本地查询
if ifaceType == conversion.InterfaceTypeModelInfo {
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
if err != nil {
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
return
}
h.handleModelInfo(c, unifiedID, clientAdapter)
return
}
nativePath := "/v1/" + path
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
return
}
// 解析 model 名称(从 JSON body 中提取GET 请求无 body
modelName := ""
if len(body) > 0 {
modelName = extractModelName(body)
}
// 构建输入 HTTPRequestSpec
inSpec := conversion.HTTPRequestSpec{
URL: nativePath,
URL: requestPath,
Method: c.Request.Method,
Headers: extractHeaders(c),
Body: body,
}
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 路由
routeResult, err := h.routingService.Route(modelName)
if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游
if len(body) == 0 || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol)
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
h.writeError(c, err, clientProtocol)
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err != nil {
if isInvalidJSONError(err) {
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
return
}
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
if providerID == "" || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
// 路由
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
if err != nil {
h.writeRouteError(c, err)
return
}
@@ -94,28 +146,53 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
}
// 构建 TargetProvider
// 注意ModelName 字段用于 Smart Passthrough 场景改写请求体
// 同协议:请求体中的统一 ID 会被改写为 ModelName上游名
// 跨协议:全量转换时 ModelName 会被编码到请求体中
targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL,
routeResult.Provider.APIKey,
routeResult.Model.ModelName,
routeResult.Model.ModelName, // 上游模型名,用于请求改写
)
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 计算统一模型 ID用于响应覆写
unifiedModelID := routeResult.Model.UnifiedModelID()
if isStream {
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
} else {
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
}
}
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
return true
default:
return false
}
}
func isInvalidJSONError(err error) bool {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
}
func appendRawQuery(path, rawQuery string) string {
if rawQuery == "" {
return path
}
return path + "?" + rawQuery
}
// handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
h.logger.Error("转换请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol)
return
}
@@ -123,37 +200,32 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
h.logger.Error("发送请求失败", zap.Error(err))
h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return
}
// 转换响应
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
h.logger.Error("转换响应失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol)
return
}
// 设置响应头
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
h.writeConvertedResponse(c, *convertedResp)
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}()
}
// handleStream 处理流式请求
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
@@ -161,15 +233,23 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
return
}
// 创建流式转换器
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
// 发送流式请求
streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
h.writeUpstreamUnavailable(c, err)
return
}
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return
}
// 发送流式请求
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
@@ -180,37 +260,61 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range eventChan {
for event := range streamResp.Events {
if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
h.logger.Error("流读取错误", zap.Error(event.Error))
break
}
if event.Done {
// flush 转换器
chunks := streamConverter.Flush()
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
}
flushed = true
break
}
chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
}
}
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}()
}
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if _, err := writer.Write(chunk); err != nil {
return err
}
if err := writer.Flush(); err != nil {
return err
}
}
return nil
}
// isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if err != nil {
return false
}
if ifaceType != conversion.InterfaceTypeChat {
return false
}
@@ -224,34 +328,166 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s
return req.Stream
}
// writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok {
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
c.Data(statusCode, "application/json", body)
// handleModelsList 处理 GET /v1/models 本地聚合
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
// 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels()
if err != nil {
h.logger.Error("查询启用模型失败", zap.Error(err))
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
// 构建 CanonicalModelList
modelList := &canonical.CanonicalModelList{
Models: make([]canonical.CanonicalModel, 0, len(models)),
}
for _, m := range models {
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
ID: m.UnifiedModelID(),
Name: m.ModelName,
Created: m.CreatedAt.Unix(),
OwnedBy: m.ProviderID,
})
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList)
if err != nil {
h.logger.Error("编码 Models 响应失败", zap.Error(err))
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return
}
c.Data(http.StatusOK, "application/json", body)
}
// writeError 写入路由错误
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
// 解析统一模型 ID
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
return
}
// 从数据库查询模型
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
if err != nil {
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
return
}
// 构建 CanonicalModelInfo
modelInfo := &canonical.CanonicalModelInfo{
ID: model.UnifiedModelID(),
Name: model.ModelName,
Created: model.CreatedAt.Unix(),
OwnedBy: model.ProviderID,
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelInfoResponse(modelInfo)
if err != nil {
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return
}
c.Data(http.StatusOK, "application/json", body)
}
// writeConversionError 写入网关层转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
var convErr *conversion.ConversionError
if errors.As(err, &convErr) {
statusCode, code, message := mapConversionError(convErr)
h.writeProxyError(c, statusCode, code, message)
return
}
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
}
func mapConversionError(err *conversion.ConversionError) (int, string, string) {
switch err.Code {
case conversion.ErrorCodeJSONParseError:
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
}
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
case conversion.ErrorCodeInvalidInput,
conversion.ErrorCodeMissingRequiredField,
conversion.ErrorCodeProtocolConstraint:
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
case conversion.ErrorCodeInterfaceNotSupported:
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
case conversion.ErrorCodeUnsupportedMultimodal:
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
default:
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
}
}
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
if appErr, ok := appErrors.AsAppError(err); ok {
switch appErr.Code {
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
default:
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
}
return
}
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
}
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
h.logger.Error("上游不可达", zap.Error(err))
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
}
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": message,
"code": code,
})
}
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range resp.Headers {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, resp.Body)
}
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range filterHopByHopHeaders(resp.Headers) {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
c.Data(resp.StatusCode, contentType, resp.Body)
}
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
registry := h.engine.GetRegistry()
adapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return
}
providers, err := h.providerService.List()
if err != nil || len(providers) == 0 {
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
return
}
@@ -261,19 +497,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
providerProtocol = "openai"
}
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
var outSpec *conversion.HTTPRequestSpec
if clientProtocol == providerProtocol {
upstreamURL := p.BaseURL + inSpec.URL
upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
headers := adapter.BuildHeaders(targetProvider)
if _, ok := headers["Content-Type"]; !ok {
headers["Content-Type"] = "application/json"
}
outSpec = &conversion.HTTPRequestSpec{
URL: upstreamURL,
URL: joinBaseURL(p.BaseURL, upstreamPath),
Method: inSpec.Method,
Headers: headers,
Body: inSpec.Body,
@@ -286,36 +521,132 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
}
}
if isStream {
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
return
}
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return
}
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
h.writeConvertedResponse(c, *convertedResp)
}
// extractModelName 从 JSON body 中提取 model
func extractModelName(body []byte) string {
var req struct {
Model string `json:"model"`
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
if err != nil {
h.writeUpstreamUnavailable(c, err)
return
}
if err := json.Unmarshal(body, &req); err != nil {
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return
}
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range streamResp.Events {
if event.Error != nil {
h.logger.Error("透传流读取错误", zap.Error(event.Error))
break
}
if event.Done {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
flushed = true
break
}
chunks := streamConverter.ProcessChunk(event.Data)
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
}
}
func stripRawQuery(path string) string {
pathOnly, _, _ := strings.Cut(path, "?")
return pathOnly
}
func rawQueryFromPath(path string) string {
_, rawQuery, found := strings.Cut(path, "?")
if !found {
return ""
}
return req.Model
return rawQuery
}
func joinBaseURL(baseURL, path string) string {
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
}
func headerValue(headers map[string]string, key string) string {
for k, v := range headers {
if strings.EqualFold(k, key) {
return v
}
}
return ""
}
func filterHopByHopHeaders(headers map[string]string) map[string]string {
if len(headers) == 0 {
return nil
}
hopByHop := map[string]struct{}{
"connection": {},
"transfer-encoding": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"upgrade": {},
}
filtered := make(map[string]string, len(headers))
for k, v := range headers {
if _, skip := hopByHop[strings.ToLower(k)]; skip {
continue
}
filtered[k] = v
}
return filtered
}
// extractHeaders 从 Gin context 提取请求头

File diff suppressed because it is too large Load Diff

View File

@@ -5,9 +5,9 @@ import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
)
// StatsHandler 统计处理器

View File

@@ -8,6 +8,7 @@ import (
"io"
"net"
"net/http"
"strings"
"syscall"
"time"
@@ -15,6 +16,7 @@ import (
"nex/backend/internal/conversion"
pkgErrors "nex/backend/pkg/errors"
pkglogger "nex/backend/pkg/logger"
)
// StreamConfig 流式处理配置
@@ -42,6 +44,14 @@ type StreamEvent struct {
Done bool
}
// StreamResponse 表示上游流式 HTTP 响应。
type StreamResponse struct {
StatusCode int
Headers map[string]string
Body []byte
Events <-chan StreamEvent
}
// Client 协议无关的供应商客户端
type Client struct {
httpClient *http.Client
@@ -50,18 +60,20 @@ type Client struct {
}
// ProviderClient 供应商客户端接口
//
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error)
}
// NewClient 创建供应商客户端
func NewClient() *Client {
func NewClient(logger *zap.Logger) *Client {
return &Client{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
logger: zap.L(),
logger: pkglogger.WithModule(logger, "provider.client"),
streamCfg: DefaultStreamConfig(),
}
}
@@ -113,7 +125,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
}
// SendStream 发送流式请求
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) {
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error) {
var bodyReader io.Reader
if len(spec.Body) > 0 {
bodyReader = bytes.NewReader(spec.Body)
@@ -136,20 +148,29 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
return nil, pkgErrors.ErrRequestSend.WithCause(err)
}
if resp.StatusCode != http.StatusOK {
respHeaders := extractResponseHeaders(resp.Header)
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
defer resp.Body.Close()
cancel()
errBody, _ := io.ReadAll(resp.Body)
if len(errBody) > 0 {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
errBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, pkgErrors.ErrResponseRead.WithCause(readErr)
}
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Body: errBody,
}, nil
}
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
return eventChan, nil
return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Events: eventChan,
}, nil
}
// readStream 读取 SSE 流
@@ -182,10 +203,10 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
if err != nil {
if err != io.EOF {
if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error()))
c.logger.Error("流网络错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else {
c.logger.Error("流读取错误", zap.String("error", err.Error()))
c.logger.Error("流读取错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
}
return
@@ -202,15 +223,17 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
}
for {
idx := bytes.Index(dataBuf, []byte("\n\n"))
idx, sepLen := findSSEFrameSeparator(dataBuf)
if idx == -1 {
break
}
rawEvent := dataBuf[:idx]
dataBuf = dataBuf[idx+2:]
frameEnd := idx + sepLen
rawEvent := append([]byte(nil), dataBuf[:frameEnd]...)
dataBuf = dataBuf[frameEnd:]
if bytes.Contains(rawEvent, []byte("data: [DONE]")) {
if isSSEDoneFrame(rawEvent) {
eventChan <- StreamEvent{Data: rawEvent}
eventChan <- StreamEvent{Done: true}
return
}
@@ -219,11 +242,66 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
}
if err == io.EOF {
if len(dataBuf) > 0 {
eventChan <- StreamEvent{Data: dataBuf}
}
return
}
}
}
func isSSEDoneFrame(frame []byte) bool {
payload, ok := sseFrameDataPayload(frame)
return ok && strings.TrimSpace(payload) == "[DONE]"
}
func sseFrameDataPayload(frame []byte) (string, bool) {
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
var dataLines []string
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
value := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(value, " ") {
value = value[1:]
}
dataLines = append(dataLines, value)
}
}
if len(dataLines) == 0 {
return "", false
}
return strings.Join(dataLines, "\n"), true
}
func extractResponseHeaders(header http.Header) map[string]string {
respHeaders := make(map[string]string)
for k, vs := range header {
if len(vs) > 0 {
respHeaders[k] = vs[0]
}
}
return respHeaders
}
func findSSEFrameSeparator(data []byte) (int, int) {
lf := bytes.Index(data, []byte("\n\n"))
crlf := bytes.Index(data, []byte("\r\n\r\n"))
switch {
case lf < 0 && crlf < 0:
return -1, 0
case lf < 0:
return crlf, 4
case crlf < 0:
return lf, 2
case crlf <= lf:
return crlf, 4
default:
return lf, 2
}
}
// isNetworkError 判断是否为网络相关错误
func isNetworkError(err error) bool {
if err == nil {

View File

@@ -13,12 +13,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion"
)
func TestNewClient(t *testing.T) {
client := NewClient()
client := NewClient(zap.NewNop())
require.NotNil(t, client)
assert.NotNil(t, client.httpClient)
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
@@ -40,11 +41,12 @@ func TestClient_Send_Success(t *testing.T) {
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
_, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
require.NoError(t, err)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -64,11 +66,12 @@ func TestClient_Send_Success(t *testing.T) {
func TestClient_Send_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
_, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
require.NoError(t, err)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -82,7 +85,7 @@ func TestClient_Send_ErrorResponse(t *testing.T) {
}
func TestClient_Send_ConnectionError(t *testing.T) {
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: "http://localhost:1/v1/chat/completions",
Method: "POST",
@@ -99,7 +102,7 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -107,11 +110,13 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, eventChan)
require.NotNil(t, streamResp)
require.Equal(t, http.StatusOK, streamResp.StatusCode)
require.NotNil(t, streamResp.Events)
for range eventChan {
for range streamResp.Events {
}
}
@@ -121,7 +126,7 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -129,8 +134,10 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
Body: []byte(`{}`),
}
_, err := client.SendStream(context.Background(), spec)
assert.Error(t, err)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, streamResp)
assert.Equal(t, http.StatusInternalServerError, streamResp.StatusCode)
}
func TestClient_SendStream_SSEEvents(t *testing.T) {
@@ -139,18 +146,21 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
_, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -158,24 +168,73 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, streamResp)
var dataEvents [][]byte
var doneEvents int
for event := range eventChan {
if event.Done {
for event := range streamResp.Events {
switch {
case event.Done:
doneEvents++
} else if event.Error != nil {
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
} else {
default:
dataEvents = append(dataEvents, event.Data)
}
}
assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream")
assert.Equal(t, 3, len(dataEvents), "expected 2 data frames plus DONE frame from SSE stream")
assert.Contains(t, string(dataEvents[0]), "Hello")
assert.Contains(t, string(dataEvents[1]), "World")
assert.Contains(t, string(dataEvents[2]), "[DONE]")
assert.Equal(t, 1, doneEvents)
assert.Contains(t, string(dataEvents[0]), "\n\n")
}
func TestClient_SendStream_DONEOnlyWhenDataPayloadEqualsDone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
_, err := w.Write([]byte("data: {\"text\":\"data: [DONE] is plain text\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Body: []byte(`{}`),
}
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, streamResp)
var dataEvents [][]byte
var doneEvents int
for event := range streamResp.Events {
switch {
case event.Done:
doneEvents++
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
default:
dataEvents = append(dataEvents, event.Data)
}
}
require.Len(t, dataEvents, 2)
assert.Contains(t, string(dataEvents[0]), "plain text")
assert.Contains(t, string(dataEvents[1]), "[DONE]")
assert.Equal(t, 1, doneEvents)
}
@@ -188,7 +247,7 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -196,13 +255,13 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(ctx, spec)
streamResp, err := client.SendStream(ctx, spec)
require.NoError(t, err)
cancel()
var gotError bool
for event := range eventChan {
for event := range streamResp.Events {
if event.Error != nil {
gotError = true
}
@@ -214,11 +273,12 @@ func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`))
_, err := w.Write([]byte(`{"result":"ok"}`))
require.NoError(t, err)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/models",
Method: "GET",
@@ -237,16 +297,18 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(100 * time.Millisecond)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -254,21 +316,22 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataCount int
var doneCount int
for event := range eventChan {
if event.Done {
for event := range streamResp.Events {
switch {
case event.Done:
doneCount++
} else if event.Error != nil {
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
} else {
default:
dataCount++
}
}
assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE")
assert.Equal(t, 2, dataCount, "expected 1 data frame plus DONE frame from slow SSE")
assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE")
}
@@ -278,16 +341,18 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -295,19 +360,19 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataEvents int
var doneEvents int
for event := range eventChan {
for event := range streamResp.Events {
if event.Done {
doneEvents++
} else {
dataEvents++
}
}
assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE")
assert.Equal(t, 3, dataEvents, "expected 2 data frames plus DONE frame from split SSE")
assert.Equal(t, 1, doneEvents)
}
@@ -363,19 +428,20 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack()
if conn != nil {
conn.Close()
require.NoError(t, conn.Close())
}
}
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -383,11 +449,11 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var gotData bool
for event := range eventChan {
for event := range streamResp.Events {
if event.Error != nil {
} else if !event.Done {
gotData = true

View File

@@ -2,12 +2,15 @@ package repository
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=model_repo.go -destination=../../tests/mocks/mock_model_repository.go -package=mocks
// ModelRepository 模型数据仓库接口
type ModelRepository interface {
Create(model *domain.Model) error
GetByID(id string) (*domain.Model, error)
List(providerID string) ([]domain.Model, error)
GetByModelName(modelName string) (*domain.Model, error)
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
ListEnabled() ([]domain.Model, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
}

View File

@@ -52,9 +52,9 @@ func (r *modelRepository) List(providerID string) ([]domain.Model, error) {
return result, nil
}
func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error) {
func (r *modelRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
var m config.Model
err := r.db.Where("model_name = ?", modelName).First(&m).Error
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&m).Error
if err != nil {
return nil, err
}
@@ -62,6 +62,21 @@ func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error
return &d, nil
}
func (r *modelRepository) ListEnabled() ([]domain.Model, error) {
var models []config.Model
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
Where("models.enabled = ? AND providers.enabled = ?", true, true).
Find(&models).Error
if err != nil {
return nil, err
}
result := make([]domain.Model, len(models))
for i := range models {
result[i] = toDomainModel(&models[i])
}
return result, nil
}
func (r *modelRepository) Update(id string, updates map[string]interface{}) error {
result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {

View File

@@ -2,6 +2,8 @@ package repository
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=provider_repo.go -destination=../../tests/mocks/mock_provider_repository.go -package=mocks
// ProviderRepository 供应商数据仓库接口
type ProviderRepository interface {
Create(provider *domain.Provider) error

View File

@@ -3,10 +3,11 @@ package repository
import (
"time"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
)

View File

@@ -3,30 +3,18 @@ package repository
import (
"testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
testHelpers "nex/backend/tests"
)
func setupTestDB(t *testing.T) *gorm.DB {
t.Helper()
dir := t.TempDir()
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
// 关闭数据库连接以便 TempDir 清理
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
return db
return testHelpers.SetupTestDB(t)
}
// ============ ProviderRepository 测试 ============
@@ -88,7 +76,7 @@ func TestProviderRepository_Update(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"})
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"}))
err := repo.Update("p1", map[string]interface{}{"name": "New"})
require.NoError(t, err)
@@ -109,7 +97,7 @@ func TestProviderRepository_Delete(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
err := repo.Delete("p1")
require.NoError(t, err)
@@ -129,17 +117,21 @@ func TestProviderRepository_Delete_NotFound(t *testing.T) {
func TestModelRepository_Create(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, err)
}
func TestModelRepository_GetByID(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
result, err := repo.GetByID("m1")
require.NoError(t, err)
@@ -147,24 +139,52 @@ func TestModelRepository_GetByID(t *testing.T) {
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelRepository_GetByModelName(t *testing.T) {
func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
result, err := repo.GetByModelName("gpt-4")
result, err := repo.FindByProviderAndModelName("p1", "gpt-4")
require.NoError(t, err)
assert.Equal(t, "m1", result.ID)
assert.Equal(t, "p1", result.ProviderID)
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
// Wrong provider_id
_, err := repo.FindByProviderAndModelName("p2", "gpt-4")
assert.Error(t, err)
// Wrong model_name
_, err = repo.FindByProviderAndModelName("p1", "gpt-3.5")
assert.Error(t, err)
// Both wrong
_, err = repo.FindByProviderAndModelName("p2", "claude-3")
assert.Error(t, err)
}
func TestModelRepository_List(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p2", Name: "Test2", APIKey: "key", BaseURL: "https://test2.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"}))
all, err := repo.List("")
require.NoError(t, err)
@@ -175,11 +195,61 @@ func TestModelRepository_List(t *testing.T) {
assert.Len(t, p1Models, 2)
}
func TestModelRepository_ListEnabled(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
modelRepo := NewModelRepository(db)
// Create two providers (both start enabled due to gorm:"default:true")
err := providerRepo.Create(&domain.Provider{
ID: "enabled-provider", Name: "Enabled Provider",
APIKey: "key1", BaseURL: "https://enabled.com", Enabled: true,
})
require.NoError(t, err)
err = providerRepo.Create(&domain.Provider{
ID: "disabled-provider", Name: "Disabled Provider",
APIKey: "key2", BaseURL: "https://disabled.com", Enabled: true,
})
require.NoError(t, err)
// Disable the second provider via Update (GORM default:true skips zero values on Create)
err = providerRepo.Update("disabled-provider", map[string]interface{}{"enabled": false})
require.NoError(t, err)
// Create models (all start enabled due to gorm:"default:true")
err = modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "enabled-provider", ModelName: "gpt-4", Enabled: true})
require.NoError(t, err)
err = modelRepo.Create(&domain.Model{ID: "m2", ProviderID: "enabled-provider", ModelName: "gpt-3.5", Enabled: true})
require.NoError(t, err)
err = modelRepo.Create(&domain.Model{ID: "m3", ProviderID: "disabled-provider", ModelName: "claude-3", Enabled: true})
require.NoError(t, err)
err = modelRepo.Create(&domain.Model{ID: "m4", ProviderID: "disabled-provider", ModelName: "claude-3.5", Enabled: true})
require.NoError(t, err)
// Disable m2 via Update
err = modelRepo.Update("m2", map[string]interface{}{"enabled": false})
require.NoError(t, err)
// ListEnabled should only return models where both model and provider are enabled:
// - m1: enabled provider + enabled model -> returned
// - m2: enabled provider + disabled model -> filtered out
// - m3: disabled provider + enabled model -> filtered out
// - m4: disabled provider + enabled model -> filtered out
enabled, err := modelRepo.ListEnabled()
require.NoError(t, err)
require.Len(t, enabled, 1)
assert.Equal(t, "m1", enabled[0].ID)
assert.Equal(t, "enabled-provider", enabled[0].ProviderID)
assert.Equal(t, "gpt-4", enabled[0].ModelName)
}
func TestModelRepository_Update(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
err := repo.Update("m1", map[string]interface{}{"enabled": false})
require.NoError(t, err)
@@ -190,9 +260,11 @@ func TestModelRepository_Update(t *testing.T) {
func TestModelRepository_Delete(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
err := repo.Delete("m1")
require.NoError(t, err)
@@ -224,10 +296,32 @@ func TestStatsRepository_Query(t *testing.T) {
db := setupTestDB(t)
repo := NewStatsRepository(db)
repo.Record("p1", "gpt-4")
require.NoError(t, repo.Record("p1", "gpt-4"))
// 注意:当前 schema 只有 date 字段有唯一约束
// 所以同一 provider + model 只能有一条记录
stats, err := repo.Query("p1", "", nil, nil)
require.NoError(t, err)
assert.Len(t, stats, 1)
}
func TestModelRepository_List_EmptyResult(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
result, err := repo.List("")
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result)
assert.Len(t, result, 0)
}
func TestProviderRepository_List_EmptyResult(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
result, err := repo.List()
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result)
assert.Len(t, result, 0)
}

View File

@@ -6,8 +6,11 @@ import (
"nex/backend/internal/domain"
)
//go:generate go run go.uber.org/mock/mockgen -source=stats_repo.go -destination=../../tests/mocks/mock_stats_repository.go -package=mocks
// StatsRepository 统计数据仓库接口
type StatsRepository interface {
Record(providerID, modelName string) error
BatchUpdate(providerID, modelName string, date time.Time, delta int) error
Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
}

View File

@@ -1,13 +1,13 @@
package repository
import (
"errors"
"time"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type statsRepository struct {
@@ -19,28 +19,46 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
}
func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
now := time.Now()
todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
return r.db.Transaction(func(tx *gorm.DB) error {
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, todayTime).First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
stats = config.UsageStats{
stats := config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
return tx.Create(&stats).Error
} else if err != nil {
return err
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "provider_id"},
{Name: "model_name"},
{Name: "date"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"request_count": gorm.Expr("request_count + 1"),
}),
}).Create(&stats).Error
}
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
stats := config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: delta,
Date: date,
}
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
})
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "provider_id"},
{Name: "model_name"},
{Name: "date"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"request_count": gorm.Expr("request_count + ?", delta),
}),
}).Create(&stats).Error
}
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {

View File

@@ -2,11 +2,14 @@ package service
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=model_service.go -destination=../../tests/mocks/mock_model_service.go -package=mocks
// ModelService 模型服务接口
type ModelService interface {
Create(model *domain.Model) error
Get(id string) (*domain.Model, error)
List(providerID string) ([]domain.Model, error)
ListEnabled() ([]domain.Model, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
}

View File

@@ -1,29 +1,44 @@
package service
import (
appErrors "nex/backend/pkg/errors"
"errors"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/google/uuid"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
)
type modelService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
cache *RoutingCache
}
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService {
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo}
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository, cache *RoutingCache) ModelService {
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo, cache: cache}
}
func (s *modelService) Create(model *domain.Model) error {
// Verify provider exists
_, err := s.providerRepo.GetByID(model.ProviderID)
if err != nil {
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
return appErrors.ErrProviderNotFound
}
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
return err
}
model.ID = uuid.New().String()
model.Enabled = true
return s.modelRepo.Create(model)
err := s.modelRepo.Create(model)
if err != nil {
return err
}
s.cache.SetModel(model)
return nil
}
func (s *modelService) Get(id string) (*domain.Model, error) {
@@ -34,17 +49,77 @@ func (s *modelService) List(providerID string) ([]domain.Model, error) {
return s.modelRepo.List(providerID)
}
func (s *modelService) ListEnabled() ([]domain.Model, error) {
return s.modelRepo.ListEnabled()
}
func (s *modelService) Update(id string, updates map[string]interface{}) error {
// If updating provider_id, verify new provider exists
if providerID, ok := updates["provider_id"].(string); ok {
_, err := s.providerRepo.GetByID(providerID)
current, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
if providerID, ok := updates["provider_id"].(string); ok {
if _, err := s.providerRepo.GetByID(providerID); err != nil {
return appErrors.ErrProviderNotFound
}
}
return s.modelRepo.Update(id, updates)
newProviderID := current.ProviderID
if v, ok := updates["provider_id"].(string); ok {
newProviderID = v
}
newModelName := current.ModelName
if v, ok := updates["model_name"].(string); ok {
newModelName = v
}
if newProviderID != current.ProviderID || newModelName != current.ModelName {
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
return err
}
}
err = s.modelRepo.Update(id, updates)
if err != nil {
return err
}
s.cache.InvalidateModel(current.ProviderID, current.ModelName)
if newProviderID != current.ProviderID || newModelName != current.ModelName {
s.cache.InvalidateModel(newProviderID, newModelName)
}
return nil
}
func (s *modelService) Delete(id string) error {
return s.modelRepo.Delete(id)
model, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
err = s.modelRepo.Delete(id)
if err != nil {
return err
}
s.cache.InvalidateModel(model.ProviderID, model.ModelName)
return nil
}
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复
// excludeID 用于更新时排除自身
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 未找到,不重复
}
return err
}
if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身
}
return appErrors.ErrDuplicateModel
}

View File

@@ -2,11 +2,16 @@ package service
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=provider_service.go -destination=../../tests/mocks/mock_provider_service.go -package=mocks
// ProviderService 供应商服务接口
type ProviderService interface {
Create(provider *domain.Provider) error
Get(id string, maskKey bool) (*domain.Provider, error)
Get(id string) (*domain.Provider, error)
List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
// 统一模型 ID 相关方法
ListEnabledModels() ([]domain.Model, error)
GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error)
}

View File

@@ -1,49 +1,85 @@
package service
import (
"strings"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"nex/backend/pkg/modelid"
appErrors "nex/backend/pkg/errors"
)
type providerService struct {
providerRepo repository.ProviderRepository
modelRepo repository.ModelRepository
cache *RoutingCache
}
func NewProviderService(providerRepo repository.ProviderRepository) ProviderService {
return &providerService{providerRepo: providerRepo}
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository, cache *RoutingCache) ProviderService {
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo, cache: cache}
}
func (s *providerService) Create(provider *domain.Provider) error {
if err := modelid.ValidateProviderID(provider.ID); err != nil {
return appErrors.ErrInvalidProviderID
}
provider.Enabled = true
return s.providerRepo.Create(provider)
err := s.providerRepo.Create(provider)
if err != nil {
if isUniqueConstraintError(err) {
return appErrors.ErrConflict
}
return err
}
s.cache.SetProvider(provider)
return nil
}
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
provider, err := s.providerRepo.GetByID(id)
if err != nil {
return nil, err
}
if maskKey {
provider.MaskAPIKey()
}
return provider, nil
func (s *providerService) Get(id string) (*domain.Provider, error) {
return s.providerRepo.GetByID(id)
}
func (s *providerService) List() ([]domain.Provider, error) {
providers, err := s.providerRepo.List()
if err != nil {
return nil, err
}
for i := range providers {
providers[i].MaskAPIKey()
}
return providers, nil
return s.providerRepo.List()
}
func (s *providerService) Update(id string, updates map[string]interface{}) error {
return s.providerRepo.Update(id, updates)
if _, ok := updates["id"]; ok {
return appErrors.ErrImmutableField
}
err := s.providerRepo.Update(id, updates)
if err != nil {
return err
}
s.cache.InvalidateProvider(id)
return nil
}
func (s *providerService) Delete(id string) error {
return s.providerRepo.Delete(id)
err := s.providerRepo.Delete(id)
if err != nil {
return err
}
s.cache.InvalidateProvider(id)
return nil
}
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)
func (s *providerService) ListEnabledModels() ([]domain.Model, error) {
return s.modelRepo.ListEnabled()
}
// GetModelByProviderAndName 按 provider_id 和 model_name 查询模型(用于 ModelInfo 接口本地查询)
func (s *providerService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return s.modelRepo.FindByProviderAndModelName(providerID, modelName)
}
// isUniqueConstraintError 判断是否为数据库唯一约束冲突错误
func isUniqueConstraintError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "unique constraint") || strings.Contains(msg, "duplicate")
}

View File

@@ -0,0 +1,149 @@
package service
import (
"strings"
"sync"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
)
type RoutingCache struct {
providers sync.Map
models sync.Map
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
logger *zap.Logger
}
func NewRoutingCache(
modelRepo repository.ModelRepository,
providerRepo repository.ProviderRepository,
logger *zap.Logger,
) *RoutingCache {
return &RoutingCache{
modelRepo: modelRepo,
providerRepo: providerRepo,
logger: pkglogger.WithModule(logger, "service.routing_cache"),
}
}
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok {
if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
}
provider, err := c.providerRepo.GetByID(id)
if err != nil {
return nil, err
}
if v, ok := c.providers.Load(id); ok {
if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
}
c.providers.Store(id, provider)
return provider, nil
}
func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, error) {
key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok {
if model, ok := v.(*domain.Model); ok {
return model, nil
}
}
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil, err
}
if v, ok := c.models.Load(key); ok {
if model, ok := v.(*domain.Model); ok {
return model, nil
}
}
c.models.Store(key, model)
return model, nil
}
func (c *RoutingCache) SetProvider(provider *domain.Provider) {
c.providers.Store(provider.ID, provider)
}
func (c *RoutingCache) SetModel(model *domain.Model) {
key := model.ProviderID + "/" + model.ModelName
c.models.Store(key, model)
}
func (c *RoutingCache) InvalidateProvider(id string) {
c.providers.Delete(id)
c.invalidateModelsByProvider(id)
c.logger.Debug("Provider 缓存失效", zap.String("provider_id", id))
}
func (c *RoutingCache) InvalidateModel(providerID, modelName string) {
key := providerID + "/" + modelName
c.models.Delete(key)
c.logger.Debug("Model 缓存失效",
zap.String("provider_id", providerID),
zap.String("model_name", modelName))
}
func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/"
count := 0
c.models.Range(func(key, value interface{}) bool {
keyStr, ok := key.(string)
if !ok {
return true
}
if strings.HasPrefix(keyStr, prefix) {
c.models.Delete(key)
count++
}
return true
})
if count > 0 {
c.logger.Debug("清除 Provider 相关 Model 缓存",
zap.String("provider_id", providerID),
zap.Int("count", count))
}
}
func (c *RoutingCache) Preload() error {
providers, err := c.providerRepo.List()
if err != nil {
return err
}
for i := range providers {
c.providers.Store(providers[i].ID, &providers[i])
}
models, err := c.modelRepo.List("")
if err != nil {
return err
}
for i := range models {
key := models[i].ProviderID + "/" + models[i].ModelName
c.models.Store(key, &models[i])
}
c.logger.Info("缓存预热完成",
zap.Int("providers", len(providers)),
zap.Int("models", len(models)))
return nil
}

View File

@@ -0,0 +1,274 @@
package service
import (
"errors"
"sync"
"testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
type mockModelRepo struct {
models map[string]*domain.Model
}
func (m *mockModelRepo) Create(model *domain.Model) error {
return nil
}
func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) {
return nil, nil
}
func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
key := providerID + "/" + modelName
if model, ok := m.models[key]; ok {
return model, nil
}
return nil, errors.New("not found")
}
func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) {
var result []domain.Model
for _, model := range m.models {
if providerID == "" || model.ProviderID == providerID {
result = append(result, *model)
}
}
return result, nil
}
func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) {
return nil, nil
}
func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error {
return nil
}
func (m *mockModelRepo) Delete(id string) error {
return nil
}
type mockProviderRepo struct {
providers map[string]*domain.Provider
}
func (m *mockProviderRepo) Create(provider *domain.Provider) error {
return nil
}
func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) {
if provider, ok := m.providers[id]; ok {
return provider, nil
}
return nil, errors.New("not found")
}
func (m *mockProviderRepo) List() ([]domain.Provider, error) {
var result []domain.Provider
for _, provider := range m.providers {
result = append(result, *provider)
}
return result, nil
}
func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error {
return nil
}
func (m *mockProviderRepo) Delete(id string) error {
return nil
}
func TestRoutingCache_GetProvider_CacheHit(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
provider, err := cache.GetProvider("openai")
require.NoError(t, err)
assert.Equal(t, "openai", provider.ID)
provider2, err := cache.GetProvider("openai")
require.NoError(t, err)
assert.Equal(t, provider, provider2)
}
func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetProvider("notexist")
assert.Error(t, err)
}
func TestRoutingCache_GetModel_CacheHit(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
model, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, "gpt-4", model.ModelName)
model2, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, model, model2)
}
func TestRoutingCache_GetModel_CacheMiss(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "notexist")
assert.Error(t, err)
}
func TestRoutingCache_DoubleCheck(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := cache.GetProvider("openai")
assert.NoError(t, err)
}()
}
wg.Wait()
}
func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
"openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true},
"anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
"anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
_, err = cache.GetModel("openai", "gpt-3.5")
require.NoError(t, err)
_, err = cache.GetModel("anthropic", "claude")
require.NoError(t, err)
cache.InvalidateProvider("openai")
var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool {
keyStr, ok := key.(string)
if ok && keyStr == "anthropic/claude" {
anthropicCount++
}
return true
})
assert.Equal(t, 0, openaiCount)
assert.Equal(t, 1, anthropicCount)
}
func TestRoutingCache_InvalidateModel(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
cache.InvalidateModel("openai", "gpt-4")
var count int
cache.models.Range(func(key, value interface{}) bool {
count++
return true
})
assert.Equal(t, 0, count)
}
func TestRoutingCache_Preload_Success(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
err := cache.Preload()
require.NoError(t, err)
var providerCount, modelCount int
cache.providers.Range(func(key, value interface{}) bool {
providerCount++
return true
})
cache.models.Range(func(key, value interface{}) bool {
modelCount++
return true
})
assert.Equal(t, 1, providerCount)
assert.Equal(t, 1, modelCount)
}
func TestRoutingCache_ConcurrentAccess(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = cache.GetProvider("openai")
_, _ = cache.GetModel("openai", "gpt-4")
cache.InvalidateProvider("openai")
cache.InvalidateModel("openai", "gpt-4")
}()
}
wg.Wait()
}

View File

@@ -2,7 +2,9 @@ package service
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=routing_service.go -destination=../../tests/mocks/mock_routing_service.go -package=mocks
// RoutingService 路由服务接口
type RoutingService interface {
Route(modelName string) (*domain.RouteResult, error)
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)
}

View File

@@ -1,23 +1,20 @@
package service
import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
)
type routingService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
cache *RoutingCache
}
func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService {
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
func NewRoutingService(cache *RoutingCache) RoutingService {
return &routingService{cache: cache}
}
func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.GetByModelName(modelName)
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
model, err := s.cache.GetModel(providerID, modelName)
if err != nil {
return nil, appErrors.ErrModelNotFound
}
@@ -26,7 +23,7 @@ func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
return nil, appErrors.ErrModelDisabled
}
provider, err := s.providerRepo.GetByID(model.ProviderID)
provider, err := s.cache.GetProvider(model.ProviderID)
if err != nil {
return nil, appErrors.ErrProviderNotFound
}

View File

@@ -3,24 +3,27 @@ package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestProviderService_Update(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}))
err := svc.Update("p1", map[string]interface{}{"name": "Updated"})
require.NoError(t, err)
result, err := svc.Get("p1", false)
result, err := svc.Get("p1")
require.NoError(t, err)
assert.Equal(t, "Updated", result.Name)
}
@@ -28,7 +31,9 @@ func TestProviderService_Update(t *testing.T) {
func TestProviderService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
assert.Error(t, err)
@@ -38,43 +43,49 @@ func TestModelService_Get(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
model, err := svc.Get("m1")
result, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4", model.ModelName)
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelService_Update(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Update("m1", map[string]interface{}{"model_name": "gpt-4o"})
err := svc.Update(model.ID, map[string]interface{}{"model_name": "gpt-4o"})
require.NoError(t, err)
model, err := svc.Get("m1")
result, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4o", model.ModelName)
assert.Equal(t, "gpt-4o", result.ModelName)
}
func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Update("m1", map[string]interface{}{"provider_id": "nonexistent"})
err := svc.Update(model.ID, map[string]interface{}{"provider_id": "nonexistent"})
assert.Error(t, err)
}
@@ -82,15 +93,17 @@ func TestModelService_Delete(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Delete("m1")
err := svc.Delete(model.ID)
require.NoError(t, err)
_, err = svc.Get("m1")
_, err = svc.Get(model.ID)
assert.Error(t, err)
}
@@ -98,7 +111,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Delete("nonexistent")
assert.Error(t, err)
@@ -106,7 +120,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
func TestStatsService_Aggregate_Default(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
stats := []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
@@ -118,7 +133,9 @@ func TestStatsService_Aggregate_Default(t *testing.T) {
totalCount := 0
for _, r := range result {
totalCount += r["request_count"].(int)
count, ok := r["request_count"].(int)
require.True(t, ok)
totalCount += count
}
assert.Equal(t, 15, totalCount)
}
@@ -127,7 +144,8 @@ func TestModelService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent", map[string]interface{}{"model_name": "test"})
assert.Error(t, err)

View File

@@ -1,270 +1,507 @@
package service
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
testHelpers "nex/backend/tests"
appErrors "nex/backend/pkg/errors"
)
func setupServiceTestDB(t *testing.T) *gorm.DB {
t.Helper()
dir := t.TempDir()
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
return db
return testHelpers.SetupTestDB(t)
}
// ============ ProviderService 测试 ============
func TestProviderService_Create(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
provider := &domain.Provider{
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
}
err := svc.Create(provider)
require.NoError(t, err)
assert.True(t, provider.Enabled)
func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache {
t.Helper()
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
return NewRoutingCache(modelRepo, providerRepo, zap.NewNop())
}
func TestProviderService_Get_MaskKey(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
// ============ RoutingService - RouteByModelName 测试 ============
svc.Create(&domain.Provider{
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
})
result, err := svc.Get("p1", true)
require.NoError(t, err)
assert.Equal(t, "***2345", result.APIKey)
result, err = svc.Get("p1", false)
require.NoError(t, err)
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
}
func TestProviderService_List(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
providers, err := svc.List()
require.NoError(t, err)
assert.Len(t, providers, 2)
assert.Contains(t, providers[0].APIKey, "***")
}
func TestProviderService_Delete(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
err := svc.Delete("p1")
require.NoError(t, err)
_, err = svc.Get("p1", false)
assert.Error(t, err)
}
// ============ ModelService 测试 ============
func TestModelService_Create(t *testing.T) {
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
result, err := svc.RouteByModelName("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
}
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
_, err := svc.RouteByModelName("openai", "nonexistent-model")
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
}
func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false}))
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
}
func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false}))
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
}
// ============ ModelService - Create with UUID 测试 ============
func TestModelService_Create_GeneratesUUID(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
assert.True(t, model.Enabled)
assert.NotEmpty(t, model.ID)
_, err = uuid.Parse(model.ID)
assert.NoError(t, err, "model.ID should be a valid UUID")
stored, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, model.ID, stored.ID)
assert.Equal(t, "gpt-4", stored.ModelName)
}
func TestModelService_Create_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err = svc.Create(model2)
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
}
func TestModelService_Create_ProviderNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
err := svc.Create(model)
assert.Error(t, err)
assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
}
func TestModelService_List(t *testing.T) {
// ============ ProviderService - Create with validation 测试 ============
func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
}
models, err := svc.List("p1")
func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
assert.Len(t, models, 2)
assert.Equal(t, "openai", provider.ID)
assert.True(t, provider.Enabled)
}
// ============ RoutingService 测试 ============
// ============ ModelService - Update with duplicate check 测试 ============
func TestRoutingService_Route(t *testing.T) {
func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
result, err := svc.Route("gpt-4")
require.NoError(t, err)
assert.Equal(t, "p1", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
}
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
_, err := svc.Route("nonexistent-model")
assert.Error(t, err)
}
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
// 先创建启用的模型,然后通过 Update 禁用
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
// 先创建启用的 provider然后禁用
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
// ============ StatsService 测试 ============
func TestStatsService_RecordAndGet(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
err := svc.Record("p1", "gpt-4")
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
stats, err := svc.Get("p1", "", nil, nil)
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
err = svc.Create(model2)
require.NoError(t, err)
assert.Len(t, stats, 1)
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
err = svc.Update(model2.ID, map[string]interface{}{
"provider_id": "openai",
"model_name": "gpt-4",
})
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
}
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
stats := []domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
}
result := svc.Aggregate(stats, "provider")
assert.Len(t, result, 2)
p1Count := 0
p2Count := 0
for _, r := range result {
if r["provider_id"] == "p1" {
p1Count = r["request_count"].(int)
}
if r["provider_id"] == "p2" {
p2Count = r["request_count"].(int)
}
}
assert.Equal(t, 15, p1Count)
assert.Equal(t, 8, p2Count)
err := svc.Update("nonexistent-id", map[string]interface{}{
"model_name": "gpt-4",
})
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
}
func TestStatsService_Aggregate_ByDate(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
stats := []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
{ProviderID: "p2", RequestCount: 5},
}
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
result := svc.Aggregate(stats, "date")
assert.Len(t, result, 1)
assert.Equal(t, 15, result[0]["request_count"])
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
// 更新 model_name 为不冲突的值
err = svc.Update(model.ID, map[string]interface{}{
"model_name": "gpt-4-turbo",
})
require.NoError(t, err)
updated, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
}
// ============ ProviderService - Update immutable ID 测试 ============
func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
// 尝试更新 id 字段
err = svc.Update("openai", map[string]interface{}{
"id": "new-id",
})
assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
}
func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
// 更新 name
err = svc.Update("openai", map[string]interface{}{
"name": "OpenAI Updated",
})
require.NoError(t, err)
updated, err := svc.Get("openai")
require.NoError(t, err)
assert.Equal(t, "OpenAI Updated", updated.Name)
}
// ============ StatsService - Aggregate ByModel 测试 ============
func TestStatsService_Aggregate_ByModel(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
stats := []domain.UsageStats{
tests := []struct {
name string
stats []domain.UsageStats
expected []map[string]interface{}
}{
{
name: "multiple providers with same model name",
stats: []domain.UsageStats{
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
{ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20},
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5},
},
expected: []map[string]interface{}{
{"provider_id": "openai", "model_name": "gpt-4", "request_count": 15},
{"provider_id": "azure", "model_name": "gpt-4", "request_count": 20},
},
},
{
name: "empty providerID",
stats: []domain.UsageStats{
{ProviderID: "", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "", ModelName: "gpt-4", RequestCount: 5},
},
expected: []map[string]interface{}{
{"provider_id": "", "model_name": "gpt-4", "request_count": 15},
},
},
{
name: "empty result set",
stats: []domain.UsageStats{},
expected: []map[string]interface{}{},
},
}
result := svc.Aggregate(stats, "model")
assert.Len(t, result, 3)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
// 验证每个 provider/model 组合的计数
counts := make(map[string]int)
result := svc.Aggregate(tt.stats, "model")
assert.Len(t, result, len(tt.expected))
for _, exp := range tt.expected {
found := false
for _, r := range result {
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
counts[key] = r["request_count"].(int)
if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] {
assert.Equal(t, exp["request_count"], r["request_count"])
found = true
break
}
}
assert.True(t, found, "expected result not found: %v", exp)
}
})
}
assert.Equal(t, 13, counts["openai/gpt-4"])
assert.Equal(t, 5, counts["openai/gpt-3.5"])
assert.Equal(t, 8, counts["anthropic/claude-3"])
}
// ============ StatsService - Aggregate ByDate 测试 ============
func TestStatsService_Aggregate_ByDate(t *testing.T) {
tests := []struct {
name string
stats []domain.UsageStats
expected []map[string]interface{}
}{
{
name: "normal date grouping",
stats: []domain.UsageStats{
{Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10},
{Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5},
{Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20},
},
expected: []map[string]interface{}{
{"date": "2024-01-01", "request_count": 15},
{"date": "2024-01-02", "request_count": 20},
},
},
{
name: "zero-value time",
stats: []domain.UsageStats{
{Date: time.Time{}, RequestCount: 10},
{Date: time.Time{}, RequestCount: 5},
},
expected: []map[string]interface{}{
{"date": "0001-01-01", "request_count": 15},
},
},
{
name: "empty result set",
stats: []domain.UsageStats{},
expected: []map[string]interface{}{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date")
assert.Len(t, result, len(tt.expected))
for _, exp := range tt.expected {
found := false
for _, r := range result {
if r["date"] == exp["date"] {
assert.Equal(t, exp["request_count"], r["request_count"])
found = true
break
}
}
assert.True(t, found, "expected result not found: %v", exp)
}
})
}
}
// ============ ProviderService - isUniqueConstraintError 测试 ============
func TestProviderService_isUniqueConstraintError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "UNIQUE constraint failed",
err: errors.New("UNIQUE constraint failed"),
expected: true,
},
{
name: "duplicate key value",
err: errors.New("duplicate key value"),
expected: true,
},
{
name: "UNIQUE constraint case insensitive",
err: errors.New("unique constraint violation"),
expected: true,
},
{
name: "other error",
err: errors.New("some other error"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isUniqueConstraintError(tt.err)
assert.Equal(t, tt.expected, result)
})
}
}
// ============ ProviderService - List API Key 测试 ============
func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
require.NoError(t, svc.Create(provider1))
require.NoError(t, svc.Create(provider2))
providers, err := svc.List()
require.NoError(t, err)
require.Len(t, providers, 2)
expectedKeys := map[string]string{
"openai": "sk-1234567890",
"anthropic": "sk-anthropic1234",
}
for _, p := range providers {
assert.NotContains(t, p.APIKey, "***")
assert.Equal(t, expectedKeys[p.ID], p.APIKey)
}
}
func TestModelService_ConcurrentCreate(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
results := make(chan error, 2)
for i := 0; i < 2; i++ {
go func() {
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
results <- svc.Create(model)
}()
}
err1 := <-results
err2 := <-results
successCount := 0
errorCount := 0
for _, err := range []error{err1, err2} {
if err == nil {
successCount++
} else {
errorCount++
}
}
assert.Equal(t, 1, successCount)
assert.Equal(t, 1, errorCount)
}

View File

@@ -0,0 +1,197 @@
package service
import (
"strings"
"sync"
"sync/atomic"
"time"
"nex/backend/internal/repository"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
)
type StatsBuffer struct {
counters sync.Map
flushInterval time.Duration
flushThreshold int
totalCount atomic.Int64
statsRepo repository.StatsRepository
logger *zap.Logger
stopCh chan struct{}
doneCh chan struct{}
}
type StatsBufferOption func(*StatsBuffer)
func WithFlushInterval(d time.Duration) StatsBufferOption {
return func(b *StatsBuffer) {
b.flushInterval = d
}
}
func WithFlushThreshold(threshold int) StatsBufferOption {
return func(b *StatsBuffer) {
b.flushThreshold = threshold
}
}
func NewStatsBuffer(
statsRepo repository.StatsRepository,
logger *zap.Logger,
opts ...StatsBufferOption,
) *StatsBuffer {
b := &StatsBuffer{
statsRepo: statsRepo,
logger: pkglogger.WithModule(logger, "service.stats_buffer"),
flushInterval: 5 * time.Second,
flushThreshold: 100,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
for _, opt := range opts {
opt(b)
}
return b
}
func (b *StatsBuffer) Increment(providerID, modelName string) {
today := time.Now().Format("2006-01-02")
key := providerID + "/" + modelName + "/" + today
var counter *int64
if v, ok := b.counters.Load(key); ok {
if existing, ok := v.(*int64); ok {
counter = existing
} else {
return
}
} else {
val := int64(0)
counter = &val
actual, loaded := b.counters.LoadOrStore(key, counter)
if loaded {
existing, ok := actual.(*int64)
if !ok {
return
}
counter = existing
}
}
atomic.AddInt64(counter, 1)
if b.totalCount.Add(1) >= int64(b.flushThreshold) {
go b.flush()
}
}
func (b *StatsBuffer) Start() {
go func() {
ticker := time.NewTicker(b.flushInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
b.flush()
case <-b.stopCh:
b.flush()
close(b.doneCh)
return
}
}
}()
}
func (b *StatsBuffer) Stop() {
close(b.stopCh)
<-b.doneCh
}
func (b *StatsBuffer) flush() {
type statEntry struct {
providerID string
modelName string
date string
count int64
}
var entries []statEntry
b.counters.Range(func(key, value interface{}) bool {
keyStr, ok := key.(string)
if !ok {
return true
}
parts := strings.Split(keyStr, "/")
if len(parts) != 3 {
return true
}
counter, ok := value.(*int64)
if !ok {
return true
}
count := atomic.SwapInt64(counter, 0)
if count > 0 {
entries = append(entries, statEntry{
providerID: parts[0],
modelName: parts[1],
date: parts[2],
count: count,
})
}
return true
})
if len(entries) == 0 {
return
}
success := 0
for _, entry := range entries {
date, err := time.Parse("2006-01-02", entry.date)
if err != nil {
b.logger.Error("解析统计日期失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.String("date", entry.date),
zap.Error(err))
continue
}
err = b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
if err != nil {
b.logger.Error("批量更新统计失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.Int64("count", entry.count),
zap.Error(err))
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
if v, ok := b.counters.Load(key); ok {
counter, ok := v.(*int64)
if ok {
atomic.AddInt64(counter, entry.count)
}
}
} else {
success++
}
}
b.totalCount.Store(0)
b.logger.Debug("统计刷新完成",
zap.Int("total", len(entries)),
zap.Int("success", success),
zap.Int("failed", len(entries)-success))
}

View File

@@ -0,0 +1,261 @@
package service
import (
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
type mockStatsRepo struct {
records []struct {
providerID string
modelName string
date time.Time
delta int
}
fail bool
mu sync.Mutex
}
func (m *mockStatsRepo) Record(providerID, modelName string) error {
return nil
}
func (m *mockStatsRepo) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.fail {
return errors.New("db error")
}
m.records = append(m.records, struct {
providerID string
modelName string
date time.Time
delta int
}{providerID, modelName, date, delta})
return nil
}
func (m *mockStatsRepo) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return nil, nil
}
func TestStatsBuffer_Increment(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-3.5")
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter, ok := value.(*int64)
if ok {
count += atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(3), count)
}
func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter, ok := value.(*int64)
if ok {
count = atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(100), count)
}
func TestStatsBuffer_LoadOrStore(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
var counterCount int
buffer.counters.Range(func(key, value interface{}) bool {
counterCount++
return true
})
assert.Equal(t, 1, counterCount)
}
func TestStatsBuffer_FlushByInterval(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(100*time.Millisecond))
buffer.Start()
defer buffer.Stop()
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
time.Sleep(200 * time.Millisecond)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_FlushByThreshold(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushThreshold(10))
buffer.Start()
defer buffer.Stop()
for i := 0; i < 10; i++ {
buffer.Increment("openai", "gpt-4")
}
time.Sleep(50 * time.Millisecond)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_SwapInt64(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
var beforeCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter, ok := value.(*int64)
if ok {
beforeCount = atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(2), beforeCount)
buffer.flush()
var afterCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter, ok := value.(*int64)
if ok {
afterCount = atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(0), afterCount)
}
func TestStatsBuffer_FailRetry(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{fail: true}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
buffer.flush()
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter, ok := value.(*int64)
if ok {
count = atomic.LoadInt64(counter)
}
return true
})
assert.Equal(t, int64(2), count)
}
func TestStatsBuffer_Stop(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(10*time.Second))
buffer.Start()
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
start := time.Now()
buffer.Stop()
elapsed := time.Since(start)
assert.Less(t, elapsed, 1*time.Second)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_ConcurrentIncrementAndFlush(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(50*time.Millisecond))
buffer.Start()
defer buffer.Stop()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
time.Sleep(100 * time.Millisecond)
statsRepo.mu.Lock()
totalDelta := 0
for _, r := range statsRepo.records {
totalDelta += r.delta
}
statsRepo.mu.Unlock()
assert.Equal(t, 100, totalDelta)
}

View File

@@ -6,6 +6,8 @@ import (
"nex/backend/internal/domain"
)
//go:generate go run go.uber.org/mock/mockgen -source=stats_service.go -destination=../../tests/mocks/mock_stats_service.go -package=mocks
// StatsService 统计服务接口
type StatsService interface {
Record(providerID, modelName string) error

Some files were not shown because too many files have changed in this diff Show More