diff --git a/.env.example b/.env.example index 2d54c20..67938d6 100644 --- a/.env.example +++ b/.env.example @@ -12,8 +12,8 @@ CODERAG_GEMINI_API_KEY=your_api_key_here # Compatibility alias also accepted: CODERAG_GEMINI_AI_KEY # Optional: Override the default Gemini embedding model -# Default: models/gemini-embedding-001 -CODERAG_GEMINI_MODEL=models/gemini-embedding-001 +# Default: models/gemini-embedding-2 +CODERAG_GEMINI_MODEL=models/gemini-embedding-2 # ============================================ # EMBEDDING CONFIGURATION diff --git a/README.md b/README.md index 110efc8..7822174 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ Supported environment overrides: - `CODERAG_CUSTOM_HTTP_FORMAT` - `CODERAG_LLM_HEADERS` -When `embedding.provider` is `gemini`, CodeRag defaults to `models/gemini-embedding-001` and requests 768-dimensional vectors explicitly so the stored embedding fingerprint matches the vectors written to LanceDB. It accepts either `CODERAG_GEMINI_API_KEY` or the compatibility alias `CODERAG_GEMINI_AI_KEY`. +When `embedding.provider` is `gemini`, CodeRag defaults to `models/gemini-embedding-2` and requests 768-dimensional vectors explicitly so the stored embedding fingerprint matches the vectors written to LanceDB. It accepts either `CODERAG_GEMINI_API_KEY` or the compatibility alias `CODERAG_GEMINI_AI_KEY`. When `embedding.provider` is `onnx`, CodeRag uses `Xenova/gte-small` (384-dim, ~33MB) running locally via `@xenova/transformers`. No API key or external server needed. The model must be downloaded to `/Xenova/gte-small/` (default `.coderag-models/models/Xenova/gte-small/`). diff --git a/src/cli/setup-wizard.ts b/src/cli/setup-wizard.ts index f6032c2..7d4dda6 100644 --- a/src/cli/setup-wizard.ts +++ b/src/cli/setup-wizard.ts @@ -68,7 +68,7 @@ export const runSetupWizard = async (cwd: string, logger?: Logger): Promise (items.length > 0 ? items.map((item) => `- ${item}`).join("\n") : EMPTY_LIST); @@ -259,9 +260,10 @@ export const buildIndexedDocuments = async ( embeddingText = [doc, sourceText].filter(Boolean).join("\n\n"); } - // Truncate to save memory — embedding models cap at ~512 tokens anyway - if (embeddingText.length > MAX_EMBEDDING_CHARS) { - embeddingText = embeddingText.slice(0, MAX_EMBEDDING_CHARS); + // Truncate to fit the model's token limit + const maxChars = embeddingProvider.maxInputTokens * CHARS_PER_TOKEN; + if (embeddingText.length > maxChars) { + embeddingText = embeddingText.slice(0, maxChars); } preparedForChunk.push({ diff --git a/src/indexer/embedder.ts b/src/indexer/embedder.ts index 5b96595..167f0b3 100644 --- a/src/indexer/embedder.ts +++ b/src/indexer/embedder.ts @@ -5,6 +5,8 @@ export class LocalHashEmbeddingProvider implements EmbeddingProvider { readonly name = "local-hash"; readonly model = "local-hash"; readonly dimensions: number; + /** Unlimited — hash-based embedding has no token limit. */ + readonly maxInputTokens = Infinity; constructor(dimensions = 256) { this.dimensions = dimensions; diff --git a/src/indexer/gemini-embedder.ts b/src/indexer/gemini-embedder.ts index fde2923..3835536 100644 --- a/src/indexer/gemini-embedder.ts +++ b/src/indexer/gemini-embedder.ts @@ -2,7 +2,7 @@ import type { EmbeddingProvider } from "../types.js"; import { ConfigurationError } from "../errors/index.js"; const GEMINI_API_BASE = "https://generativelanguage.googleapis.com/v1beta"; -const DEFAULT_MODEL = "models/gemini-embedding-001"; +const DEFAULT_MODEL = "models/gemini-embedding-2"; const DEFAULT_DIMENSIONS = 768; const MAX_BATCH_SIZE = 100; const GEMINI_API_KEY_ENV = "CODERAG_GEMINI_API_KEY"; @@ -23,6 +23,7 @@ export class GeminiEmbeddingProvider implements EmbeddingProvider { readonly name = "gemini"; readonly dimensions = DEFAULT_DIMENSIONS; readonly maxBatchSize = MAX_BATCH_SIZE; + readonly maxInputTokens = 8192; readonly model: string; private readonly apiKey: string; private readonly timeoutMs: number; diff --git a/src/indexer/onnx-embedder.ts b/src/indexer/onnx-embedder.ts index b4f2917..02bc2a2 100644 --- a/src/indexer/onnx-embedder.ts +++ b/src/indexer/onnx-embedder.ts @@ -98,6 +98,7 @@ export class OnnxEmbeddingProvider implements EmbeddingProvider { readonly model = DEFAULT_MODEL; readonly dimensions = DEFAULT_DIMENSIONS; readonly maxBatchSize = 1; // One at a time to minimize memory pressure + readonly maxInputTokens = 256; // all-MiniLM-L6-v2 max sequence length private readonly modelDir: string; private readonly logger?: Logger; diff --git a/src/llm/context-builder.ts b/src/llm/context-builder.ts index 158fbdb..58c33d9 100644 --- a/src/llm/context-builder.ts +++ b/src/llm/context-builder.ts @@ -1,6 +1,13 @@ import type { BlueprintNode } from "@abhinav2203/codeflow-core/schema"; -import type { ContextPackage, GraphSnapshot, IndexedNodeDocument, RetrievedNodeContext, RetrievalConfig } from "../types.js"; +import type { + ContextPackage, + GraphSnapshot, + IndexedNodeDocument, + RetrievedNodeContext, + RetrievalConfig +} from "../types.js"; +import type { SectionLimits } from "./prompt.js"; import { FileCache } from "../store/file-cache.js"; import { createRetrievedNodeContext } from "../retrieval/page-index.js"; @@ -25,6 +32,36 @@ const buildGraphSummary = ( return parts.join(" "); }; +/** + * Derives per-section char limits from retrieval config. + * + * Defaults are proportional to maxContextChars so they scale automatically. + * Explicit overrides (when the user sets primaryDocLimit, etc.) always take precedence. + * + * Default distribution for a 16K baseline: + * primaryDoc -> 1,200 (7.5%) + * primaryFile -> 4,000 (25%) + * relatedDoc -> 320 (2%) + * relatedFile -> 1,200 (7.5%) + * Remaining ~58% is for structural overhead (headers, warnings, graph summary). + */ +export const deriveSectionLimits = (retrieval: RetrievalConfig): SectionLimits => { + const mcc = retrieval.maxContextChars; + + // Proportional defaults relative to a 16,000 baseline. + const primaryDocDefault = Math.max(1, Math.round((mcc / 16000) * 1200)); + const primaryFileDefault = Math.max(1, Math.round((mcc / 16000) * 4000)); + const relatedDocDefault = Math.max(1, Math.round((mcc / 16000) * 320)); + const relatedFileDefault = Math.max(1, Math.round((mcc / 16000) * 1200)); + + return { + primaryDoc: retrieval.primaryDocLimit ?? primaryDocDefault, + primaryFile: retrieval.primaryFileLimit ?? primaryFileDefault, + relatedDoc: retrieval.relatedDocLimit ?? relatedDocDefault, + relatedFile: retrieval.relatedFileLimit ?? relatedFileDefault + }; +}; + const truncateContext = (context: RetrievedNodeContext, maxChars: number, warnings: string[]): RetrievedNodeContext => { if (context.fullFileContent.length <= maxChars) { return context; @@ -115,6 +152,8 @@ const buildRelatedContextPromises = ( /** * Builds the final context package passed to the LLM or returned directly to the caller. + * + * The caller receives `limits` so it can pass them through to `buildMessages()`. */ export const buildContextPackage = async ( question: string, @@ -127,7 +166,7 @@ export const buildContextPackage = async ( dependencies: BlueprintNode[], dependents: BlueprintNode[], answerMode: ContextPackage["answerMode"] -): Promise => { +): Promise<{ context: ContextPackage; limits: SectionLimits }> => { const primaryDocument = primaryNode ? documents[primaryNode.id] : undefined; const primaryContext = primaryDocument ? await createRetrievedNodeContext(repoPath, fileCache, snapshot, primaryDocument, "primary") @@ -138,13 +177,18 @@ export const buildContextPackage = async ( const primaryResult = fitPrimaryContext(primaryContext, retrieval.maxContextChars); const relatedResult = fitRelatedContexts(resolvedRelatedContexts, primaryResult.remainingBudget); + const limits = deriveSectionLimits(retrieval); + return { - question, - answerMode, - retrievalMode: "single" as const, - primaryNode: primaryResult.primaryContext, - relatedNodes: relatedResult.relatedContexts, - graphSummary: buildGraphSummary(primaryNode, dependencies, dependents), - warnings: [...primaryResult.warnings, ...relatedResult.warnings] + context: { + question, + answerMode, + retrievalMode: "single" as const, + primaryNode: primaryResult.primaryContext, + relatedNodes: relatedResult.relatedContexts, + graphSummary: buildGraphSummary(primaryNode, dependencies, dependents), + warnings: [...primaryResult.warnings, ...relatedResult.warnings] + }, + limits }; -}; +}; \ No newline at end of file diff --git a/src/llm/multi-hop-context-builder.ts b/src/llm/multi-hop-context-builder.ts index ea665e4..fa35010 100644 --- a/src/llm/multi-hop-context-builder.ts +++ b/src/llm/multi-hop-context-builder.ts @@ -5,9 +5,11 @@ import type { GraphSnapshot, IndexedNodeDocument, MultiHopRetrievalResult, + RetrievedNodeContext, RetrievalConfig } from "../types.js"; -import type { RetrievedNodeContext } from "../types.js"; +import type { SectionLimits } from "./prompt.js"; +import { deriveSectionLimits } from "./context-builder.js"; import { FileCache } from "../store/file-cache.js"; import { createRetrievedNodeContext } from "../retrieval/page-index.js"; @@ -77,6 +79,8 @@ const buildRelatedNodeContexts = async ( * Unlike the single-node path, there is no single primary node. * The first retrieved node is promoted to "primary" for display purposes, * and all remaining nodes are listed as related. + * + * Returns both the context and the derived section limits for prompt building. */ export const buildMultiHopContextPackage = async ( question: string, @@ -87,10 +91,9 @@ export const buildMultiHopContextPackage = async ( documents: Record, retrieval: RetrievalConfig, fileCache: FileCache -): Promise => { +): Promise<{ context: ContextPackage; limits: SectionLimits }> => { const allNodes = retrievalResult.deduplicatedNodes; - // Build RetrievedNodeContext for all deduplicated nodes const allContexts = await buildRelatedNodeContexts( allNodes, repoPath, @@ -99,14 +102,12 @@ export const buildMultiHopContextPackage = async ( documents ); - // Promote the first node to "primary" for display const firstCtx = allContexts[0]; const primaryContext: RetrievedNodeContext | null = firstCtx ? Object.assign({}, firstCtx, { relationship: "primary" as const, subQuestionIndex: undefined }) : null; const relatedContexts: RetrievedNodeContext[] = allContexts.length > 1 ? allContexts.slice(1) : []; - // Apply context budgeting const warnings: string[] = []; let remainingBudget = retrieval.maxContextChars; @@ -125,6 +126,7 @@ export const buildMultiHopContextPackage = async ( const fittedRelated: RetrievedNodeContext[] = []; for (const ctx of relatedContexts) { if (remainingBudget <= 0) { + warnings.push(`Dropped file content for ${ctx.filePath} because the context budget was exhausted.`); fittedRelated.push({ ...ctx, fullFileContent: "" }); continue; } @@ -149,15 +151,20 @@ export const buildMultiHopContextPackage = async ( filesReferenced: meta.filesReferenced })); + const limits = deriveSectionLimits(retrieval); + return { - question, - answerMode: "llm" as const, - retrievalMode: "multi-hop" as const, - primaryNode: fittedPrimary, - relatedNodes: fittedRelated, - graphSummary: buildMultiHopGraphSummary(subQuestions, retrievalResult, snapshot), - warnings, - subQuestions, - subQuestionResults + context: { + question, + answerMode: "llm" as const, + retrievalMode: "multi-hop" as const, + primaryNode: fittedPrimary, + relatedNodes: fittedRelated, + graphSummary: buildMultiHopGraphSummary(subQuestions, retrievalResult, snapshot), + warnings, + subQuestions, + subQuestionResults + }, + limits }; -}; +}; \ No newline at end of file diff --git a/src/llm/prompt.ts b/src/llm/prompt.ts index 4ff5956..1bc3530 100644 --- a/src/llm/prompt.ts +++ b/src/llm/prompt.ts @@ -1,11 +1,11 @@ import type { ContextPackage, RetrievedNodeContext, LlmRequest } from "../types.js"; -const PRIMARY_DOC_CHAR_LIMIT = 1_200; -const PRIMARY_FILE_CHAR_LIMIT = 4_000; -const RELATED_DOC_CHAR_LIMIT = 320; -const RELATED_FILE_CHAR_LIMIT = 1_200; -const WARNING_CHAR_LIMIT = 160; -const MAX_WARNING_COUNT = 4; +export interface SectionLimits { + primaryDoc: number; + primaryFile: number; + relatedDoc: number; + relatedFile: number; +} const truncateText = (value: string, maxChars: number): string => { if (value.length <= maxChars) { @@ -35,11 +35,11 @@ const formatFileSection = (value: string, maxChars: number): string => const joinSections = (sections: string[]): string => sections.filter(Boolean).join("\n\n"); -const formatPrimaryNode = (node: RetrievedNodeContext): string => +const formatPrimaryNode = (node: RetrievedNodeContext, limits: SectionLimits): string => joinSections([ `Primary node:\n${formatNodeHeader(node)}`, - formatDocSection("Primary doc", node.doc, PRIMARY_DOC_CHAR_LIMIT), - formatFileSection(node.fullFileContent, PRIMARY_FILE_CHAR_LIMIT) + formatDocSection("Primary doc", node.doc, limits.primaryDoc), + formatFileSection(node.fullFileContent, limits.primaryFile) ]); const shouldIncludeRelatedFile = ( @@ -51,7 +51,8 @@ const shouldIncludeRelatedFile = ( const formatRelatedNode = ( node: RetrievedNodeContext, primaryNode: RetrievedNodeContext | null, - includedFiles: Set + includedFiles: Set, + limits: SectionLimits ): string => { const includeFile = shouldIncludeRelatedFile(node, primaryNode, includedFiles); if (includeFile) { @@ -60,14 +61,15 @@ const formatRelatedNode = ( return joinSections([ formatNodeHeader(node), - formatDocSection("Related doc", node.doc, RELATED_DOC_CHAR_LIMIT), - includeFile ? formatFileSection(node.fullFileContent, RELATED_FILE_CHAR_LIMIT) : "" + formatDocSection("Related doc", node.doc, limits.relatedDoc), + includeFile ? formatFileSection(node.fullFileContent, limits.relatedFile) : "" ]); }; const formatRelatedNodes = ( relatedNodes: RetrievedNodeContext[], - primaryNode: RetrievedNodeContext | null + primaryNode: RetrievedNodeContext | null, + limits: SectionLimits ): string => { if (relatedNodes.length === 0) { return "Related nodes:\nnone"; @@ -75,7 +77,7 @@ const formatRelatedNodes = ( const includedFiles = new Set(primaryNode ? [primaryNode.filePath] : []); const entries = relatedNodes.map((node, index) => - `${index + 1}. ${formatRelatedNode(node, primaryNode, includedFiles)}` + `${index + 1}. ${formatRelatedNode(node, primaryNode, includedFiles, limits)}` ); return `Related nodes:\n${entries.join("\n\n")}`; }; @@ -85,17 +87,19 @@ const formatWarnings = (warnings: string[]): string => { return ""; } + const MAX_WARNING_COUNT = 4; + const WARNING_CHAR_LIMIT = 160; const entries = warnings .slice(0, MAX_WARNING_COUNT) .map((warning, index) => `${index + 1}. ${truncateText(warning, WARNING_CHAR_LIMIT)}`); return `Warnings:\n${entries.join("\n")}`; }; -const summarizeContext = (context: ContextPackage): string => +const summarizeContext = (context: ContextPackage, limits: SectionLimits): string => joinSections([ `Graph summary:\n${context.graphSummary}`, - context.primaryNode ? formatPrimaryNode(context.primaryNode) : "Primary node:\nnone", - formatRelatedNodes(context.relatedNodes, context.primaryNode), + context.primaryNode ? formatPrimaryNode(context.primaryNode, limits) : "Primary node:\nnone", + formatRelatedNodes(context.relatedNodes, context.primaryNode, limits), formatWarnings(context.warnings) ]); @@ -107,14 +111,18 @@ export const buildSystemPrompt = (): string => "Do not invent functions, files, or behavior that is not present in the retrieved context." ].join(" "); -export const buildMessages = (question: string, context: ContextPackage): LlmRequest["messages"] => [ +export const buildMessages = ( + question: string, + context: ContextPackage, + limits: SectionLimits +): LlmRequest["messages"] => [ { role: "system", content: buildSystemPrompt() }, { role: "user", - content: `Question:\n${question}\n\n${summarizeContext(context)}` + content: `Question:\n${question}\n\n${summarizeContext(context, limits)}` } ]; @@ -129,7 +137,8 @@ const MULTI_HOP_SYSTEM_PROMPT = [ const formatSubQuestionSection = ( index: number, question: string, - nodes: RetrievedNodeContext[] + nodes: RetrievedNodeContext[], + limits: SectionLimits ): string => { if (nodes.length === 0) { return `Sub-question ${index + 1}: "${question}"\nNo matching code found.`; @@ -148,7 +157,8 @@ const formatSubQuestionSection = ( */ export const buildMultiHopMessages = ( question: string, - context: ContextPackage + context: ContextPackage, + limits: SectionLimits ): LlmRequest["messages"] => { const subQuestions = context.subQuestions ?? []; const nodeByIndex = new Map(); @@ -163,7 +173,7 @@ export const buildMultiHopMessages = ( const subQuestionSections = subQuestions .map((sq, i) => { const nodes = nodeByIndex.get(i) ?? []; - return formatSubQuestionSection(i, sq, nodes); + return formatSubQuestionSection(i, sq, nodes, limits); }) .join("\n\n"); diff --git a/src/service/coderag.ts b/src/service/coderag.ts index d25b68b..b45c6fa 100644 --- a/src/service/coderag.ts +++ b/src/service/coderag.ts @@ -268,7 +268,7 @@ export class CodeRag { const { dependencies, dependents } = primaryNode ? traverseDependencies(snapshot, primaryNode.id, depth) : { dependencies: [], dependents: [] }; - const context = await buildContextPackage( + const { context, limits } = await buildContextPackage( question, this.config.repoPath, snapshot, @@ -297,7 +297,7 @@ export class CodeRag { model: this.config.llm.model, stream: Boolean(options.onToken), context, - messages: buildMessages(question, context) + messages: buildMessages(question, context, limits) }, options.onToken ); @@ -348,7 +348,7 @@ export class CodeRag { ); // Stage 3: Context assembly + synthesis - const context = await buildMultiHopContextPackage( + const { context, limits } = await buildMultiHopContextPackage( question, subQuestions, retrievalResult, @@ -365,7 +365,7 @@ export class CodeRag { model: this.config.llm.model, stream: Boolean(options.onToken), context, - messages: buildMultiHopMessages(question, context) + messages: buildMultiHopMessages(question, context, limits) }, options.onToken ); diff --git a/src/service/config.ts b/src/service/config.ts index 99d31b6..b98878a 100644 --- a/src/service/config.ts +++ b/src/service/config.ts @@ -154,7 +154,11 @@ export const loadSerializableConfig = async (cwd: string, configPath?: string): ...baseConfig.retrieval, topK: parseNumber(process.env.CODERAG_TOP_K) ?? baseConfig.retrieval.topK, rerankK: parseNumber(process.env.CODERAG_RERANK_K) ?? baseConfig.retrieval.rerankK, - maxContextChars: parseNumber(process.env.CODERAG_MAX_CONTEXT_CHARS) ?? baseConfig.retrieval.maxContextChars + maxContextChars: parseNumber(process.env.CODERAG_MAX_CONTEXT_CHARS) ?? baseConfig.retrieval.maxContextChars, + primaryDocLimit: parseNumber(process.env.CODERAG_PRIMARY_DOC_LIMIT) ?? baseConfig.retrieval.primaryDocLimit, + primaryFileLimit: parseNumber(process.env.CODERAG_PRIMARY_FILE_LIMIT) ?? baseConfig.retrieval.primaryFileLimit, + relatedDocLimit: parseNumber(process.env.CODERAG_RELATED_DOC_LIMIT) ?? baseConfig.retrieval.relatedDocLimit, + relatedFileLimit: parseNumber(process.env.CODERAG_RELATED_FILE_LIMIT) ?? baseConfig.retrieval.relatedFileLimit }, multiHop: { ...baseConfig.multiHop, @@ -206,7 +210,7 @@ export const resolveRuntimeConfig = (config: SerializableCodeRagConfig, cwd: str const embeddingConfig = config.embedding ?? { provider: "local-hash" as const, dimensions: 256, - geminiModel: "models/gemini-embedding-001", + geminiModel: "models/gemini-embedding-2", timeoutMs: 30000 }; diff --git a/src/test/context-builder.test.ts b/src/test/context-builder.test.ts index ff0091e..652eefd 100644 --- a/src/test/context-builder.test.ts +++ b/src/test/context-builder.test.ts @@ -89,7 +89,7 @@ describe("context builder", () => { related: createDocument("related", "related", "src/related.ts", "related doc") }; - const context = await buildContextPackage( + const { context, limits } = await buildContextPackage( "what calls related", repoPath, snapshot, @@ -105,6 +105,12 @@ describe("context builder", () => { expect(context.primaryNode?.fullFileContent).toBe("PRIMARY-CONTENT"); expect(context.relatedNodes[0]?.callSiteLines).toEqual([3]); expect(context.warnings).toContain("Truncated src/related.ts to stay within the context budget."); + expect(limits).toEqual({ + primaryDoc: 1, + primaryFile: 5, + relatedDoc: 1, + relatedFile: 1 + }); await cleanupPaths([repoPath]); }); @@ -148,7 +154,7 @@ describe("context builder", () => { related: createDocument("related", "related", "src/related.ts", "related doc") }; - const context = await buildContextPackage( + const { context, limits } = await buildContextPackage( "missing", repoPath, snapshot, @@ -165,6 +171,12 @@ describe("context builder", () => { expect(context.graphSummary).toContain("No matching node"); expect(context.relatedNodes[0]?.fullFileContent).toBe(""); expect(context.warnings).toContain("Dropped file content for src/related.ts because the context budget was exhausted."); + expect(limits).toEqual({ + primaryDoc: 1, + primaryFile: 1, + relatedDoc: 1, + relatedFile: 1 + }); await cleanupPaths([repoPath]); }); @@ -208,7 +220,7 @@ describe("context builder", () => { primary: createDocument("primary", "primary", "src/primary.ts", "primary doc") }; - const context = await buildContextPackage( + const { context, limits } = await buildContextPackage( "primary", repoPath, snapshot, @@ -222,6 +234,12 @@ describe("context builder", () => { ); expect(context.graphSummary).toBe("Primary node: primary."); + expect(limits).toEqual({ + primaryDoc: 5, + primaryFile: 16, + relatedDoc: 1, + relatedFile: 5 + }); await cleanupPaths([repoPath]); }); }); diff --git a/src/test/gemini-embedder.test.ts b/src/test/gemini-embedder.test.ts index 055df4a..4ab6dd4 100644 --- a/src/test/gemini-embedder.test.ts +++ b/src/test/gemini-embedder.test.ts @@ -73,7 +73,7 @@ describe("GeminiEmbeddingProvider", () => { await expect(provider.embed("hello from alias env")).resolves.toEqual([9, 8, 7]); expect(fetchSpy).toHaveBeenCalledWith( - "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:embedContent?key=alias-key", + "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2:embedContent?key=alias-key", expect.any(Object) ); }); @@ -85,9 +85,9 @@ describe("GeminiEmbeddingProvider", () => { const provider = new GeminiEmbeddingProvider({ apiKey: "config-key" }); await expect(provider.embed("default model")).resolves.toEqual([7]); - expect(provider.model).toBe("models/gemini-embedding-001"); + expect(provider.model).toBe("models/gemini-embedding-2"); expect(fetchSpy).toHaveBeenCalledWith( - "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:embedContent?key=config-key", + "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2:embedContent?key=config-key", expect.any(Object) ); }); diff --git a/src/test/prompt.test.ts b/src/test/prompt.test.ts index 4fd4de0..157f10c 100644 --- a/src/test/prompt.test.ts +++ b/src/test/prompt.test.ts @@ -48,8 +48,21 @@ const context = { }; describe("prompt builder", () => { + const testLimits = { + primaryDoc: 1200, + primaryFile: 4000, + relatedDoc: 320, + relatedFile: 1200 + }; + const tinyLimits = { + primaryDoc: 20, + primaryFile: 24, + relatedDoc: 22, + relatedFile: 26 + }; + it("builds the system prompt and a compact user context", () => { - const userMessage = buildMessages("where is auth handled?", context)[1]?.content ?? ""; + const userMessage = buildMessages("where is auth handled?", context, testLimits)[1]?.content ?? ""; expect(buildSystemPrompt()).toContain("Only use the provided repository context."); expect(userMessage).toContain("Graph summary:"); @@ -60,9 +73,21 @@ describe("prompt builder", () => { expect(userMessage).toContain("Warnings:"); expect(userMessage).not.toContain("\"graphSummary\""); expect(userMessage).not.toContain("DUPLICATE FILE CONTENT"); + expect(userMessage).not.toContain("...[truncated]"); expect(userMessage.length).toBeLessThan(1_500); }); + it("applies caller-provided limits when truncating docs and file excerpts", () => { + const userMessage = buildMessages("where is auth handled?", context, tinyLimits)[1]?.content ?? ""; + + expect(userMessage).toContain("Primary doc:\nrequi\n...[truncated]"); + expect(userMessage).toContain("File excerpt:\nexport fu\n...[truncated]"); + expect(userMessage).toContain("Related doc:\nverifyT\n...[truncated]"); + expect(userMessage).toContain("Related doc:\ngetSess\n...[truncated]"); + expect(userMessage).toContain("File excerpt:\nexport func\n...[truncated]"); + expect(userMessage).not.toContain("DUPLICATE FILE CONTENT"); + }); + it("renders a missing primary node without related entries", () => { const userMessage = buildMessages("where is auth handled?", { @@ -70,11 +95,12 @@ describe("prompt builder", () => { primaryNode: null, relatedNodes: [], warnings: [] - })[1]?.content ?? ""; + }, testLimits)[1]?.content ?? ""; expect(userMessage).toContain("Primary node:\nnone"); expect(userMessage).toContain("Related nodes:\nnone"); expect(userMessage).not.toContain("Warnings:"); + expect(userMessage).not.toContain("...[truncated]"); }); it("omits blank related docs and file excerpts when there is no primary node", () => { @@ -97,11 +123,12 @@ describe("prompt builder", () => { } ], warnings: new Array(6).fill("warning message that should appear only in the capped warning list") - })[1]?.content ?? ""; + }, testLimits)[1]?.content ?? ""; expect(userMessage).toContain("1. name=blankNode | relationship=calls | kind=function | file=src/blank.ts:1-1 | callSites=none"); expect(userMessage).not.toContain("Related doc:"); expect(userMessage).not.toContain("File excerpt:"); expect(userMessage.match(/warning message/g)?.length).toBe(4); + expect(userMessage).not.toContain("...[truncated]"); }); }); diff --git a/src/types.ts b/src/types.ts index 71c8f36..7e75317 100644 --- a/src/types.ts +++ b/src/types.ts @@ -21,7 +21,12 @@ export type MultiHopConfig = z.infer; export const retrievalConfigSchema = z.object({ topK: z.number().int().positive().default(6), rerankK: z.number().int().positive().default(3), - maxContextChars: z.number().int().positive().default(16000) + maxContextChars: z.number().int().positive().default(16000), + /** Per-section char limits. All optional — defaults scale with maxContextChars. */ + primaryDocLimit: z.number().int().positive().optional(), + primaryFileLimit: z.number().int().positive().optional(), + relatedDocLimit: z.number().int().positive().optional(), + relatedFileLimit: z.number().int().positive().optional() }); export type RetrievalConfig = z.infer; @@ -61,7 +66,7 @@ export type LlmConfig = SerializableLlmConfig; export const embeddingConfigSchema = z.object({ provider: embeddingProviderKindSchema.default("local-hash"), dimensions: z.number().int().positive().default(256), - geminiModel: z.string().min(1).default("models/gemini-embedding-001"), + geminiModel: z.string().min(1).default("models/gemini-embedding-2"), timeoutMs: z.number().int().positive().default(30000), onnxModelDir: z.string().min(1).default(".coderag-models/models") }); @@ -73,7 +78,7 @@ export const serializableConfigSchema = z.object({ embedding: embeddingConfigSchema.default({ provider: "local-hash", dimensions: 256, - geminiModel: "models/gemini-embedding-001", + geminiModel: "models/gemini-embedding-2", timeoutMs: 30000, onnxModelDir: ".coderag-models/models" }), @@ -414,6 +419,8 @@ export interface EmbeddingProvider { readonly model: string; readonly dimensions: number; readonly maxBatchSize?: number; + /** Maximum input tokens the model accepts. Used to derive MAX_EMBEDDING_CHARS. */ + readonly maxInputTokens: number; embed(text: string): Promise; embedBatch?(texts: string[]): Promise; }