Skip to content
Merged
186 changes: 122 additions & 64 deletions app/lib/workflows/__tests__/runAgentStep.test.ts
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { streamText } from "ai";
import { streamText, createUIMessageStream } from "ai";
import { runAgentStep } from "@/app/lib/workflows/runAgentStep";
import { persistAssistantMessage } from "@/lib/chat/persistAssistantMessage";

vi.mock("ai", async () => {
const actual = await vi.importActual<typeof import("ai")>("ai");
return { ...actual, streamText: vi.fn() };
return { ...actual, streamText: vi.fn(), createUIMessageStream: vi.fn() };
});

// Avoid pulling in real gateway / fetch surface.
vi.mock("@ai-sdk/gateway", () => ({
gateway: vi.fn((modelId: string) => ({ modelId, __mock: "gateway" })),
}));

vi.mock("@/lib/chat/persistAssistantMessage", () => ({
persistAssistantMessage: vi.fn(),
}));

// Captures the options runAgentStep passes to createUIMessageStream so
// tests can drive its onStepFinish / onFinish callbacks directly.
type CreateOpts = {
generateId?: () => string;
onStepFinish?: (e: { responseMessage: unknown }) => unknown;
onFinish?: (e: { responseMessage: unknown }) => unknown;
execute?: (a: { writer: { write: () => void; merge: () => void; onError: undefined } }) => void;
};
let capturedCreateOpts: CreateOpts;

function makeStreamResult(opts?: {
metadataCalls?: Array<unknown>;
onFinishCalls?: Array<unknown>;
emittedResponseMessage?: unknown;
generateIdCalls?: Array<unknown>;
}) {
const calls = opts?.metadataCalls ?? [];
const onFinishCalls = opts?.onFinishCalls ?? [];
const genCalls = opts?.generateIdCalls ?? [];
return {
toUIMessageStream: vi.fn((streamOpts: { messageMetadata?: unknown; onFinish?: unknown }) => {
// Capture the callbacks so tests can inspect (and invoke) them
calls.push(streamOpts.messageMetadata);
onFinishCalls.push(streamOpts.onFinish);
return (async function* () {
yield { type: "start" };
yield { type: "finish" };
// Mirror the AI SDK contract: onFinish fires after the
// generator yields its last chunk with the assembled message.
if (typeof streamOpts.onFinish === "function" && opts?.emittedResponseMessage) {
(streamOpts.onFinish as (a: { responseMessage: unknown }) => void)({
responseMessage: opts.emittedResponseMessage,
});
}
})();
}),
toUIMessageStream: vi.fn(
(streamOpts: { messageMetadata?: unknown; generateMessageId?: unknown }) => {
// Capture the callbacks so tests can inspect them.
calls.push(streamOpts.messageMetadata);
genCalls.push(streamOpts.generateMessageId);
return (async function* () {
yield { type: "start" };
yield { type: "finish" };
})();
},
),
finishReason: Promise.resolve("stop"),
};
}
Expand All @@ -59,6 +68,7 @@ const baseInput = {
},
],
modelId: "anthropic/claude-haiku-4.5",
chatId: "chat-1",
agentContext: {
sandbox: { state: { type: "vercel" }, workingDirectory: "/sandbox/mono" },
},
Expand All @@ -68,6 +78,20 @@ const baseInput = {
describe("runAgentStep", () => {
beforeEach(() => {
vi.clearAllMocks();
// Default: capture the options, run execute (so toUIMessageStream — and
// its messageMetadata callback — is exercised), and return an empty
// stream that closes immediately so pipeTo resolves.
vi.mocked(createUIMessageStream).mockImplementation((opts: never) => {
capturedCreateOpts = opts as CreateOpts;
capturedCreateOpts.execute?.({
writer: { write: () => {}, merge: () => {}, onError: undefined },
});
return new ReadableStream({
start(controller) {
controller.close();
},
}) as never;
});
});

it("wires a messageMetadata callback into toUIMessageStream", async () => {
Expand Down Expand Up @@ -103,8 +127,7 @@ describe("runAgentStep", () => {
});

it("includes cwd from agentContext.sandbox in the system prompt", async () => {
const captured: unknown[] = [];
vi.mocked(streamText).mockReturnValue(makeStreamResult({ metadataCalls: captured }) as never);
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
const { stream } = makeWritable();

await runAgentStep({
Expand All @@ -125,8 +148,7 @@ describe("runAgentStep", () => {
});

it("wraps tools with anthropic cacheControl on the last tool before passing to streamText", async () => {
const captured: unknown[] = [];
vi.mocked(streamText).mockReturnValue(makeStreamResult({ metadataCalls: captured }) as never);
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
const { stream } = makeWritable();

await runAgentStep({ ...baseInput, writable: stream } as never);
Expand All @@ -148,8 +170,7 @@ describe("runAgentStep", () => {
});

it("wires a prepareStep callback that marks the last message with cacheControl", async () => {
const captured: unknown[] = [];
vi.mocked(streamText).mockReturnValue(makeStreamResult({ metadataCalls: captured }) as never);
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
const { stream } = makeWritable();

await runAgentStep({ ...baseInput, writable: stream } as never);
Expand Down Expand Up @@ -189,55 +210,49 @@ describe("runAgentStep", () => {
expect(cb({ part: { type: "start" } })).toBeUndefined();
});

it("wires an onFinish callback into toUIMessageStream", async () => {
const onFinishCalls: unknown[] = [];
vi.mocked(streamText).mockReturnValue(makeStreamResult({ onFinishCalls }) as never);
it("persists the assistant message on each step via onStepFinish", async () => {
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
const { stream } = makeWritable();

await runAgentStep({ ...baseInput, writable: stream } as never);

expect(onFinishCalls).toHaveLength(1);
expect(typeof onFinishCalls[0]).toBe("function");
const msg = { id: "a1", role: "assistant", parts: [{ type: "text", text: "partial" }] };
await capturedCreateOpts.onStepFinish?.({ responseMessage: msg });

expect(persistAssistantMessage).toHaveBeenCalledWith("chat-1", msg);
});

it("returns the responseMessage captured from onFinish", async () => {
const emittedResponseMessage = {
id: "assistant-msg-1",
role: "assistant",
parts: [{ type: "text", text: "Hello" }],
};
vi.mocked(streamText).mockReturnValue(makeStreamResult({ emittedResponseMessage }) as never);
it("persists the final assistant message via onFinish", async () => {
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
const { stream } = makeWritable();

const result = await runAgentStep({ ...baseInput, writable: stream } as never);
await runAgentStep({ ...baseInput, writable: stream } as never);

expect(result.responseMessage).toEqual(emittedResponseMessage);
expect(result.finishReason).toBe("stop");
const msg = { id: "a1", role: "assistant", parts: [{ type: "text", text: "done" }] };
await capturedCreateOpts.onFinish?.({ responseMessage: msg });

expect(persistAssistantMessage).toHaveBeenCalledWith("chat-1", msg);
});

it("returns responseMessage: undefined when onFinish never fires", async () => {
// Default makeStreamResult — no emittedResponseMessage, so onFinish is wired but never invoked
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
it("forwards assistantMessageId into toUIMessageStream's generateMessageId (stable row id)", async () => {
const generateIdCalls: unknown[] = [];
vi.mocked(streamText).mockReturnValue(makeStreamResult({ generateIdCalls }) as never);
const { stream } = makeWritable();

const result = await runAgentStep({ ...baseInput, writable: stream } as never);
await runAgentStep({
...baseInput,
writable: stream,
assistantMessageId: "asst-from-workflow-xyz",
} as never);

expect(result.responseMessage).toBeUndefined();
expect(result.finishReason).toBe("stop");
expect(generateIdCalls).toHaveLength(1);
const gen = generateIdCalls[0] as () => string;
expect(typeof gen).toBe("function");
expect(gen()).toBe("asst-from-workflow-xyz");
});

it("forwards input.assistantMessageId into toUIMessageStream's generateMessageId", async () => {
const generateMessageIdCalls: unknown[] = [];
const streamResult = makeStreamResult();
// Spy on the options passed to toUIMessageStream to grab the generateMessageId fn.
const originalToUIMessageStream = streamResult.toUIMessageStream;
streamResult.toUIMessageStream = vi.fn((streamOpts: { generateMessageId?: unknown }) => {
generateMessageIdCalls.push(streamOpts.generateMessageId);
return (originalToUIMessageStream as unknown as (o: unknown) => AsyncGenerator<unknown>)(
streamOpts,
);
}) as never;
vi.mocked(streamText).mockReturnValue(streamResult as never);
it("sets a stable generateId on the createUIMessageStream", async () => {
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
const { stream } = makeWritable();

await runAgentStep({
Expand All @@ -246,9 +261,52 @@ describe("runAgentStep", () => {
assistantMessageId: "asst-from-workflow-xyz",
} as never);

expect(generateMessageIdCalls).toHaveLength(1);
const gen = generateMessageIdCalls[0] as () => string;
expect(typeof gen).toBe("function");
expect(gen()).toBe("asst-from-workflow-xyz");
expect(typeof capturedCreateOpts.generateId).toBe("function");
expect(capturedCreateOpts.generateId!()).toBe("asst-from-workflow-xyz");
});

it("returns the finishReason from the model result", async () => {
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
const { stream } = makeWritable();

const result = await runAgentStep({ ...baseInput, writable: stream } as never);

expect(result.finishReason).toBe("stop");
});

it("returns the responseMessage captured from onFinish (so the workflow can charge credits)", async () => {
const emitted = {
id: "asst-test-id",
role: "assistant",
parts: [{ type: "text", text: "Hello" }],
metadata: { totalMessageCost: 0.05 },
};
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
vi.mocked(createUIMessageStream).mockImplementation((opts: never) => {
const o = opts as CreateOpts;
o.execute?.({ writer: { write: () => {}, merge: () => {}, onError: undefined } });
// Drive onFinish so runAgentStep captures the final message.
void o.onFinish?.({ responseMessage: emitted });
return new ReadableStream({
start(controller) {
controller.close();
},
}) as never;
});
const { stream } = makeWritable();

const result = await runAgentStep({ ...baseInput, writable: stream } as never);

expect(result.responseMessage).toEqual(emitted);
});

it("returns responseMessage: undefined when onFinish never fires", async () => {
// Default mock never invokes onFinish.
vi.mocked(streamText).mockReturnValue(makeStreamResult() as never);
const { stream } = makeWritable();

const result = await runAgentStep({ ...baseInput, writable: stream } as never);

expect(result.responseMessage).toBeUndefined();
});
});
36 changes: 2 additions & 34 deletions app/lib/workflows/__tests__/runAgentWorkflow.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { runAgentStep } from "@/app/lib/workflows/runAgentStep";
import { clearChatActiveStream } from "@/lib/chat/clearChatActiveStream";
import { closeChatStream } from "@/app/lib/workflows/closeChatStream";
import { generateAssistantMessageId } from "@/app/lib/workflows/generateAssistantMessageId";
import { persistAssistantMessage } from "@/lib/chat/persistAssistantMessage";
import { handleChatCredits } from "@/lib/credits/handleChatCredits";
import { autoCommitChatTurn } from "@/lib/chat/auto-commit/autoCommitChatTurn";

Expand All @@ -20,9 +19,6 @@ vi.mock("@/app/lib/workflows/closeChatStream", () => ({
vi.mock("@/app/lib/workflows/generateAssistantMessageId", () => ({
generateAssistantMessageId: vi.fn(),
}));
vi.mock("@/lib/chat/persistAssistantMessage", () => ({
persistAssistantMessage: vi.fn(),
}));
vi.mock("@/lib/credits/handleChatCredits", () => ({
handleChatCredits: vi.fn(),
}));
Expand Down Expand Up @@ -114,43 +110,15 @@ describe("runAgentWorkflow", () => {
expect(closeChatStream).toHaveBeenCalledWith(writableStub);
});

it("persists the assistant message when runAgentStep returns one", async () => {
const responseMessage = {
id: "assistant-msg-xyz",
role: "assistant",
parts: [{ type: "text", text: "Hello!" }],
};
vi.mocked(runAgentStep).mockResolvedValue({
finishReason: "stop",
responseMessage: responseMessage as never,
});

await runAgentWorkflow(baseInput);

expect(persistAssistantMessage).toHaveBeenCalledTimes(1);
expect(persistAssistantMessage).toHaveBeenCalledWith("chat-1", responseMessage);
});

it("does NOT call persistAssistantMessage when runAgentStep returns no responseMessage", async () => {
it("forwards chatId to runAgentStep so it can persist the assistant message per step", async () => {
vi.mocked(runAgentStep).mockResolvedValue({
finishReason: "stop",
responseMessage: undefined,
});

await runAgentWorkflow(baseInput);

expect(persistAssistantMessage).not.toHaveBeenCalled();
});

it("does NOT call persistAssistantMessage when runAgentStep throws (no message to persist)", async () => {
vi.mocked(runAgentStep).mockRejectedValue(new Error("model exploded"));

await expect(runAgentWorkflow(baseInput)).rejects.toThrow("model exploded");

expect(persistAssistantMessage).not.toHaveBeenCalled();
// But cleanup steps still run via the try/finally
expect(clearChatActiveStream).toHaveBeenCalledTimes(1);
expect(closeChatStream).toHaveBeenCalledTimes(1);
expect(runAgentStep).toHaveBeenCalledWith(expect.objectContaining({ chatId: "chat-1" }));
});

it("generates a fresh assistantMessageId via the step and forwards it to runAgentStep", async () => {
Expand Down
Loading
Loading