diff --git a/docs/development/backend.md b/docs/development/backend.md index 807b3e0..256ada2 100644 --- a/docs/development/backend.md +++ b/docs/development/backend.md @@ -80,17 +80,19 @@ middleware.ts 提供 API 参数校验函数: `src/server/ai/registry.ts` 提供: -- `buildProviderRegistry(db)` — 从 DB 查询启用的供应商,构建 Vercel AI SDK Provider Registry -- `testProviderConnection(config)` — 使用 generateText 测试供应商连接 +- `buildProviderRegistry(db)` — 从 DB 查询所有供应商,构建 Vercel AI SDK Provider Registry +- `testProviderConnection(config)` — 先测试 Base URL 可达性,再请求 `/models` 验证 API Key 和模型列表接口 -每次 AI 调用时从 DB 查询 enabled providers,构建 registry 后通过 `registry.languageModel('providerId:modelId')` 获取模型实例。不使用缓存层。模型是否存在、是否启用以及业务能力标签由调用方基于 models 表先行校验,registry 只负责将 providerId/modelId 映射到 AI SDK 模型实例。 +每次 AI 调用时从 DB 查询 providers,构建 registry 后通过 `registry.languageModel('providerId:modelId')` 获取模型实例。不使用缓存层。模型是否存在以及业务能力标签由调用方基于 models 表先行校验,registry 只负责将 providerId/modelId 映射到 AI SDK 模型实例。 ### 供应商连通性测试 供应商连通性测试返回 `{ providerTestResponse: { ok, message } }`,前端根据 `ok` 展示成功或失败提示。 -- `POST /api/providers/:id/test` — 使用已保存供应商配置测试连接 - `POST /api/providers/test` — 使用表单中尚未保存的供应商配置测试连接 +- `POST /api/models/test` — 使用模型关联供应商配置和 modelId 测试模型连接 + +测试连接不会写入数据库,也不会阻止保存。Base URL 不可达或 API Key 无效返回 `ok: false`;Base URL 可达但 `/models` 不支持、非标准或返回非鉴权错误时返回 `ok: true` 并在 `message` 中提示用户可检查 URL 或忽略提醒。 ### 支持的供应商类型 diff --git a/docs/development/frontend.md b/docs/development/frontend.md index 1495e3b..46ba5fe 100644 --- a/docs/development/frontend.md +++ b/docs/development/frontend.md @@ -131,9 +131,9 @@ Sidebar(`src/web/components/Sidebar/index.tsx`)是纯展示/导航组件, Workbench 项目上下文通过 `ProjectContext` 提供,在 `WorkbenchProjectGate` 中从 URL path param 读取 `projectId`,通过 `useProject(projectId)` 加载项目,仅 active 项目渲染工作台布局,不存在或 archived 项目显示"项目不存在或不可访问"。 -模型管理页面(`src/web/pages/models/index.tsx`)属于 Admin 路由 `/models`,通过 antd `Tabs` 在同页组织供应商和模型两个视图。页面使用 `ProviderToolbar`、`ProviderTable`、`ProviderFormModal`、`ModelToolbar`、`ModelTable`、`ModelFormModal` 拆分筛选、表格和表单职责;模型表单和模型表格使用独立 provider 列表查询,不能复用供应商标签页当前分页或搜索结果作为全量选项。 +模型管理页面(`src/web/pages/models/index.tsx`)属于 Admin 路由 `/models`,通过 antd `Tabs` 在同页组织供应商和模型两个视图。页面使用 `ModelsToolbar`、`ProviderTable`、`ProviderFormModal`、`ModelTable`、`ModelFormModal` 拆分筛选、表格和表单职责;模型表单和模型表格必须使用 `GET /api/providers/options` 获取最小供应商选项,不能复用供应商标签页当前分页或搜索结果作为全量选项。 -供应商表单必须支持未保存配置的连通性测试,新建供应商时 type 默认 `openai-compatible`,baseURL 不设默认值。连通性测试返回 `ok: false` 时应展示失败反馈,不得使用成功提示样式。 +供应商表单必须支持未保存配置的连通性测试,新建供应商时 type 默认 `openai-compatible`,baseURL 不设默认值。连通性测试返回 `ok: false` 时应展示失败反馈,不得使用成功提示样式;`/models` 不支持或响应格式不兼容属于可忽略提醒,不得阻止保存。 - 生产入口必须启用 `ErrorBoundary`,运行时渲染异常使用 antd `Result status="500"` 或等价组件展示。 - `ReactQueryDevtools` 仅在 `import.meta.env.DEV` 条件下渲染,不进入生产渲染路径。 diff --git a/docs/user/usage.md b/docs/user/usage.md index 37b5905..9c4f4fe 100644 --- a/docs/user/usage.md +++ b/docs/user/usage.md @@ -52,7 +52,7 @@ bun run dev config.yaml 在 Admin 侧栏进入 `/models` 后,页面通过两个标签页管理 AI 基础配置: -- **供应商**:新增、编辑、启用、禁用、删除 OpenAI、Anthropic 或 OpenAI 兼容供应商。新建供应商时类型默认是 `openai-compatible`,baseURL 和 API Key 由用户填写。 -- **模型**:为已启用供应商新增模型,填写模型显示名称、实际调用用的 modelId、能力标签,以及可选的上下文长度和最大输出 token。 +- **供应商**:新增、编辑、删除 OpenAI、Anthropic 或 OpenAI 兼容供应商。新建供应商时类型默认是 `openai-compatible`,baseURL 和 API Key 由用户填写。 +- **模型**:为供应商新增模型,填写模型显示名称、实际调用用的 modelId、能力标签,以及可选的上下文长度和最大输出 token。 -供应商表格和供应商表单都提供“测试连接”操作。测试连接只返回成功或失败提示,不会阻止保存供应商或模型。删除供应商前必须先删除或迁移其关联模型,否则系统会拒绝删除以避免误删模型配置。 +供应商表单提供“测试连接”操作:系统先测试 Base URL 是否可达,再尝试请求 `/models` 验证 API Key 和模型列表接口。若服务不支持 `/models`,页面会提示接口可达但可能不支持模型列表;该结果只作为提醒,不会阻止保存供应商或模型。删除供应商前必须先删除或迁移其关联模型,否则系统会拒绝删除以避免误删模型配置。 diff --git a/drizzle/0002_remove_model_management_enabled.sql b/drizzle/0002_remove_model_management_enabled.sql new file mode 100644 index 0000000..8019c4f --- /dev/null +++ b/drizzle/0002_remove_model_management_enabled.sql @@ -0,0 +1,3 @@ +ALTER TABLE `providers` DROP COLUMN `enabled`; +--> statement-breakpoint +ALTER TABLE `models` DROP COLUMN `enabled`; diff --git a/src/server/ai/registry.ts b/src/server/ai/registry.ts index ff134be..44b8dfa 100644 --- a/src/server/ai/registry.ts +++ b/src/server/ai/registry.ts @@ -8,10 +8,10 @@ import { createProviderRegistry, generateText } from "ai"; import type { AIProviderConfig } from "./types"; export function buildProviderRegistry(db: Database) { - const enabledProviders = getEnabledProviders(db); + const providers = getProviders(db); const providerEntries: Record> = {}; - for (const p of enabledProviders) { + for (const p of providers) { providerEntries[p.id] = createProvider({ apiKey: p.api_key, baseUrl: p.base_url, @@ -23,24 +23,105 @@ export function buildProviderRegistry(db: Database) { return createProviderRegistry(providerEntries); } -export async function testProviderConnection(config: AIProviderConfig): Promise<{ message: string; ok: boolean }> { +export async function testModelConnection( + config: AIProviderConfig & { modelId: string }, +): Promise<{ message: string; ok: boolean }> { try { const provider = createProvider(config); - const model = provider.languageModel("test"); - await generateText({ - maxOutputTokens: 1, - model, + maxOutputTokens: 10, + model: provider.languageModel(config.modelId), prompt: "Hi", }); - - return { message: "连接成功", ok: true }; + return { message: "模型连接成功", ok: true }; } catch (e: unknown) { const msg = e instanceof Error ? e.message : String(e); - return { message: `连接失败: ${msg}`, ok: false }; + return { message: `模型连接失败:${msg}`, ok: false }; } } +export async function testProviderConnection(config: AIProviderConfig): Promise<{ message: string; ok: boolean }> { + const baseUrlResult = await probeBaseUrl(config.baseUrl); + if (!baseUrlResult.ok) return baseUrlResult; + + const modelsUrl = buildModelsUrl(config.baseUrl); + + try { + const response = await fetch(modelsUrl, { + headers: buildModelsHeaders(config), + signal: AbortSignal.timeout(5000), + }); + + if (response.status === 401 || response.status === 403) { + return { message: "Base URL 可连接,但 API Key 无效或权限不足。", ok: false }; + } + + if ([404, 405, 501].includes(response.status)) { + return { + message: "Base URL 可连接,但可能不支持 /models 接口;可检查 URL 或忽略此提示。", + ok: true, + }; + } + + if (!response.ok) { + return { + message: `Base URL 可连接,但 /models 请求失败(HTTP ${response.status});可检查 URL 或忽略此提示。`, + ok: true, + }; + } + + const body = (await response.json().catch(() => null)) as unknown; + const modelCount = countModels(body); + if (modelCount !== null) { + return { message: `连接成功,/models 返回 ${modelCount} 个模型。`, ok: true }; + } + + return { + message: "Base URL 可连接,但 /models 返回格式不兼容,可能不支持 /models;可检查 URL 或忽略此提示。", + ok: true, + }; + } catch (e: unknown) { + const msg = e instanceof Error ? e.message : String(e); + return { message: `Base URL 可连接,但 /models 请求异常:${msg};可检查 URL 或忽略此提示。`, ok: true }; + } +} + +function buildModelsHeaders(config: AIProviderConfig): HeadersInit { + if (config.type === "anthropic") { + return { + accept: "application/json", + "anthropic-version": "2023-06-01", + "x-api-key": config.apiKey, + }; + } + + return { + accept: "application/json", + authorization: `Bearer ${config.apiKey}`, + }; +} + +function buildModelsUrl(baseUrl: string): string { + const url = new URL(baseUrl); + url.pathname = `${url.pathname.replace(/\/$/, "")}/models`; + url.search = ""; + url.hash = ""; + return url.toString(); +} + +function countModels(body: unknown): null | number { + if (Array.isArray(body)) return body.length; + if (!body || typeof body !== "object") return null; + + const data = (body as { data?: unknown }).data; + if (Array.isArray(data)) return data.length; + + const models = (body as { models?: unknown }).models; + if (Array.isArray(models)) return models.length; + + return null; +} + function createProvider(config: AIProviderConfig) { switch (config.type) { case "anthropic": @@ -56,14 +137,14 @@ function createProvider(config: AIProviderConfig) { } } -function getEnabledProviders(db: Database): Array<{ +function getProviders(db: Database): Array<{ api_key: string; base_url: string; id: string; name: string; type: "anthropic" | "openai" | "openai-compatible"; }> { - const stmt = db.prepare("SELECT id, name, type, base_url, api_key FROM providers WHERE enabled = 1"); + const stmt = db.prepare("SELECT id, name, type, base_url, api_key FROM providers"); return stmt.all() as Array<{ api_key: string; base_url: string; @@ -72,3 +153,16 @@ function getEnabledProviders(db: Database): Array<{ type: "anthropic" | "openai" | "openai-compatible"; }>; } + +async function probeBaseUrl(baseUrl: string): Promise<{ message: string; ok: boolean }> { + try { + await fetch(baseUrl, { + method: "HEAD", + signal: AbortSignal.timeout(5000), + }); + return { message: "Base URL 可连接", ok: true }; + } catch (e: unknown) { + const msg = e instanceof Error ? e.message : String(e); + return { message: `Base URL 不可达:${msg}`, ok: false }; + } +} diff --git a/src/server/db/models.ts b/src/server/db/models.ts index a325a93..329a1aa 100644 --- a/src/server/db/models.ts +++ b/src/server/db/models.ts @@ -36,7 +36,6 @@ export function createModel( capabilities: JSON.stringify(capabilities), contextLength: request.contextLength ?? null, createdAt: now, - enabled: true, id, maxOutputTokens: request.maxOutputTokens ?? null, modelId, @@ -66,32 +65,6 @@ export function deleteModel(raw: Database, id: string): { error: string; status: return { success: true }; } -export function disableModel(raw: Database, id: string): { error: string; status: number } | { model: Model } { - const db = wrap(raw); - const existing = db.select().from(models).where(eq(models.id, id)).get(); - if (!existing) return { error: "模型不存在", status: 404 }; - if (!existing.enabled) return { error: "模型已禁用", status: 409 }; - - const now = new Date().toISOString(); - db.update(models).set({ enabled: false, updatedAt: now }).where(eq(models.id, id)).run(); - - const updated = db.select().from(models).where(eq(models.id, id)).get(); - return { model: toModel(updated!) }; -} - -export function enableModel(raw: Database, id: string): { error: string; status: number } | { model: Model } { - const db = wrap(raw); - const existing = db.select().from(models).where(eq(models.id, id)).get(); - if (!existing) return { error: "模型不存在", status: 404 }; - if (existing.enabled) return { error: "模型已启用", status: 409 }; - - const now = new Date().toISOString(); - db.update(models).set({ enabled: true, updatedAt: now }).where(eq(models.id, id)).run(); - - const updated = db.select().from(models).where(eq(models.id, id)).get(); - return { model: toModel(updated!) }; -} - export function getModel(raw: Database, id: string): { error: string; status: number } | { model: Model } { const db = wrap(raw); const row = db.select().from(models).where(eq(models.id, id)).get(); @@ -222,7 +195,6 @@ function toModel(row: typeof models.$inferSelect): Model { capabilities: JSON.parse(row.capabilities) as ModelCapability[], contextLength: row.contextLength, createdAt: row.createdAt, - enabled: row.enabled, id: row.id, maxOutputTokens: row.maxOutputTokens, modelId: row.modelId, diff --git a/src/server/db/providers.ts b/src/server/db/providers.ts index 84f389e..800c11b 100644 --- a/src/server/db/providers.ts +++ b/src/server/db/providers.ts @@ -3,7 +3,7 @@ import type Database from "bun:sqlite"; import { and, desc, eq, like, sql } from "drizzle-orm"; import { drizzle } from "drizzle-orm/bun-sqlite"; -import type { CreateProviderRequest, Provider, UpdateProviderRequest } from "../../shared/api"; +import type { CreateProviderRequest, Provider, ProviderOption, UpdateProviderRequest } from "../../shared/api"; import { providers } from "./schema"; @@ -30,7 +30,6 @@ export function createProvider( apiKey, baseUrl, createdAt: now, - enabled: true, id, name, type: request.type, @@ -58,32 +57,6 @@ export function deleteProvider(raw: Database, id: string): { error: string; stat return { success: true }; } -export function disableProvider(raw: Database, id: string): { error: string; status: number } | { provider: Provider } { - const db = wrap(raw); - const existing = db.select().from(providers).where(eq(providers.id, id)).get(); - if (!existing) return { error: "供应商不存在", status: 404 }; - if (!existing.enabled) return { error: "供应商已禁用", status: 409 }; - - const now = new Date().toISOString(); - db.update(providers).set({ enabled: false, updatedAt: now }).where(eq(providers.id, id)).run(); - - const updated = db.select().from(providers).where(eq(providers.id, id)).get(); - return { provider: toProvider(updated!) }; -} - -export function enableProvider(raw: Database, id: string): { error: string; status: number } | { provider: Provider } { - const db = wrap(raw); - const existing = db.select().from(providers).where(eq(providers.id, id)).get(); - if (!existing) return { error: "供应商不存在", status: 404 }; - if (existing.enabled) return { error: "供应商已启用", status: 409 }; - - const now = new Date().toISOString(); - db.update(providers).set({ enabled: true, updatedAt: now }).where(eq(providers.id, id)).run(); - - const updated = db.select().from(providers).where(eq(providers.id, id)).get(); - return { provider: toProvider(updated!) }; -} - export function getProvider(raw: Database, id: string): { error: string; status: number } | { provider: Provider } { const db = wrap(raw); const row = db.select().from(providers).where(eq(providers.id, id)).get(); @@ -92,6 +65,17 @@ export function getProvider(raw: Database, id: string): { error: string; status: return { provider: toProvider(row) }; } +export function listProviderOptions(raw: Database): ProviderOption[] { + const db = wrap(raw); + const rows = db + .select({ id: providers.id, name: providers.name, type: providers.type }) + .from(providers) + .orderBy(desc(providers.createdAt)) + .all(); + + return rows; +} + export function listProviders( raw: Database, options: { keyword?: string; page: number; pageSize: number }, @@ -189,7 +173,6 @@ function toProvider(row: typeof providers.$inferSelect): Provider { apiKey: row.apiKey, baseUrl: row.baseUrl, createdAt: row.createdAt, - enabled: row.enabled, id: row.id, name: row.name, type: row.type, diff --git a/src/server/db/schema.ts b/src/server/db/schema.ts index 81d6b5b..94f9d94 100644 --- a/src/server/db/schema.ts +++ b/src/server/db/schema.ts @@ -16,7 +16,6 @@ export const providers = sqliteTable("providers", { apiKey: text("api_key").notNull(), baseUrl: text("base_url").notNull(), createdAt: text("created_at").notNull(), - enabled: integer("enabled", { mode: "boolean" }).notNull().default(true), id: text("id").primaryKey(), name: text("name").notNull().unique(), type: text("type", { enum: ["anthropic", "openai", "openai-compatible"] }) @@ -31,7 +30,6 @@ export const models = sqliteTable( capabilities: text("capabilities").notNull(), contextLength: integer("context_length"), createdAt: text("created_at").notNull(), - enabled: integer("enabled", { mode: "boolean" }).notNull().default(true), id: text("id").primaryKey(), maxOutputTokens: integer("max_output_tokens"), modelId: text("model_id").notNull(), diff --git a/src/server/routes/models/create.ts b/src/server/routes/models/create.ts index 55b5d13..a9e1c6d 100644 --- a/src/server/routes/models/create.ts +++ b/src/server/routes/models/create.ts @@ -38,6 +38,12 @@ export async function handleCreateModel(req: Request, db: Database, mode: Runtim return jsonResponse(createApiError(`Invalid capabilities: ${invalidCaps.join(", ")}`, 400), { mode, status: 400 }); } + const numberError = validateOptionalPositiveInteger("contextLength", body.contextLength); + if (numberError) return jsonResponse(createApiError(numberError, 400), { mode, status: 400 }); + + const tokenError = validateOptionalPositiveInteger("maxOutputTokens", body.maxOutputTokens); + if (tokenError) return jsonResponse(createApiError(tokenError, 400), { mode, status: 400 }); + const result = createModel(db, body); if ("error" in result) { return jsonResponse(createApiError(result.error, result.status), { mode, status: result.status }); @@ -45,3 +51,9 @@ export async function handleCreateModel(req: Request, db: Database, mode: Runtim return jsonResponse(result, { mode, status: 201 }); } + +function validateOptionalPositiveInteger(field: string, value: null | number | undefined): null | string { + if (value === undefined || value === null) return null; + if (!Number.isInteger(value) || value <= 0) return `${field} must be a positive integer`; + return null; +} diff --git a/src/server/routes/models/disable.ts b/src/server/routes/models/disable.ts deleted file mode 100644 index ba076d9..0000000 --- a/src/server/routes/models/disable.ts +++ /dev/null @@ -1,22 +0,0 @@ -import type Database from "bun:sqlite"; - -import type { RuntimeMode } from "../../../shared/api"; - -import { disableModel } from "../../db/models"; -import { createApiError, jsonResponse } from "../../helpers"; -import { validateIdParam } from "../../middleware"; - -export function handleDisableModel(req: Request, db: Database, mode: RuntimeMode): Response { - const url = new URL(req.url); - const idStr = url.pathname.split("/")[3]; - - const validated = validateIdParam(idStr ?? "", mode); - if (validated instanceof Response) return validated; - - const result = disableModel(db, validated.id); - if ("error" in result) { - return jsonResponse(createApiError(result.error, result.status), { mode, status: result.status }); - } - - return jsonResponse(result, { mode }); -} diff --git a/src/server/routes/models/enable.ts b/src/server/routes/models/enable.ts deleted file mode 100644 index 264295b..0000000 --- a/src/server/routes/models/enable.ts +++ /dev/null @@ -1,22 +0,0 @@ -import type Database from "bun:sqlite"; - -import type { RuntimeMode } from "../../../shared/api"; - -import { enableModel } from "../../db/models"; -import { createApiError, jsonResponse } from "../../helpers"; -import { validateIdParam } from "../../middleware"; - -export function handleEnableModel(req: Request, db: Database, mode: RuntimeMode): Response { - const url = new URL(req.url); - const idStr = url.pathname.split("/")[3]; - - const validated = validateIdParam(idStr ?? "", mode); - if (validated instanceof Response) return validated; - - const result = enableModel(db, validated.id); - if ("error" in result) { - return jsonResponse(createApiError(result.error, result.status), { mode, status: result.status }); - } - - return jsonResponse(result, { mode }); -} diff --git a/src/server/routes/models/test.ts b/src/server/routes/models/test.ts new file mode 100644 index 0000000..49e0eff --- /dev/null +++ b/src/server/routes/models/test.ts @@ -0,0 +1,42 @@ +import type Database from "bun:sqlite"; + +import type { RuntimeMode, TestModelRequest } from "../../../shared/api"; + +import { testModelConnection } from "../../ai/registry"; +import { getProvider } from "../../db/providers"; +import { createApiError, jsonResponse } from "../../helpers"; + +export async function handleTestModelConfig(req: Request, db: Database, mode: RuntimeMode): Promise { + let body: TestModelRequest; + try { + body = (await req.json()) as TestModelRequest; + } catch { + return jsonResponse(createApiError("Invalid JSON body", 400), { mode, status: 400 }); + } + + if (!body.providerId || typeof body.providerId !== "string") { + return jsonResponse(createApiError("providerId is required", 400), { mode, status: 400 }); + } + + if (!body.modelId || typeof body.modelId !== "string") { + return jsonResponse(createApiError("modelId is required", 400), { mode, status: 400 }); + } + + const providerResult = getProvider(db, body.providerId); + if ("error" in providerResult) { + return jsonResponse(createApiError(providerResult.error, providerResult.status), { + mode, + status: providerResult.status, + }); + } + + const testResult = await testModelConnection({ + apiKey: providerResult.provider.apiKey, + baseUrl: providerResult.provider.baseUrl, + modelId: body.modelId, + name: providerResult.provider.name, + type: providerResult.provider.type, + }); + + return jsonResponse({ modelTestResponse: testResult }, { mode }); +} diff --git a/src/server/routes/models/update.ts b/src/server/routes/models/update.ts index 5267a94..be0a90c 100644 --- a/src/server/routes/models/update.ts +++ b/src/server/routes/models/update.ts @@ -34,6 +34,12 @@ export async function handleUpdateModel(req: Request, db: Database, mode: Runtim } } + const numberError = validateOptionalPositiveInteger("contextLength", body.contextLength); + if (numberError) return jsonResponse(createApiError(numberError, 400), { mode, status: 400 }); + + const tokenError = validateOptionalPositiveInteger("maxOutputTokens", body.maxOutputTokens); + if (tokenError) return jsonResponse(createApiError(tokenError, 400), { mode, status: 400 }); + const result = updateModel(db, validated.id, body); if ("error" in result) { return jsonResponse(createApiError(result.error, result.status), { mode, status: result.status }); @@ -41,3 +47,9 @@ export async function handleUpdateModel(req: Request, db: Database, mode: Runtim return jsonResponse(result, { mode }); } + +function validateOptionalPositiveInteger(field: string, value: null | number | undefined): null | string { + if (value === undefined || value === null) return null; + if (!Number.isInteger(value) || value <= 0) return `${field} must be a positive integer`; + return null; +} diff --git a/src/server/routes/providers/disable.ts b/src/server/routes/providers/disable.ts deleted file mode 100644 index 917a9bb..0000000 --- a/src/server/routes/providers/disable.ts +++ /dev/null @@ -1,22 +0,0 @@ -import type Database from "bun:sqlite"; - -import type { RuntimeMode } from "../../../shared/api"; - -import { disableProvider } from "../../db/providers"; -import { createApiError, jsonResponse } from "../../helpers"; -import { validateIdParam } from "../../middleware"; - -export function handleDisableProvider(req: Request, db: Database, mode: RuntimeMode): Response { - const url = new URL(req.url); - const idStr = url.pathname.split("/")[3]; - - const validated = validateIdParam(idStr ?? "", mode); - if (validated instanceof Response) return validated; - - const result = disableProvider(db, validated.id); - if ("error" in result) { - return jsonResponse(createApiError(result.error, result.status), { mode, status: result.status }); - } - - return jsonResponse(result, { mode }); -} diff --git a/src/server/routes/providers/enable.ts b/src/server/routes/providers/enable.ts deleted file mode 100644 index b9ad432..0000000 --- a/src/server/routes/providers/enable.ts +++ /dev/null @@ -1,22 +0,0 @@ -import type Database from "bun:sqlite"; - -import type { RuntimeMode } from "../../../shared/api"; - -import { enableProvider } from "../../db/providers"; -import { createApiError, jsonResponse } from "../../helpers"; -import { validateIdParam } from "../../middleware"; - -export function handleEnableProvider(req: Request, db: Database, mode: RuntimeMode): Response { - const url = new URL(req.url); - const idStr = url.pathname.split("/")[3]; - - const validated = validateIdParam(idStr ?? "", mode); - if (validated instanceof Response) return validated; - - const result = enableProvider(db, validated.id); - if ("error" in result) { - return jsonResponse(createApiError(result.error, result.status), { mode, status: result.status }); - } - - return jsonResponse(result, { mode }); -} diff --git a/src/server/routes/providers/options.ts b/src/server/routes/providers/options.ts new file mode 100644 index 0000000..a2a9cae --- /dev/null +++ b/src/server/routes/providers/options.ts @@ -0,0 +1,10 @@ +import type Database from "bun:sqlite"; + +import type { RuntimeMode } from "../../../shared/api"; + +import { listProviderOptions } from "../../db/providers"; +import { jsonResponse } from "../../helpers"; + +export function handleListProviderOptions(db: Database, mode: RuntimeMode): Response { + return jsonResponse({ items: listProviderOptions(db) }, { mode }); +} diff --git a/src/server/routes/providers/test.ts b/src/server/routes/providers/test.ts index c1ab803..68a2724 100644 --- a/src/server/routes/providers/test.ts +++ b/src/server/routes/providers/test.ts @@ -3,35 +3,7 @@ import type Database from "bun:sqlite"; import type { CreateProviderRequest, RuntimeMode } from "../../../shared/api"; import { testProviderConnection } from "../../ai/registry"; -import { getProvider } from "../../db/providers"; import { createApiError, jsonResponse } from "../../helpers"; -import { validateIdParam } from "../../middleware"; - -export async function handleTestProvider(req: Request, db: Database, mode: RuntimeMode): Promise { - const url = new URL(req.url); - const idStr = url.pathname.split("/")[3]; - - const validated = validateIdParam(idStr ?? "", mode); - if (validated instanceof Response) return validated; - - const providerResult = getProvider(db, validated.id); - if ("error" in providerResult) { - return jsonResponse(createApiError(providerResult.error, providerResult.status), { - mode, - status: providerResult.status, - }); - } - - const provider = providerResult.provider; - const testResult = await testProviderConnection({ - apiKey: provider.apiKey, - baseUrl: provider.baseUrl, - name: provider.name, - type: provider.type, - }); - - return jsonResponse({ providerTestResponse: testResult }, { mode }); -} export async function handleTestProviderConfig(req: Request, _db: Database, mode: RuntimeMode): Promise { const validated = await readProviderConfig(req, mode); diff --git a/src/server/server.ts b/src/server/server.ts index c7c56f8..6565c7b 100644 --- a/src/server/server.ts +++ b/src/server/server.ts @@ -67,16 +67,10 @@ export function startServer(options: StartServerOptions) { return handleUpdateModel(req, db, mode); }, }, - "/api/models/:id/disable": { + "/api/models/test": { POST: async (req) => { - const { handleDisableModel } = await import("./routes/models/disable"); - return handleDisableModel(req, db, mode); - }, - }, - "/api/models/:id/enable": { - POST: async (req) => { - const { handleEnableModel } = await import("./routes/models/enable"); - return handleEnableModel(req, db, mode); + const { handleTestModelConfig } = await import("./routes/models/test"); + return handleTestModelConfig(req, db, mode); }, }, "/api/projects": { @@ -139,22 +133,10 @@ export function startServer(options: StartServerOptions) { return handleUpdateProvider(req, db, mode); }, }, - "/api/providers/:id/disable": { - POST: async (req) => { - const { handleDisableProvider } = await import("./routes/providers/disable"); - return handleDisableProvider(req, db, mode); - }, - }, - "/api/providers/:id/enable": { - POST: async (req) => { - const { handleEnableProvider } = await import("./routes/providers/enable"); - return handleEnableProvider(req, db, mode); - }, - }, - "/api/providers/:id/test": { - POST: async (req) => { - const { handleTestProvider } = await import("./routes/providers/test"); - return handleTestProvider(req, db, mode); + "/api/providers/options": { + GET: async () => { + const { handleListProviderOptions } = await import("./routes/providers/options"); + return handleListProviderOptions(db, mode); }, }, "/api/providers/test": { diff --git a/src/shared/api.ts b/src/shared/api.ts index bbcc9e6..0413435 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -40,7 +40,6 @@ export interface Model { capabilities: ModelCapability[]; contextLength: null | number; createdAt: string; - enabled: boolean; id: string; maxOutputTokens: null | number; modelId: string; @@ -81,6 +80,15 @@ export interface ModelResponse { model: Model; } +export interface ModelTestResponse { + message: string; + ok: boolean; +} + +export interface ModelTestResultResponse { + modelTestResponse: ModelTestResponse; +} + export interface Project { archivedAt: null | string; createdAt: string; @@ -108,7 +116,6 @@ export interface Provider { apiKey: string; baseUrl: string; createdAt: string; - enabled: boolean; id: string; name: string; type: ProviderType; @@ -122,6 +129,16 @@ export interface ProviderListResponse { total: number; } +export interface ProviderOption { + id: string; + name: string; + type: ProviderType; +} + +export interface ProviderOptionsResponse { + items: ProviderOption[]; +} + export interface ProviderResponse { provider: Provider; } @@ -139,6 +156,11 @@ export type ProviderType = "anthropic" | "openai" | "openai-compatible"; export type RuntimeMode = "development" | "production" | "test"; +export interface TestModelRequest { + modelId: string; + providerId: string; +} + export interface UpdateModelRequest { capabilities?: ModelCapability[]; contextLength?: null | number; diff --git a/src/web/hooks/use-models.ts b/src/web/hooks/use-models.ts index 07b3498..25a0f6f 100644 --- a/src/web/hooks/use-models.ts +++ b/src/web/hooks/use-models.ts @@ -1,6 +1,15 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; -import type { CreateModelRequest, Model, ModelListResponse, ModelResponse, UpdateModelRequest } from "../../shared/api"; +import type { + CreateModelRequest, + Model, + ModelListResponse, + ModelResponse, + ModelTestResponse, + ModelTestResultResponse, + TestModelRequest, + UpdateModelRequest, +} from "../../shared/api"; const MODELS_KEY = ["models"] as const; @@ -21,16 +30,6 @@ export async function deleteModel(id: string): Promise { } } -export async function disableModel(id: string): Promise { - const response = await fetch(`/api/models/${id}/disable`, { method: "POST" }); - return handleResponse(response); -} - -export async function enableModel(id: string): Promise { - const response = await fetch(`/api/models/${id}/enable`, { method: "POST" }); - return handleResponse(response); -} - export async function fetchModel(id: string): Promise { const response = await fetch(`/api/models/${id}`); return handleResponse(response); @@ -57,6 +56,20 @@ export async function fetchModelList(params: { return response.json() as Promise; } +export async function testModelConnection(data: TestModelRequest): Promise { + const response = await fetch("/api/models/test", { + body: JSON.stringify(data), + headers: { "Content-Type": "application/json" }, + method: "POST", + }); + if (!response.ok) { + const body = (await response.json().catch(() => null)) as null | { error?: string }; + throw new Error(body?.error ?? `HTTP ${response.status}`); + } + const result = (await response.json()) as ModelTestResultResponse; + return result.modelTestResponse; +} + export async function updateModel(id: string, data: UpdateModelRequest): Promise { const response = await fetch(`/api/models/${id}`, { body: JSON.stringify(data), @@ -86,26 +99,6 @@ export function useDeleteModel() { }); } -export function useDisableModel() { - const queryClient = useQueryClient(); - return useMutation({ - mutationFn: disableModel, - onSuccess: () => { - void queryClient.invalidateQueries({ queryKey: MODELS_KEY }); - }, - }); -} - -export function useEnableModel() { - const queryClient = useQueryClient(); - return useMutation({ - mutationFn: enableModel, - onSuccess: () => { - void queryClient.invalidateQueries({ queryKey: MODELS_KEY }); - }, - }); -} - export function useModel(id: string) { return useQuery({ enabled: !!id, @@ -121,6 +114,12 @@ export function useModelList(params: { keyword?: string; page?: number; pageSize }); } +export function useTestModelConnection() { + return useMutation({ + mutationFn: testModelConnection, + }); +} + export function useUpdateModel() { const queryClient = useQueryClient(); return useMutation({ diff --git a/src/web/hooks/use-providers.ts b/src/web/hooks/use-providers.ts index 7ec0a19..13d0d93 100644 --- a/src/web/hooks/use-providers.ts +++ b/src/web/hooks/use-providers.ts @@ -4,6 +4,7 @@ import type { CreateProviderRequest, Provider, ProviderListResponse, + ProviderOptionsResponse, ProviderResponse, ProviderTestResponse, ProviderTestResultResponse, @@ -30,16 +31,6 @@ export async function deleteProvider(id: string): Promise { } } -export async function disableProvider(id: string): Promise { - const response = await fetch(`/api/providers/${id}/disable`, { method: "POST" }); - return handleResponse(response); -} - -export async function enableProvider(id: string): Promise { - const response = await fetch(`/api/providers/${id}/enable`, { method: "POST" }); - return handleResponse(response); -} - export async function fetchProvider(id: string): Promise { const response = await fetch(`/api/providers/${id}`); return handleResponse(response); @@ -64,6 +55,15 @@ export async function fetchProviderList(params: { return response.json() as Promise; } +export async function fetchProviderOptions(): Promise { + const response = await fetch("/api/providers/options"); + if (!response.ok) { + const body = (await response.json().catch(() => null)) as null | { error?: string }; + throw new Error(body?.error ?? `HTTP ${response.status}`); + } + return response.json() as Promise; +} + export async function testProviderConfig(data: CreateProviderRequest): Promise { const response = await fetch("/api/providers/test", { body: JSON.stringify(data), @@ -78,16 +78,6 @@ export async function testProviderConfig(data: CreateProviderRequest): Promise

{ - const response = await fetch(`/api/providers/${id}/test`, { method: "POST" }); - if (!response.ok) { - const body = (await response.json().catch(() => null)) as null | { error?: string }; - throw new Error(body?.error ?? `HTTP ${response.status}`); - } - const data = (await response.json()) as ProviderTestResultResponse; - return data.providerTestResponse; -} - export async function updateProvider(id: string, data: UpdateProviderRequest): Promise { const response = await fetch(`/api/providers/${id}`, { body: JSON.stringify(data), @@ -118,26 +108,6 @@ export function useDeleteProvider() { }); } -export function useDisableProvider() { - const queryClient = useQueryClient(); - return useMutation({ - mutationFn: disableProvider, - onSuccess: () => { - void queryClient.invalidateQueries({ queryKey: PROVIDERS_KEY }); - }, - }); -} - -export function useEnableProvider() { - const queryClient = useQueryClient(); - return useMutation({ - mutationFn: enableProvider, - onSuccess: () => { - void queryClient.invalidateQueries({ queryKey: PROVIDERS_KEY }); - }, - }); -} - export function useProvider(id: string) { return useQuery({ enabled: !!id, @@ -153,15 +123,16 @@ export function useProviderList(params: { keyword?: string; page?: number; pageS }); } -export function useTestProviderConfig() { - return useMutation({ - mutationFn: testProviderConfig, +export function useProviderOptions() { + return useQuery({ + queryFn: fetchProviderOptions, + queryKey: [...PROVIDERS_KEY, "options"], }); } -export function useTestProviderConnection() { +export function useTestProviderConfig() { return useMutation({ - mutationFn: testProviderConnection, + mutationFn: testProviderConfig, }); } diff --git a/src/web/pages/models/components/ModelFormModal.tsx b/src/web/pages/models/components/ModelFormModal.tsx index e1c367f..3401bf8 100644 --- a/src/web/pages/models/components/ModelFormModal.tsx +++ b/src/web/pages/models/components/ModelFormModal.tsx @@ -5,8 +5,9 @@ import type { CreateModelRequest, Model, ModelCapability, - Provider, - ProviderTestResponse, + ModelTestResponse, + ProviderOption, + TestModelRequest, UpdateModelRequest, } from "../../../../shared/api"; @@ -26,11 +27,15 @@ interface ModelFormModalProps { onOpenChange: (open: boolean) => void; onUpdate: (args: { data: UpdateModelRequest; id: string }) => Promise; open: boolean; - providers: Provider[]; + providers: ProviderOption[]; + providersError: Error | null; + providersLoading: boolean; submitting: boolean; - testConnection?: (providerId: string) => Promise; + testModelConnection?: (data: TestModelRequest) => Promise; } +const DEFAULT_CAPABILITIES: ModelCapability[] = ["text", "reasoning"]; + const CAPABILITY_OPTIONS: Array<{ label: string; value: ModelCapability }> = [ { label: "文本", value: "text" }, { label: "推理", value: "reasoning" }, @@ -50,8 +55,10 @@ export function ModelFormModal({ onUpdate, open, providers, + providersError, + providersLoading, submitting, - testConnection, + testModelConnection, }: ModelFormModalProps) { const { message } = AntApp.useApp(); const [form] = Form.useForm(); @@ -70,6 +77,7 @@ export function ModelFormModal({ }); } else { form.resetFields(); + form.setFieldsValue({ capabilities: DEFAULT_CAPABILITIES }); } }, [editingModel, form, open]); @@ -109,15 +117,20 @@ export function ModelFormModal({ }; const handleTest = async () => { - if (!testConnection) return; + if (!testModelConnection) return; const providerId: unknown = form.getFieldValue("providerId"); + const modelId: unknown = form.getFieldValue("modelId"); if (typeof providerId !== "string" || !providerId) { message.warning("请先选择供应商"); return; } + if (typeof modelId !== "string" || !modelId) { + message.warning("请先输入模型 ID"); + return; + } setTesting(true); try { - const result = await testConnection(providerId); + const result = await testModelConnection({ modelId, providerId }); if (result.ok) { message.success(result.message); } else { @@ -130,7 +143,7 @@ export function ModelFormModal({ } }; - const providerOptions = providers.filter((p) => p.enabled).map((p) => ({ label: p.name, value: p.id })); + const providerOptions = providers.map((p) => ({ label: p.name, value: p.id })); return ( - - + {CAPABILITY_OPTIONS.map((opt) => ( - + {opt.label} ))} - - - - - - - {testConnection && ( + + + + + + + + + + + + + {testModelConnection && ( - {record.enabled ? ( - void handleDisable(record.id)} title="确认禁用此模型?"> - - - ) : ( - - )} void handleDelete(record.id)} @@ -152,7 +102,7 @@ export function ModelTable({ ), title: "操作", - width: 220, + width: 180, }; return ( @@ -169,13 +119,6 @@ export function ModelTable({ total: data?.total ?? 0, }} rowKey="id" - scroll={{ x: 1100 }} /> ); } - -function formatDatetime(dateStr: string): string { - const d = new Date(dateStr); - const pad = (n: number) => String(n).padStart(2, "0"); - return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())} ${pad(d.getHours())}:${pad(d.getMinutes())}:${pad(d.getSeconds())}`; -} diff --git a/src/web/pages/models/components/ModelToolbar.tsx b/src/web/pages/models/components/ModelToolbar.tsx deleted file mode 100644 index 9b007c7..0000000 --- a/src/web/pages/models/components/ModelToolbar.tsx +++ /dev/null @@ -1,34 +0,0 @@ -import { PlusOutlined } from "@ant-design/icons"; -import { Button, Flex, Input } from "antd"; -import { useState } from "react"; - -interface ModelToolbarProps { - keyword: string; - onSearch: (value: string) => void; - onSearchClear: () => void; - openCreateDialog: () => void; -} - -export function ModelToolbar({ keyword, onSearch, onSearchClear, openCreateDialog }: ModelToolbarProps) { - const [draftKeyword, setDraftKeyword] = useState(keyword); - - return ( - - setDraftKeyword(event.target.value)} - onClear={() => { - setDraftKeyword(""); - onSearchClear(); - }} - onSearch={(value) => onSearch(value)} - placeholder="搜索模型名称或 ID" - value={draftKeyword} - /> - - - ); -} diff --git a/src/web/pages/models/components/ModelsToolbar.tsx b/src/web/pages/models/components/ModelsToolbar.tsx new file mode 100644 index 0000000..2d17902 --- /dev/null +++ b/src/web/pages/models/components/ModelsToolbar.tsx @@ -0,0 +1,53 @@ +import { PlusOutlined } from "@ant-design/icons"; +import { Button, Flex, Input, Tabs } from "antd"; +import { useState } from "react"; + +interface ModelsToolbarProps { + activeTab: string; + keyword: string; + onSearch: (value: string) => void; + onSearchClear: () => void; + onTabChange: (key: string) => void; + openCreateDialog: () => void; +} + +const TAB_ITEMS = [ + { key: "models", label: "模型" }, + { key: "providers", label: "供应商" }, +]; + +export function ModelsToolbar({ + activeTab, + keyword, + onSearch, + onSearchClear, + onTabChange, + openCreateDialog, +}: ModelsToolbarProps) { + const [draftKeyword, setDraftKeyword] = useState(keyword); + const placeholder = activeTab === "providers" ? "搜索供应商名称" : "搜索模型名称或 ID"; + const createLabel = activeTab === "providers" ? "新建供应商" : "新建模型"; + + return ( + + + + setDraftKeyword(event.target.value)} + onClear={() => { + setDraftKeyword(""); + onSearchClear(); + }} + onSearch={(value) => onSearch(value)} + placeholder={placeholder} + value={draftKeyword} + /> + + + + ); +} diff --git a/src/web/pages/models/components/ProviderTable.tsx b/src/web/pages/models/components/ProviderTable.tsx index 6a80f74..e4a8a58 100644 --- a/src/web/pages/models/components/ProviderTable.tsx +++ b/src/web/pages/models/components/ProviderTable.tsx @@ -1,25 +1,16 @@ import type { ColumnsType } from "antd/es/table"; -import { - CheckCircleOutlined, - DeleteOutlined, - EditOutlined, - StopOutlined, - ThunderboltOutlined, -} from "@ant-design/icons"; -import { App as AntApp, Button, Popconfirm, Space, Table, Tag, Tooltip } from "antd"; +import { DeleteOutlined, EditOutlined } from "@ant-design/icons"; +import { App as AntApp, Button, Popconfirm, Space, Table } from "antd"; -import type { Provider, ProviderListResponse, ProviderTestResponse } from "../../../../shared/api"; +import type { Provider, ProviderListResponse } from "../../../../shared/api"; interface ProviderTableProps { data: ProviderListResponse | undefined; loading: boolean; onDelete: (id: string) => Promise; - onDisable: (id: string) => Promise; onEdit: (provider: Provider) => void; - onEnable: (id: string) => Promise; onPageChange: (page: number, pageSize: number) => void; - onTest: (id: string) => Promise; page: number; pageSize: number; } @@ -31,63 +22,19 @@ const TYPE_LABELS: Record = { }; const COLUMNS: ColumnsType = [ - { dataIndex: "name", ellipsis: true, title: "供应商名称", width: 160 }, + { dataIndex: "name", ellipsis: true, title: "名称", width: 180 }, { - align: "center", dataIndex: "type", render: (value: Provider["type"]) => TYPE_LABELS[value] ?? value, title: "类型", - width: 130, + width: 140, }, { dataIndex: "baseUrl", ellipsis: true, title: "Base URL" }, - { - align: "center", - dataIndex: "enabled", - render: (value: boolean) => (value ? 已启用 : 已禁用), - title: "状态", - width: 100, - }, - { - align: "center", - dataIndex: "createdAt", - render: (_value: unknown, record: Provider) => formatDatetime(record.createdAt), - title: "创建时间", - width: 185, - }, ]; -export function ProviderTable({ - data, - loading, - onDelete, - onDisable, - onEdit, - onEnable, - onPageChange, - onTest, - page, - pageSize, -}: ProviderTableProps) { +export function ProviderTable({ data, loading, onDelete, onEdit, onPageChange, page, pageSize }: ProviderTableProps) { const { message } = AntApp.useApp(); - const handleEnable = async (id: string) => { - try { - await onEnable(id); - message.success("供应商已启用"); - } catch (err) { - message.error((err as Error).message); - } - }; - - const handleDisable = async (id: string) => { - try { - await onDisable(id); - message.success("供应商已禁用"); - } catch (err) { - message.error((err as Error).message); - } - }; - const handleDelete = async (id: string) => { try { await onDelete(id); @@ -97,49 +44,15 @@ export function ProviderTable({ } }; - const handleTest = async (id: string) => { - try { - const result = await onTest(id); - if (result.ok) { - message.success(result.message); - } else { - message.error(result.message); - } - } catch (err) { - message.error((err as Error).message); - } - }; - const operationColumn: ColumnsType[number] = { dataIndex: "op", - fixed: "right", render: (_value: unknown, record: Provider) => ( - - - {record.enabled ? ( - void handleDisable(record.id)} title="确认禁用此供应商?"> - - - ) : ( - - )} void handleDelete(record.id)} title="确认删除此供应商?" > @@ -150,7 +63,7 @@ export function ProviderTable({ ), title: "操作", - width: 280, + width: 180, }; return ( @@ -167,13 +80,6 @@ export function ProviderTable({ total: data?.total ?? 0, }} rowKey="id" - scroll={{ x: 900 }} /> ); } - -function formatDatetime(dateStr: string): string { - const d = new Date(dateStr); - const pad = (n: number) => String(n).padStart(2, "0"); - return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())} ${pad(d.getHours())}:${pad(d.getMinutes())}:${pad(d.getSeconds())}`; -} diff --git a/src/web/pages/models/components/ProviderToolbar.tsx b/src/web/pages/models/components/ProviderToolbar.tsx deleted file mode 100644 index c69a24d..0000000 --- a/src/web/pages/models/components/ProviderToolbar.tsx +++ /dev/null @@ -1,34 +0,0 @@ -import { PlusOutlined } from "@ant-design/icons"; -import { Button, Flex, Input } from "antd"; -import { useState } from "react"; - -interface ProviderToolbarProps { - keyword: string; - onSearch: (value: string) => void; - onSearchClear: () => void; - openCreateDialog: () => void; -} - -export function ProviderToolbar({ keyword, onSearch, onSearchClear, openCreateDialog }: ProviderToolbarProps) { - const [draftKeyword, setDraftKeyword] = useState(keyword); - - return ( - - setDraftKeyword(event.target.value)} - onClear={() => { - setDraftKeyword(""); - onSearchClear(); - }} - onSearch={(value) => onSearch(value)} - placeholder="搜索供应商名称" - value={draftKeyword} - /> - - - ); -} diff --git a/src/web/pages/models/index.tsx b/src/web/pages/models/index.tsx index 11c4d13..90fea1c 100644 --- a/src/web/pages/models/index.tsx +++ b/src/web/pages/models/index.tsx @@ -1,35 +1,31 @@ -import { Flex, Tabs } from "antd"; +import { Flex } from "antd"; import { useState } from "react"; -import type { Model, Provider } from "../../../shared/api"; +import type { Model, Provider, TestModelRequest } from "../../../shared/api"; import { useCreateModel, useDeleteModel, - useDisableModel, - useEnableModel, useModelList, + useTestModelConnection, useUpdateModel, } from "../../hooks/use-models"; import { useCreateProvider, useDeleteProvider, - useDisableProvider, - useEnableProvider, useProviderList, + useProviderOptions, useTestProviderConfig, - useTestProviderConnection, useUpdateProvider, } from "../../hooks/use-providers"; import { ModelFormModal } from "./components/ModelFormModal"; +import { ModelsToolbar } from "./components/ModelsToolbar"; import { ModelTable } from "./components/ModelTable"; -import { ModelToolbar } from "./components/ModelToolbar"; import { ProviderFormModal } from "./components/ProviderFormModal"; import { ProviderTable } from "./components/ProviderTable"; -import { ProviderToolbar } from "./components/ProviderToolbar"; export function ModelsPage() { - const [activeTab, setActiveTab] = useState("providers"); + const [activeTab, setActiveTab] = useState("models"); const [providerPage, setProviderPage] = useState(1); const [providerPageSize, setProviderPageSize] = useState(20); @@ -49,10 +45,12 @@ export function ModelsPage() { pageSize: providerPageSize, }); - const { data: modelProviderData, isLoading: modelProviderLoading } = useProviderList({ - page: 1, - pageSize: 1000, - }); + const { + data: providerOptionsData, + error: providerOptionsError, + isError: providerOptionsIsError, + isLoading: providerOptionsLoading, + } = useProviderOptions(); const { data: modelData, isLoading: modelLoading } = useModelList({ keyword: modelKeyword || undefined, @@ -63,69 +61,81 @@ export function ModelsPage() { const createProviderMutation = useCreateProvider(); const updateProviderMutation = useUpdateProvider(); const deleteProviderMutation = useDeleteProvider(); - const enableProviderMutation = useEnableProvider(); - const disableProviderMutation = useDisableProvider(); - const testProviderMutation = useTestProviderConnection(); const testProviderConfigMutation = useTestProviderConfig(); const createModelMutation = useCreateModel(); const updateModelMutation = useUpdateModel(); const deleteModelMutation = useDeleteModel(); - const enableModelMutation = useEnableModel(); - const disableModelMutation = useDisableModel(); + const testModelMutation = useTestModelConnection(); const isProviderSubmitting = createProviderMutation.isPending || updateProviderMutation.isPending; - const isProviderActionPending = - deleteProviderMutation.isPending || enableProviderMutation.isPending || disableProviderMutation.isPending; + const isProviderActionPending = deleteProviderMutation.isPending; const isModelSubmitting = createModelMutation.isPending || updateModelMutation.isPending; - const isModelActionPending = - deleteModelMutation.isPending || enableModelMutation.isPending || disableModelMutation.isPending; - const modelProviders = modelProviderData?.items ?? []; + const isModelActionPending = deleteModelMutation.isPending; + const modelProviders = providerOptionsData?.items ?? []; + + const currentKeyword = activeTab === "providers" ? providerKeyword : modelKeyword; + + const handleSearch = + activeTab === "providers" + ? (value: string) => { + setProviderKeyword(value); + setProviderPage(1); + } + : (value: string) => { + setModelKeyword(value); + setModelPage(1); + }; + + const handleSearchClear = + activeTab === "providers" + ? () => { + setProviderKeyword(""); + setProviderPage(1); + } + : () => { + setModelKeyword(""); + setModelPage(1); + }; + + const handleOpenCreate = + activeTab === "providers" + ? () => { + setEditingProvider(null); + setProviderDialogOpen(true); + } + : () => { + setEditingModel(null); + setModelDialogOpen(true); + }; return ( - setActiveTab(key)} + setActiveTab(key)} + openCreateDialog={handleOpenCreate} /> {activeTab === "providers" && ( <> - { - setProviderKeyword(value); - setProviderPage(1); - }} - onSearchClear={() => { - setProviderKeyword(""); - setProviderPage(1); - }} - openCreateDialog={() => { - setEditingProvider(null); - setProviderDialogOpen(true); - }} - /> deleteProviderMutation.mutateAsync(id)} - onDisable={(id) => disableProviderMutation.mutateAsync(id)} onEdit={(provider) => { setEditingProvider(provider); setProviderDialogOpen(true); }} - onEnable={(id) => enableProviderMutation.mutateAsync(id)} onPageChange={(p, ps) => { setProviderPage(p); setProviderPageSize(ps); }} - onTest={(id) => testProviderMutation.mutateAsync(id)} page={providerPage} pageSize={providerPageSize} /> @@ -137,38 +147,21 @@ export function ModelsPage() { onTest={(data) => testProviderConfigMutation.mutateAsync(data)} onUpdate={(args) => updateProviderMutation.mutateAsync(args)} open={providerDialogOpen} - submitting={isProviderSubmitting || testProviderConfigMutation.isPending} + submitting={isProviderSubmitting} /> )} {activeTab === "models" && ( <> - { - setModelKeyword(value); - setModelPage(1); - }} - onSearchClear={() => { - setModelKeyword(""); - setModelPage(1); - }} - openCreateDialog={() => { - setEditingModel(null); - setModelDialogOpen(true); - }} - /> deleteModelMutation.mutateAsync(id)} - onDisable={(id) => disableModelMutation.mutateAsync(id)} onEdit={(model) => { setEditingModel(model); setModelDialogOpen(true); }} - onEnable={(id) => enableModelMutation.mutateAsync(id)} onPageChange={(p, ps) => { setModelPage(p); setModelPageSize(ps); @@ -185,8 +178,10 @@ export function ModelsPage() { onUpdate={(args) => updateModelMutation.mutateAsync(args)} open={modelDialogOpen} providers={modelProviders} + providersError={providerOptionsIsError ? providerOptionsError : null} + providersLoading={providerOptionsLoading} submitting={isModelSubmitting} - testConnection={(id: string) => testProviderMutation.mutateAsync(id)} + testModelConnection={(data: TestModelRequest) => testModelMutation.mutateAsync(data)} /> )} diff --git a/tests/server/ai/registry.test.ts b/tests/server/ai/registry.test.ts index 1928c94..66f74e0 100644 --- a/tests/server/ai/registry.test.ts +++ b/tests/server/ai/registry.test.ts @@ -2,8 +2,6 @@ import { describe, expect, mock, test } from "bun:test"; import { createMigratedTestDatabase } from "../../helpers"; -let generateTextImpl: (_opts: unknown) => unknown = () => ({}); - void mock.module("ai", () => ({ createProviderRegistry: (providers: Record unknown }>) => ({ languageModel: (id: string) => { @@ -13,70 +11,116 @@ void mock.module("ai", () => ({ return provider.languageModel(modelId); }, }), - generateText: mock((opts: unknown) => generateTextImpl(opts)), + generateText: () => Promise.resolve({ text: "Hi" }), })); -describe("AI registry", () => { - test("testProviderConnection rejects invalid config", async () => { - generateTextImpl = () => { - throw new Error("Connection failed"); - }; +async function withProviderServer( + modelsResponse: Response, + callback: (baseUrl: string) => Promise, +): Promise { + const server = Bun.serve({ + fetch(request) { + if (request.method === "HEAD") return new Response(null, { status: 200 }); + return modelsResponse; + }, + port: 0, + }); + try { + await callback(`http://127.0.0.1:${server.port}/v1`); + } finally { + await server.stop(true); + } +} +describe("AI registry", () => { + test("testProviderConnection reports unreachable Base URL", async () => { const { testProviderConnection } = await import("../../../src/server/ai/registry"); const result = await testProviderConnection({ apiKey: "bad-key", - baseUrl: "https://0.0.0.0:1", + baseUrl: "http://127.0.0.1:1", name: "Bad", type: "openai-compatible", }); expect(result.ok).toBe(false); - expect(result.message).toContain("连接失败"); - expect(typeof result.message).toBe("string"); + expect(result.message).toContain("Base URL 不可达"); + }); + + test("testProviderConnection rejects invalid config", async () => { + await withProviderServer(new Response(null, { status: 401 }), async (baseUrl) => { + const { testProviderConnection } = await import("../../../src/server/ai/registry"); + + const result = await testProviderConnection({ + apiKey: "bad-key", + baseUrl, + name: "Bad", + type: "openai-compatible", + }); + + expect(result.ok).toBe(false); + expect(result.message).toContain("API Key 无效"); + expect(typeof result.message).toBe("string"); + }); }); test("testProviderConnection return shape is correct", async () => { - generateTextImpl = () => ({}); + await withProviderServer(Response.json({ data: [{ id: "gpt-4o" }] }), async (baseUrl) => { + const { testProviderConnection } = await import("../../../src/server/ai/registry"); - const { testProviderConnection } = await import("../../../src/server/ai/registry"); + const result = await testProviderConnection({ + apiKey: "sk-test", + baseUrl, + name: "Test", + type: "openai", + }); - const result = await testProviderConnection({ - apiKey: "sk-test", - baseUrl: "https://api.openai.com/v1", - name: "Test", - type: "openai", + expect(result.ok).toBe(true); + expect(result.message).toContain("/models 返回 1 个模型"); }); - - expect(result.ok).toBe(true); - expect(result.message).toBe("连接成功"); }); - test("buildProviderRegistry 从 DB 构建包含启用供应商的注册表", async () => { + test("testProviderConnection treats unsupported /models as non-blocking", async () => { + await withProviderServer(new Response(null, { status: 404 }), async (baseUrl) => { + const { testProviderConnection } = await import("../../../src/server/ai/registry"); + + const result = await testProviderConnection({ + apiKey: "sk-test", + baseUrl, + name: "Test", + type: "openai", + }); + + expect(result.ok).toBe(true); + expect(result.message).toContain("可能不支持 /models"); + }); + }); + + test("buildProviderRegistry 从 DB 构建包含所有供应商的注册表", async () => { const handle = createMigratedTestDatabase("registry-build-test"); const now = new Date().toISOString(); handle.db .prepare( - "INSERT INTO providers (id, name, type, base_url, api_key, enabled, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + "INSERT INTO providers (id, name, type, base_url, api_key, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", ) - .run("pv1", "OpenAI", "openai", "https://api.openai.com/v1", "sk-test", 1, now, now); + .run("pv1", "OpenAI", "openai", "https://api.openai.com/v1", "sk-test", now, now); handle.db .prepare( - "INSERT INTO providers (id, name, type, base_url, api_key, enabled, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + "INSERT INTO providers (id, name, type, base_url, api_key, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", ) - .run("pv2", "Disabled", "anthropic", "https://api.anthropic.com", "sk-off", 0, now, now); + .run("pv2", "Anthropic", "anthropic", "https://api.anthropic.com", "sk-off", now, now); const { buildProviderRegistry } = await import("../../../src/server/ai/registry"); const registry = buildProviderRegistry(handle.db); expect(() => registry.languageModel("pv1:gpt-4o")).not.toThrow(); - expect(() => registry.languageModel("pv2:claude-3")).toThrow(); + expect(() => registry.languageModel("pv2:claude-3")).not.toThrow(); handle.cleanup(); }); - test("buildProviderRegistry 无启用供应商时返回空注册表", async () => { + test("buildProviderRegistry 无供应商时返回空注册表", async () => { const handle = createMigratedTestDatabase("registry-empty-test"); const { buildProviderRegistry } = await import("../../../src/server/ai/registry"); @@ -86,4 +130,19 @@ describe("AI registry", () => { handle.cleanup(); }); + + test("testModelConnection 成功返回 ok:true", async () => { + const { testModelConnection } = await import("../../../src/server/ai/registry"); + + const result = await testModelConnection({ + apiKey: "sk-test", + baseUrl: "https://api.openai.com/v1", + modelId: "gpt-4o", + name: "Test", + type: "openai", + }); + + expect(result.ok).toBe(true); + expect(result.message).toContain("模型连接成功"); + }); }); diff --git a/tests/server/db/models.test.ts b/tests/server/db/models.test.ts index b8bca13..da24279 100644 --- a/tests/server/db/models.test.ts +++ b/tests/server/db/models.test.ts @@ -5,8 +5,6 @@ import { describe, expect, test } from "bun:test"; import { createModel, deleteModel, - disableModel, - enableModel, getModel, getModelsByProviderId, listModels, @@ -41,16 +39,12 @@ describe("模型数据访问层", () => { providerId, }); expect("error" in result).toBe(false); - const model = ( - result as { - model: { capabilities: string[]; enabled: boolean; modelId: string; name: string; providerId: string }; - } - ).model; + const model = (result as { model: { capabilities: string[]; modelId: string; name: string; providerId: string } }) + .model; expect(model.name).toBe("GPT-4o"); expect(model.modelId).toBe("gpt-4o"); expect(model.providerId).toBe(providerId); expect(model.capabilities).toEqual(["text", "reasoning"]); - expect(model.enabled).toBe(true); }); }); @@ -150,35 +144,6 @@ describe("模型数据访问层", () => { }); }); - test("启用/禁用模型", () => { - withDb((db) => { - const providerId = seedProvider(db); - const created = createModel(db, { capabilities: ["text"], modelId: "gpt-4o", name: "测试", providerId }); - const id = (created as { model: { id: string } }).model.id; - - const disabled = disableModel(db, id); - expect("error" in disabled).toBe(false); - expect((disabled as { model: { enabled: boolean } }).model.enabled).toBe(false); - - const enabled = enableModel(db, id); - expect("error" in enabled).toBe(false); - expect((enabled as { model: { enabled: boolean } }).model.enabled).toBe(true); - }); - }); - - test("重复禁用失败", () => { - withDb((db) => { - const providerId = seedProvider(db); - const created = createModel(db, { capabilities: ["text"], modelId: "gpt-4o", name: "测试", providerId }); - const id = (created as { model: { id: string } }).model.id; - disableModel(db, id); - - const result = disableModel(db, id); - expect("error" in result).toBe(true); - expect((result as unknown as { status: number }).status).toBe(409); - }); - }); - test("删除模型", () => { withDb((db) => { const providerId = seedProvider(db); diff --git a/tests/server/db/providers.test.ts b/tests/server/db/providers.test.ts index a30d441..f922cec 100644 --- a/tests/server/db/providers.test.ts +++ b/tests/server/db/providers.test.ts @@ -5,9 +5,8 @@ import { describe, expect, test } from "bun:test"; import { createProvider, deleteProvider, - disableProvider, - enableProvider, getProvider, + listProviderOptions, listProviders, updateProvider, } from "../../../src/server/db/providers"; @@ -24,6 +23,16 @@ function withDb(callback: (db: Database) => void): void { } describe("供应商数据访问层", () => { + test("迁移后的供应商和模型表不包含 enabled 字段", () => { + withDb((db) => { + const providerColumns = db.query("PRAGMA table_info(providers)").all() as Array<{ name: string }>; + const modelColumns = db.query("PRAGMA table_info(models)").all() as Array<{ name: string }>; + + expect(providerColumns.map((column) => column.name)).not.toContain("enabled"); + expect(modelColumns.map((column) => column.name)).not.toContain("enabled"); + }); + }); + test("创建供应商", () => { withDb((db) => { const result = createProvider(db, { @@ -33,14 +42,12 @@ describe("供应商数据访问层", () => { type: "openai", }); expect("error" in result).toBe(false); - const provider = ( - result as { provider: { apiKey: string; baseUrl: string; enabled: boolean; name: string; type: string } } - ).provider; + const provider = (result as { provider: { apiKey: string; baseUrl: string; name: string; type: string } }) + .provider; expect(provider.name).toBe("OpenAI"); expect(provider.type).toBe("openai"); expect(provider.baseUrl).toBe("https://api.openai.com/v1"); expect(provider.apiKey).toBe("sk-test"); - expect(provider.enabled).toBe(true); }); }); @@ -121,44 +128,6 @@ describe("供应商数据访问层", () => { }); }); - test("启用/禁用供应商", () => { - withDb((db) => { - const created = createProvider(db, { apiKey: "sk", baseUrl: "https://a.com", name: "测试", type: "openai" }); - const id = (created as { provider: { id: string } }).provider.id; - - const disabled = disableProvider(db, id); - expect("error" in disabled).toBe(false); - expect((disabled as { provider: { enabled: boolean } }).provider.enabled).toBe(false); - - const enabled = enableProvider(db, id); - expect("error" in enabled).toBe(false); - expect((enabled as { provider: { enabled: boolean } }).provider.enabled).toBe(true); - }); - }); - - test("重复禁用失败", () => { - withDb((db) => { - const created = createProvider(db, { apiKey: "sk", baseUrl: "https://a.com", name: "测试", type: "openai" }); - const id = (created as { provider: { id: string } }).provider.id; - disableProvider(db, id); - - const result = disableProvider(db, id); - expect("error" in result).toBe(true); - expect((result as unknown as { status: number }).status).toBe(409); - }); - }); - - test("重复启用失败", () => { - withDb((db) => { - const created = createProvider(db, { apiKey: "sk", baseUrl: "https://a.com", name: "测试", type: "openai" }); - const id = (created as { provider: { id: string } }).provider.id; - - const result = enableProvider(db, id); - expect("error" in result).toBe(true); - expect((result as unknown as { status: number }).status).toBe(409); - }); - }); - test("删除供应商", () => { withDb((db) => { const created = createProvider(db, { apiKey: "sk", baseUrl: "https://a.com", name: "删除测试", type: "openai" }); @@ -192,4 +161,17 @@ describe("供应商数据访问层", () => { expect((result as { provider: { type: string } }).provider.type).toBe("openai-compatible"); }); }); + + test("供应商 options 返回最小字段", () => { + withDb((db) => { + createProvider(db, { apiKey: "sk", baseUrl: "https://a.com", name: "选项", type: "openai" }); + + const options = listProviderOptions(db); + expect(options.length).toBe(1); + expect(typeof options[0]?.id).toBe("string"); + expect(options[0]).toMatchObject({ name: "选项", type: "openai" }); + expect(options[0]).not.toHaveProperty("apiKey"); + expect(options[0]).not.toHaveProperty("enabled"); + }); + }); }); diff --git a/tests/server/routes/models.test.ts b/tests/server/routes/models.test.ts index a302ec9..f539918 100644 --- a/tests/server/routes/models.test.ts +++ b/tests/server/routes/models.test.ts @@ -1,6 +1,6 @@ import type Database from "bun:sqlite"; -import { describe, expect, test } from "bun:test"; +import { describe, expect, mock, test } from "bun:test"; import type { Model, RuntimeMode } from "../../../src/shared/api"; @@ -30,16 +30,6 @@ async function deleteModelViaHandler(req: Request, db: Database): Promise { - const { handleDisableModel: h } = await import("../../../src/server/routes/models/disable"); - return h(req, db, MODE); -} - -async function enableModelViaHandler(req: Request, db: Database): Promise { - const { handleEnableModel: h } = await import("../../../src/server/routes/models/enable"); - return h(req, db, MODE); -} - async function getModelViaHandler(req: Request, db: Database): Promise { const { handleGetModel: h } = await import("../../../src/server/routes/models/get"); return h(req, db, MODE); @@ -53,6 +43,13 @@ async function listModelsViaHandler(req: Request, db: Database): Promise ({ + createProviderRegistry: () => ({ + languageModel: () => ({}), + }), + generateText: () => Promise.resolve({ text: "Hi" }), +})); + function seedProvider(db: Database, name?: string): string { const result = createProvider(db, { apiKey: "sk-test", @@ -64,6 +61,11 @@ function seedProvider(db: Database, name?: string): string { return result.provider.id; } +async function testModelViaHandler(req: Request, db: Database): Promise { + const { handleTestModelConfig: h } = await import("../../../src/server/routes/models/test"); + return h(req, db, MODE); +} + async function updateModelViaHandler(req: Request, db: Database): Promise { const { handleUpdateModel: h } = await import("../../../src/server/routes/models/update"); return h(req, db, MODE); @@ -163,34 +165,6 @@ describe("models API routes", () => { }); }); - test("POST /api/models/:id/enable", async () => { - await withRouteDb(async (db) => { - const model = createTestModel(db, "EnableTest"); - await disableModelViaHandler( - new Request("http://localhost/api/models/" + model.id + "/disable", { method: "POST" }), - db, - ); - - const req = new Request("http://localhost/api/models/" + model.id + "/enable", { method: "POST" }); - const res = await enableModelViaHandler(req, db); - expect(res.status).toBe(200); - const body = (await res.json()) as { model: Model }; - expect(body.model.enabled).toBe(true); - }); - }); - - test("POST /api/models/:id/disable", async () => { - await withRouteDb(async (db) => { - const model = createTestModel(db, "DisableTest"); - - const req = new Request("http://localhost/api/models/" + model.id + "/disable", { method: "POST" }); - const res = await disableModelViaHandler(req, db); - expect(res.status).toBe(200); - const body = (await res.json()) as { model: Model }; - expect(body.model.enabled).toBe(false); - }); - }); - test("DELETE /api/models/:id", async () => { await withRouteDb(async (db) => { const model = createTestModel(db, "DeleteTest"); @@ -219,4 +193,74 @@ describe("models API routes", () => { expect(res.status).toBe(400); }); }); + + test("invalid numeric fields return 400", async () => { + await withRouteDb(async (db) => { + const providerId = seedProvider(db); + + const createReq = new Request("http://localhost/api/models", { + body: JSON.stringify({ + capabilities: ["text"], + contextLength: 0, + modelId: "test", + name: "Test", + providerId, + }), + headers: { "Content-Type": "application/json" }, + method: "POST", + }); + const createRes = await createModelViaHandler(createReq, db); + expect(createRes.status).toBe(400); + + const model = createTestModel(db, "NumericTest", providerId); + const updateReq = new Request("http://localhost/api/models/" + model.id, { + body: JSON.stringify({ maxOutputTokens: 1.5 }), + headers: { "Content-Type": "application/json" }, + method: "PATCH", + }); + const updateRes = await updateModelViaHandler(updateReq, db); + expect(updateRes.status).toBe(400); + }); + }); + + test("POST /api/models/test 成功测试模型连接", async () => { + await withRouteDb(async (db) => { + const providerId = seedProvider(db); + + const req = new Request("http://localhost/api/models/test", { + body: JSON.stringify({ modelId: "gpt-4o", providerId }), + headers: { "Content-Type": "application/json" }, + method: "POST", + }); + const res = await testModelViaHandler(req, db); + expect(res.status).toBe(200); + const body = (await res.json()) as { modelTestResponse: { message: string; ok: boolean } }; + expect(body.modelTestResponse.ok).toBe(true); + expect(body.modelTestResponse.message).toContain("模型连接成功"); + }); + }); + + test("POST /api/models/test 缺少 providerId 返回 400", async () => { + await withRouteDb(async (db) => { + const req = new Request("http://localhost/api/models/test", { + body: JSON.stringify({ modelId: "gpt-4o" }), + headers: { "Content-Type": "application/json" }, + method: "POST", + }); + const res = await testModelViaHandler(req, db); + expect(res.status).toBe(400); + }); + }); + + test("POST /api/models/test 不存在的供应商返回 404", async () => { + await withRouteDb(async (db) => { + const req = new Request("http://localhost/api/models/test", { + body: JSON.stringify({ modelId: "gpt-4o", providerId: "nonexistent" }), + headers: { "Content-Type": "application/json" }, + method: "POST", + }); + const res = await testModelViaHandler(req, db); + expect(res.status).toBe(404); + }); + }); }); diff --git a/tests/server/routes/providers.test.ts b/tests/server/routes/providers.test.ts index 62b4291..a7c018d 100644 --- a/tests/server/routes/providers.test.ts +++ b/tests/server/routes/providers.test.ts @@ -2,7 +2,7 @@ import type Database from "bun:sqlite"; import { describe, expect, mock, test } from "bun:test"; -import type { Provider, RuntimeMode } from "../../../src/shared/api"; +import type { Provider, ProviderOption, RuntimeMode } from "../../../src/shared/api"; import { createModel } from "../../../src/server/db/models"; import { createProvider } from "../../../src/server/db/providers"; @@ -10,13 +10,10 @@ import { createMigratedMemoryTestDatabase } from "../../helpers"; const MODE: RuntimeMode = "test"; -let generateTextImpl: (_opts: unknown) => unknown = () => ({}); - void mock.module("ai", () => ({ createProviderRegistry: () => ({ languageModel: () => ({}), }), - generateText: mock((opts: unknown) => generateTextImpl(opts)), })); async function createProviderViaHandler(req: Request, db: Database): Promise { @@ -24,10 +21,10 @@ async function createProviderViaHandler(req: Request, db: Database): Promise { - const { handleDisableProvider: h } = await import("../../../src/server/routes/providers/disable"); - return h(req, db, MODE); -} - -async function enableProviderViaHandler(req: Request, db: Database): Promise { - const { handleEnableProvider: h } = await import("../../../src/server/routes/providers/enable"); - return h(req, db, MODE); -} - async function getProviderViaHandler(req: Request, db: Database): Promise { const { handleGetProvider: h } = await import("../../../src/server/routes/providers/get"); return h(req, db, MODE); } +async function listProviderOptionsViaHandler(_req: Request, db: Database): Promise { + const { handleListProviderOptions: h } = await import("../../../src/server/routes/providers/options"); + return h(db, MODE); +} + async function listProvidersViaHandler(req: Request, db: Database): Promise { const { handleListProviders: h } = await import("../../../src/server/routes/providers/list"); return h(req, db, MODE); @@ -65,16 +57,29 @@ async function testProviderConfigViaHandler(req: Request, db: Database): Promise return h(req, db, MODE); } -async function testProviderViaHandler(req: Request, db: Database): Promise { - const { handleTestProvider: h } = await import("../../../src/server/routes/providers/test"); - return h(req, db, MODE); -} - async function updateProviderViaHandler(req: Request, db: Database): Promise { const { handleUpdateProvider: h } = await import("../../../src/server/routes/providers/update"); return h(req, db, MODE); } +async function withProviderServer( + modelsResponse: Response, + callback: (baseUrl: string) => Promise, +): Promise { + const server = Bun.serve({ + fetch(request) { + if (request.method === "HEAD") return new Response(null, { status: 200 }); + return modelsResponse; + }, + port: 0, + }); + try { + await callback(`http://127.0.0.1:${server.port}/v1`); + } finally { + await server.stop(true); + } +} + async function withRouteDb(callback: (db: Database) => Promise): Promise { const handle = createMigratedMemoryTestDatabase("route-provider-test"); try { @@ -120,6 +125,22 @@ describe("供应商 API 路由", () => { }); }); + test("GET /api/providers/options 返回最小字段", async () => { + await withRouteDb(async (db) => { + createTestProvider(db, "选项供应商"); + + const req = new Request("http://localhost/api/providers/options"); + const res = await listProviderOptionsViaHandler(req, db); + expect(res.status).toBe(200); + const body = (await res.json()) as { items: ProviderOption[] }; + expect(body.items).toHaveLength(1); + expect(typeof body.items[0]?.id).toBe("string"); + expect(body.items[0]).toMatchObject({ name: "选项供应商", type: "openai" }); + expect(body.items[0]).not.toHaveProperty("apiKey"); + expect(body.items[0]).not.toHaveProperty("enabled"); + }); + }); + test("GET /api/providers/:id 获取详情", async () => { await withRouteDb(async (db) => { const provider = createTestProvider(db, "详情路由"); @@ -148,34 +169,6 @@ describe("供应商 API 路由", () => { }); }); - test("POST /api/providers/:id/enable 启用", async () => { - await withRouteDb(async (db) => { - const provider = createTestProvider(db, "启用测试"); - await disableProviderViaHandler( - new Request(`http://localhost/api/providers/${provider.id}/disable`, { method: "POST" }), - db, - ); - - const req = new Request(`http://localhost/api/providers/${provider.id}/enable`, { method: "POST" }); - const res = await enableProviderViaHandler(req, db); - expect(res.status).toBe(200); - const body = (await res.json()) as { provider: Provider }; - expect(body.provider.enabled).toBe(true); - }); - }); - - test("POST /api/providers/:id/disable 禁用", async () => { - await withRouteDb(async (db) => { - const provider = createTestProvider(db, "禁用测试"); - - const req = new Request(`http://localhost/api/providers/${provider.id}/disable`, { method: "POST" }); - const res = await disableProviderViaHandler(req, db); - expect(res.status).toBe(200); - const body = (await res.json()) as { provider: Provider }; - expect(body.provider.enabled).toBe(false); - }); - }); - test("DELETE /api/providers/:id 删除供应商", async () => { await withRouteDb(async (db) => { const provider = createTestProvider(db, "删除路由"); @@ -205,40 +198,25 @@ describe("供应商 API 路由", () => { }); }); - test("POST /api/providers/:id/test 返回连通性失败结果", async () => { - await withRouteDb(async (db) => { - generateTextImpl = () => { - throw new Error("bad key"); - }; - const provider = createTestProvider(db, "测试失败供应商"); - - const req = new Request(`http://localhost/api/providers/${provider.id}/test`, { method: "POST" }); - const res = await testProviderViaHandler(req, db); - expect(res.status).toBe(200); - const body = (await res.json()) as { providerTestResponse: { message: string; ok: boolean } }; - expect(body.providerTestResponse.ok).toBe(false); - expect(body.providerTestResponse.message).toContain("连接失败"); - generateTextImpl = () => ({}); - }); - }); - test("POST /api/providers/test 使用表单配置测试连通性", async () => { await withRouteDb(async (db) => { - generateTextImpl = () => ({}); - const req = new Request("http://localhost/api/providers/test", { - body: JSON.stringify({ - apiKey: "sk-test", - baseUrl: "https://api.openai.com/v1", - name: "OpenAI", - type: "openai", - }), - headers: { "Content-Type": "application/json" }, - method: "POST", + await withProviderServer(Response.json({ data: [{ id: "gpt-4o" }] }), async (baseUrl) => { + const req = new Request("http://localhost/api/providers/test", { + body: JSON.stringify({ + apiKey: "sk-test", + baseUrl, + name: "OpenAI", + type: "openai", + }), + headers: { "Content-Type": "application/json" }, + method: "POST", + }); + const res = await testProviderConfigViaHandler(req, db); + expect(res.status).toBe(200); + const body = (await res.json()) as { providerTestResponse: { message: string; ok: boolean } }; + expect(body.providerTestResponse.ok).toBe(true); + expect(body.providerTestResponse.message).toContain("/models 返回 1 个模型"); }); - const res = await testProviderConfigViaHandler(req, db); - expect(res.status).toBe(200); - const body = (await res.json()) as { providerTestResponse: { message: string; ok: boolean } }; - expect(body.providerTestResponse).toEqual({ message: "连接成功", ok: true }); }); }); diff --git a/tests/web/components/ModelTable.test.tsx b/tests/web/components/ModelTable.test.tsx index 8a0de50..fb94cfd 100644 --- a/tests/web/components/ModelTable.test.tsx +++ b/tests/web/components/ModelTable.test.tsx @@ -2,38 +2,27 @@ import { fireEvent, screen, waitFor } from "@testing-library/react"; import { describe, expect, mock, test } from "bun:test"; import { createElement } from "react"; -import type { Model, Provider } from "../../../src/shared/api"; +import type { Model, ProviderOption } from "../../../src/shared/api"; import { ModelTable } from "../../../src/web/pages/models/components/ModelTable"; import { renderWithProviders } from "../test-utils"; -const ENABLED_PROVIDER: Provider = { - apiKey: "sk-test", - baseUrl: "https://api.openai.com/v1", - createdAt: "2024-01-01T00:00:00.000Z", - enabled: true, +const OPENAI_PROVIDER: ProviderOption = { id: "pv1", name: "OpenAI", type: "openai", - updatedAt: "2024-01-01T00:00:00.000Z", }; -const DISABLED_PROVIDER: Provider = { - apiKey: "sk-off", - baseUrl: "https://api.deepseek.com/v1", - createdAt: "2024-01-01T00:00:00.000Z", - enabled: false, +const DEEPSEEK_PROVIDER: ProviderOption = { id: "pv2", name: "DeepSeek", type: "openai-compatible", - updatedAt: "2024-01-01T00:00:00.000Z", }; const ENABLED_MODEL: Model = { capabilities: ["text", "reasoning"], contextLength: 128000, createdAt: "2024-01-01T00:00:00.000Z", - enabled: true, id: "m1", maxOutputTokens: 4096, modelId: "gpt-4o", @@ -46,7 +35,6 @@ const DISABLED_MODEL: Model = { capabilities: ["text"], contextLength: null, createdAt: "2024-01-01T00:00:00.000Z", - enabled: false, id: "m2", maxOutputTokens: null, modelId: "deepseek-chat", @@ -67,51 +55,45 @@ describe("ModelTable", () => { data: { items: [ENABLED_MODEL, DISABLED_MODEL], page: 1, pageSize: 20, total: 2 }, loading: false, onDelete: () => Promise.resolve(), - onDisable: () => Promise.resolve(), onEdit: () => undefined, - onEnable: () => Promise.resolve(), onPageChange: () => undefined, page: 1, pageSize: 20, - providers: [ENABLED_PROVIDER, DISABLED_PROVIDER], + providers: [OPENAI_PROVIDER, DEEPSEEK_PROVIDER], }), ); expect(screen.getByText("GPT-4o")).not.toBeNull(); - expect(screen.getByText("gpt-4o")).not.toBeNull(); expect(screen.getByText("DeepSeek Chat")).not.toBeNull(); expect(screen.getByText("OpenAI")).not.toBeNull(); expect(screen.getByText("DeepSeek")).not.toBeNull(); + expect(screen.queryByText("状态")).toBeNull(); + expect(screen.queryByRole("button", { name: /启用|禁用/ })).toBeNull(); }); - test("模型表格操作触发 enable/disable/delete", async () => { - const onDisable = mock(() => Promise.resolve()); - const onEnable = mock(() => Promise.resolve()); + test("模型表格操作触发 edit/delete", async () => { const onDelete = mock(() => Promise.resolve()); + const onEdit = mock(() => undefined); renderWithProviders( createElement(ModelTable, { data: { items: [ENABLED_MODEL, DISABLED_MODEL], page: 1, pageSize: 20, total: 2 }, loading: false, onDelete, - onDisable, - onEdit: () => undefined, - onEnable, + onEdit, onPageChange: () => undefined, page: 1, pageSize: 20, - providers: [ENABLED_PROVIDER, DISABLED_PROVIDER], + providers: [OPENAI_PROVIDER, DEEPSEEK_PROVIDER], }), ); - const disableButtons = screen.getAllByRole("button", { name: /禁用/ }); - fireEvent.click(disableButtons[0]!); - await waitFor(() => expect(screen.getByText("确认禁用此模型?")).not.toBeNull()); - clickLatestConfirmButton(); - await waitFor(() => expect(onDisable).toHaveBeenCalledWith("m1")); + fireEvent.click(screen.getAllByRole("button", { name: /编辑/ })[0]!); + expect(onEdit).toHaveBeenCalledWith(ENABLED_MODEL); - const enableButtons = screen.getAllByRole("button", { name: /启用/ }); - fireEvent.click(enableButtons[0]!); - await waitFor(() => expect(onEnable).toHaveBeenCalledWith("m2")); + fireEvent.click(screen.getAllByRole("button", { name: /删除/ })[0]!); + await waitFor(() => expect(screen.getByText("确认删除此模型?")).not.toBeNull()); + clickLatestConfirmButton(); + await waitFor(() => expect(onDelete).toHaveBeenCalledWith("m1")); }); }); diff --git a/tests/web/components/ProviderTable.test.tsx b/tests/web/components/ProviderTable.test.tsx index 242f878..b55b88b 100644 --- a/tests/web/components/ProviderTable.test.tsx +++ b/tests/web/components/ProviderTable.test.tsx @@ -7,22 +7,20 @@ import type { Provider } from "../../../src/shared/api"; import { ProviderTable } from "../../../src/web/pages/models/components/ProviderTable"; import { renderWithProviders } from "../test-utils"; -const ENABLED_PROVIDER: Provider = { +const OPENAI_PROVIDER: Provider = { apiKey: "sk-test", baseUrl: "https://api.openai.com/v1", createdAt: "2024-01-01T00:00:00.000Z", - enabled: true, id: "pv1", name: "OpenAI", type: "openai", updatedAt: "2024-01-01T00:00:00.000Z", }; -const DISABLED_PROVIDER: Provider = { +const DEEPSEEK_PROVIDER: Provider = { apiKey: "sk-off", baseUrl: "https://api.deepseek.com/v1", createdAt: "2024-01-01T00:00:00.000Z", - enabled: false, id: "pv2", name: "DeepSeek", type: "openai-compatible", @@ -38,14 +36,11 @@ describe("ProviderTable", () => { test("渲染供应商表格数据", () => { renderWithProviders( createElement(ProviderTable, { - data: { items: [ENABLED_PROVIDER, DISABLED_PROVIDER], page: 1, pageSize: 20, total: 2 }, + data: { items: [OPENAI_PROVIDER, DEEPSEEK_PROVIDER], page: 1, pageSize: 20, total: 2 }, loading: false, onDelete: () => Promise.resolve(), - onDisable: () => Promise.resolve(), onEdit: () => undefined, - onEnable: () => Promise.resolve(), onPageChange: () => undefined, - onTest: () => Promise.resolve({ message: "ok", ok: true }), page: 1, pageSize: 20, }), @@ -54,58 +49,33 @@ describe("ProviderTable", () => { expect(screen.getAllByText("OpenAI").length).toBeGreaterThan(0); expect(screen.getByText("DeepSeek")).not.toBeNull(); expect(screen.getByText("https://api.openai.com/v1")).not.toBeNull(); + expect(screen.queryByText("状态")).toBeNull(); + expect(screen.queryByRole("button", { name: "测试连接" })).toBeNull(); + expect(screen.queryByRole("button", { name: /启用|禁用/ })).toBeNull(); }); - test("供应商表格操作触发 enable/disable/delete", async () => { - const onDisable = mock(() => Promise.resolve()); - const onEnable = mock(() => Promise.resolve()); + test("供应商表格操作触发 edit/delete", async () => { const onDelete = mock(() => Promise.resolve()); + const onEdit = mock(() => undefined); renderWithProviders( createElement(ProviderTable, { - data: { items: [ENABLED_PROVIDER, DISABLED_PROVIDER], page: 1, pageSize: 20, total: 2 }, + data: { items: [OPENAI_PROVIDER, DEEPSEEK_PROVIDER], page: 1, pageSize: 20, total: 2 }, loading: false, onDelete, - onDisable, - onEdit: () => undefined, - onEnable, + onEdit, onPageChange: () => undefined, - onTest: () => Promise.resolve({ message: "ok", ok: true }), page: 1, pageSize: 20, }), ); - const disableButtons = screen.getAllByRole("button", { name: /禁用/ }); - fireEvent.click(disableButtons[0]!); - await waitFor(() => expect(screen.getByText("确认禁用此供应商?")).not.toBeNull()); + fireEvent.click(screen.getAllByRole("button", { name: /编辑/ })[0]!); + expect(onEdit).toHaveBeenCalledWith(OPENAI_PROVIDER); + + fireEvent.click(screen.getAllByRole("button", { name: /删除/ })[0]!); + await waitFor(() => expect(screen.getByText("确认删除此供应商?")).not.toBeNull()); clickLatestConfirmButton(); - await waitFor(() => expect(onDisable).toHaveBeenCalledWith("pv1")); - - const enableButtons = screen.getAllByRole("button", { name: /启用/ }); - fireEvent.click(enableButtons[0]!); - await waitFor(() => expect(onEnable).toHaveBeenCalledWith("pv2")); - }); - - test("供应商表格操作触发连接测试", async () => { - const onTest = mock(() => Promise.resolve({ message: "连接失败", ok: false })); - - renderWithProviders( - createElement(ProviderTable, { - data: { items: [ENABLED_PROVIDER], page: 1, pageSize: 20, total: 1 }, - loading: false, - onDelete: () => Promise.resolve(), - onDisable: () => Promise.resolve(), - onEdit: () => undefined, - onEnable: () => Promise.resolve(), - onPageChange: () => undefined, - onTest, - page: 1, - pageSize: 20, - }), - ); - - fireEvent.click(screen.getByRole("button", { name: "测试连接" })); - await waitFor(() => expect(onTest).toHaveBeenCalledWith("pv1")); + await waitFor(() => expect(onDelete).toHaveBeenCalledWith("pv1")); }); }); diff --git a/tests/web/hooks/use-models.test.ts b/tests/web/hooks/use-models.test.ts index fd885ec..3d0fc6b 100644 --- a/tests/web/hooks/use-models.test.ts +++ b/tests/web/hooks/use-models.test.ts @@ -3,10 +3,9 @@ import { describe, expect, test } from "bun:test"; import { createModel, deleteModel, - disableModel, - enableModel, fetchModel, fetchModelList, + testModelConnection, updateModel, } from "../../../src/web/hooks/use-models"; import { installFetchMock, jsonResponse } from "../test-utils"; @@ -15,7 +14,6 @@ const MODEL = { capabilities: ["text"] as Array<"text">, contextLength: null, createdAt: "2024-01-01T00:00:00.000Z", - enabled: true, id: "m1", maxOutputTokens: null, modelId: "gpt-4o", @@ -50,7 +48,7 @@ describe("use-models request helpers", () => { expect(calls[0]?.url).toContain("keyword=GPT"); }); - test("模型 CRUD 与 enable/disable 使用正确 method、URL 与 body", async () => { + test("模型 CRUD 使用正确 method、URL 与 body", async () => { const calls = installFetchMock((call) => { if (call.method === "DELETE") return new Response(null, { status: 204 }); return jsonResponse( @@ -66,16 +64,12 @@ describe("use-models request helpers", () => { providerId: "pv1", }); await updateModel("m1", { name: "GPT-4o Mini" }); - await enableModel("m1"); - await disableModel("m1"); await deleteModel("m1"); await fetchModel("m1"); expect(calls.map((call) => `${call.method} ${call.url}`)).toEqual([ "POST /api/models", "PATCH /api/models/m1", - "POST /api/models/m1/enable", - "POST /api/models/m1/disable", "DELETE /api/models/m1", "GET /api/models/m1", ]); @@ -102,4 +96,16 @@ describe("use-models request helpers", () => { await expectRejectsWithMessage(() => fetchModel("m-missing"), "HTTP 500"); }); + + test("testModelConnection 调用正确 URL 和 body", async () => { + const calls = installFetchMock(() => jsonResponse({ modelTestResponse: { message: "模型连接成功", ok: true } })); + + const result = await testModelConnection({ modelId: "gpt-4o", providerId: "pv1" }); + + expect(result.ok).toBe(true); + expect(result.message).toBe("模型连接成功"); + expect(calls[0]?.method).toBe("POST"); + expect(calls[0]?.url).toBe("/api/models/test"); + expect(jsonBody(calls[0]?.body)).toEqual({ modelId: "gpt-4o", providerId: "pv1" }); + }); }); diff --git a/tests/web/hooks/use-providers.test.ts b/tests/web/hooks/use-providers.test.ts index 56b6916..f816b50 100644 --- a/tests/web/hooks/use-providers.test.ts +++ b/tests/web/hooks/use-providers.test.ts @@ -3,12 +3,10 @@ import { describe, expect, test } from "bun:test"; import { createProvider, deleteProvider, - disableProvider, - enableProvider, fetchProvider, fetchProviderList, + fetchProviderOptions, testProviderConfig, - testProviderConnection, updateProvider, } from "../../../src/web/hooks/use-providers"; import { installFetchMock, jsonResponse } from "../test-utils"; @@ -17,7 +15,6 @@ const PROVIDER = { apiKey: "sk-test", baseUrl: "https://api.openai.com/v1", createdAt: "2024-01-01T00:00:00.000Z", - enabled: true, id: "pv1", name: "OpenAI", type: "openai" as const, @@ -49,7 +46,7 @@ describe("use-providers request helpers", () => { expect(calls[0]?.url).toBe("/api/providers?page=1&pageSize=20&keyword=OpenAI"); }); - test("CRUD and enable/disable use correct method, URL and body", async () => { + test("CRUD uses correct method, URL and body", async () => { const calls = installFetchMock((call) => { if (call.method === "DELETE") return new Response(null, { status: 204 }); return jsonResponse( @@ -60,16 +57,12 @@ describe("use-providers request helpers", () => { await createProvider({ apiKey: "sk-test", baseUrl: "https://api.openai.com/v1", name: "OpenAI", type: "openai" }); await updateProvider("pv1", { name: "New OpenAI" }); - await enableProvider("pv1"); - await disableProvider("pv1"); await deleteProvider("pv1"); await fetchProvider("pv1"); expect(calls.map((c) => c.method + " " + c.url)).toEqual([ "POST /api/providers", "PATCH /api/providers/pv1", - "POST /api/providers/pv1/enable", - "POST /api/providers/pv1/disable", "DELETE /api/providers/pv1", "GET /api/providers/pv1", ]); @@ -82,12 +75,14 @@ describe("use-providers request helpers", () => { expect(jsonBody(calls[1]?.body)).toEqual({ name: "New OpenAI" }); }); - test("testProviderConnection uses correct URL and parses response", async () => { - installFetchMock(() => jsonResponse({ providerTestResponse: { message: "ok", ok: true } })); + test("fetchProviderOptions uses dedicated minimal endpoint", async () => { + const calls = installFetchMock(() => jsonResponse({ items: [{ id: "pv1", name: "OpenAI", type: "openai" }] })); - const result = await testProviderConnection("pv1"); + const result = await fetchProviderOptions(); - expect(result).toEqual({ message: "ok", ok: true }); + expect(result.items).toEqual([{ id: "pv1", name: "OpenAI", type: "openai" }]); + expect(calls[0]?.method).toBe("GET"); + expect(calls[0]?.url).toBe("/api/providers/options"); }); test("testProviderConfig posts form config and parses response", async () => { diff --git a/tests/web/routes/models.test.tsx b/tests/web/routes/models.test.tsx index 2e1f411..6d552da 100644 --- a/tests/web/routes/models.test.tsx +++ b/tests/web/routes/models.test.tsx @@ -12,7 +12,6 @@ const ENABLED_PROVIDER: Provider = { apiKey: "sk-test", baseUrl: "https://api.openai.com/v1", createdAt: "2024-01-01T00:00:00.000Z", - enabled: true, id: "pv1", name: "OpenAI", type: "openai", @@ -23,7 +22,6 @@ const DISABLED_PROVIDER: Provider = { apiKey: "sk-off", baseUrl: "https://api.deepseek.com/v1", createdAt: "2024-01-01T00:00:00.000Z", - enabled: false, id: "pv2", name: "DeepSeek", type: "openai-compatible", @@ -34,7 +32,6 @@ const ENABLED_MODEL: Model = { capabilities: ["text", "reasoning"], contextLength: 128000, createdAt: "2024-01-01T00:00:00.000Z", - enabled: true, id: "m1", maxOutputTokens: 4096, modelId: "gpt-4o", @@ -165,8 +162,9 @@ describe("ModelFormModal", () => { }, open: true, providers: [ENABLED_PROVIDER, DISABLED_PROVIDER], + providersError: null, + providersLoading: false, submitting: false, - testConnection: () => Promise.resolve({ message: "连接成功", ok: true }), }), ); @@ -190,8 +188,9 @@ describe("ModelFormModal", () => { onUpdate: () => Promise.resolve(), open: true, providers: [ENABLED_PROVIDER], + providersError: null, + providersLoading: false, submitting: false, - testConnection: () => Promise.resolve({ message: "连接成功", ok: true }), }), ); @@ -200,9 +199,7 @@ describe("ModelFormModal", () => { expect(onCreate).not.toHaveBeenCalled(); }); - test("新建模型时可测试所选供应商连接", async () => { - const testConnection = mock(() => Promise.resolve({ message: "连接成功", ok: true })); - + test("新建模型默认选中文本和推理能力", async () => { renderWithProviders( createElement(ModelFormModal, { editingModel: null, @@ -212,16 +209,111 @@ describe("ModelFormModal", () => { onUpdate: () => Promise.resolve(), open: true, providers: [ENABLED_PROVIDER], + providersError: null, + providersLoading: false, submitting: false, - testConnection, }), ); - await waitFor(() => expect(screen.getByText("测试连接")).not.toBeNull()); + await waitFor(() => expect(screen.getByLabelText("文本")).not.toBeNull()); + const textCheckbox = screen.getByLabelText("文本"); + const reasoningCheckbox = screen.getByLabelText("推理"); + expect((textCheckbox as { checked?: boolean }).checked).toBe(true); + expect((reasoningCheckbox as { checked?: boolean }).checked).toBe(true); + }); + + test("新建模型展示供应商 options 列表", async () => { + renderWithProviders( + createElement(ModelFormModal, { + editingModel: null, + onCancel: () => undefined, + onCreate: () => Promise.resolve(), + onOpenChange: () => undefined, + onUpdate: () => Promise.resolve(), + open: true, + providers: [ENABLED_PROVIDER, DISABLED_PROVIDER], + providersError: null, + providersLoading: false, + submitting: false, + }), + ); + + await waitFor(() => expect(screen.getByPlaceholderText("请输入模型名称")).not.toBeNull()); fireEvent.mouseDown(screen.getByRole("combobox")); - fireEvent.click(await screen.findByText("OpenAI")); + + expect(await screen.findByText("OpenAI")).not.toBeNull(); + expect(await screen.findByText("DeepSeek")).not.toBeNull(); + }); + + test("供应商下拉展示加载错误提示", async () => { + renderWithProviders( + createElement(ModelFormModal, { + editingModel: null, + onCancel: () => undefined, + onCreate: () => Promise.resolve(), + onOpenChange: () => undefined, + onUpdate: () => Promise.resolve(), + open: true, + providers: [], + providersError: new Error("options failed"), + providersLoading: false, + submitting: false, + }), + ); + + await waitFor(() => expect(screen.getByPlaceholderText("请输入模型名称")).not.toBeNull()); + fireEvent.mouseDown(screen.getByRole("combobox")); + + expect(await screen.findByText("供应商加载失败:options failed")).not.toBeNull(); + }); + + test("编辑模型时可测试模型连接", async () => { + const testModelConnection = mock(() => Promise.resolve({ message: "模型连接成功", ok: true })); + + renderWithProviders( + createElement(ModelFormModal, { + editingModel: ENABLED_MODEL, + onCancel: () => undefined, + onCreate: () => Promise.resolve(), + onOpenChange: () => undefined, + onUpdate: () => Promise.resolve(), + open: true, + providers: [ENABLED_PROVIDER], + providersError: null, + providersLoading: false, + submitting: false, + testModelConnection, + }), + ); + + await waitFor(() => expect(screen.getByRole("button", { name: "测试连接" })).not.toBeNull()); fireEvent.click(screen.getByRole("button", { name: "测试连接" })); - await waitFor(() => expect(testConnection).toHaveBeenCalledWith("pv1")); + await waitFor(() => + expect(testModelConnection).toHaveBeenCalledWith({ + modelId: "gpt-4o", + providerId: "pv1", + }), + ); + }); + + test("新建模型也显示测试连接按钮", async () => { + renderWithProviders( + createElement(ModelFormModal, { + editingModel: null, + onCancel: () => undefined, + onCreate: () => Promise.resolve(), + onOpenChange: () => undefined, + onUpdate: () => Promise.resolve(), + open: true, + providers: [ENABLED_PROVIDER], + providersError: null, + providersLoading: false, + submitting: false, + testModelConnection: () => Promise.resolve({ message: "ok", ok: true }), + }), + ); + + await waitFor(() => expect(screen.getByRole("button", { name: "测试连接" })).not.toBeNull()); }); }); diff --git a/tests/web/test-utils.tsx b/tests/web/test-utils.tsx index c02f29b..9430131 100644 --- a/tests/web/test-utils.tsx +++ b/tests/web/test-utils.tsx @@ -7,6 +7,8 @@ import { MemoryRouter } from "react-router"; import { ErrorBoundary } from "../../src/web/components/ErrorBoundary"; +const REAL_FETCH = globalThis.fetch.bind(globalThis); + // Mock recharts BEFORE any component imports void mock.module("recharts", () => ({ Area: () => null, @@ -34,6 +36,7 @@ export function installFetchMock(handler: (call: FetchMockCall) => Promise { const request = input instanceof Request ? input : undefined; const url = request?.url ?? (typeof input === "string" ? input : input instanceof URL ? input.href : input.url); + if (url.startsWith("http://") || url.startsWith("https://")) return REAL_FETCH(input, init); const call: FetchMockCall = { body: init?.body ?? null, method: init?.method ?? request?.method ?? "GET",