Some checks failed
CI / test (push) Failing after 15s
- Add --rag-debug flag to show retrieved chunk names and similarity scores - Add explicit fallback notices when RAG indexing/query embedding fails - Log RAG index/query metrics (duration, scores, top hit, token estimate) - Normalize and cap chunk content for more stable prompt shape on small models - Add hypothesis-continuity instruction for follow-up prompts - Add retrieval scoring API and new tests for truncation/fallback/debug paths
199 lines
5.9 KiB
Python
199 lines
5.9 KiB
Python
"""Tests for rag_retriever — pure-Python, no network calls."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from tai.collectors import CollectedItem, CollectionReport
|
|
from tai.rag_retriever import (
|
|
Chunk,
|
|
EmbeddedChunk,
|
|
_cosine_similarity,
|
|
chunk_report,
|
|
retrieve,
|
|
retrieve_scored,
|
|
)
|
|
from tai.ssh_client import SSHCommandResult
|
|
|
|
|
|
def _report(*items: tuple[str, str, int]) -> CollectionReport:
|
|
"""Build a CollectionReport from (name, stdout, exit_code) tuples."""
|
|
return CollectionReport(
|
|
host="test-host",
|
|
items=[
|
|
CollectedItem(
|
|
name=name,
|
|
result=SSHCommandResult(
|
|
command=f"cmd-{name}",
|
|
exit_code=code,
|
|
stdout=stdout,
|
|
stderr="",
|
|
),
|
|
)
|
|
for name, stdout, code in items
|
|
],
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# chunk_report
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_chunk_report_creates_one_chunk_per_item() -> None:
|
|
report = _report(("kernel", "Linux test 6.1", 0), ("journal", "Started nginx.", 0))
|
|
chunks = chunk_report(report)
|
|
assert len(chunks) == 2
|
|
assert chunks[0].name == "kernel"
|
|
assert chunks[1].name == "journal"
|
|
|
|
|
|
def test_chunk_report_includes_stdout_in_content() -> None:
|
|
report = _report(("kernel", "Linux test 6.1", 0))
|
|
chunks = chunk_report(report)
|
|
assert "Linux test 6.1" in chunks[0].content
|
|
|
|
|
|
def test_chunk_report_includes_exit_code_in_content() -> None:
|
|
report = _report(("fail", "error output", 1))
|
|
chunks = chunk_report(report)
|
|
assert "Exit code: 1" in chunks[0].content
|
|
|
|
|
|
def test_chunk_report_skips_ssh_unreachable_items() -> None:
|
|
"""Items with exit 255 and no output represent SSH failures and are dropped."""
|
|
report = CollectionReport(
|
|
host="test-host",
|
|
items=[
|
|
CollectedItem(
|
|
name="unreachable",
|
|
result=SSHCommandResult(
|
|
command="some-cmd", exit_code=255, stdout="", stderr=""
|
|
),
|
|
),
|
|
CollectedItem(
|
|
name="ok",
|
|
result=SSHCommandResult(
|
|
command="uname -a", exit_code=0, stdout="Linux", stderr=""
|
|
),
|
|
),
|
|
],
|
|
)
|
|
chunks = chunk_report(report)
|
|
assert len(chunks) == 1
|
|
assert chunks[0].name == "ok"
|
|
|
|
|
|
def test_chunk_report_keeps_exit_255_with_output() -> None:
|
|
"""Exit 255 with stderr present is a real failure — keep it."""
|
|
report = CollectionReport(
|
|
host="test-host",
|
|
items=[
|
|
CollectedItem(
|
|
name="partial",
|
|
result=SSHCommandResult(
|
|
command="some-cmd",
|
|
exit_code=255,
|
|
stdout="",
|
|
stderr="Permission denied",
|
|
),
|
|
),
|
|
],
|
|
)
|
|
chunks = chunk_report(report)
|
|
assert len(chunks) == 1
|
|
assert "Permission denied" in chunks[0].content
|
|
|
|
|
|
def test_chunk_report_notes_no_output() -> None:
|
|
report = CollectionReport(
|
|
host="test-host",
|
|
items=[
|
|
CollectedItem(
|
|
name="silent",
|
|
result=SSHCommandResult(command="cmd", exit_code=0, stdout="", stderr=""),
|
|
),
|
|
],
|
|
)
|
|
chunks = chunk_report(report)
|
|
assert "(no output)" in chunks[0].content
|
|
|
|
|
|
def test_chunk_report_caps_large_content() -> None:
|
|
report = _report(("huge", "x" * 5000, 0))
|
|
chunks = chunk_report(report, max_chunk_chars=200)
|
|
assert len(chunks[0].content) <= 230
|
|
assert "...[truncated for RAG]" in chunks[0].content
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _cosine_similarity
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_cosine_similarity_identical_vectors() -> None:
|
|
v = [1.0, 0.0, 0.0]
|
|
assert abs(_cosine_similarity(v, v) - 1.0) < 1e-9
|
|
|
|
|
|
def test_cosine_similarity_orthogonal_vectors() -> None:
|
|
a = [1.0, 0.0]
|
|
b = [0.0, 1.0]
|
|
assert abs(_cosine_similarity(a, b)) < 1e-9
|
|
|
|
|
|
def test_cosine_similarity_opposite_vectors() -> None:
|
|
a = [1.0, 0.0]
|
|
b = [-1.0, 0.0]
|
|
assert abs(_cosine_similarity(a, b) - (-1.0)) < 1e-9
|
|
|
|
|
|
def test_cosine_similarity_zero_vector_returns_zero() -> None:
|
|
assert _cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# retrieve
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _embedded(name: str, vec: list[float]) -> EmbeddedChunk:
|
|
return EmbeddedChunk(chunk=Chunk(name=name, content=f"content of {name}"), embedding=vec)
|
|
|
|
|
|
def test_retrieve_returns_top_k_by_similarity() -> None:
|
|
chunks = [
|
|
_embedded("close", [1.0, 0.0]), # most similar
|
|
_embedded("mid", [0.7, 0.7]),
|
|
_embedded("far", [0.0, 1.0]), # orthogonal to query
|
|
]
|
|
query = [1.0, 0.0]
|
|
result = retrieve(query, chunks, top_k=2)
|
|
assert len(result) == 2
|
|
assert result[0].name == "close"
|
|
assert result[1].name == "mid"
|
|
|
|
|
|
def test_retrieve_scored_includes_scores() -> None:
|
|
chunks = [
|
|
_embedded("close", [1.0, 0.0]),
|
|
_embedded("far", [0.0, 1.0]),
|
|
]
|
|
result = retrieve_scored([1.0, 0.0], chunks, top_k=2)
|
|
assert len(result) == 2
|
|
assert result[0][0].name == "close"
|
|
assert result[0][1] > result[1][1]
|
|
|
|
|
|
def test_retrieve_respects_top_k_larger_than_pool() -> None:
|
|
chunks = [_embedded("only", [1.0, 0.0])]
|
|
result = retrieve([1.0, 0.0], chunks, top_k=10)
|
|
assert len(result) == 1
|
|
|
|
|
|
def test_retrieve_empty_pool_returns_empty() -> None:
|
|
assert retrieve([1.0, 0.0], [], top_k=5) == []
|
|
|
|
|
|
def test_retrieve_top_k_zero_returns_empty() -> None:
|
|
chunks = [_embedded("x", [1.0, 0.0])]
|
|
assert retrieve([1.0, 0.0], chunks, top_k=0) == []
|