diff --git a/deeptutor/api/routers/sessions.py b/deeptutor/api/routers/sessions.py index 0d1787d8a..a5db87c82 100644 --- a/deeptutor/api/routers/sessions.py +++ b/deeptutor/api/routers/sessions.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, field_validator from deeptutor.services.session import get_session_store, get_sqlite_session_store +from deeptutor.services.storage.attachment_store import get_attachment_store logger = logging.getLogger(__name__) @@ -106,9 +107,32 @@ async def delete_session(session_id: str): deleted = await store.delete_session(session_id) if not deleted: raise HTTPException(status_code=404, detail="Session not found") + try: + await get_attachment_store().delete_session(session_id) + except Exception: + logger.exception("failed to clean up attachments for session %s", session_id) return {"deleted": True, "session_id": session_id} +@router.delete("/{session_id}/messages/{message_id}") +async def delete_turn_by_message(session_id: str, message_id: int): + store = get_sqlite_session_store() + result = await store.delete_turn_by_message(session_id, message_id) + if result["was_running"]: + raise HTTPException( + status_code=409, detail="Cannot delete a message while its turn is running" + ) + if not result["deleted"]: + raise HTTPException(status_code=404, detail="Message not found") + attachment_store = get_attachment_store() + for aid in result["attachment_ids"]: + try: + await attachment_store.delete_attachment(session_id, aid) + except Exception: + logger.exception("failed to delete attachment %s for session %s", aid, session_id) + return result + + @router.post("/{session_id}/quiz-results") async def record_quiz_results(session_id: str, payload: QuizResultsRequest): if not payload.answers: diff --git a/deeptutor/services/session/sqlite_store.py b/deeptutor/services/session/sqlite_store.py index 27c39574b..9e763e84c 100644 --- a/deeptutor/services/session/sqlite_store.py +++ b/deeptutor/services/session/sqlite_store.py @@ -648,6 +648,125 @@ def _delete_message_sync(self, message_id: int | str) -> bool: async def delete_message(self, message_id: int | str) -> bool: return await self._run(self._delete_message_sync, message_id) + def _delete_turn_by_message_sync(self, session_id: str, message_id: int) -> dict[str, Any]: + with self._connect() as conn: + msg = conn.execute( + """ + SELECT id, session_id, role, attachments_json, created_at + FROM messages + WHERE id = ? + """, + (int(message_id),), + ).fetchone() + if msg is None or msg["session_id"] != session_id: + return { + "deleted": False, + "attachment_ids": [], + "turn_id": None, + "was_running": False, + } + + role = msg["role"] + paired_msg = None + if role == "user": + paired_msg = conn.execute( + """ + SELECT id, session_id, role, attachments_json, created_at + FROM messages + WHERE session_id = ? AND role = 'assistant' AND id > ? + ORDER BY id ASC + LIMIT 1 + """, + (session_id, int(message_id)), + ).fetchone() + elif role == "assistant": + paired_msg = conn.execute( + """ + SELECT id, session_id, role, attachments_json, created_at + FROM messages + WHERE session_id = ? AND role = 'user' AND id < ? + ORDER BY id DESC + LIMIT 1 + """, + (session_id, int(message_id)), + ).fetchone() + + user_msg = msg if role == "user" else paired_msg + turn_id = None + was_running = False + if user_msg is not None: + user_created_at = user_msg["created_at"] + turn_row = conn.execute( + """ + SELECT id, status + FROM turns + WHERE session_id = ? AND created_at >= ? + ORDER BY created_at ASC + LIMIT 1 + """, + (session_id, user_created_at), + ).fetchone() + if turn_row is not None: + turn_id = turn_row["id"] + was_running = turn_row["status"] == "running" + + if was_running: + return { + "deleted": False, + "attachment_ids": [], + "turn_id": turn_id, + "was_running": True, + } + + attachment_ids: list[str] = [] + for m in [msg, paired_msg]: + if m is not None: + atts = _json_loads(m["attachments_json"], []) + for att in atts: + aid = att.get("id") or att.get("attachment_id") + if aid: + attachment_ids.append(aid) + + if turn_id is not None: + conn.execute("DELETE FROM turn_events WHERE turn_id = ?", (turn_id,)) + conn.execute("DELETE FROM turns WHERE id = ?", (turn_id,)) + + ids_to_delete = [int(message_id)] + if paired_msg is not None: + ids_to_delete.append(int(paired_msg["id"])) + conn.execute( + f"DELETE FROM messages WHERE id IN ({','.join('?' * len(ids_to_delete))})", # nosec B608 + tuple(ids_to_delete), + ) + + session_row = conn.execute( + "SELECT summary_up_to_msg_id FROM sessions WHERE id = ?", + (session_id,), + ).fetchone() + if session_row is not None: + summary_up_to = int(session_row["summary_up_to_msg_id"]) + if any(mid <= summary_up_to for mid in ids_to_delete): + conn.execute( + "UPDATE sessions SET summary_up_to_msg_id = 0 WHERE id = ?", + (session_id,), + ) + + conn.execute( + "UPDATE sessions SET updated_at = ? WHERE id = ?", + (time.time(), session_id), + ) + conn.commit() + + return { + "deleted": True, + "attachment_ids": attachment_ids, + "turn_id": turn_id, + "was_running": was_running, + } + + async def delete_turn_by_message(self, session_id: str, message_id: int) -> dict[str, Any]: + return await self._run(self._delete_turn_by_message_sync, session_id, message_id) + def _get_last_message_sync( self, session_id: str, role: str | None = None ) -> dict[str, Any] | None: diff --git a/deeptutor/services/storage/attachment_store.py b/deeptutor/services/storage/attachment_store.py index 2826394bd..d1eb89adc 100644 --- a/deeptutor/services/storage/attachment_store.py +++ b/deeptutor/services/storage/attachment_store.py @@ -88,6 +88,9 @@ async def put( async def delete_session(self, session_id: str) -> None: """Best-effort cleanup of all attachments for *session_id*.""" + async def delete_attachment(self, session_id: str, attachment_id: str) -> None: + """Best-effort cleanup of a single attachment identified by *attachment_id*.""" + def resolve_path(self, *, session_id: str, attachment_id: str, filename: str) -> Path | None: """Return the absolute path on disk for an attachment, or ``None`` if it does not exist or escapes the storage root. @@ -186,6 +189,13 @@ async def delete_session(self, session_id: str) -> None: loop = asyncio.get_running_loop() await loop.run_in_executor(None, self._rmtree_sync, session_dir) + async def delete_attachment(self, session_id: str, attachment_id: str) -> None: + session_dir = self._session_dir(session_id) + if not session_dir.exists(): + return + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._delete_attachment_sync, session_dir, attachment_id) + @staticmethod def _rmtree_sync(path: Path) -> None: import shutil @@ -195,6 +205,21 @@ def _rmtree_sync(path: Path) -> None: except OSError as exc: logger.warning("failed to clean up attachment dir %s: %s", path, exc) + @staticmethod + def _delete_attachment_sync(session_dir: Path, attachment_id: str) -> None: + prefix = f"{attachment_id}_" + for entry in session_dir.iterdir(): + if entry.name.startswith(prefix): + try: + entry.unlink() + except OSError as exc: + logger.warning("failed to delete attachment file %s: %s", entry, exc) + try: + if session_dir.exists() and not any(session_dir.iterdir()): + session_dir.rmdir() + except OSError as exc: + logger.warning("failed to remove empty attachment dir %s: %s", session_dir, exc) + def resolve_path(self, *, session_id: str, attachment_id: str, filename: str) -> Path | None: stored = self._stored_filename(attachment_id, filename) target = self._safe_join(session_id, stored) diff --git a/web/app/(workspace)/chat/[[...sessionId]]/page.tsx b/web/app/(workspace)/chat/[[...sessionId]]/page.tsx index 02af2acbe..efb090e70 100644 --- a/web/app/(workspace)/chat/[[...sessionId]]/page.tsx +++ b/web/app/(workspace)/chat/[[...sessionId]]/page.tsx @@ -310,6 +310,7 @@ export default function ChatPage() { sendMessage, cancelStreamingTurn, regenerateLastMessage, + deleteTurn, newSession, loadSession, } = useUnifiedChat(); @@ -1645,6 +1646,7 @@ export default function ChatPage() { onRegenerateMessage={handleRegenerateMessage} onConfirmOutline={handleConfirmOutline} onPreviewAttachment={handlePreviewMessageAttachment} + onDeleteTurn={deleteTurn} />
diff --git a/web/components/chat/home/ChatMessages.tsx b/web/components/chat/home/ChatMessages.tsx index d8bb3d797..a26eddfd6 100644 --- a/web/components/chat/home/ChatMessages.tsx +++ b/web/components/chat/home/ChatMessages.tsx @@ -12,6 +12,7 @@ import { RefreshCcw, Wand2, X, + Trash2, Zap, type LucideIcon, } from "lucide-react"; @@ -52,6 +53,7 @@ const VisualizationViewer = dynamic( ); interface ChatMessageItem { + id?: number; role: "user" | "assistant" | "system"; content: string; capability?: string; @@ -336,17 +338,57 @@ const UserMessage = memo(function UserMessage({ msg, index, onPreviewAttachment, + onDeleteTurn, }: { msg: ChatMessageItem; index: number; onPreviewAttachment?: (attachment: MessageAttachment) => void; + onDeleteTurn?: (messageId: number) => void; }) { const { t } = useTranslation(); + const [confirmDelete, setConfirmDelete] = useState(false); if (msg.content.startsWith("[Quiz Performance]")) return null; return ( -
+
+
+
+ {!confirmDelete ? ( + + ) : ( +
+ + {t("Delete this turn?")} + + + +
+ )} +
+
{t(getModeBadgeLabel(msg.capability))} @@ -686,6 +728,7 @@ export const ChatMessageList = memo(function ChatMessageList({ onConfirmOutline, onPreviewAttachment, onSwitchToManualMode, + onDeleteTurn, }: { messages: ChatMessageItem[]; isStreaming: boolean; @@ -707,6 +750,7 @@ export const ChatMessageList = memo(function ChatMessageList({ // button after a terminal failure. Optional so non-auto chat surfaces don't // have to wire it. onSwitchToManualMode?: () => void; + onDeleteTurn?: (messageId: number) => void; }) { const { t } = useTranslation(); const outlineStatusByIndex = useMemo(() => { @@ -783,6 +827,7 @@ export const ChatMessageList = memo(function ChatMessageList({ msg={msg} index={i} onPreviewAttachment={onPreviewAttachment} + onDeleteTurn={onDeleteTurn} /> ); } diff --git a/web/context/UnifiedChatContext.tsx b/web/context/UnifiedChatContext.tsx index cb5ad05d4..8652cb11f 100644 --- a/web/context/UnifiedChatContext.tsx +++ b/web/context/UnifiedChatContext.tsx @@ -19,7 +19,11 @@ import { } from "@/context/app-shell-storage"; import type { StreamEvent, ChatMessage, LLMSelection } from "@/lib/unified-ws"; import { UnifiedWSClient } from "@/lib/unified-ws"; -import { getSession, type SessionMessage } from "@/lib/session-api"; +import { + getSession, + deleteMessage, + type SessionMessage, +} from "@/lib/session-api"; import { normalizeMarkdownForDisplay } from "@/lib/markdown-display"; import { normalizeMessageContent } from "@/lib/message-content"; import { shouldAppendEventContent } from "@/lib/stream"; @@ -117,6 +121,7 @@ export interface MessageRequestSnapshot { } export interface MessageItem { + id?: number; role: "user" | "assistant" | "system"; content: string; capability?: string; @@ -182,6 +187,7 @@ type Action = llmSelection?: LLMSelection | null; language?: string; } + | { type: "DELETE_TURN"; key: string; messageId: number } | { type: "NEW_SESSION"; key: string }; function createSessionEntry( @@ -269,6 +275,7 @@ function reducer(state: ProviderState, action: Action): ProviderState { messages: [ ...session.messages, { + id: -Date.now(), role: "user", content: action.content, capability: action.capability || "", @@ -343,6 +350,7 @@ function reducer(state: ProviderState, action: Action): ProviderState { messages: [ ...(state.sessions[action.key]?.messages ?? []), { + id: -Date.now(), role: "assistant", content: "", events: [], @@ -362,6 +370,7 @@ function reducer(state: ProviderState, action: Action): ProviderState { let last = msgs[msgs.length - 1]; if (last?.role !== "assistant") { msgs.push({ + id: -Date.now(), role: "assistant", content: "", events: [], @@ -480,6 +489,42 @@ function reducer(state: ProviderState, action: Action): ProviderState { }, }; } + case "DELETE_TURN": { + const session = state.sessions[action.key]; + if (!session) return state; + const idx = session.messages.findIndex((m) => m.id === action.messageId); + if (idx === -1) return state; + const msg = session.messages[idx]; + const toRemove = new Set(); + toRemove.add(idx); + if (msg.role === "user") { + if ( + idx + 1 < session.messages.length && + session.messages[idx + 1].role === "assistant" + ) { + toRemove.add(idx + 1); + } + } else if (msg.role === "assistant") { + if (idx - 1 >= 0 && session.messages[idx - 1].role === "user") { + toRemove.add(idx - 1); + } + } + const nextMessages = session.messages.filter((_, i) => !toRemove.has(i)); + return { + ...state, + sessions: { + ...state.sessions, + [action.key]: { + ...session, + messages: nextMessages, + isStreaming: false, + status: "idle", + updatedAt: Date.now(), + }, + }, + sidebarRefreshToken: state.sidebarRefreshToken + 1, + }; + } case "NEW_SESSION": { const MAX_CACHED_SESSIONS = 20; let nextSessions = { @@ -531,6 +576,7 @@ interface ChatContextValue { ) => void; cancelStreamingTurn: () => void; regenerateLastMessage: () => void; + deleteTurn: (messageId: number) => Promise; newSession: () => void; loadSession: (sessionId: string) => Promise; selectedSessionId: string | null; @@ -714,6 +760,7 @@ export function UnifiedChatProvider({ attachments, ); return { + id: message.id, role: message.role, content: message.role === "assistant" @@ -1216,6 +1263,38 @@ export function UnifiedChatProvider({ dispatch({ type: "NEW_SESSION", key: makeDraftKey() }); }, [makeDraftKey]); + const deleteTurn = useCallback( + async (messageId: number) => { + const currentState = stateRef.current; + const key = currentState.selectedKey; + if (!key) return; + const session = currentState.sessions[key]; + if (!session || !session.sessionId) return; + if (session.isStreaming) return; + let effectiveId = messageId; + if (messageId < 0) { + const origIdx = session.messages.findIndex((m) => m.id === messageId); + if (origIdx === -1) return; + try { + await loadSession(session.sessionId); + } catch { + return; + } + const refreshed = stateRef.current.sessions[key]; + const realId = refreshed?.messages[origIdx]?.id; + if (realId == null || realId < 0) return; + effectiveId = realId; + } + try { + await deleteMessage(session.sessionId, effectiveId); + dispatch({ type: "DELETE_TURN", key, messageId: effectiveId }); + } catch (err) { + console.error("Failed to delete turn:", err); + } + }, + [loadSession], + ); + const value: ChatContextValue = { state: derivedState, setTools, @@ -1226,6 +1305,7 @@ export function UnifiedChatProvider({ sendMessage, cancelStreamingTurn, regenerateLastMessage, + deleteTurn, newSession, loadSession, selectedSessionId: derivedState.sessionId, diff --git a/web/lib/session-api.ts b/web/lib/session-api.ts index 7a0c62780..02aa0326d 100644 --- a/web/lib/session-api.ts +++ b/web/lib/session-api.ts @@ -189,3 +189,14 @@ export async function recordQuizResults( ); await expectJson<{ recorded: boolean }>(response); } + +export async function deleteMessage( + sessionId: string, + messageId: number, +): Promise { + const response = await apiFetch( + apiUrl(`/api/v1/sessions/${sessionId}/messages/${messageId}`), + { method: "DELETE" }, + ); + await expectJson<{ deleted: boolean }>(response); +}