feat(rag): harden Tier 1 retrieval observability and stability
Some checks failed
CI / test (push) Failing after 15s

- Add --rag-debug flag to show retrieved chunk names and similarity scores
- Add explicit fallback notices when RAG indexing/query embedding fails
- Log RAG index/query metrics (duration, scores, top hit, token estimate)
- Normalize and cap chunk content for more stable prompt shape on small models
- Add hypothesis-continuity instruction for follow-up prompts
- Add retrieval scoring API and new tests for truncation/fallback/debug paths
This commit is contained in:
2026-05-04 19:13:57 +02:00
parent 5529960e79
commit e943e84bd2
5 changed files with 337 additions and 35 deletions

View File

@@ -4,6 +4,7 @@ from typer.testing import CliRunner
from tai.cli import app
from tai.collectors import CollectedItem, CollectionReport
from tai.rag_retriever import Chunk, EmbeddedChunk
from tai.ssh_client import SSHCommandResult
@@ -230,3 +231,99 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
assert result.exit_code == 0
assert "AI Response" in result.stdout
assert "Check logs." in result.stdout
def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0))
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "RAG unavailable (indexing failed)" in result.stdout
assert "AI Response" in result.stdout
def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli._try_embed_report",
lambda *_args: (
[EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])],
None,
1.0,
),
)
monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0])
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
"--rag-debug",
],
)
assert result.exit_code == 0
assert "RAG retrieve:" in result.stdout

View File

@@ -3,7 +3,14 @@
from __future__ import annotations
from tai.collectors import CollectedItem, CollectionReport
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve
from tai.rag_retriever import (
Chunk,
EmbeddedChunk,
_cosine_similarity,
chunk_report,
retrieve,
retrieve_scored,
)
from tai.ssh_client import SSHCommandResult
@@ -110,6 +117,13 @@ def test_chunk_report_notes_no_output() -> None:
assert "(no output)" in chunks[0].content
def test_chunk_report_caps_large_content() -> None:
report = _report(("huge", "x" * 5000, 0))
chunks = chunk_report(report, max_chunk_chars=200)
assert len(chunks[0].content) <= 230
assert "...[truncated for RAG]" in chunks[0].content
# ---------------------------------------------------------------------------
# _cosine_similarity
# ---------------------------------------------------------------------------
@@ -158,6 +172,17 @@ def test_retrieve_returns_top_k_by_similarity() -> None:
assert result[1].name == "mid"
def test_retrieve_scored_includes_scores() -> None:
chunks = [
_embedded("close", [1.0, 0.0]),
_embedded("far", [0.0, 1.0]),
]
result = retrieve_scored([1.0, 0.0], chunks, top_k=2)
assert len(result) == 2
assert result[0][0].name == "close"
assert result[0][1] > result[1][1]
def test_retrieve_respects_top_k_larger_than_pool() -> None:
chunks = [_embedded("only", [1.0, 0.0])]
result = retrieve([1.0, 0.0], chunks, top_k=10)