From e943e84bd24f9e0e128452c5f16140b0fb830b96 Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 19:13:57 +0200 Subject: [PATCH] feat(rag): harden Tier 1 retrieval observability and stability - Add --rag-debug flag to show retrieved chunk names and similarity scores - Add explicit fallback notices when RAG indexing/query embedding fails - Log RAG index/query metrics (duration, scores, top hit, token estimate) - Normalize and cap chunk content for more stable prompt shape on small models - Add hypothesis-continuity instruction for follow-up prompts - Add retrieval scoring API and new tests for truncation/fallback/debug paths --- src/tai/cli.py | 186 ++++++++++++++++++++++++++++++++---- src/tai/prompt_builder.py | 8 ++ src/tai/rag_retriever.py | 54 ++++++++--- tests/test_cli.py | 97 +++++++++++++++++++ tests/test_rag_retriever.py | 27 +++++- 5 files changed, 337 insertions(+), 35 deletions(-) diff --git a/src/tai/cli.py b/src/tai/cli.py index b39649d..0d3ee51 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from time import perf_counter from typing import Annotated import typer @@ -24,7 +25,7 @@ from tai.prompt_builder import ( build_system_prompt, build_user_message, ) -from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve +from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored from tai.session_log import SessionLogger from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession @@ -117,6 +118,13 @@ def run( help="Embedding model for RAG. Must be pulled in Ollama on the AI host.", ), ] = DEFAULT_EMBED_MODEL, + rag_debug: Annotated[ + bool, + typer.Option( + "--rag-debug/--no-rag-debug", + help="Print retrieved chunk names/scores and log per-question retrieval metrics.", + ), + ] = False, ) -> None: """Start an interactive troubleshooting session scaffold.""" try: @@ -168,6 +176,7 @@ def run( interactive=interactive, ai_config=ai_config, no_rag=no_rag, + rag_debug=rag_debug, logger=logger, ) ) @@ -191,6 +200,7 @@ async def _async_main( interactive: bool, ai_config: AIConfig, no_rag: bool, + rag_debug: bool, logger: SessionLogger | None, ) -> None: """Open a single SSH session and run probe / collection / analysis through it.""" @@ -241,7 +251,15 @@ async def _async_main( _run_analysis(ai_config, req.issue, report, logger=logger) if interactive: - await _interactive_loop(session, req, ai_config, report, no_rag=no_rag, logger=logger) + await _interactive_loop( + session, + req, + ai_config, + report, + no_rag=no_rag, + rag_debug=rag_debug, + logger=logger, + ) async def _interactive_loop( @@ -251,6 +269,7 @@ async def _interactive_loop( report: CollectionReport | None, *, no_rag: bool = False, + rag_debug: bool = False, logger: SessionLogger | None, ) -> None: """Run a follow-up loop for collecting and conversational analysis.""" @@ -269,9 +288,33 @@ async def _interactive_loop( ai_embed = AIClient(ai_config) if not no_rag and report is not None: - embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed) + embedded_chunks, index_error, index_ms = await asyncio.to_thread( + _try_embed_report, report, ai_embed + ) if embedded_chunks is not None: console.print(f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]") + if logger is not None: + logger.log_event( + "rag_index", + { + "status": "ok", + "chunk_count": len(embedded_chunks), + "duration_ms": round(index_ms, 2), + }, + ) + else: + console.print( + "[yellow]RAG unavailable (indexing failed); using full-context fallback.[/yellow]" + ) + if logger is not None: + logger.log_event( + "rag_index", + { + "status": "fallback", + "error": index_error, + "duration_ms": round(index_ms, 2), + }, + ) while True: try: @@ -312,11 +355,36 @@ async def _interactive_loop( report = await collect_from_plan(session, plan) _handle_collection_report(report) if not no_rag: - embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed) + embedded_chunks, index_error, index_ms = await asyncio.to_thread( + _try_embed_report, report, ai_embed + ) if embedded_chunks is not None: console.print( f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]" ) + if logger is not None: + logger.log_event( + "rag_index", + { + "status": "ok", + "chunk_count": len(embedded_chunks), + "duration_ms": round(index_ms, 2), + }, + ) + else: + console.print( + "[yellow]RAG unavailable (indexing failed); " + "using full-context fallback.[/yellow]" + ) + if logger is not None: + logger.log_event( + "rag_index", + { + "status": "fallback", + "error": index_error, + "duration_ms": round(index_ms, 2), + }, + ) if logger is not None: logger.log_event( "collection_summary", @@ -344,6 +412,7 @@ async def _interactive_loop( "Provide an updated diagnosis from the current diagnostics.", prior_questions, embedded_chunks=embedded_chunks, + rag_debug=rag_debug, logger=logger, ) prior_questions.append("/analyze") @@ -357,11 +426,36 @@ async def _interactive_loop( report = await collect_from_plan(session, plan) _handle_collection_report(report) if not no_rag: - embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed) + embedded_chunks, index_error, index_ms = await asyncio.to_thread( + _try_embed_report, report, ai_embed + ) if embedded_chunks is not None: console.print( f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]" ) + if logger is not None: + logger.log_event( + "rag_index", + { + "status": "ok", + "chunk_count": len(embedded_chunks), + "duration_ms": round(index_ms, 2), + }, + ) + else: + console.print( + "[yellow]RAG unavailable (indexing failed); " + "using full-context fallback.[/yellow]" + ) + if logger is not None: + logger.log_event( + "rag_index", + { + "status": "fallback", + "error": index_error, + "duration_ms": round(index_ms, 2), + }, + ) if report is None: console.print("[red]No diagnostics available to analyze.[/red]") @@ -374,6 +468,7 @@ async def _interactive_loop( command, prior_questions, embedded_chunks=embedded_chunks, + rag_debug=rag_debug, logger=logger, ) prior_questions.append(command) @@ -382,21 +477,23 @@ async def _interactive_loop( def _try_embed_report( - report: CollectionReport, ai: AIClient -) -> list[EmbeddedChunk] | None: - """Embed all diagnostic chunks from *report*; returns None on any failure. + report: CollectionReport, + ai: AIClient, +) -> tuple[list[EmbeddedChunk] | None, str | None, float]: + """Embed all diagnostic chunks from *report*. - Failures are expected when the embedding model is not yet pulled or the - AI backend is unavailable — in those cases the caller falls back to - sending the full report as context. + Returns (chunks, error_message, duration_ms). On failure, chunks is None + and callers should fall back to non-RAG full-context prompts. """ + start = perf_counter() try: chunks = chunk_report(report) if not chunks: - return None - return [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks] - except Exception: # noqa: BLE001 - return None + return None, "no eligible chunks to index", (perf_counter() - start) * 1000.0 + embedded = [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks] + return embedded, None, (perf_counter() - start) * 1000.0 + except Exception as exc: # noqa: BLE001 + return None, str(exc), (perf_counter() - start) * 1000.0 def _handle_probe_result(result: SSHCommandResult) -> None: @@ -473,6 +570,11 @@ def _run_analysis( raise typer.Exit(code=1) from exc +def _estimate_tokens(text: str) -> int: + """Rough token estimate for metrics and tuning; assumes ~4 chars/token.""" + return max(1, len(text) // 4) + + def _run_followup_analysis( ai_config: AIConfig, issue: str, @@ -481,13 +583,14 @@ def _run_followup_analysis( prior_questions: list[str], *, embedded_chunks: list[EmbeddedChunk] | None = None, + rag_debug: bool = False, logger: SessionLogger | None, ) -> str: """Run grounded follow-up analysis re-anchored to current diagnostics. - When *embedded_chunks* is provided the question is embedded and the top-5 - most relevant chunks are used instead of the full report, reducing token - usage. Falls back to full-context on any embedding failure. + When *embedded_chunks* is provided, the question is embedded and top-k + relevant chunks are selected. If retrieval fails, a clear fallback message + is emitted and full diagnostic context is used. """ console.print() console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan")) @@ -496,18 +599,59 @@ def _run_followup_analysis( system_prompt = build_system_prompt() user_message: str + retrieved_names: list[str] = [] + retrieved_scores: list[float] = [] + retrieval_ms = 0.0 + fallback_reason: str | None = None + if embedded_chunks is not None: + retrieval_start = perf_counter() try: q_embedding = ai.embed(question) - retrieved = retrieve(q_embedding, embedded_chunks, top_k=5) + scored = retrieve_scored(q_embedding, embedded_chunks, top_k=5) + retrieval_ms = (perf_counter() - retrieval_start) * 1000.0 + retrieved_names = [chunk.name for chunk, _score in scored] + retrieved_scores = [round(score, 4) for _chunk, score in scored] user_message = build_message_with_chunks( - issue, report.host, retrieved, question, prior_questions + issue, + report.host, + [chunk for chunk, _score in scored], + question, + prior_questions, + ) + if rag_debug: + pairs = ", ".join( + f"{name}={score:.3f}" + for name, score in zip(retrieved_names, retrieved_scores, strict=False) + ) + console.print(f"[dim]RAG retrieve:[/dim] {pairs or 'no matches'}") + except Exception as exc: # noqa: BLE001 + retrieval_ms = (perf_counter() - retrieval_start) * 1000.0 + fallback_reason = str(exc) + console.print( + "[yellow]RAG unavailable (query embedding failed); using full-context " + "fallback.[/yellow]" ) - except Exception: # noqa: BLE001 user_message = build_followup_message(issue, report, question, prior_questions) else: + fallback_reason = "rag not indexed" user_message = build_followup_message(issue, report, question, prior_questions) + if logger is not None: + logger.log_event( + "rag_query", + { + "question": question, + "retrieved_chunk_names": retrieved_names, + "scores": retrieved_scores, + "retrieval_ms": round(retrieval_ms, 2), + "top_score": retrieved_scores[0] if retrieved_scores else None, + "used_fallback": fallback_reason is not None, + "fallback_reason": fallback_reason, + "estimated_prompt_tokens": _estimate_tokens(system_prompt + user_message), + }, + ) + try: chunks: list[str] = [] for chunk in ai.stream(system_prompt, user_message): diff --git a/src/tai/prompt_builder.py b/src/tai/prompt_builder.py index 68164fa..6094123 100644 --- a/src/tai/prompt_builder.py +++ b/src/tai/prompt_builder.py @@ -99,6 +99,10 @@ def build_followup_message( "\nAnswer strictly from the collected diagnostics above. " "If evidence is insufficient, explicitly say so." ) + lines.append( + "Keep hypothesis continuity across turns: retain the previous leading " + "hypothesis unless newly retrieved evidence directly contradicts it." + ) return "\n".join(lines) @@ -137,4 +141,8 @@ def build_message_with_chunks( "\nAnswer strictly from the retrieved diagnostics above. " "If evidence is insufficient, explicitly say so." ) + lines.append( + "Keep hypothesis continuity across turns: retain the previous leading " + "hypothesis unless newly retrieved evidence directly contradicts it." + ) return "\n".join(lines) diff --git a/src/tai/rag_retriever.py b/src/tai/rag_retriever.py index f1cd061..56f8d83 100644 --- a/src/tai/rag_retriever.py +++ b/src/tai/rag_retriever.py @@ -12,6 +12,8 @@ from dataclasses import dataclass from tai.collectors import CollectionReport +DEFAULT_MAX_CHUNK_CHARS = 1800 + @dataclass(slots=True) class Chunk: @@ -29,11 +31,25 @@ class EmbeddedChunk: embedding: list[float] -def chunk_report(report: CollectionReport) -> list[Chunk]: +def _normalize_text(text: str, *, max_chars: int) -> str: + """Normalize whitespace and cap text length with a truncation marker.""" + compact = text.strip() + if len(compact) <= max_chars: + return compact + clipped = compact[:max_chars].rstrip() + return f"{clipped}\n...[truncated for RAG]" + + +def chunk_report( + report: CollectionReport, + *, + max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS, +) -> list[Chunk]: """Split a CollectionReport into one Chunk per diagnostic item. Items that SSH could not execute at all (exit 255, no output) are dropped — - they carry no diagnostic signal. + they carry no diagnostic signal. Chunk text is normalized and capped so the + prompt shape stays more stable on smaller local models. """ chunks: list[Chunk] = [] for item in report.items: @@ -46,13 +62,14 @@ def chunk_report(report: CollectionReport) -> list[Chunk]: f"Exit code: {result.exit_code}", ] if result.stdout: - parts.append(f"stdout:\n{result.stdout.strip()}") + parts.append(f"stdout:\n{_normalize_text(result.stdout, max_chars=max_chunk_chars)}") if result.stderr: - parts.append(f"stderr:\n{result.stderr.strip()}") + parts.append(f"stderr:\n{_normalize_text(result.stderr, max_chars=max_chunk_chars)}") if not result.stdout and not result.stderr: parts.append("(no output)") - chunks.append(Chunk(name=item.name, content="\n".join(parts))) + content = _normalize_text("\n".join(parts), max_chars=max_chunk_chars) + chunks.append(Chunk(name=item.name, content=content)) return chunks @@ -66,17 +83,13 @@ def _cosine_similarity(a: list[float], b: list[float]) -> float: return dot / (norm_a * norm_b) -def retrieve( +def retrieve_scored( question_embedding: list[float], embedded_chunks: list[EmbeddedChunk], *, top_k: int = 5, -) -> list[Chunk]: - """Return the *top_k* chunks most similar to *question_embedding*. - - Chunks are ranked by cosine similarity in descending order. - If *embedded_chunks* is empty or *top_k* is zero, returns an empty list. - """ +) -> list[tuple[Chunk, float]]: + """Return top-k retrieved chunks with similarity scores.""" if not embedded_chunks or top_k <= 0: return [] scored: list[tuple[float, Chunk]] = [ @@ -84,4 +97,19 @@ def retrieve( for ec in embedded_chunks ] scored.sort(key=lambda x: x[0], reverse=True) - return [chunk for _, chunk in scored[:top_k]] + return [(chunk, score) for score, chunk in scored[:top_k]] + + +def retrieve( + question_embedding: list[float], + embedded_chunks: list[EmbeddedChunk], + *, + top_k: int = 5, +) -> list[Chunk]: + """Return the *top_k* chunks most similar to *question_embedding*.""" + scored = retrieve_scored( + question_embedding, + embedded_chunks, + top_k=top_k, + ) + return [chunk for chunk, _score in scored] diff --git a/tests/test_cli.py b/tests/test_cli.py index b13fa58..9ba8233 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,6 +4,7 @@ from typer.testing import CliRunner from tai.cli import app from tai.collectors import CollectedItem, CollectionReport +from tai.rag_retriever import Chunk, EmbeddedChunk from tai.ssh_client import SSHCommandResult @@ -230,3 +231,99 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: assert result.exit_code == 0 assert "AI Response" in result.stdout assert "Check logs." in result.stdout + + +def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def] + _mock_session(monkeypatch) + + async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] + return CollectionReport( + host="ssh.archflux.net", + items=[ + CollectedItem( + name="kernel", + result=SSHCommandResult( + command="uname -a", + exit_code=0, + stdout="Linux test", + stderr="", + ), + ), + ], + ) + + commands = iter(["what should I check next?", "/quit"]) + monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) + monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0)) + monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."])) + monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands)) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "apache failed", + "--host", + "ssh.archflux.net", + "--port", + "5566", + "--no-probe", + "--interactive", + ], + ) + + assert result.exit_code == 0 + assert "RAG unavailable (indexing failed)" in result.stdout + assert "AI Response" in result.stdout + + +def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def] + _mock_session(monkeypatch) + + async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] + return CollectionReport( + host="ssh.archflux.net", + items=[ + CollectedItem( + name="kernel", + result=SSHCommandResult( + command="uname -a", + exit_code=0, + stdout="Linux test", + stderr="", + ), + ), + ], + ) + + commands = iter(["what should I check next?", "/quit"]) + monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) + monkeypatch.setattr( + "tai.cli._try_embed_report", + lambda *_args: ( + [EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])], + None, + 1.0, + ), + ) + monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0]) + monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."])) + monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands)) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "apache failed", + "--host", + "ssh.archflux.net", + "--port", + "5566", + "--no-probe", + "--interactive", + "--rag-debug", + ], + ) + + assert result.exit_code == 0 + assert "RAG retrieve:" in result.stdout diff --git a/tests/test_rag_retriever.py b/tests/test_rag_retriever.py index 87f510f..083c52d 100644 --- a/tests/test_rag_retriever.py +++ b/tests/test_rag_retriever.py @@ -3,7 +3,14 @@ from __future__ import annotations from tai.collectors import CollectedItem, CollectionReport -from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve +from tai.rag_retriever import ( + Chunk, + EmbeddedChunk, + _cosine_similarity, + chunk_report, + retrieve, + retrieve_scored, +) from tai.ssh_client import SSHCommandResult @@ -110,6 +117,13 @@ def test_chunk_report_notes_no_output() -> None: assert "(no output)" in chunks[0].content +def test_chunk_report_caps_large_content() -> None: + report = _report(("huge", "x" * 5000, 0)) + chunks = chunk_report(report, max_chunk_chars=200) + assert len(chunks[0].content) <= 230 + assert "...[truncated for RAG]" in chunks[0].content + + # --------------------------------------------------------------------------- # _cosine_similarity # --------------------------------------------------------------------------- @@ -158,6 +172,17 @@ def test_retrieve_returns_top_k_by_similarity() -> None: assert result[1].name == "mid" +def test_retrieve_scored_includes_scores() -> None: + chunks = [ + _embedded("close", [1.0, 0.0]), + _embedded("far", [0.0, 1.0]), + ] + result = retrieve_scored([1.0, 0.0], chunks, top_k=2) + assert len(result) == 2 + assert result[0][0].name == "close" + assert result[0][1] > result[1][1] + + def test_retrieve_respects_top_k_larger_than_pool() -> None: chunks = [_embedded("only", [1.0, 0.0])] result = retrieve([1.0, 0.0], chunks, top_k=10)