feat(rag): harden Tier 1 retrieval observability and stability
Some checks failed
CI / test (push) Failing after 15s
Some checks failed
CI / test (push) Failing after 15s
- Add --rag-debug flag to show retrieved chunk names and similarity scores - Add explicit fallback notices when RAG indexing/query embedding fails - Log RAG index/query metrics (duration, scores, top hit, token estimate) - Normalize and cap chunk content for more stable prompt shape on small models - Add hypothesis-continuity instruction for follow-up prompts - Add retrieval scoring API and new tests for truncation/fallback/debug paths
This commit is contained in:
186
src/tai/cli.py
186
src/tai/cli.py
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from time import perf_counter
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
@@ -24,7 +25,7 @@ from tai.prompt_builder import (
|
|||||||
build_system_prompt,
|
build_system_prompt,
|
||||||
build_user_message,
|
build_user_message,
|
||||||
)
|
)
|
||||||
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve
|
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored
|
||||||
from tai.session_log import SessionLogger
|
from tai.session_log import SessionLogger
|
||||||
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
||||||
|
|
||||||
@@ -117,6 +118,13 @@ def run(
|
|||||||
help="Embedding model for RAG. Must be pulled in Ollama on the AI host.",
|
help="Embedding model for RAG. Must be pulled in Ollama on the AI host.",
|
||||||
),
|
),
|
||||||
] = DEFAULT_EMBED_MODEL,
|
] = DEFAULT_EMBED_MODEL,
|
||||||
|
rag_debug: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--rag-debug/--no-rag-debug",
|
||||||
|
help="Print retrieved chunk names/scores and log per-question retrieval metrics.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start an interactive troubleshooting session scaffold."""
|
"""Start an interactive troubleshooting session scaffold."""
|
||||||
try:
|
try:
|
||||||
@@ -168,6 +176,7 @@ def run(
|
|||||||
interactive=interactive,
|
interactive=interactive,
|
||||||
ai_config=ai_config,
|
ai_config=ai_config,
|
||||||
no_rag=no_rag,
|
no_rag=no_rag,
|
||||||
|
rag_debug=rag_debug,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -191,6 +200,7 @@ async def _async_main(
|
|||||||
interactive: bool,
|
interactive: bool,
|
||||||
ai_config: AIConfig,
|
ai_config: AIConfig,
|
||||||
no_rag: bool,
|
no_rag: bool,
|
||||||
|
rag_debug: bool,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Open a single SSH session and run probe / collection / analysis through it."""
|
"""Open a single SSH session and run probe / collection / analysis through it."""
|
||||||
@@ -241,7 +251,15 @@ async def _async_main(
|
|||||||
_run_analysis(ai_config, req.issue, report, logger=logger)
|
_run_analysis(ai_config, req.issue, report, logger=logger)
|
||||||
|
|
||||||
if interactive:
|
if interactive:
|
||||||
await _interactive_loop(session, req, ai_config, report, no_rag=no_rag, logger=logger)
|
await _interactive_loop(
|
||||||
|
session,
|
||||||
|
req,
|
||||||
|
ai_config,
|
||||||
|
report,
|
||||||
|
no_rag=no_rag,
|
||||||
|
rag_debug=rag_debug,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _interactive_loop(
|
async def _interactive_loop(
|
||||||
@@ -251,6 +269,7 @@ async def _interactive_loop(
|
|||||||
report: CollectionReport | None,
|
report: CollectionReport | None,
|
||||||
*,
|
*,
|
||||||
no_rag: bool = False,
|
no_rag: bool = False,
|
||||||
|
rag_debug: bool = False,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run a follow-up loop for collecting and conversational analysis."""
|
"""Run a follow-up loop for collecting and conversational analysis."""
|
||||||
@@ -269,9 +288,33 @@ async def _interactive_loop(
|
|||||||
ai_embed = AIClient(ai_config)
|
ai_embed = AIClient(ai_config)
|
||||||
|
|
||||||
if not no_rag and report is not None:
|
if not no_rag and report is not None:
|
||||||
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
|
embedded_chunks, index_error, index_ms = await asyncio.to_thread(
|
||||||
|
_try_embed_report, report, ai_embed
|
||||||
|
)
|
||||||
if embedded_chunks is not None:
|
if embedded_chunks is not None:
|
||||||
console.print(f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]")
|
console.print(f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]")
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event(
|
||||||
|
"rag_index",
|
||||||
|
{
|
||||||
|
"status": "ok",
|
||||||
|
"chunk_count": len(embedded_chunks),
|
||||||
|
"duration_ms": round(index_ms, 2),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
console.print(
|
||||||
|
"[yellow]RAG unavailable (indexing failed); using full-context fallback.[/yellow]"
|
||||||
|
)
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event(
|
||||||
|
"rag_index",
|
||||||
|
{
|
||||||
|
"status": "fallback",
|
||||||
|
"error": index_error,
|
||||||
|
"duration_ms": round(index_ms, 2),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -312,11 +355,36 @@ async def _interactive_loop(
|
|||||||
report = await collect_from_plan(session, plan)
|
report = await collect_from_plan(session, plan)
|
||||||
_handle_collection_report(report)
|
_handle_collection_report(report)
|
||||||
if not no_rag:
|
if not no_rag:
|
||||||
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
|
embedded_chunks, index_error, index_ms = await asyncio.to_thread(
|
||||||
|
_try_embed_report, report, ai_embed
|
||||||
|
)
|
||||||
if embedded_chunks is not None:
|
if embedded_chunks is not None:
|
||||||
console.print(
|
console.print(
|
||||||
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
|
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
|
||||||
)
|
)
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event(
|
||||||
|
"rag_index",
|
||||||
|
{
|
||||||
|
"status": "ok",
|
||||||
|
"chunk_count": len(embedded_chunks),
|
||||||
|
"duration_ms": round(index_ms, 2),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
console.print(
|
||||||
|
"[yellow]RAG unavailable (indexing failed); "
|
||||||
|
"using full-context fallback.[/yellow]"
|
||||||
|
)
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event(
|
||||||
|
"rag_index",
|
||||||
|
{
|
||||||
|
"status": "fallback",
|
||||||
|
"error": index_error,
|
||||||
|
"duration_ms": round(index_ms, 2),
|
||||||
|
},
|
||||||
|
)
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event(
|
logger.log_event(
|
||||||
"collection_summary",
|
"collection_summary",
|
||||||
@@ -344,6 +412,7 @@ async def _interactive_loop(
|
|||||||
"Provide an updated diagnosis from the current diagnostics.",
|
"Provide an updated diagnosis from the current diagnostics.",
|
||||||
prior_questions,
|
prior_questions,
|
||||||
embedded_chunks=embedded_chunks,
|
embedded_chunks=embedded_chunks,
|
||||||
|
rag_debug=rag_debug,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
prior_questions.append("/analyze")
|
prior_questions.append("/analyze")
|
||||||
@@ -357,11 +426,36 @@ async def _interactive_loop(
|
|||||||
report = await collect_from_plan(session, plan)
|
report = await collect_from_plan(session, plan)
|
||||||
_handle_collection_report(report)
|
_handle_collection_report(report)
|
||||||
if not no_rag:
|
if not no_rag:
|
||||||
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
|
embedded_chunks, index_error, index_ms = await asyncio.to_thread(
|
||||||
|
_try_embed_report, report, ai_embed
|
||||||
|
)
|
||||||
if embedded_chunks is not None:
|
if embedded_chunks is not None:
|
||||||
console.print(
|
console.print(
|
||||||
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
|
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
|
||||||
)
|
)
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event(
|
||||||
|
"rag_index",
|
||||||
|
{
|
||||||
|
"status": "ok",
|
||||||
|
"chunk_count": len(embedded_chunks),
|
||||||
|
"duration_ms": round(index_ms, 2),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
console.print(
|
||||||
|
"[yellow]RAG unavailable (indexing failed); "
|
||||||
|
"using full-context fallback.[/yellow]"
|
||||||
|
)
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event(
|
||||||
|
"rag_index",
|
||||||
|
{
|
||||||
|
"status": "fallback",
|
||||||
|
"error": index_error,
|
||||||
|
"duration_ms": round(index_ms, 2),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if report is None:
|
if report is None:
|
||||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||||
@@ -374,6 +468,7 @@ async def _interactive_loop(
|
|||||||
command,
|
command,
|
||||||
prior_questions,
|
prior_questions,
|
||||||
embedded_chunks=embedded_chunks,
|
embedded_chunks=embedded_chunks,
|
||||||
|
rag_debug=rag_debug,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
prior_questions.append(command)
|
prior_questions.append(command)
|
||||||
@@ -382,21 +477,23 @@ async def _interactive_loop(
|
|||||||
|
|
||||||
|
|
||||||
def _try_embed_report(
|
def _try_embed_report(
|
||||||
report: CollectionReport, ai: AIClient
|
report: CollectionReport,
|
||||||
) -> list[EmbeddedChunk] | None:
|
ai: AIClient,
|
||||||
"""Embed all diagnostic chunks from *report*; returns None on any failure.
|
) -> tuple[list[EmbeddedChunk] | None, str | None, float]:
|
||||||
|
"""Embed all diagnostic chunks from *report*.
|
||||||
|
|
||||||
Failures are expected when the embedding model is not yet pulled or the
|
Returns (chunks, error_message, duration_ms). On failure, chunks is None
|
||||||
AI backend is unavailable — in those cases the caller falls back to
|
and callers should fall back to non-RAG full-context prompts.
|
||||||
sending the full report as context.
|
|
||||||
"""
|
"""
|
||||||
|
start = perf_counter()
|
||||||
try:
|
try:
|
||||||
chunks = chunk_report(report)
|
chunks = chunk_report(report)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return None
|
return None, "no eligible chunks to index", (perf_counter() - start) * 1000.0
|
||||||
return [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks]
|
embedded = [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks]
|
||||||
except Exception: # noqa: BLE001
|
return embedded, None, (perf_counter() - start) * 1000.0
|
||||||
return None
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return None, str(exc), (perf_counter() - start) * 1000.0
|
||||||
|
|
||||||
|
|
||||||
def _handle_probe_result(result: SSHCommandResult) -> None:
|
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||||
@@ -473,6 +570,11 @@ def _run_analysis(
|
|||||||
raise typer.Exit(code=1) from exc
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_tokens(text: str) -> int:
|
||||||
|
"""Rough token estimate for metrics and tuning; assumes ~4 chars/token."""
|
||||||
|
return max(1, len(text) // 4)
|
||||||
|
|
||||||
|
|
||||||
def _run_followup_analysis(
|
def _run_followup_analysis(
|
||||||
ai_config: AIConfig,
|
ai_config: AIConfig,
|
||||||
issue: str,
|
issue: str,
|
||||||
@@ -481,13 +583,14 @@ def _run_followup_analysis(
|
|||||||
prior_questions: list[str],
|
prior_questions: list[str],
|
||||||
*,
|
*,
|
||||||
embedded_chunks: list[EmbeddedChunk] | None = None,
|
embedded_chunks: list[EmbeddedChunk] | None = None,
|
||||||
|
rag_debug: bool = False,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run grounded follow-up analysis re-anchored to current diagnostics.
|
"""Run grounded follow-up analysis re-anchored to current diagnostics.
|
||||||
|
|
||||||
When *embedded_chunks* is provided the question is embedded and the top-5
|
When *embedded_chunks* is provided, the question is embedded and top-k
|
||||||
most relevant chunks are used instead of the full report, reducing token
|
relevant chunks are selected. If retrieval fails, a clear fallback message
|
||||||
usage. Falls back to full-context on any embedding failure.
|
is emitted and full diagnostic context is used.
|
||||||
"""
|
"""
|
||||||
console.print()
|
console.print()
|
||||||
console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan"))
|
console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan"))
|
||||||
@@ -496,18 +599,59 @@ def _run_followup_analysis(
|
|||||||
system_prompt = build_system_prompt()
|
system_prompt = build_system_prompt()
|
||||||
|
|
||||||
user_message: str
|
user_message: str
|
||||||
|
retrieved_names: list[str] = []
|
||||||
|
retrieved_scores: list[float] = []
|
||||||
|
retrieval_ms = 0.0
|
||||||
|
fallback_reason: str | None = None
|
||||||
|
|
||||||
if embedded_chunks is not None:
|
if embedded_chunks is not None:
|
||||||
|
retrieval_start = perf_counter()
|
||||||
try:
|
try:
|
||||||
q_embedding = ai.embed(question)
|
q_embedding = ai.embed(question)
|
||||||
retrieved = retrieve(q_embedding, embedded_chunks, top_k=5)
|
scored = retrieve_scored(q_embedding, embedded_chunks, top_k=5)
|
||||||
|
retrieval_ms = (perf_counter() - retrieval_start) * 1000.0
|
||||||
|
retrieved_names = [chunk.name for chunk, _score in scored]
|
||||||
|
retrieved_scores = [round(score, 4) for _chunk, score in scored]
|
||||||
user_message = build_message_with_chunks(
|
user_message = build_message_with_chunks(
|
||||||
issue, report.host, retrieved, question, prior_questions
|
issue,
|
||||||
|
report.host,
|
||||||
|
[chunk for chunk, _score in scored],
|
||||||
|
question,
|
||||||
|
prior_questions,
|
||||||
|
)
|
||||||
|
if rag_debug:
|
||||||
|
pairs = ", ".join(
|
||||||
|
f"{name}={score:.3f}"
|
||||||
|
for name, score in zip(retrieved_names, retrieved_scores, strict=False)
|
||||||
|
)
|
||||||
|
console.print(f"[dim]RAG retrieve:[/dim] {pairs or 'no matches'}")
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
retrieval_ms = (perf_counter() - retrieval_start) * 1000.0
|
||||||
|
fallback_reason = str(exc)
|
||||||
|
console.print(
|
||||||
|
"[yellow]RAG unavailable (query embedding failed); using full-context "
|
||||||
|
"fallback.[/yellow]"
|
||||||
)
|
)
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
user_message = build_followup_message(issue, report, question, prior_questions)
|
user_message = build_followup_message(issue, report, question, prior_questions)
|
||||||
else:
|
else:
|
||||||
|
fallback_reason = "rag not indexed"
|
||||||
user_message = build_followup_message(issue, report, question, prior_questions)
|
user_message = build_followup_message(issue, report, question, prior_questions)
|
||||||
|
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event(
|
||||||
|
"rag_query",
|
||||||
|
{
|
||||||
|
"question": question,
|
||||||
|
"retrieved_chunk_names": retrieved_names,
|
||||||
|
"scores": retrieved_scores,
|
||||||
|
"retrieval_ms": round(retrieval_ms, 2),
|
||||||
|
"top_score": retrieved_scores[0] if retrieved_scores else None,
|
||||||
|
"used_fallback": fallback_reason is not None,
|
||||||
|
"fallback_reason": fallback_reason,
|
||||||
|
"estimated_prompt_tokens": _estimate_tokens(system_prompt + user_message),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunks: list[str] = []
|
chunks: list[str] = []
|
||||||
for chunk in ai.stream(system_prompt, user_message):
|
for chunk in ai.stream(system_prompt, user_message):
|
||||||
|
|||||||
@@ -99,6 +99,10 @@ def build_followup_message(
|
|||||||
"\nAnswer strictly from the collected diagnostics above. "
|
"\nAnswer strictly from the collected diagnostics above. "
|
||||||
"If evidence is insufficient, explicitly say so."
|
"If evidence is insufficient, explicitly say so."
|
||||||
)
|
)
|
||||||
|
lines.append(
|
||||||
|
"Keep hypothesis continuity across turns: retain the previous leading "
|
||||||
|
"hypothesis unless newly retrieved evidence directly contradicts it."
|
||||||
|
)
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@@ -137,4 +141,8 @@ def build_message_with_chunks(
|
|||||||
"\nAnswer strictly from the retrieved diagnostics above. "
|
"\nAnswer strictly from the retrieved diagnostics above. "
|
||||||
"If evidence is insufficient, explicitly say so."
|
"If evidence is insufficient, explicitly say so."
|
||||||
)
|
)
|
||||||
|
lines.append(
|
||||||
|
"Keep hypothesis continuity across turns: retain the previous leading "
|
||||||
|
"hypothesis unless newly retrieved evidence directly contradicts it."
|
||||||
|
)
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
from tai.collectors import CollectionReport
|
from tai.collectors import CollectionReport
|
||||||
|
|
||||||
|
DEFAULT_MAX_CHUNK_CHARS = 1800
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class Chunk:
|
class Chunk:
|
||||||
@@ -29,11 +31,25 @@ class EmbeddedChunk:
|
|||||||
embedding: list[float]
|
embedding: list[float]
|
||||||
|
|
||||||
|
|
||||||
def chunk_report(report: CollectionReport) -> list[Chunk]:
|
def _normalize_text(text: str, *, max_chars: int) -> str:
|
||||||
|
"""Normalize whitespace and cap text length with a truncation marker."""
|
||||||
|
compact = text.strip()
|
||||||
|
if len(compact) <= max_chars:
|
||||||
|
return compact
|
||||||
|
clipped = compact[:max_chars].rstrip()
|
||||||
|
return f"{clipped}\n...[truncated for RAG]"
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_report(
|
||||||
|
report: CollectionReport,
|
||||||
|
*,
|
||||||
|
max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS,
|
||||||
|
) -> list[Chunk]:
|
||||||
"""Split a CollectionReport into one Chunk per diagnostic item.
|
"""Split a CollectionReport into one Chunk per diagnostic item.
|
||||||
|
|
||||||
Items that SSH could not execute at all (exit 255, no output) are dropped —
|
Items that SSH could not execute at all (exit 255, no output) are dropped —
|
||||||
they carry no diagnostic signal.
|
they carry no diagnostic signal. Chunk text is normalized and capped so the
|
||||||
|
prompt shape stays more stable on smaller local models.
|
||||||
"""
|
"""
|
||||||
chunks: list[Chunk] = []
|
chunks: list[Chunk] = []
|
||||||
for item in report.items:
|
for item in report.items:
|
||||||
@@ -46,13 +62,14 @@ def chunk_report(report: CollectionReport) -> list[Chunk]:
|
|||||||
f"Exit code: {result.exit_code}",
|
f"Exit code: {result.exit_code}",
|
||||||
]
|
]
|
||||||
if result.stdout:
|
if result.stdout:
|
||||||
parts.append(f"stdout:\n{result.stdout.strip()}")
|
parts.append(f"stdout:\n{_normalize_text(result.stdout, max_chars=max_chunk_chars)}")
|
||||||
if result.stderr:
|
if result.stderr:
|
||||||
parts.append(f"stderr:\n{result.stderr.strip()}")
|
parts.append(f"stderr:\n{_normalize_text(result.stderr, max_chars=max_chunk_chars)}")
|
||||||
if not result.stdout and not result.stderr:
|
if not result.stdout and not result.stderr:
|
||||||
parts.append("(no output)")
|
parts.append("(no output)")
|
||||||
|
|
||||||
chunks.append(Chunk(name=item.name, content="\n".join(parts)))
|
content = _normalize_text("\n".join(parts), max_chars=max_chunk_chars)
|
||||||
|
chunks.append(Chunk(name=item.name, content=content))
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
@@ -66,17 +83,13 @@ def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
|||||||
return dot / (norm_a * norm_b)
|
return dot / (norm_a * norm_b)
|
||||||
|
|
||||||
|
|
||||||
def retrieve(
|
def retrieve_scored(
|
||||||
question_embedding: list[float],
|
question_embedding: list[float],
|
||||||
embedded_chunks: list[EmbeddedChunk],
|
embedded_chunks: list[EmbeddedChunk],
|
||||||
*,
|
*,
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
) -> list[Chunk]:
|
) -> list[tuple[Chunk, float]]:
|
||||||
"""Return the *top_k* chunks most similar to *question_embedding*.
|
"""Return top-k retrieved chunks with similarity scores."""
|
||||||
|
|
||||||
Chunks are ranked by cosine similarity in descending order.
|
|
||||||
If *embedded_chunks* is empty or *top_k* is zero, returns an empty list.
|
|
||||||
"""
|
|
||||||
if not embedded_chunks or top_k <= 0:
|
if not embedded_chunks or top_k <= 0:
|
||||||
return []
|
return []
|
||||||
scored: list[tuple[float, Chunk]] = [
|
scored: list[tuple[float, Chunk]] = [
|
||||||
@@ -84,4 +97,19 @@ def retrieve(
|
|||||||
for ec in embedded_chunks
|
for ec in embedded_chunks
|
||||||
]
|
]
|
||||||
scored.sort(key=lambda x: x[0], reverse=True)
|
scored.sort(key=lambda x: x[0], reverse=True)
|
||||||
return [chunk for _, chunk in scored[:top_k]]
|
return [(chunk, score) for score, chunk in scored[:top_k]]
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
question_embedding: list[float],
|
||||||
|
embedded_chunks: list[EmbeddedChunk],
|
||||||
|
*,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> list[Chunk]:
|
||||||
|
"""Return the *top_k* chunks most similar to *question_embedding*."""
|
||||||
|
scored = retrieve_scored(
|
||||||
|
question_embedding,
|
||||||
|
embedded_chunks,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
return [chunk for chunk, _score in scored]
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typer.testing import CliRunner
|
|||||||
|
|
||||||
from tai.cli import app
|
from tai.cli import app
|
||||||
from tai.collectors import CollectedItem, CollectionReport
|
from tai.collectors import CollectedItem, CollectionReport
|
||||||
|
from tai.rag_retriever import Chunk, EmbeddedChunk
|
||||||
from tai.ssh_client import SSHCommandResult
|
from tai.ssh_client import SSHCommandResult
|
||||||
|
|
||||||
|
|
||||||
@@ -230,3 +231,99 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
|
|||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "AI Response" in result.stdout
|
assert "AI Response" in result.stdout
|
||||||
assert "Check logs." in result.stdout
|
assert "Check logs." in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||||
|
_mock_session(monkeypatch)
|
||||||
|
|
||||||
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||||
|
return CollectionReport(
|
||||||
|
host="ssh.archflux.net",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name="kernel",
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command="uname -a",
|
||||||
|
exit_code=0,
|
||||||
|
stdout="Linux test",
|
||||||
|
stderr="",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
commands = iter(["what should I check next?", "/quit"])
|
||||||
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||||
|
monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0))
|
||||||
|
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
|
||||||
|
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"apache failed",
|
||||||
|
"--host",
|
||||||
|
"ssh.archflux.net",
|
||||||
|
"--port",
|
||||||
|
"5566",
|
||||||
|
"--no-probe",
|
||||||
|
"--interactive",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "RAG unavailable (indexing failed)" in result.stdout
|
||||||
|
assert "AI Response" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||||
|
_mock_session(monkeypatch)
|
||||||
|
|
||||||
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||||
|
return CollectionReport(
|
||||||
|
host="ssh.archflux.net",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name="kernel",
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command="uname -a",
|
||||||
|
exit_code=0,
|
||||||
|
stdout="Linux test",
|
||||||
|
stderr="",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
commands = iter(["what should I check next?", "/quit"])
|
||||||
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tai.cli._try_embed_report",
|
||||||
|
lambda *_args: (
|
||||||
|
[EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])],
|
||||||
|
None,
|
||||||
|
1.0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0])
|
||||||
|
monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]))
|
||||||
|
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"apache failed",
|
||||||
|
"--host",
|
||||||
|
"ssh.archflux.net",
|
||||||
|
"--port",
|
||||||
|
"5566",
|
||||||
|
"--no-probe",
|
||||||
|
"--interactive",
|
||||||
|
"--rag-debug",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "RAG retrieve:" in result.stdout
|
||||||
|
|||||||
@@ -3,7 +3,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from tai.collectors import CollectedItem, CollectionReport
|
from tai.collectors import CollectedItem, CollectionReport
|
||||||
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve
|
from tai.rag_retriever import (
|
||||||
|
Chunk,
|
||||||
|
EmbeddedChunk,
|
||||||
|
_cosine_similarity,
|
||||||
|
chunk_report,
|
||||||
|
retrieve,
|
||||||
|
retrieve_scored,
|
||||||
|
)
|
||||||
from tai.ssh_client import SSHCommandResult
|
from tai.ssh_client import SSHCommandResult
|
||||||
|
|
||||||
|
|
||||||
@@ -110,6 +117,13 @@ def test_chunk_report_notes_no_output() -> None:
|
|||||||
assert "(no output)" in chunks[0].content
|
assert "(no output)" in chunks[0].content
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_report_caps_large_content() -> None:
|
||||||
|
report = _report(("huge", "x" * 5000, 0))
|
||||||
|
chunks = chunk_report(report, max_chunk_chars=200)
|
||||||
|
assert len(chunks[0].content) <= 230
|
||||||
|
assert "...[truncated for RAG]" in chunks[0].content
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _cosine_similarity
|
# _cosine_similarity
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -158,6 +172,17 @@ def test_retrieve_returns_top_k_by_similarity() -> None:
|
|||||||
assert result[1].name == "mid"
|
assert result[1].name == "mid"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_scored_includes_scores() -> None:
|
||||||
|
chunks = [
|
||||||
|
_embedded("close", [1.0, 0.0]),
|
||||||
|
_embedded("far", [0.0, 1.0]),
|
||||||
|
]
|
||||||
|
result = retrieve_scored([1.0, 0.0], chunks, top_k=2)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0][0].name == "close"
|
||||||
|
assert result[0][1] > result[1][1]
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_respects_top_k_larger_than_pool() -> None:
|
def test_retrieve_respects_top_k_larger_than_pool() -> None:
|
||||||
chunks = [_embedded("only", [1.0, 0.0])]
|
chunks = [_embedded("only", [1.0, 0.0])]
|
||||||
result = retrieve([1.0, 0.0], chunks, top_k=10)
|
result = retrieve([1.0, 0.0], chunks, top_k=10)
|
||||||
|
|||||||
Reference in New Issue
Block a user