diff --git a/.gitignore b/.gitignore index c71e7e6..89f5560 100644 --- a/.gitignore +++ b/.gitignore @@ -415,6 +415,7 @@ data/ backend/bin backend/server backend/desktop +!src/**/* # Embedfs generated embedfs/assets/ diff --git a/src/server/db/conversations.ts b/src/server/db/conversations.ts index 24530fe..bc8bd15 100644 --- a/src/server/db/conversations.ts +++ b/src/server/db/conversations.ts @@ -2,7 +2,7 @@ import type Database from "bun:sqlite"; import { desc, eq } from "drizzle-orm"; -import type { Conversation, Message } from "../../shared/api"; +import type { Conversation, Message, UpdateConversationRequest } from "../../shared/api"; import { paginateQuery, wrap } from "./connection"; import { conversations, messages, models } from "./schema"; @@ -150,6 +150,33 @@ export function listMessages( }); } +export function updateConversation( + raw: Database, + id: string, + data: UpdateConversationRequest, +): { conversation: Conversation } | { error: string; status: number } { + const db = wrap(raw); + const existing = db.select().from(conversations).where(eq(conversations.id, id)).get(); + if (!existing) return { error: "会话不存在", status: 404 }; + + const updates: { modelId?: string; title?: string; updatedAt: string } = { updatedAt: new Date().toISOString() }; + + if (data.modelId !== undefined) { + const model = db.select().from(models).where(eq(models.id, data.modelId)).get(); + if (!model) return { error: "模型不存在", status: 400 }; + updates.modelId = data.modelId; + } + + if (data.title !== undefined) { + updates.title = data.title; + } + + db.update(conversations).set(updates).where(eq(conversations.id, id)).run(); + + const row = db.select().from(conversations).where(eq(conversations.id, id)).get(); + return { conversation: toConversation(row!) }; +} + export function updateConversationTimestamp(raw: Database, id: string): void { const db = wrap(raw); db.update(conversations).set({ updatedAt: new Date().toISOString() }).where(eq(conversations.id, id)).run(); diff --git a/src/server/routes/chat/update.ts b/src/server/routes/chat/update.ts new file mode 100644 index 0000000..c97e318 --- /dev/null +++ b/src/server/routes/chat/update.ts @@ -0,0 +1,47 @@ +import type Database from "bun:sqlite"; + +import type { RuntimeMode, UpdateConversationRequest } from "../../../shared/api"; + +import { getConversation, updateConversation } from "../../db/conversations"; +import { createApiError, jsonResponse } from "../../helpers"; +import { validateIdParam } from "../../middleware"; + +export async function handleUpdateConversation(req: Request, db: Database, mode: RuntimeMode): Promise { + const url = new URL(req.url); + const parts = url.pathname.split("/"); + const projectId = parts[3]; + const conversationId = parts[5]; + + const validatedProject = validateIdParam(projectId ?? "", mode); + if (validatedProject instanceof Response) return validatedProject; + + const validatedConv = validateIdParam(conversationId ?? "", mode); + if (validatedConv instanceof Response) return validatedConv; + + const existing = getConversation(db, validatedConv.id); + if ("error" in existing) { + return jsonResponse(createApiError(existing.error, existing.status), { mode, status: existing.status }); + } + + if (existing.conversation.projectId !== validatedProject.id) { + return jsonResponse(createApiError("会话不属于该项目", 403), { mode, status: 403 }); + } + + let body: UpdateConversationRequest; + try { + body = (await req.json()) as UpdateConversationRequest; + } catch { + return jsonResponse(createApiError("Invalid JSON body", 400), { mode, status: 400 }); + } + + if (body.modelId === undefined && body.title === undefined) { + return jsonResponse(createApiError("至少需要传 modelId 或 title", 400), { mode, status: 400 }); + } + + const result = updateConversation(db, validatedConv.id, body); + if ("error" in result) { + return jsonResponse(createApiError(result.error, result.status), { mode, status: result.status }); + } + + return jsonResponse({ conversation: result.conversation }, { mode }); +} diff --git a/src/server/server.ts b/src/server/server.ts index fea93a0..4572d80 100644 --- a/src/server/server.ts +++ b/src/server/server.ts @@ -201,6 +201,14 @@ export function startServer(options: StartServerOptions) { mode, logger, ), + PATCH: withErrorHandler( + async (req) => { + const { handleUpdateConversation } = await import("./routes/chat/update"); + return handleUpdateConversation(req, db, mode); + }, + mode, + logger, + ), }, "/api/projects/:id/conversations/:cid/messages": { GET: withErrorHandler( diff --git a/src/shared/api.ts b/src/shared/api.ts index 0ff88f8..8aeb850 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -42,11 +42,6 @@ export interface CreateProjectRequest { name: string; } -// ========================================== -// 在此定义你的业务类型 -// 前后端共享的类型都放在这个文件中 -// ========================================== - export interface CreateProviderRequest { apiKey: string; baseUrl: string; @@ -54,6 +49,11 @@ export interface CreateProviderRequest { type: ProviderType; } +// ========================================== +// 在此定义你的业务类型 +// 前后端共享的类型都放在这个文件中 +// ========================================== + export interface Message { content: string; conversationId: string; @@ -104,6 +104,11 @@ export interface SendMessageRequest { messages: Array<{ content: string; role: "assistant" | "system" | "user" }>; } +export interface UpdateConversationRequest { + modelId?: string; + title?: string; +} + export const MODEL_CAPABILITIES: readonly ModelCapability[] = [ "audio-generation", "audio-recognition", diff --git a/src/web/consoles/workbench/components/chat/ChatPanel.tsx b/src/web/consoles/workbench/components/chat/ChatPanel.tsx index 5bbf264..45d96b4 100644 --- a/src/web/consoles/workbench/components/chat/ChatPanel.tsx +++ b/src/web/consoles/workbench/components/chat/ChatPanel.tsx @@ -1,11 +1,13 @@ import { useChat } from "@ai-sdk/react"; import { DefaultChatTransport, type UIMessage } from "ai"; -import { App, Button, Card, Collapse, Empty, Flex, Input, Spin, Typography } from "antd"; -import { useCallback, useEffect, useRef, useState } from "react"; -import { Streamdown } from "streamdown"; +import { App, Button, Card, Empty, Flex, Input, Select, Spin } from "antd"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; -import { fetchMessages } from "../../../../hooks/use-conversations"; -import { ToolCallCard } from "./ToolCallCard"; +import { fetchConversation, fetchMessages, updateConversation } from "../../../../hooks/use-conversations"; +import { useModelList } from "../../../../hooks/use-models"; +import { ReasoningPart } from "./parts/ReasoningPart"; +import { TextPart } from "./parts/TextPart"; +import { ToolPart } from "./parts/ToolPart"; interface ChatPanelProps { conversationId: null | string; @@ -16,9 +18,18 @@ export function ChatPanel({ conversationId, projectId }: ChatPanelProps) { const { message } = App.useApp(); const [input, setInput] = useState(""); const [loadingHistory, setLoadingHistory] = useState(false); + const [selectedModelId, setSelectedModelId] = useState(null); const fetchRef = useRef(fetchMessages); const scrollRef = useRef(null); + const { data: modelsData } = useModelList({ pageSize: 200 }); + const textModels = useMemo( + () => (modelsData?.items ?? []).filter((m) => m.capabilities.includes("text")), + [modelsData], + ); + + const modelOptions = useMemo(() => textModels.map((m) => ({ label: m.name, value: m.id })), [textModels]); + const { messages, sendMessage, setMessages, status, stop } = useChat({ onError: (err) => { void message.error(`发送失败:${err.message}`); @@ -45,8 +56,21 @@ export function ChatPanel({ conversationId, projectId }: ChatPanelProps) { setInput(""); setMessages([]); try { - const data = await fetchRef.current(projectId, conversationId); + const convPromise = fetchConversation(projectId, conversationId); + const msgPromise = fetchRef.current(projectId, conversationId); + + const conv = await convPromise; + const data = await msgPromise; if (cancelled) return; + + const firstTextId = textModels[0]?.id; + if (firstTextId && textModels.every((m) => m.id !== conv.modelId)) { + setSelectedModelId(firstTextId); + void updateConversation(projectId, conversationId, { modelId: firstTextId }); + } else { + setSelectedModelId(conv.modelId); + } + const history = data.items .filter((m: { role: string }) => m.role === "user" || m.role === "assistant") .reverse() @@ -71,7 +95,19 @@ export function ChatPanel({ conversationId, projectId }: ChatPanelProps) { return () => { cancelled = true; }; - }, [conversationId, projectId, setMessages, message]); + }, [conversationId, projectId, setMessages, message, textModels]); + + const displayModelId = conversationId ? selectedModelId : (textModels[0]?.id ?? null); + + const handleModelChange = useCallback( + (value: string) => { + setSelectedModelId(value); + if (conversationId) { + void updateConversation(projectId, conversationId, { modelId: value }); + } + }, + [projectId, conversationId], + ); const handleSend = useCallback(() => { if (!input.trim() || !conversationId) return; @@ -118,7 +154,7 @@ export function ChatPanel({ conversationId, projectId }: ChatPanelProps) { )} - +
- {isLoading ? ( - - ) : ( - - )} - + +