feat(rag): harden Tier 1 retrieval observability and stability
Some checks failed
CI / test (push) Failing after 15s
Some checks failed
CI / test (push) Failing after 15s
- Add --rag-debug flag to show retrieved chunk names and similarity scores - Add explicit fallback notices when RAG indexing/query embedding fails - Log RAG index/query metrics (duration, scores, top hit, token estimate) - Normalize and cap chunk content for more stable prompt shape on small models - Add hypothesis-continuity instruction for follow-up prompts - Add retrieval scoring API and new tests for truncation/fallback/debug paths
This commit is contained in:
@@ -4,6 +4,7 @@ from typer.testing import CliRunner
|
||||
|
||||
from tai.cli import app
|
||||
from tai.collectors import CollectedItem, CollectionReport
|
||||
from tai.rag_retriever import Chunk, EmbeddedChunk
|
||||
from tai.ssh_client import SSHCommandResult
|
||||
|
||||
|
||||
@@ -230,3 +231,99 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
|
||||
assert result.exit_code == 0
|
||||
assert "AI Response" in result.stdout
|
||||
assert "Check logs." in result.stdout
|
||||
|
||||
|
||||
def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
_mock_session(monkeypatch)
|
||||
|
||||
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||
return CollectionReport(
|
||||
host="ssh.archflux.net",
|
||||
items=[
|
||||
CollectedItem(
|
||||
name="kernel",
|
||||
result=SSHCommandResult(
|
||||
command="uname -a",
|
||||
exit_code=0,
|
||||
stdout="Linux test",
|
||||
stderr="",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
commands = iter(["what should I check next?", "/quit"])
|
||||
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||
monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0))
|
||||
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
|
||||
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"apache failed",
|
||||
"--host",
|
||||
"ssh.archflux.net",
|
||||
"--port",
|
||||
"5566",
|
||||
"--no-probe",
|
||||
"--interactive",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "RAG unavailable (indexing failed)" in result.stdout
|
||||
assert "AI Response" in result.stdout
|
||||
|
||||
|
||||
def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
_mock_session(monkeypatch)
|
||||
|
||||
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||
return CollectionReport(
|
||||
host="ssh.archflux.net",
|
||||
items=[
|
||||
CollectedItem(
|
||||
name="kernel",
|
||||
result=SSHCommandResult(
|
||||
command="uname -a",
|
||||
exit_code=0,
|
||||
stdout="Linux test",
|
||||
stderr="",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
commands = iter(["what should I check next?", "/quit"])
|
||||
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||
monkeypatch.setattr(
|
||||
"tai.cli._try_embed_report",
|
||||
lambda *_args: (
|
||||
[EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])],
|
||||
None,
|
||||
1.0,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0])
|
||||
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
|
||||
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"apache failed",
|
||||
"--host",
|
||||
"ssh.archflux.net",
|
||||
"--port",
|
||||
"5566",
|
||||
"--no-probe",
|
||||
"--interactive",
|
||||
"--rag-debug",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "RAG retrieve:" in result.stdout
|
||||
|
||||
@@ -3,7 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from tai.collectors import CollectedItem, CollectionReport
|
||||
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve
|
||||
from tai.rag_retriever import (
|
||||
Chunk,
|
||||
EmbeddedChunk,
|
||||
_cosine_similarity,
|
||||
chunk_report,
|
||||
retrieve,
|
||||
retrieve_scored,
|
||||
)
|
||||
from tai.ssh_client import SSHCommandResult
|
||||
|
||||
|
||||
@@ -110,6 +117,13 @@ def test_chunk_report_notes_no_output() -> None:
|
||||
assert "(no output)" in chunks[0].content
|
||||
|
||||
|
||||
def test_chunk_report_caps_large_content() -> None:
|
||||
report = _report(("huge", "x" * 5000, 0))
|
||||
chunks = chunk_report(report, max_chunk_chars=200)
|
||||
assert len(chunks[0].content) <= 230
|
||||
assert "...[truncated for RAG]" in chunks[0].content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cosine_similarity
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,6 +172,17 @@ def test_retrieve_returns_top_k_by_similarity() -> None:
|
||||
assert result[1].name == "mid"
|
||||
|
||||
|
||||
def test_retrieve_scored_includes_scores() -> None:
|
||||
chunks = [
|
||||
_embedded("close", [1.0, 0.0]),
|
||||
_embedded("far", [0.0, 1.0]),
|
||||
]
|
||||
result = retrieve_scored([1.0, 0.0], chunks, top_k=2)
|
||||
assert len(result) == 2
|
||||
assert result[0][0].name == "close"
|
||||
assert result[0][1] > result[1][1]
|
||||
|
||||
|
||||
def test_retrieve_respects_top_k_larger_than_pool() -> None:
|
||||
chunks = [_embedded("only", [1.0, 0.0])]
|
||||
result = retrieve([1.0, 0.0], chunks, top_k=10)
|
||||
|
||||
Reference in New Issue
Block a user