Files
tai/tests/test_rag_retriever.py
zphinx e943e84bd2
Some checks failed
CI / test (push) Failing after 15s
feat(rag): harden Tier 1 retrieval observability and stability
- Add --rag-debug flag to show retrieved chunk names and similarity scores
- Add explicit fallback notices when RAG indexing/query embedding fails
- Log RAG index/query metrics (duration, scores, top hit, token estimate)
- Normalize and cap chunk content for more stable prompt shape on small models
- Add hypothesis-continuity instruction for follow-up prompts
- Add retrieval scoring API and new tests for truncation/fallback/debug paths
2026-05-04 19:13:57 +02:00

199 lines
5.9 KiB
Python

"""Tests for rag_retriever — pure-Python, no network calls."""
from __future__ import annotations
from tai.collectors import CollectedItem, CollectionReport
from tai.rag_retriever import (
Chunk,
EmbeddedChunk,
_cosine_similarity,
chunk_report,
retrieve,
retrieve_scored,
)
from tai.ssh_client import SSHCommandResult
def _report(*items: tuple[str, str, int]) -> CollectionReport:
"""Build a CollectionReport from (name, stdout, exit_code) tuples."""
return CollectionReport(
host="test-host",
items=[
CollectedItem(
name=name,
result=SSHCommandResult(
command=f"cmd-{name}",
exit_code=code,
stdout=stdout,
stderr="",
),
)
for name, stdout, code in items
],
)
# ---------------------------------------------------------------------------
# chunk_report
# ---------------------------------------------------------------------------
def test_chunk_report_creates_one_chunk_per_item() -> None:
report = _report(("kernel", "Linux test 6.1", 0), ("journal", "Started nginx.", 0))
chunks = chunk_report(report)
assert len(chunks) == 2
assert chunks[0].name == "kernel"
assert chunks[1].name == "journal"
def test_chunk_report_includes_stdout_in_content() -> None:
report = _report(("kernel", "Linux test 6.1", 0))
chunks = chunk_report(report)
assert "Linux test 6.1" in chunks[0].content
def test_chunk_report_includes_exit_code_in_content() -> None:
report = _report(("fail", "error output", 1))
chunks = chunk_report(report)
assert "Exit code: 1" in chunks[0].content
def test_chunk_report_skips_ssh_unreachable_items() -> None:
"""Items with exit 255 and no output represent SSH failures and are dropped."""
report = CollectionReport(
host="test-host",
items=[
CollectedItem(
name="unreachable",
result=SSHCommandResult(
command="some-cmd", exit_code=255, stdout="", stderr=""
),
),
CollectedItem(
name="ok",
result=SSHCommandResult(
command="uname -a", exit_code=0, stdout="Linux", stderr=""
),
),
],
)
chunks = chunk_report(report)
assert len(chunks) == 1
assert chunks[0].name == "ok"
def test_chunk_report_keeps_exit_255_with_output() -> None:
"""Exit 255 with stderr present is a real failure — keep it."""
report = CollectionReport(
host="test-host",
items=[
CollectedItem(
name="partial",
result=SSHCommandResult(
command="some-cmd",
exit_code=255,
stdout="",
stderr="Permission denied",
),
),
],
)
chunks = chunk_report(report)
assert len(chunks) == 1
assert "Permission denied" in chunks[0].content
def test_chunk_report_notes_no_output() -> None:
report = CollectionReport(
host="test-host",
items=[
CollectedItem(
name="silent",
result=SSHCommandResult(command="cmd", exit_code=0, stdout="", stderr=""),
),
],
)
chunks = chunk_report(report)
assert "(no output)" in chunks[0].content
def test_chunk_report_caps_large_content() -> None:
report = _report(("huge", "x" * 5000, 0))
chunks = chunk_report(report, max_chunk_chars=200)
assert len(chunks[0].content) <= 230
assert "...[truncated for RAG]" in chunks[0].content
# ---------------------------------------------------------------------------
# _cosine_similarity
# ---------------------------------------------------------------------------
def test_cosine_similarity_identical_vectors() -> None:
v = [1.0, 0.0, 0.0]
assert abs(_cosine_similarity(v, v) - 1.0) < 1e-9
def test_cosine_similarity_orthogonal_vectors() -> None:
a = [1.0, 0.0]
b = [0.0, 1.0]
assert abs(_cosine_similarity(a, b)) < 1e-9
def test_cosine_similarity_opposite_vectors() -> None:
a = [1.0, 0.0]
b = [-1.0, 0.0]
assert abs(_cosine_similarity(a, b) - (-1.0)) < 1e-9
def test_cosine_similarity_zero_vector_returns_zero() -> None:
assert _cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
# ---------------------------------------------------------------------------
# retrieve
# ---------------------------------------------------------------------------
def _embedded(name: str, vec: list[float]) -> EmbeddedChunk:
return EmbeddedChunk(chunk=Chunk(name=name, content=f"content of {name}"), embedding=vec)
def test_retrieve_returns_top_k_by_similarity() -> None:
chunks = [
_embedded("close", [1.0, 0.0]), # most similar
_embedded("mid", [0.7, 0.7]),
_embedded("far", [0.0, 1.0]), # orthogonal to query
]
query = [1.0, 0.0]
result = retrieve(query, chunks, top_k=2)
assert len(result) == 2
assert result[0].name == "close"
assert result[1].name == "mid"
def test_retrieve_scored_includes_scores() -> None:
chunks = [
_embedded("close", [1.0, 0.0]),
_embedded("far", [0.0, 1.0]),
]
result = retrieve_scored([1.0, 0.0], chunks, top_k=2)
assert len(result) == 2
assert result[0][0].name == "close"
assert result[0][1] > result[1][1]
def test_retrieve_respects_top_k_larger_than_pool() -> None:
chunks = [_embedded("only", [1.0, 0.0])]
result = retrieve([1.0, 0.0], chunks, top_k=10)
assert len(result) == 1
def test_retrieve_empty_pool_returns_empty() -> None:
assert retrieve([1.0, 0.0], [], top_k=5) == []
def test_retrieve_top_k_zero_returns_empty() -> None:
chunks = [_embedded("x", [1.0, 0.0])]
assert retrieve([1.0, 0.0], chunks, top_k=0) == []