feat(rag): harden Tier 1 retrieval observability and stability
Some checks failed
CI / test (push) Failing after 15s
Some checks failed
CI / test (push) Failing after 15s
- Add --rag-debug flag to show retrieved chunk names and similarity scores - Add explicit fallback notices when RAG indexing/query embedding fails - Log RAG index/query metrics (duration, scores, top hit, token estimate) - Normalize and cap chunk content for more stable prompt shape on small models - Add hypothesis-continuity instruction for follow-up prompts - Add retrieval scoring API and new tests for truncation/fallback/debug paths
This commit is contained in:
186
src/tai/cli.py
186
src/tai/cli.py
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from time import perf_counter
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
@@ -24,7 +25,7 @@ from tai.prompt_builder import (
|
||||
build_system_prompt,
|
||||
build_user_message,
|
||||
)
|
||||
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve
|
||||
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored
|
||||
from tai.session_log import SessionLogger
|
||||
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
||||
|
||||
@@ -117,6 +118,13 @@ def run(
|
||||
help="Embedding model for RAG. Must be pulled in Ollama on the AI host.",
|
||||
),
|
||||
] = DEFAULT_EMBED_MODEL,
|
||||
rag_debug: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--rag-debug/--no-rag-debug",
|
||||
help="Print retrieved chunk names/scores and log per-question retrieval metrics.",
|
||||
),
|
||||
] = False,
|
||||
) -> None:
|
||||
"""Start an interactive troubleshooting session scaffold."""
|
||||
try:
|
||||
@@ -168,6 +176,7 @@ def run(
|
||||
interactive=interactive,
|
||||
ai_config=ai_config,
|
||||
no_rag=no_rag,
|
||||
rag_debug=rag_debug,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
@@ -191,6 +200,7 @@ async def _async_main(
|
||||
interactive: bool,
|
||||
ai_config: AIConfig,
|
||||
no_rag: bool,
|
||||
rag_debug: bool,
|
||||
logger: SessionLogger | None,
|
||||
) -> None:
|
||||
"""Open a single SSH session and run probe / collection / analysis through it."""
|
||||
@@ -241,7 +251,15 @@ async def _async_main(
|
||||
_run_analysis(ai_config, req.issue, report, logger=logger)
|
||||
|
||||
if interactive:
|
||||
await _interactive_loop(session, req, ai_config, report, no_rag=no_rag, logger=logger)
|
||||
await _interactive_loop(
|
||||
session,
|
||||
req,
|
||||
ai_config,
|
||||
report,
|
||||
no_rag=no_rag,
|
||||
rag_debug=rag_debug,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
async def _interactive_loop(
|
||||
@@ -251,6 +269,7 @@ async def _interactive_loop(
|
||||
report: CollectionReport | None,
|
||||
*,
|
||||
no_rag: bool = False,
|
||||
rag_debug: bool = False,
|
||||
logger: SessionLogger | None,
|
||||
) -> None:
|
||||
"""Run a follow-up loop for collecting and conversational analysis."""
|
||||
@@ -269,9 +288,33 @@ async def _interactive_loop(
|
||||
ai_embed = AIClient(ai_config)
|
||||
|
||||
if not no_rag and report is not None:
|
||||
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
|
||||
embedded_chunks, index_error, index_ms = await asyncio.to_thread(
|
||||
_try_embed_report, report, ai_embed
|
||||
)
|
||||
if embedded_chunks is not None:
|
||||
console.print(f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]")
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"rag_index",
|
||||
{
|
||||
"status": "ok",
|
||||
"chunk_count": len(embedded_chunks),
|
||||
"duration_ms": round(index_ms, 2),
|
||||
},
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"[yellow]RAG unavailable (indexing failed); using full-context fallback.[/yellow]"
|
||||
)
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"rag_index",
|
||||
{
|
||||
"status": "fallback",
|
||||
"error": index_error,
|
||||
"duration_ms": round(index_ms, 2),
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -312,11 +355,36 @@ async def _interactive_loop(
|
||||
report = await collect_from_plan(session, plan)
|
||||
_handle_collection_report(report)
|
||||
if not no_rag:
|
||||
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
|
||||
embedded_chunks, index_error, index_ms = await asyncio.to_thread(
|
||||
_try_embed_report, report, ai_embed
|
||||
)
|
||||
if embedded_chunks is not None:
|
||||
console.print(
|
||||
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
|
||||
)
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"rag_index",
|
||||
{
|
||||
"status": "ok",
|
||||
"chunk_count": len(embedded_chunks),
|
||||
"duration_ms": round(index_ms, 2),
|
||||
},
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"[yellow]RAG unavailable (indexing failed); "
|
||||
"using full-context fallback.[/yellow]"
|
||||
)
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"rag_index",
|
||||
{
|
||||
"status": "fallback",
|
||||
"error": index_error,
|
||||
"duration_ms": round(index_ms, 2),
|
||||
},
|
||||
)
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"collection_summary",
|
||||
@@ -344,6 +412,7 @@ async def _interactive_loop(
|
||||
"Provide an updated diagnosis from the current diagnostics.",
|
||||
prior_questions,
|
||||
embedded_chunks=embedded_chunks,
|
||||
rag_debug=rag_debug,
|
||||
logger=logger,
|
||||
)
|
||||
prior_questions.append("/analyze")
|
||||
@@ -357,11 +426,36 @@ async def _interactive_loop(
|
||||
report = await collect_from_plan(session, plan)
|
||||
_handle_collection_report(report)
|
||||
if not no_rag:
|
||||
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
|
||||
embedded_chunks, index_error, index_ms = await asyncio.to_thread(
|
||||
_try_embed_report, report, ai_embed
|
||||
)
|
||||
if embedded_chunks is not None:
|
||||
console.print(
|
||||
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
|
||||
)
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"rag_index",
|
||||
{
|
||||
"status": "ok",
|
||||
"chunk_count": len(embedded_chunks),
|
||||
"duration_ms": round(index_ms, 2),
|
||||
},
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"[yellow]RAG unavailable (indexing failed); "
|
||||
"using full-context fallback.[/yellow]"
|
||||
)
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"rag_index",
|
||||
{
|
||||
"status": "fallback",
|
||||
"error": index_error,
|
||||
"duration_ms": round(index_ms, 2),
|
||||
},
|
||||
)
|
||||
|
||||
if report is None:
|
||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||
@@ -374,6 +468,7 @@ async def _interactive_loop(
|
||||
command,
|
||||
prior_questions,
|
||||
embedded_chunks=embedded_chunks,
|
||||
rag_debug=rag_debug,
|
||||
logger=logger,
|
||||
)
|
||||
prior_questions.append(command)
|
||||
@@ -382,21 +477,23 @@ async def _interactive_loop(
|
||||
|
||||
|
||||
def _try_embed_report(
|
||||
report: CollectionReport, ai: AIClient
|
||||
) -> list[EmbeddedChunk] | None:
|
||||
"""Embed all diagnostic chunks from *report*; returns None on any failure.
|
||||
report: CollectionReport,
|
||||
ai: AIClient,
|
||||
) -> tuple[list[EmbeddedChunk] | None, str | None, float]:
|
||||
"""Embed all diagnostic chunks from *report*.
|
||||
|
||||
Failures are expected when the embedding model is not yet pulled or the
|
||||
AI backend is unavailable — in those cases the caller falls back to
|
||||
sending the full report as context.
|
||||
Returns (chunks, error_message, duration_ms). On failure, chunks is None
|
||||
and callers should fall back to non-RAG full-context prompts.
|
||||
"""
|
||||
start = perf_counter()
|
||||
try:
|
||||
chunks = chunk_report(report)
|
||||
if not chunks:
|
||||
return None
|
||||
return [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks]
|
||||
except Exception: # noqa: BLE001
|
||||
return None
|
||||
return None, "no eligible chunks to index", (perf_counter() - start) * 1000.0
|
||||
embedded = [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks]
|
||||
return embedded, None, (perf_counter() - start) * 1000.0
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return None, str(exc), (perf_counter() - start) * 1000.0
|
||||
|
||||
|
||||
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||
@@ -473,6 +570,11 @@ def _run_analysis(
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
"""Rough token estimate for metrics and tuning; assumes ~4 chars/token."""
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
|
||||
def _run_followup_analysis(
|
||||
ai_config: AIConfig,
|
||||
issue: str,
|
||||
@@ -481,13 +583,14 @@ def _run_followup_analysis(
|
||||
prior_questions: list[str],
|
||||
*,
|
||||
embedded_chunks: list[EmbeddedChunk] | None = None,
|
||||
rag_debug: bool = False,
|
||||
logger: SessionLogger | None,
|
||||
) -> str:
|
||||
"""Run grounded follow-up analysis re-anchored to current diagnostics.
|
||||
|
||||
When *embedded_chunks* is provided the question is embedded and the top-5
|
||||
most relevant chunks are used instead of the full report, reducing token
|
||||
usage. Falls back to full-context on any embedding failure.
|
||||
When *embedded_chunks* is provided, the question is embedded and top-k
|
||||
relevant chunks are selected. If retrieval fails, a clear fallback message
|
||||
is emitted and full diagnostic context is used.
|
||||
"""
|
||||
console.print()
|
||||
console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan"))
|
||||
@@ -496,18 +599,59 @@ def _run_followup_analysis(
|
||||
system_prompt = build_system_prompt()
|
||||
|
||||
user_message: str
|
||||
retrieved_names: list[str] = []
|
||||
retrieved_scores: list[float] = []
|
||||
retrieval_ms = 0.0
|
||||
fallback_reason: str | None = None
|
||||
|
||||
if embedded_chunks is not None:
|
||||
retrieval_start = perf_counter()
|
||||
try:
|
||||
q_embedding = ai.embed(question)
|
||||
retrieved = retrieve(q_embedding, embedded_chunks, top_k=5)
|
||||
scored = retrieve_scored(q_embedding, embedded_chunks, top_k=5)
|
||||
retrieval_ms = (perf_counter() - retrieval_start) * 1000.0
|
||||
retrieved_names = [chunk.name for chunk, _score in scored]
|
||||
retrieved_scores = [round(score, 4) for _chunk, score in scored]
|
||||
user_message = build_message_with_chunks(
|
||||
issue, report.host, retrieved, question, prior_questions
|
||||
issue,
|
||||
report.host,
|
||||
[chunk for chunk, _score in scored],
|
||||
question,
|
||||
prior_questions,
|
||||
)
|
||||
if rag_debug:
|
||||
pairs = ", ".join(
|
||||
f"{name}={score:.3f}"
|
||||
for name, score in zip(retrieved_names, retrieved_scores, strict=False)
|
||||
)
|
||||
console.print(f"[dim]RAG retrieve:[/dim] {pairs or 'no matches'}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
retrieval_ms = (perf_counter() - retrieval_start) * 1000.0
|
||||
fallback_reason = str(exc)
|
||||
console.print(
|
||||
"[yellow]RAG unavailable (query embedding failed); using full-context "
|
||||
"fallback.[/yellow]"
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
user_message = build_followup_message(issue, report, question, prior_questions)
|
||||
else:
|
||||
fallback_reason = "rag not indexed"
|
||||
user_message = build_followup_message(issue, report, question, prior_questions)
|
||||
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"rag_query",
|
||||
{
|
||||
"question": question,
|
||||
"retrieved_chunk_names": retrieved_names,
|
||||
"scores": retrieved_scores,
|
||||
"retrieval_ms": round(retrieval_ms, 2),
|
||||
"top_score": retrieved_scores[0] if retrieved_scores else None,
|
||||
"used_fallback": fallback_reason is not None,
|
||||
"fallback_reason": fallback_reason,
|
||||
"estimated_prompt_tokens": _estimate_tokens(system_prompt + user_message),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
chunks: list[str] = []
|
||||
for chunk in ai.stream(system_prompt, user_message):
|
||||
|
||||
@@ -99,6 +99,10 @@ def build_followup_message(
|
||||
"\nAnswer strictly from the collected diagnostics above. "
|
||||
"If evidence is insufficient, explicitly say so."
|
||||
)
|
||||
lines.append(
|
||||
"Keep hypothesis continuity across turns: retain the previous leading "
|
||||
"hypothesis unless newly retrieved evidence directly contradicts it."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -137,4 +141,8 @@ def build_message_with_chunks(
|
||||
"\nAnswer strictly from the retrieved diagnostics above. "
|
||||
"If evidence is insufficient, explicitly say so."
|
||||
)
|
||||
lines.append(
|
||||
"Keep hypothesis continuity across turns: retain the previous leading "
|
||||
"hypothesis unless newly retrieved evidence directly contradicts it."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
@@ -12,6 +12,8 @@ from dataclasses import dataclass
|
||||
|
||||
from tai.collectors import CollectionReport
|
||||
|
||||
DEFAULT_MAX_CHUNK_CHARS = 1800
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Chunk:
|
||||
@@ -29,11 +31,25 @@ class EmbeddedChunk:
|
||||
embedding: list[float]
|
||||
|
||||
|
||||
def chunk_report(report: CollectionReport) -> list[Chunk]:
|
||||
def _normalize_text(text: str, *, max_chars: int) -> str:
|
||||
"""Normalize whitespace and cap text length with a truncation marker."""
|
||||
compact = text.strip()
|
||||
if len(compact) <= max_chars:
|
||||
return compact
|
||||
clipped = compact[:max_chars].rstrip()
|
||||
return f"{clipped}\n...[truncated for RAG]"
|
||||
|
||||
|
||||
def chunk_report(
|
||||
report: CollectionReport,
|
||||
*,
|
||||
max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS,
|
||||
) -> list[Chunk]:
|
||||
"""Split a CollectionReport into one Chunk per diagnostic item.
|
||||
|
||||
Items that SSH could not execute at all (exit 255, no output) are dropped —
|
||||
they carry no diagnostic signal.
|
||||
they carry no diagnostic signal. Chunk text is normalized and capped so the
|
||||
prompt shape stays more stable on smaller local models.
|
||||
"""
|
||||
chunks: list[Chunk] = []
|
||||
for item in report.items:
|
||||
@@ -46,13 +62,14 @@ def chunk_report(report: CollectionReport) -> list[Chunk]:
|
||||
f"Exit code: {result.exit_code}",
|
||||
]
|
||||
if result.stdout:
|
||||
parts.append(f"stdout:\n{result.stdout.strip()}")
|
||||
parts.append(f"stdout:\n{_normalize_text(result.stdout, max_chars=max_chunk_chars)}")
|
||||
if result.stderr:
|
||||
parts.append(f"stderr:\n{result.stderr.strip()}")
|
||||
parts.append(f"stderr:\n{_normalize_text(result.stderr, max_chars=max_chunk_chars)}")
|
||||
if not result.stdout and not result.stderr:
|
||||
parts.append("(no output)")
|
||||
|
||||
chunks.append(Chunk(name=item.name, content="\n".join(parts)))
|
||||
content = _normalize_text("\n".join(parts), max_chars=max_chunk_chars)
|
||||
chunks.append(Chunk(name=item.name, content=content))
|
||||
return chunks
|
||||
|
||||
|
||||
@@ -66,17 +83,13 @@ def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
def retrieve(
|
||||
def retrieve_scored(
|
||||
question_embedding: list[float],
|
||||
embedded_chunks: list[EmbeddedChunk],
|
||||
*,
|
||||
top_k: int = 5,
|
||||
) -> list[Chunk]:
|
||||
"""Return the *top_k* chunks most similar to *question_embedding*.
|
||||
|
||||
Chunks are ranked by cosine similarity in descending order.
|
||||
If *embedded_chunks* is empty or *top_k* is zero, returns an empty list.
|
||||
"""
|
||||
) -> list[tuple[Chunk, float]]:
|
||||
"""Return top-k retrieved chunks with similarity scores."""
|
||||
if not embedded_chunks or top_k <= 0:
|
||||
return []
|
||||
scored: list[tuple[float, Chunk]] = [
|
||||
@@ -84,4 +97,19 @@ def retrieve(
|
||||
for ec in embedded_chunks
|
||||
]
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [chunk for _, chunk in scored[:top_k]]
|
||||
return [(chunk, score) for score, chunk in scored[:top_k]]
|
||||
|
||||
|
||||
def retrieve(
|
||||
question_embedding: list[float],
|
||||
embedded_chunks: list[EmbeddedChunk],
|
||||
*,
|
||||
top_k: int = 5,
|
||||
) -> list[Chunk]:
|
||||
"""Return the *top_k* chunks most similar to *question_embedding*."""
|
||||
scored = retrieve_scored(
|
||||
question_embedding,
|
||||
embedded_chunks,
|
||||
top_k=top_k,
|
||||
)
|
||||
return [chunk for chunk, _score in scored]
|
||||
|
||||
@@ -4,6 +4,7 @@ from typer.testing import CliRunner
|
||||
|
||||
from tai.cli import app
|
||||
from tai.collectors import CollectedItem, CollectionReport
|
||||
from tai.rag_retriever import Chunk, EmbeddedChunk
|
||||
from tai.ssh_client import SSHCommandResult
|
||||
|
||||
|
||||
@@ -230,3 +231,99 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
|
||||
assert result.exit_code == 0
|
||||
assert "AI Response" in result.stdout
|
||||
assert "Check logs." in result.stdout
|
||||
|
||||
|
||||
def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
_mock_session(monkeypatch)
|
||||
|
||||
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||
return CollectionReport(
|
||||
host="ssh.archflux.net",
|
||||
items=[
|
||||
CollectedItem(
|
||||
name="kernel",
|
||||
result=SSHCommandResult(
|
||||
command="uname -a",
|
||||
exit_code=0,
|
||||
stdout="Linux test",
|
||||
stderr="",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
commands = iter(["what should I check next?", "/quit"])
|
||||
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||
monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0))
|
||||
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
|
||||
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"apache failed",
|
||||
"--host",
|
||||
"ssh.archflux.net",
|
||||
"--port",
|
||||
"5566",
|
||||
"--no-probe",
|
||||
"--interactive",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "RAG unavailable (indexing failed)" in result.stdout
|
||||
assert "AI Response" in result.stdout
|
||||
|
||||
|
||||
def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
_mock_session(monkeypatch)
|
||||
|
||||
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||
return CollectionReport(
|
||||
host="ssh.archflux.net",
|
||||
items=[
|
||||
CollectedItem(
|
||||
name="kernel",
|
||||
result=SSHCommandResult(
|
||||
command="uname -a",
|
||||
exit_code=0,
|
||||
stdout="Linux test",
|
||||
stderr="",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
commands = iter(["what should I check next?", "/quit"])
|
||||
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||
monkeypatch.setattr(
|
||||
"tai.cli._try_embed_report",
|
||||
lambda *_args: (
|
||||
[EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])],
|
||||
None,
|
||||
1.0,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0])
|
||||
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
|
||||
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"apache failed",
|
||||
"--host",
|
||||
"ssh.archflux.net",
|
||||
"--port",
|
||||
"5566",
|
||||
"--no-probe",
|
||||
"--interactive",
|
||||
"--rag-debug",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "RAG retrieve:" in result.stdout
|
||||
|
||||
@@ -3,7 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from tai.collectors import CollectedItem, CollectionReport
|
||||
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve
|
||||
from tai.rag_retriever import (
|
||||
Chunk,
|
||||
EmbeddedChunk,
|
||||
_cosine_similarity,
|
||||
chunk_report,
|
||||
retrieve,
|
||||
retrieve_scored,
|
||||
)
|
||||
from tai.ssh_client import SSHCommandResult
|
||||
|
||||
|
||||
@@ -110,6 +117,13 @@ def test_chunk_report_notes_no_output() -> None:
|
||||
assert "(no output)" in chunks[0].content
|
||||
|
||||
|
||||
def test_chunk_report_caps_large_content() -> None:
|
||||
report = _report(("huge", "x" * 5000, 0))
|
||||
chunks = chunk_report(report, max_chunk_chars=200)
|
||||
assert len(chunks[0].content) <= 230
|
||||
assert "...[truncated for RAG]" in chunks[0].content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cosine_similarity
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,6 +172,17 @@ def test_retrieve_returns_top_k_by_similarity() -> None:
|
||||
assert result[1].name == "mid"
|
||||
|
||||
|
||||
def test_retrieve_scored_includes_scores() -> None:
|
||||
chunks = [
|
||||
_embedded("close", [1.0, 0.0]),
|
||||
_embedded("far", [0.0, 1.0]),
|
||||
]
|
||||
result = retrieve_scored([1.0, 0.0], chunks, top_k=2)
|
||||
assert len(result) == 2
|
||||
assert result[0][0].name == "close"
|
||||
assert result[0][1] > result[1][1]
|
||||
|
||||
|
||||
def test_retrieve_respects_top_k_larger_than_pool() -> None:
|
||||
chunks = [_embedded("only", [1.0, 0.0])]
|
||||
result = retrieve([1.0, 0.0], chunks, top_k=10)
|
||||
|
||||
Reference in New Issue
Block a user