feat(rag): harden Tier 1 retrieval observability and stability
Some checks failed
CI / test (push) Failing after 15s

- Add --rag-debug flag to show retrieved chunk names and similarity scores
- Add explicit fallback notices when RAG indexing/query embedding fails
- Log RAG index/query metrics (duration, scores, top hit, token estimate)
- Normalize and cap chunk content for more stable prompt shape on small models
- Add hypothesis-continuity instruction for follow-up prompts
- Add retrieval scoring API and new tests for truncation/fallback/debug paths
This commit is contained in:
2026-05-04 19:13:57 +02:00
parent 5529960e79
commit e943e84bd2
5 changed files with 337 additions and 35 deletions

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from time import perf_counter
from typing import Annotated from typing import Annotated
import typer import typer
@@ -24,7 +25,7 @@ from tai.prompt_builder import (
build_system_prompt, build_system_prompt,
build_user_message, build_user_message,
) )
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored
from tai.session_log import SessionLogger from tai.session_log import SessionLogger
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
@@ -117,6 +118,13 @@ def run(
help="Embedding model for RAG. Must be pulled in Ollama on the AI host.", help="Embedding model for RAG. Must be pulled in Ollama on the AI host.",
), ),
] = DEFAULT_EMBED_MODEL, ] = DEFAULT_EMBED_MODEL,
rag_debug: Annotated[
bool,
typer.Option(
"--rag-debug/--no-rag-debug",
help="Print retrieved chunk names/scores and log per-question retrieval metrics.",
),
] = False,
) -> None: ) -> None:
"""Start an interactive troubleshooting session scaffold.""" """Start an interactive troubleshooting session scaffold."""
try: try:
@@ -168,6 +176,7 @@ def run(
interactive=interactive, interactive=interactive,
ai_config=ai_config, ai_config=ai_config,
no_rag=no_rag, no_rag=no_rag,
rag_debug=rag_debug,
logger=logger, logger=logger,
) )
) )
@@ -191,6 +200,7 @@ async def _async_main(
interactive: bool, interactive: bool,
ai_config: AIConfig, ai_config: AIConfig,
no_rag: bool, no_rag: bool,
rag_debug: bool,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> None: ) -> None:
"""Open a single SSH session and run probe / collection / analysis through it.""" """Open a single SSH session and run probe / collection / analysis through it."""
@@ -241,7 +251,15 @@ async def _async_main(
_run_analysis(ai_config, req.issue, report, logger=logger) _run_analysis(ai_config, req.issue, report, logger=logger)
if interactive: if interactive:
await _interactive_loop(session, req, ai_config, report, no_rag=no_rag, logger=logger) await _interactive_loop(
session,
req,
ai_config,
report,
no_rag=no_rag,
rag_debug=rag_debug,
logger=logger,
)
async def _interactive_loop( async def _interactive_loop(
@@ -251,6 +269,7 @@ async def _interactive_loop(
report: CollectionReport | None, report: CollectionReport | None,
*, *,
no_rag: bool = False, no_rag: bool = False,
rag_debug: bool = False,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> None: ) -> None:
"""Run a follow-up loop for collecting and conversational analysis.""" """Run a follow-up loop for collecting and conversational analysis."""
@@ -269,9 +288,33 @@ async def _interactive_loop(
ai_embed = AIClient(ai_config) ai_embed = AIClient(ai_config)
if not no_rag and report is not None: if not no_rag and report is not None:
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed) embedded_chunks, index_error, index_ms = await asyncio.to_thread(
_try_embed_report, report, ai_embed
)
if embedded_chunks is not None: if embedded_chunks is not None:
console.print(f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]") console.print(f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]")
if logger is not None:
logger.log_event(
"rag_index",
{
"status": "ok",
"chunk_count": len(embedded_chunks),
"duration_ms": round(index_ms, 2),
},
)
else:
console.print(
"[yellow]RAG unavailable (indexing failed); using full-context fallback.[/yellow]"
)
if logger is not None:
logger.log_event(
"rag_index",
{
"status": "fallback",
"error": index_error,
"duration_ms": round(index_ms, 2),
},
)
while True: while True:
try: try:
@@ -312,11 +355,36 @@ async def _interactive_loop(
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
if not no_rag: if not no_rag:
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed) embedded_chunks, index_error, index_ms = await asyncio.to_thread(
_try_embed_report, report, ai_embed
)
if embedded_chunks is not None: if embedded_chunks is not None:
console.print( console.print(
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]" f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
) )
if logger is not None:
logger.log_event(
"rag_index",
{
"status": "ok",
"chunk_count": len(embedded_chunks),
"duration_ms": round(index_ms, 2),
},
)
else:
console.print(
"[yellow]RAG unavailable (indexing failed); "
"using full-context fallback.[/yellow]"
)
if logger is not None:
logger.log_event(
"rag_index",
{
"status": "fallback",
"error": index_error,
"duration_ms": round(index_ms, 2),
},
)
if logger is not None: if logger is not None:
logger.log_event( logger.log_event(
"collection_summary", "collection_summary",
@@ -344,6 +412,7 @@ async def _interactive_loop(
"Provide an updated diagnosis from the current diagnostics.", "Provide an updated diagnosis from the current diagnostics.",
prior_questions, prior_questions,
embedded_chunks=embedded_chunks, embedded_chunks=embedded_chunks,
rag_debug=rag_debug,
logger=logger, logger=logger,
) )
prior_questions.append("/analyze") prior_questions.append("/analyze")
@@ -357,11 +426,36 @@ async def _interactive_loop(
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
if not no_rag: if not no_rag:
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed) embedded_chunks, index_error, index_ms = await asyncio.to_thread(
_try_embed_report, report, ai_embed
)
if embedded_chunks is not None: if embedded_chunks is not None:
console.print( console.print(
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]" f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
) )
if logger is not None:
logger.log_event(
"rag_index",
{
"status": "ok",
"chunk_count": len(embedded_chunks),
"duration_ms": round(index_ms, 2),
},
)
else:
console.print(
"[yellow]RAG unavailable (indexing failed); "
"using full-context fallback.[/yellow]"
)
if logger is not None:
logger.log_event(
"rag_index",
{
"status": "fallback",
"error": index_error,
"duration_ms": round(index_ms, 2),
},
)
if report is None: if report is None:
console.print("[red]No diagnostics available to analyze.[/red]") console.print("[red]No diagnostics available to analyze.[/red]")
@@ -374,6 +468,7 @@ async def _interactive_loop(
command, command,
prior_questions, prior_questions,
embedded_chunks=embedded_chunks, embedded_chunks=embedded_chunks,
rag_debug=rag_debug,
logger=logger, logger=logger,
) )
prior_questions.append(command) prior_questions.append(command)
@@ -382,21 +477,23 @@ async def _interactive_loop(
def _try_embed_report( def _try_embed_report(
report: CollectionReport, ai: AIClient report: CollectionReport,
) -> list[EmbeddedChunk] | None: ai: AIClient,
"""Embed all diagnostic chunks from *report*; returns None on any failure. ) -> tuple[list[EmbeddedChunk] | None, str | None, float]:
"""Embed all diagnostic chunks from *report*.
Failures are expected when the embedding model is not yet pulled or the Returns (chunks, error_message, duration_ms). On failure, chunks is None
AI backend is unavailable — in those cases the caller falls back to and callers should fall back to non-RAG full-context prompts.
sending the full report as context.
""" """
start = perf_counter()
try: try:
chunks = chunk_report(report) chunks = chunk_report(report)
if not chunks: if not chunks:
return None return None, "no eligible chunks to index", (perf_counter() - start) * 1000.0
return [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks] embedded = [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks]
except Exception: # noqa: BLE001 return embedded, None, (perf_counter() - start) * 1000.0
return None except Exception as exc: # noqa: BLE001
return None, str(exc), (perf_counter() - start) * 1000.0
def _handle_probe_result(result: SSHCommandResult) -> None: def _handle_probe_result(result: SSHCommandResult) -> None:
@@ -473,6 +570,11 @@ def _run_analysis(
raise typer.Exit(code=1) from exc raise typer.Exit(code=1) from exc
def _estimate_tokens(text: str) -> int:
"""Rough token estimate for metrics and tuning; assumes ~4 chars/token."""
return max(1, len(text) // 4)
def _run_followup_analysis( def _run_followup_analysis(
ai_config: AIConfig, ai_config: AIConfig,
issue: str, issue: str,
@@ -481,13 +583,14 @@ def _run_followup_analysis(
prior_questions: list[str], prior_questions: list[str],
*, *,
embedded_chunks: list[EmbeddedChunk] | None = None, embedded_chunks: list[EmbeddedChunk] | None = None,
rag_debug: bool = False,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> str: ) -> str:
"""Run grounded follow-up analysis re-anchored to current diagnostics. """Run grounded follow-up analysis re-anchored to current diagnostics.
When *embedded_chunks* is provided the question is embedded and the top-5 When *embedded_chunks* is provided, the question is embedded and top-k
most relevant chunks are used instead of the full report, reducing token relevant chunks are selected. If retrieval fails, a clear fallback message
usage. Falls back to full-context on any embedding failure. is emitted and full diagnostic context is used.
""" """
console.print() console.print()
console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan")) console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan"))
@@ -496,18 +599,59 @@ def _run_followup_analysis(
system_prompt = build_system_prompt() system_prompt = build_system_prompt()
user_message: str user_message: str
retrieved_names: list[str] = []
retrieved_scores: list[float] = []
retrieval_ms = 0.0
fallback_reason: str | None = None
if embedded_chunks is not None: if embedded_chunks is not None:
retrieval_start = perf_counter()
try: try:
q_embedding = ai.embed(question) q_embedding = ai.embed(question)
retrieved = retrieve(q_embedding, embedded_chunks, top_k=5) scored = retrieve_scored(q_embedding, embedded_chunks, top_k=5)
retrieval_ms = (perf_counter() - retrieval_start) * 1000.0
retrieved_names = [chunk.name for chunk, _score in scored]
retrieved_scores = [round(score, 4) for _chunk, score in scored]
user_message = build_message_with_chunks( user_message = build_message_with_chunks(
issue, report.host, retrieved, question, prior_questions issue,
report.host,
[chunk for chunk, _score in scored],
question,
prior_questions,
)
if rag_debug:
pairs = ", ".join(
f"{name}={score:.3f}"
for name, score in zip(retrieved_names, retrieved_scores, strict=False)
)
console.print(f"[dim]RAG retrieve:[/dim] {pairs or 'no matches'}")
except Exception as exc: # noqa: BLE001
retrieval_ms = (perf_counter() - retrieval_start) * 1000.0
fallback_reason = str(exc)
console.print(
"[yellow]RAG unavailable (query embedding failed); using full-context "
"fallback.[/yellow]"
) )
except Exception: # noqa: BLE001
user_message = build_followup_message(issue, report, question, prior_questions) user_message = build_followup_message(issue, report, question, prior_questions)
else: else:
fallback_reason = "rag not indexed"
user_message = build_followup_message(issue, report, question, prior_questions) user_message = build_followup_message(issue, report, question, prior_questions)
if logger is not None:
logger.log_event(
"rag_query",
{
"question": question,
"retrieved_chunk_names": retrieved_names,
"scores": retrieved_scores,
"retrieval_ms": round(retrieval_ms, 2),
"top_score": retrieved_scores[0] if retrieved_scores else None,
"used_fallback": fallback_reason is not None,
"fallback_reason": fallback_reason,
"estimated_prompt_tokens": _estimate_tokens(system_prompt + user_message),
},
)
try: try:
chunks: list[str] = [] chunks: list[str] = []
for chunk in ai.stream(system_prompt, user_message): for chunk in ai.stream(system_prompt, user_message):

View File

@@ -99,6 +99,10 @@ def build_followup_message(
"\nAnswer strictly from the collected diagnostics above. " "\nAnswer strictly from the collected diagnostics above. "
"If evidence is insufficient, explicitly say so." "If evidence is insufficient, explicitly say so."
) )
lines.append(
"Keep hypothesis continuity across turns: retain the previous leading "
"hypothesis unless newly retrieved evidence directly contradicts it."
)
return "\n".join(lines) return "\n".join(lines)
@@ -137,4 +141,8 @@ def build_message_with_chunks(
"\nAnswer strictly from the retrieved diagnostics above. " "\nAnswer strictly from the retrieved diagnostics above. "
"If evidence is insufficient, explicitly say so." "If evidence is insufficient, explicitly say so."
) )
lines.append(
"Keep hypothesis continuity across turns: retain the previous leading "
"hypothesis unless newly retrieved evidence directly contradicts it."
)
return "\n".join(lines) return "\n".join(lines)

View File

@@ -12,6 +12,8 @@ from dataclasses import dataclass
from tai.collectors import CollectionReport from tai.collectors import CollectionReport
DEFAULT_MAX_CHUNK_CHARS = 1800
@dataclass(slots=True) @dataclass(slots=True)
class Chunk: class Chunk:
@@ -29,11 +31,25 @@ class EmbeddedChunk:
embedding: list[float] embedding: list[float]
def chunk_report(report: CollectionReport) -> list[Chunk]: def _normalize_text(text: str, *, max_chars: int) -> str:
"""Normalize whitespace and cap text length with a truncation marker."""
compact = text.strip()
if len(compact) <= max_chars:
return compact
clipped = compact[:max_chars].rstrip()
return f"{clipped}\n...[truncated for RAG]"
def chunk_report(
report: CollectionReport,
*,
max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS,
) -> list[Chunk]:
"""Split a CollectionReport into one Chunk per diagnostic item. """Split a CollectionReport into one Chunk per diagnostic item.
Items that SSH could not execute at all (exit 255, no output) are dropped — Items that SSH could not execute at all (exit 255, no output) are dropped —
they carry no diagnostic signal. they carry no diagnostic signal. Chunk text is normalized and capped so the
prompt shape stays more stable on smaller local models.
""" """
chunks: list[Chunk] = [] chunks: list[Chunk] = []
for item in report.items: for item in report.items:
@@ -46,13 +62,14 @@ def chunk_report(report: CollectionReport) -> list[Chunk]:
f"Exit code: {result.exit_code}", f"Exit code: {result.exit_code}",
] ]
if result.stdout: if result.stdout:
parts.append(f"stdout:\n{result.stdout.strip()}") parts.append(f"stdout:\n{_normalize_text(result.stdout, max_chars=max_chunk_chars)}")
if result.stderr: if result.stderr:
parts.append(f"stderr:\n{result.stderr.strip()}") parts.append(f"stderr:\n{_normalize_text(result.stderr, max_chars=max_chunk_chars)}")
if not result.stdout and not result.stderr: if not result.stdout and not result.stderr:
parts.append("(no output)") parts.append("(no output)")
chunks.append(Chunk(name=item.name, content="\n".join(parts))) content = _normalize_text("\n".join(parts), max_chars=max_chunk_chars)
chunks.append(Chunk(name=item.name, content=content))
return chunks return chunks
@@ -66,17 +83,13 @@ def _cosine_similarity(a: list[float], b: list[float]) -> float:
return dot / (norm_a * norm_b) return dot / (norm_a * norm_b)
def retrieve( def retrieve_scored(
question_embedding: list[float], question_embedding: list[float],
embedded_chunks: list[EmbeddedChunk], embedded_chunks: list[EmbeddedChunk],
*, *,
top_k: int = 5, top_k: int = 5,
) -> list[Chunk]: ) -> list[tuple[Chunk, float]]:
"""Return the *top_k* chunks most similar to *question_embedding*. """Return top-k retrieved chunks with similarity scores."""
Chunks are ranked by cosine similarity in descending order.
If *embedded_chunks* is empty or *top_k* is zero, returns an empty list.
"""
if not embedded_chunks or top_k <= 0: if not embedded_chunks or top_k <= 0:
return [] return []
scored: list[tuple[float, Chunk]] = [ scored: list[tuple[float, Chunk]] = [
@@ -84,4 +97,19 @@ def retrieve(
for ec in embedded_chunks for ec in embedded_chunks
] ]
scored.sort(key=lambda x: x[0], reverse=True) scored.sort(key=lambda x: x[0], reverse=True)
return [chunk for _, chunk in scored[:top_k]] return [(chunk, score) for score, chunk in scored[:top_k]]
def retrieve(
question_embedding: list[float],
embedded_chunks: list[EmbeddedChunk],
*,
top_k: int = 5,
) -> list[Chunk]:
"""Return the *top_k* chunks most similar to *question_embedding*."""
scored = retrieve_scored(
question_embedding,
embedded_chunks,
top_k=top_k,
)
return [chunk for chunk, _score in scored]

View File

@@ -4,6 +4,7 @@ from typer.testing import CliRunner
from tai.cli import app from tai.cli import app
from tai.collectors import CollectedItem, CollectionReport from tai.collectors import CollectedItem, CollectionReport
from tai.rag_retriever import Chunk, EmbeddedChunk
from tai.ssh_client import SSHCommandResult from tai.ssh_client import SSHCommandResult
@@ -230,3 +231,99 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
assert result.exit_code == 0 assert result.exit_code == 0
assert "AI Response" in result.stdout assert "AI Response" in result.stdout
assert "Check logs." in result.stdout assert "Check logs." in result.stdout
def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0))
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "RAG unavailable (indexing failed)" in result.stdout
assert "AI Response" in result.stdout
def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli._try_embed_report",
lambda *_args: (
[EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])],
None,
1.0,
),
)
monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0])
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
"--rag-debug",
],
)
assert result.exit_code == 0
assert "RAG retrieve:" in result.stdout

View File

@@ -3,7 +3,14 @@
from __future__ import annotations from __future__ import annotations
from tai.collectors import CollectedItem, CollectionReport from tai.collectors import CollectedItem, CollectionReport
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve from tai.rag_retriever import (
Chunk,
EmbeddedChunk,
_cosine_similarity,
chunk_report,
retrieve,
retrieve_scored,
)
from tai.ssh_client import SSHCommandResult from tai.ssh_client import SSHCommandResult
@@ -110,6 +117,13 @@ def test_chunk_report_notes_no_output() -> None:
assert "(no output)" in chunks[0].content assert "(no output)" in chunks[0].content
def test_chunk_report_caps_large_content() -> None:
report = _report(("huge", "x" * 5000, 0))
chunks = chunk_report(report, max_chunk_chars=200)
assert len(chunks[0].content) <= 230
assert "...[truncated for RAG]" in chunks[0].content
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _cosine_similarity # _cosine_similarity
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -158,6 +172,17 @@ def test_retrieve_returns_top_k_by_similarity() -> None:
assert result[1].name == "mid" assert result[1].name == "mid"
def test_retrieve_scored_includes_scores() -> None:
chunks = [
_embedded("close", [1.0, 0.0]),
_embedded("far", [0.0, 1.0]),
]
result = retrieve_scored([1.0, 0.0], chunks, top_k=2)
assert len(result) == 2
assert result[0][0].name == "close"
assert result[0][1] > result[1][1]
def test_retrieve_respects_top_k_larger_than_pool() -> None: def test_retrieve_respects_top_k_larger_than_pool() -> None:
chunks = [_embedded("only", [1.0, 0.0])] chunks = [_embedded("only", [1.0, 0.0])]
result = retrieve([1.0, 0.0], chunks, top_k=10) result = retrieve([1.0, 0.0], chunks, top_k=10)