Files
tai/tests/test_rag_retriever.py
zphinx be181c2d7f
Some checks failed
CI / test (push) Failing after 15s
feat(rag): implement Tier 1 in-memory RAG for interactive follow-ups
- Add embed() to AIClient using Ollama nomic-embed-text via /v1/embeddings
- Add DEFAULT_EMBED_MODEL and embed_model field to AIConfig
- New rag_retriever.py: chunk_report(), EmbeddedChunk, retrieve() (pure-Python cosine)
- prompt_builder: add build_message_with_chunks() for RAG-aware follow-up prompts
- cli: add --no-rag flag, embed report chunks after collection, retrieve top-5 per question
- Graceful fallback to full-context if embedding model unavailable
- 16 new tests in test_rag_retriever.py (67 total, all passing)
- Add chromadb>=0.5 as optional [rag] dep in pyproject.toml
- README: add step 3 (pull nomic-embed-text), update Suggested Tooling table
2026-05-04 18:36:12 +02:00

174 lines
5.3 KiB
Python

"""Tests for rag_retriever — pure-Python, no network calls."""
from __future__ import annotations
from tai.collectors import CollectedItem, CollectionReport
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve
from tai.ssh_client import SSHCommandResult
def _report(*items: tuple[str, str, int]) -> CollectionReport:
"""Build a CollectionReport from (name, stdout, exit_code) tuples."""
return CollectionReport(
host="test-host",
items=[
CollectedItem(
name=name,
result=SSHCommandResult(
command=f"cmd-{name}",
exit_code=code,
stdout=stdout,
stderr="",
),
)
for name, stdout, code in items
],
)
# ---------------------------------------------------------------------------
# chunk_report
# ---------------------------------------------------------------------------
def test_chunk_report_creates_one_chunk_per_item() -> None:
report = _report(("kernel", "Linux test 6.1", 0), ("journal", "Started nginx.", 0))
chunks = chunk_report(report)
assert len(chunks) == 2
assert chunks[0].name == "kernel"
assert chunks[1].name == "journal"
def test_chunk_report_includes_stdout_in_content() -> None:
report = _report(("kernel", "Linux test 6.1", 0))
chunks = chunk_report(report)
assert "Linux test 6.1" in chunks[0].content
def test_chunk_report_includes_exit_code_in_content() -> None:
report = _report(("fail", "error output", 1))
chunks = chunk_report(report)
assert "Exit code: 1" in chunks[0].content
def test_chunk_report_skips_ssh_unreachable_items() -> None:
"""Items with exit 255 and no output represent SSH failures and are dropped."""
report = CollectionReport(
host="test-host",
items=[
CollectedItem(
name="unreachable",
result=SSHCommandResult(
command="some-cmd", exit_code=255, stdout="", stderr=""
),
),
CollectedItem(
name="ok",
result=SSHCommandResult(
command="uname -a", exit_code=0, stdout="Linux", stderr=""
),
),
],
)
chunks = chunk_report(report)
assert len(chunks) == 1
assert chunks[0].name == "ok"
def test_chunk_report_keeps_exit_255_with_output() -> None:
"""Exit 255 with stderr present is a real failure — keep it."""
report = CollectionReport(
host="test-host",
items=[
CollectedItem(
name="partial",
result=SSHCommandResult(
command="some-cmd",
exit_code=255,
stdout="",
stderr="Permission denied",
),
),
],
)
chunks = chunk_report(report)
assert len(chunks) == 1
assert "Permission denied" in chunks[0].content
def test_chunk_report_notes_no_output() -> None:
report = CollectionReport(
host="test-host",
items=[
CollectedItem(
name="silent",
result=SSHCommandResult(command="cmd", exit_code=0, stdout="", stderr=""),
),
],
)
chunks = chunk_report(report)
assert "(no output)" in chunks[0].content
# ---------------------------------------------------------------------------
# _cosine_similarity
# ---------------------------------------------------------------------------
def test_cosine_similarity_identical_vectors() -> None:
v = [1.0, 0.0, 0.0]
assert abs(_cosine_similarity(v, v) - 1.0) < 1e-9
def test_cosine_similarity_orthogonal_vectors() -> None:
a = [1.0, 0.0]
b = [0.0, 1.0]
assert abs(_cosine_similarity(a, b)) < 1e-9
def test_cosine_similarity_opposite_vectors() -> None:
a = [1.0, 0.0]
b = [-1.0, 0.0]
assert abs(_cosine_similarity(a, b) - (-1.0)) < 1e-9
def test_cosine_similarity_zero_vector_returns_zero() -> None:
assert _cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
# ---------------------------------------------------------------------------
# retrieve
# ---------------------------------------------------------------------------
def _embedded(name: str, vec: list[float]) -> EmbeddedChunk:
return EmbeddedChunk(chunk=Chunk(name=name, content=f"content of {name}"), embedding=vec)
def test_retrieve_returns_top_k_by_similarity() -> None:
chunks = [
_embedded("close", [1.0, 0.0]), # most similar
_embedded("mid", [0.7, 0.7]),
_embedded("far", [0.0, 1.0]), # orthogonal to query
]
query = [1.0, 0.0]
result = retrieve(query, chunks, top_k=2)
assert len(result) == 2
assert result[0].name == "close"
assert result[1].name == "mid"
def test_retrieve_respects_top_k_larger_than_pool() -> None:
chunks = [_embedded("only", [1.0, 0.0])]
result = retrieve([1.0, 0.0], chunks, top_k=10)
assert len(result) == 1
def test_retrieve_empty_pool_returns_empty() -> None:
assert retrieve([1.0, 0.0], [], top_k=5) == []
def test_retrieve_top_k_zero_returns_empty() -> None:
chunks = [_embedded("x", [1.0, 0.0])]
assert retrieve([1.0, 0.0], chunks, top_k=0) == []