Some checks failed
CI / test (push) Failing after 15s
- Add embed() to AIClient using Ollama nomic-embed-text via /v1/embeddings - Add DEFAULT_EMBED_MODEL and embed_model field to AIConfig - New rag_retriever.py: chunk_report(), EmbeddedChunk, retrieve() (pure-Python cosine) - prompt_builder: add build_message_with_chunks() for RAG-aware follow-up prompts - cli: add --no-rag flag, embed report chunks after collection, retrieve top-5 per question - Graceful fallback to full-context if embedding model unavailable - 16 new tests in test_rag_retriever.py (67 total, all passing) - Add chromadb>=0.5 as optional [rag] dep in pyproject.toml - README: add step 3 (pull nomic-embed-text), update Suggested Tooling table
174 lines
5.3 KiB
Python
174 lines
5.3 KiB
Python
"""Tests for rag_retriever — pure-Python, no network calls."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from tai.collectors import CollectedItem, CollectionReport
|
|
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve
|
|
from tai.ssh_client import SSHCommandResult
|
|
|
|
|
|
def _report(*items: tuple[str, str, int]) -> CollectionReport:
|
|
"""Build a CollectionReport from (name, stdout, exit_code) tuples."""
|
|
return CollectionReport(
|
|
host="test-host",
|
|
items=[
|
|
CollectedItem(
|
|
name=name,
|
|
result=SSHCommandResult(
|
|
command=f"cmd-{name}",
|
|
exit_code=code,
|
|
stdout=stdout,
|
|
stderr="",
|
|
),
|
|
)
|
|
for name, stdout, code in items
|
|
],
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# chunk_report
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_chunk_report_creates_one_chunk_per_item() -> None:
|
|
report = _report(("kernel", "Linux test 6.1", 0), ("journal", "Started nginx.", 0))
|
|
chunks = chunk_report(report)
|
|
assert len(chunks) == 2
|
|
assert chunks[0].name == "kernel"
|
|
assert chunks[1].name == "journal"
|
|
|
|
|
|
def test_chunk_report_includes_stdout_in_content() -> None:
|
|
report = _report(("kernel", "Linux test 6.1", 0))
|
|
chunks = chunk_report(report)
|
|
assert "Linux test 6.1" in chunks[0].content
|
|
|
|
|
|
def test_chunk_report_includes_exit_code_in_content() -> None:
|
|
report = _report(("fail", "error output", 1))
|
|
chunks = chunk_report(report)
|
|
assert "Exit code: 1" in chunks[0].content
|
|
|
|
|
|
def test_chunk_report_skips_ssh_unreachable_items() -> None:
|
|
"""Items with exit 255 and no output represent SSH failures and are dropped."""
|
|
report = CollectionReport(
|
|
host="test-host",
|
|
items=[
|
|
CollectedItem(
|
|
name="unreachable",
|
|
result=SSHCommandResult(
|
|
command="some-cmd", exit_code=255, stdout="", stderr=""
|
|
),
|
|
),
|
|
CollectedItem(
|
|
name="ok",
|
|
result=SSHCommandResult(
|
|
command="uname -a", exit_code=0, stdout="Linux", stderr=""
|
|
),
|
|
),
|
|
],
|
|
)
|
|
chunks = chunk_report(report)
|
|
assert len(chunks) == 1
|
|
assert chunks[0].name == "ok"
|
|
|
|
|
|
def test_chunk_report_keeps_exit_255_with_output() -> None:
|
|
"""Exit 255 with stderr present is a real failure — keep it."""
|
|
report = CollectionReport(
|
|
host="test-host",
|
|
items=[
|
|
CollectedItem(
|
|
name="partial",
|
|
result=SSHCommandResult(
|
|
command="some-cmd",
|
|
exit_code=255,
|
|
stdout="",
|
|
stderr="Permission denied",
|
|
),
|
|
),
|
|
],
|
|
)
|
|
chunks = chunk_report(report)
|
|
assert len(chunks) == 1
|
|
assert "Permission denied" in chunks[0].content
|
|
|
|
|
|
def test_chunk_report_notes_no_output() -> None:
|
|
report = CollectionReport(
|
|
host="test-host",
|
|
items=[
|
|
CollectedItem(
|
|
name="silent",
|
|
result=SSHCommandResult(command="cmd", exit_code=0, stdout="", stderr=""),
|
|
),
|
|
],
|
|
)
|
|
chunks = chunk_report(report)
|
|
assert "(no output)" in chunks[0].content
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _cosine_similarity
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_cosine_similarity_identical_vectors() -> None:
|
|
v = [1.0, 0.0, 0.0]
|
|
assert abs(_cosine_similarity(v, v) - 1.0) < 1e-9
|
|
|
|
|
|
def test_cosine_similarity_orthogonal_vectors() -> None:
|
|
a = [1.0, 0.0]
|
|
b = [0.0, 1.0]
|
|
assert abs(_cosine_similarity(a, b)) < 1e-9
|
|
|
|
|
|
def test_cosine_similarity_opposite_vectors() -> None:
|
|
a = [1.0, 0.0]
|
|
b = [-1.0, 0.0]
|
|
assert abs(_cosine_similarity(a, b) - (-1.0)) < 1e-9
|
|
|
|
|
|
def test_cosine_similarity_zero_vector_returns_zero() -> None:
|
|
assert _cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# retrieve
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _embedded(name: str, vec: list[float]) -> EmbeddedChunk:
|
|
return EmbeddedChunk(chunk=Chunk(name=name, content=f"content of {name}"), embedding=vec)
|
|
|
|
|
|
def test_retrieve_returns_top_k_by_similarity() -> None:
|
|
chunks = [
|
|
_embedded("close", [1.0, 0.0]), # most similar
|
|
_embedded("mid", [0.7, 0.7]),
|
|
_embedded("far", [0.0, 1.0]), # orthogonal to query
|
|
]
|
|
query = [1.0, 0.0]
|
|
result = retrieve(query, chunks, top_k=2)
|
|
assert len(result) == 2
|
|
assert result[0].name == "close"
|
|
assert result[1].name == "mid"
|
|
|
|
|
|
def test_retrieve_respects_top_k_larger_than_pool() -> None:
|
|
chunks = [_embedded("only", [1.0, 0.0])]
|
|
result = retrieve([1.0, 0.0], chunks, top_k=10)
|
|
assert len(result) == 1
|
|
|
|
|
|
def test_retrieve_empty_pool_returns_empty() -> None:
|
|
assert retrieve([1.0, 0.0], [], top_k=5) == []
|
|
|
|
|
|
def test_retrieve_top_k_zero_returns_empty() -> None:
|
|
chunks = [_embedded("x", [1.0, 0.0])]
|
|
assert retrieve([1.0, 0.0], chunks, top_k=0) == []
|