feat(rag): harden Tier 1 retrieval observability and stability
Some checks failed
CI / test (push) Failing after 15s
Some checks failed
CI / test (push) Failing after 15s
- Add --rag-debug flag to show retrieved chunk names and similarity scores - Add explicit fallback notices when RAG indexing/query embedding fails - Log RAG index/query metrics (duration, scores, top hit, token estimate) - Normalize and cap chunk content for more stable prompt shape on small models - Add hypothesis-continuity instruction for follow-up prompts - Add retrieval scoring API and new tests for truncation/fallback/debug paths
This commit is contained in:
@@ -3,7 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from tai.collectors import CollectedItem, CollectionReport
|
||||
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve
|
||||
from tai.rag_retriever import (
|
||||
Chunk,
|
||||
EmbeddedChunk,
|
||||
_cosine_similarity,
|
||||
chunk_report,
|
||||
retrieve,
|
||||
retrieve_scored,
|
||||
)
|
||||
from tai.ssh_client import SSHCommandResult
|
||||
|
||||
|
||||
@@ -110,6 +117,13 @@ def test_chunk_report_notes_no_output() -> None:
|
||||
assert "(no output)" in chunks[0].content
|
||||
|
||||
|
||||
def test_chunk_report_caps_large_content() -> None:
|
||||
report = _report(("huge", "x" * 5000, 0))
|
||||
chunks = chunk_report(report, max_chunk_chars=200)
|
||||
assert len(chunks[0].content) <= 230
|
||||
assert "...[truncated for RAG]" in chunks[0].content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cosine_similarity
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,6 +172,17 @@ def test_retrieve_returns_top_k_by_similarity() -> None:
|
||||
assert result[1].name == "mid"
|
||||
|
||||
|
||||
def test_retrieve_scored_includes_scores() -> None:
|
||||
chunks = [
|
||||
_embedded("close", [1.0, 0.0]),
|
||||
_embedded("far", [0.0, 1.0]),
|
||||
]
|
||||
result = retrieve_scored([1.0, 0.0], chunks, top_k=2)
|
||||
assert len(result) == 2
|
||||
assert result[0][0].name == "close"
|
||||
assert result[0][1] > result[1][1]
|
||||
|
||||
|
||||
def test_retrieve_respects_top_k_larger_than_pool() -> None:
|
||||
chunks = [_embedded("only", [1.0, 0.0])]
|
||||
result = retrieve([1.0, 0.0], chunks, top_k=10)
|
||||
|
||||
Reference in New Issue
Block a user