"""Tests for rag_retriever — pure-Python, no network calls.""" from __future__ import annotations from tai.collectors import CollectedItem, CollectionReport from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve from tai.ssh_client import SSHCommandResult def _report(*items: tuple[str, str, int]) -> CollectionReport: """Build a CollectionReport from (name, stdout, exit_code) tuples.""" return CollectionReport( host="test-host", items=[ CollectedItem( name=name, result=SSHCommandResult( command=f"cmd-{name}", exit_code=code, stdout=stdout, stderr="", ), ) for name, stdout, code in items ], ) # --------------------------------------------------------------------------- # chunk_report # --------------------------------------------------------------------------- def test_chunk_report_creates_one_chunk_per_item() -> None: report = _report(("kernel", "Linux test 6.1", 0), ("journal", "Started nginx.", 0)) chunks = chunk_report(report) assert len(chunks) == 2 assert chunks[0].name == "kernel" assert chunks[1].name == "journal" def test_chunk_report_includes_stdout_in_content() -> None: report = _report(("kernel", "Linux test 6.1", 0)) chunks = chunk_report(report) assert "Linux test 6.1" in chunks[0].content def test_chunk_report_includes_exit_code_in_content() -> None: report = _report(("fail", "error output", 1)) chunks = chunk_report(report) assert "Exit code: 1" in chunks[0].content def test_chunk_report_skips_ssh_unreachable_items() -> None: """Items with exit 255 and no output represent SSH failures and are dropped.""" report = CollectionReport( host="test-host", items=[ CollectedItem( name="unreachable", result=SSHCommandResult( command="some-cmd", exit_code=255, stdout="", stderr="" ), ), CollectedItem( name="ok", result=SSHCommandResult( command="uname -a", exit_code=0, stdout="Linux", stderr="" ), ), ], ) chunks = chunk_report(report) assert len(chunks) == 1 assert chunks[0].name == "ok" def test_chunk_report_keeps_exit_255_with_output() -> None: """Exit 255 with stderr present is a real failure — keep it.""" report = CollectionReport( host="test-host", items=[ CollectedItem( name="partial", result=SSHCommandResult( command="some-cmd", exit_code=255, stdout="", stderr="Permission denied", ), ), ], ) chunks = chunk_report(report) assert len(chunks) == 1 assert "Permission denied" in chunks[0].content def test_chunk_report_notes_no_output() -> None: report = CollectionReport( host="test-host", items=[ CollectedItem( name="silent", result=SSHCommandResult(command="cmd", exit_code=0, stdout="", stderr=""), ), ], ) chunks = chunk_report(report) assert "(no output)" in chunks[0].content # --------------------------------------------------------------------------- # _cosine_similarity # --------------------------------------------------------------------------- def test_cosine_similarity_identical_vectors() -> None: v = [1.0, 0.0, 0.0] assert abs(_cosine_similarity(v, v) - 1.0) < 1e-9 def test_cosine_similarity_orthogonal_vectors() -> None: a = [1.0, 0.0] b = [0.0, 1.0] assert abs(_cosine_similarity(a, b)) < 1e-9 def test_cosine_similarity_opposite_vectors() -> None: a = [1.0, 0.0] b = [-1.0, 0.0] assert abs(_cosine_similarity(a, b) - (-1.0)) < 1e-9 def test_cosine_similarity_zero_vector_returns_zero() -> None: assert _cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0 # --------------------------------------------------------------------------- # retrieve # --------------------------------------------------------------------------- def _embedded(name: str, vec: list[float]) -> EmbeddedChunk: return EmbeddedChunk(chunk=Chunk(name=name, content=f"content of {name}"), embedding=vec) def test_retrieve_returns_top_k_by_similarity() -> None: chunks = [ _embedded("close", [1.0, 0.0]), # most similar _embedded("mid", [0.7, 0.7]), _embedded("far", [0.0, 1.0]), # orthogonal to query ] query = [1.0, 0.0] result = retrieve(query, chunks, top_k=2) assert len(result) == 2 assert result[0].name == "close" assert result[1].name == "mid" def test_retrieve_respects_top_k_larger_than_pool() -> None: chunks = [_embedded("only", [1.0, 0.0])] result = retrieve([1.0, 0.0], chunks, top_k=10) assert len(result) == 1 def test_retrieve_empty_pool_returns_empty() -> None: assert retrieve([1.0, 0.0], [], top_k=5) == [] def test_retrieve_top_k_zero_returns_empty() -> None: chunks = [_embedded("x", [1.0, 0.0])] assert retrieve([1.0, 0.0], chunks, top_k=0) == []