feat(rag): implement Tier 1 in-memory RAG for interactive follow-ups
Some checks failed
CI / test (push) Failing after 15s

- Add embed() to AIClient using Ollama nomic-embed-text via /v1/embeddings
- Add DEFAULT_EMBED_MODEL and embed_model field to AIConfig
- New rag_retriever.py: chunk_report(), EmbeddedChunk, retrieve() (pure-Python cosine)
- prompt_builder: add build_message_with_chunks() for RAG-aware follow-up prompts
- cli: add --no-rag flag, embed report chunks after collection, retrieve top-5 per question
- Graceful fallback to full-context if embedding model unavailable
- 16 new tests in test_rag_retriever.py (67 total, all passing)
- Add chromadb>=0.5 as optional [rag] dep in pyproject.toml
- README: add step 3 (pull nomic-embed-text), update Suggested Tooling table
This commit is contained in:
2026-05-04 18:36:12 +02:00
parent c1192cdb94
commit be181c2d7f
4 changed files with 377 additions and 4 deletions

View File

@@ -18,7 +18,13 @@ from tai.collectors import CollectionReport, collect_from_plan
from tai.input_parser import InputValidationError, build_request from tai.input_parser import InputValidationError, build_request
from tai.models import TroubleshootRequest from tai.models import TroubleshootRequest
from tai.plan import plan_from_request from tai.plan import plan_from_request
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message from tai.prompt_builder import (
build_followup_message,
build_message_with_chunks,
build_system_prompt,
build_user_message,
)
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve
from tai.session_log import SessionLogger from tai.session_log import SessionLogger
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
@@ -97,6 +103,13 @@ def run(
help="Optional JSONL file path to log AI and session output.", help="Optional JSONL file path to log AI and session output.",
), ),
] = None, ] = None,
no_rag: Annotated[
bool,
typer.Option(
"--no-rag",
help="Disable RAG; send full diagnostics to AI instead of retrieved chunks.",
),
] = False,
) -> None: ) -> None:
"""Start an interactive troubleshooting session scaffold.""" """Start an interactive troubleshooting session scaffold."""
try: try:
@@ -147,6 +160,7 @@ def run(
analyze=analyze, analyze=analyze,
interactive=interactive, interactive=interactive,
ai_config=ai_config, ai_config=ai_config,
no_rag=no_rag,
logger=logger, logger=logger,
) )
) )
@@ -169,6 +183,7 @@ async def _async_main(
analyze: bool, analyze: bool,
interactive: bool, interactive: bool,
ai_config: AIConfig, ai_config: AIConfig,
no_rag: bool,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> None: ) -> None:
"""Open a single SSH session and run probe / collection / analysis through it.""" """Open a single SSH session and run probe / collection / analysis through it."""
@@ -219,7 +234,7 @@ async def _async_main(
_run_analysis(ai_config, req.issue, report, logger=logger) _run_analysis(ai_config, req.issue, report, logger=logger)
if interactive: if interactive:
await _interactive_loop(session, req, ai_config, report, logger=logger) await _interactive_loop(session, req, ai_config, report, no_rag=no_rag, logger=logger)
async def _interactive_loop( async def _interactive_loop(
@@ -227,6 +242,8 @@ async def _interactive_loop(
req: TroubleshootRequest, req: TroubleshootRequest,
ai_config: AIConfig, ai_config: AIConfig,
report: CollectionReport | None, report: CollectionReport | None,
*,
no_rag: bool = False,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> None: ) -> None:
"""Run a follow-up loop for collecting and conversational analysis.""" """Run a follow-up loop for collecting and conversational analysis."""
@@ -241,6 +258,13 @@ async def _interactive_loop(
) )
prior_questions: list[str] = [] prior_questions: list[str] = []
embedded_chunks: list[EmbeddedChunk] | None = None
ai_embed = AIClient(ai_config)
if not no_rag and report is not None:
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
if embedded_chunks is not None:
console.print(f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]")
while True: while True:
try: try:
@@ -280,6 +304,12 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
if not no_rag:
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
if embedded_chunks is not None:
console.print(
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
)
if logger is not None: if logger is not None:
logger.log_event( logger.log_event(
"collection_summary", "collection_summary",
@@ -306,6 +336,7 @@ async def _interactive_loop(
report, report,
"Provide an updated diagnosis from the current diagnostics.", "Provide an updated diagnosis from the current diagnostics.",
prior_questions, prior_questions,
embedded_chunks=embedded_chunks,
logger=logger, logger=logger,
) )
prior_questions.append("/analyze") prior_questions.append("/analyze")
@@ -318,6 +349,12 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
if not no_rag:
embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed)
if embedded_chunks is not None:
console.print(
f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]"
)
if report is None: if report is None:
console.print("[red]No diagnostics available to analyze.[/red]") console.print("[red]No diagnostics available to analyze.[/red]")
@@ -329,6 +366,7 @@ async def _interactive_loop(
report, report,
command, command,
prior_questions, prior_questions,
embedded_chunks=embedded_chunks,
logger=logger, logger=logger,
) )
prior_questions.append(command) prior_questions.append(command)
@@ -336,6 +374,24 @@ async def _interactive_loop(
logger.log_event("interactive_followup", {"question": command}) logger.log_event("interactive_followup", {"question": command})
def _try_embed_report(
report: CollectionReport, ai: AIClient
) -> list[EmbeddedChunk] | None:
"""Embed all diagnostic chunks from *report*; returns None on any failure.
Failures are expected when the embedding model is not yet pulled or the
AI backend is unavailable — in those cases the caller falls back to
sending the full report as context.
"""
try:
chunks = chunk_report(report)
if not chunks:
return None
return [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks]
except Exception: # noqa: BLE001
return None
def _handle_probe_result(result: SSHCommandResult) -> None: def _handle_probe_result(result: SSHCommandResult) -> None:
"""Handle and render probe output for success or failure.""" """Handle and render probe output for success or failure."""
console.print("[dim]▶ SSH probe:[/dim] uname -a") console.print("[dim]▶ SSH probe:[/dim] uname -a")
@@ -417,15 +473,33 @@ def _run_followup_analysis(
question: str, question: str,
prior_questions: list[str], prior_questions: list[str],
*, *,
embedded_chunks: list[EmbeddedChunk] | None = None,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> str: ) -> str:
"""Run grounded follow-up analysis re-anchored to current diagnostics.""" """Run grounded follow-up analysis re-anchored to current diagnostics.
When *embedded_chunks* is provided the question is embedded and the top-5
most relevant chunks are used instead of the full report, reducing token
usage. Falls back to full-context on any embedding failure.
"""
console.print() console.print()
console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan")) console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan"))
console.print() console.print()
ai = AIClient(ai_config) ai = AIClient(ai_config)
system_prompt = build_system_prompt() system_prompt = build_system_prompt()
user_message = build_followup_message(issue, report, question, prior_questions)
user_message: str
if embedded_chunks is not None:
try:
q_embedding = ai.embed(question)
retrieved = retrieve(q_embedding, embedded_chunks, top_k=5)
user_message = build_message_with_chunks(
issue, report.host, retrieved, question, prior_questions
)
except Exception: # noqa: BLE001
user_message = build_followup_message(issue, report, question, prior_questions)
else:
user_message = build_followup_message(issue, report, question, prior_questions)
try: try:
chunks: list[str] = [] chunks: list[str] = []

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from tai.collectors import CollectionReport from tai.collectors import CollectionReport
from tai.rag_retriever import Chunk
_SYSTEM_PROMPT = """\ _SYSTEM_PROMPT = """\
You are an expert Linux systems administrator and troubleshooting assistant. You are an expert Linux systems administrator and troubleshooting assistant.
@@ -99,3 +100,41 @@ def build_followup_message(
"If evidence is insufficient, explicitly say so." "If evidence is insufficient, explicitly say so."
) )
return "\n".join(lines) return "\n".join(lines)
def build_message_with_chunks(
issue: str,
host: str,
chunks: list[Chunk],
question: str,
prior_questions: list[str],
) -> str:
"""Build a follow-up message using only semantically retrieved diagnostic chunks.
Used by the RAG path: instead of sending the full report, only the top-k
most relevant chunks are included, reducing token usage and focusing the AI.
"""
lines: list[str] = []
lines.append(f"## Issue reported\n\n{issue}\n")
lines.append(f"## Target host\n\n{host}\n")
lines.append("## Most relevant diagnostics (retrieved by semantic similarity)\n")
for chunk in chunks:
lines.append(f"### {chunk.name}\n")
lines.append(chunk.content)
lines.append("")
lines.append("## Follow-up")
if prior_questions:
lines.append("\nRecent user follow-up questions:")
for idx, q in enumerate(prior_questions[-5:], start=1):
lines.append(f"{idx}. {q}")
lines.append("\nCurrent follow-up question:")
lines.append(question)
lines.append(
"\nAnswer strictly from the retrieved diagnostics above. "
"If evidence is insufficient, explicitly say so."
)
return "\n".join(lines)

87
src/tai/rag_retriever.py Normal file
View File

@@ -0,0 +1,87 @@
"""In-memory RAG retriever for diagnostic report chunks (Tier 1).
Chunks one CollectionReport item per Chunk, embeds via AIClient, then
ranks chunks against a question using pure-Python cosine similarity.
No external vector store required — everything lives in process memory.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from tai.collectors import CollectionReport
@dataclass(slots=True)
class Chunk:
"""A single retrievable piece of diagnostic content."""
name: str
content: str
@dataclass(slots=True)
class EmbeddedChunk:
"""A Chunk paired with its embedding vector."""
chunk: Chunk
embedding: list[float]
def chunk_report(report: CollectionReport) -> list[Chunk]:
"""Split a CollectionReport into one Chunk per diagnostic item.
Items that SSH could not execute at all (exit 255, no output) are dropped —
they carry no diagnostic signal.
"""
chunks: list[Chunk] = []
for item in report.items:
result = item.result
if result.exit_code == 255 and not result.stdout and not result.stderr:
continue
parts: list[str] = [
f"Command: {result.command}",
f"Exit code: {result.exit_code}",
]
if result.stdout:
parts.append(f"stdout:\n{result.stdout.strip()}")
if result.stderr:
parts.append(f"stderr:\n{result.stderr.strip()}")
if not result.stdout and not result.stderr:
parts.append("(no output)")
chunks.append(Chunk(name=item.name, content="\n".join(parts)))
return chunks
def _cosine_similarity(a: list[float], b: list[float]) -> float:
"""Return cosine similarity in [-1, 1] using pure Python (no numpy)."""
dot = sum(x * y for x, y in zip(a, b, strict=False))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return dot / (norm_a * norm_b)
def retrieve(
question_embedding: list[float],
embedded_chunks: list[EmbeddedChunk],
*,
top_k: int = 5,
) -> list[Chunk]:
"""Return the *top_k* chunks most similar to *question_embedding*.
Chunks are ranked by cosine similarity in descending order.
If *embedded_chunks* is empty or *top_k* is zero, returns an empty list.
"""
if not embedded_chunks or top_k <= 0:
return []
scored: list[tuple[float, Chunk]] = [
(_cosine_similarity(question_embedding, ec.embedding), ec.chunk)
for ec in embedded_chunks
]
scored.sort(key=lambda x: x[0], reverse=True)
return [chunk for _, chunk in scored[:top_k]]

173
tests/test_rag_retriever.py Normal file
View File

@@ -0,0 +1,173 @@
"""Tests for rag_retriever — pure-Python, no network calls."""
from __future__ import annotations
from tai.collectors import CollectedItem, CollectionReport
from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve
from tai.ssh_client import SSHCommandResult
def _report(*items: tuple[str, str, int]) -> CollectionReport:
"""Build a CollectionReport from (name, stdout, exit_code) tuples."""
return CollectionReport(
host="test-host",
items=[
CollectedItem(
name=name,
result=SSHCommandResult(
command=f"cmd-{name}",
exit_code=code,
stdout=stdout,
stderr="",
),
)
for name, stdout, code in items
],
)
# ---------------------------------------------------------------------------
# chunk_report
# ---------------------------------------------------------------------------
def test_chunk_report_creates_one_chunk_per_item() -> None:
report = _report(("kernel", "Linux test 6.1", 0), ("journal", "Started nginx.", 0))
chunks = chunk_report(report)
assert len(chunks) == 2
assert chunks[0].name == "kernel"
assert chunks[1].name == "journal"
def test_chunk_report_includes_stdout_in_content() -> None:
report = _report(("kernel", "Linux test 6.1", 0))
chunks = chunk_report(report)
assert "Linux test 6.1" in chunks[0].content
def test_chunk_report_includes_exit_code_in_content() -> None:
report = _report(("fail", "error output", 1))
chunks = chunk_report(report)
assert "Exit code: 1" in chunks[0].content
def test_chunk_report_skips_ssh_unreachable_items() -> None:
"""Items with exit 255 and no output represent SSH failures and are dropped."""
report = CollectionReport(
host="test-host",
items=[
CollectedItem(
name="unreachable",
result=SSHCommandResult(
command="some-cmd", exit_code=255, stdout="", stderr=""
),
),
CollectedItem(
name="ok",
result=SSHCommandResult(
command="uname -a", exit_code=0, stdout="Linux", stderr=""
),
),
],
)
chunks = chunk_report(report)
assert len(chunks) == 1
assert chunks[0].name == "ok"
def test_chunk_report_keeps_exit_255_with_output() -> None:
"""Exit 255 with stderr present is a real failure — keep it."""
report = CollectionReport(
host="test-host",
items=[
CollectedItem(
name="partial",
result=SSHCommandResult(
command="some-cmd",
exit_code=255,
stdout="",
stderr="Permission denied",
),
),
],
)
chunks = chunk_report(report)
assert len(chunks) == 1
assert "Permission denied" in chunks[0].content
def test_chunk_report_notes_no_output() -> None:
report = CollectionReport(
host="test-host",
items=[
CollectedItem(
name="silent",
result=SSHCommandResult(command="cmd", exit_code=0, stdout="", stderr=""),
),
],
)
chunks = chunk_report(report)
assert "(no output)" in chunks[0].content
# ---------------------------------------------------------------------------
# _cosine_similarity
# ---------------------------------------------------------------------------
def test_cosine_similarity_identical_vectors() -> None:
v = [1.0, 0.0, 0.0]
assert abs(_cosine_similarity(v, v) - 1.0) < 1e-9
def test_cosine_similarity_orthogonal_vectors() -> None:
a = [1.0, 0.0]
b = [0.0, 1.0]
assert abs(_cosine_similarity(a, b)) < 1e-9
def test_cosine_similarity_opposite_vectors() -> None:
a = [1.0, 0.0]
b = [-1.0, 0.0]
assert abs(_cosine_similarity(a, b) - (-1.0)) < 1e-9
def test_cosine_similarity_zero_vector_returns_zero() -> None:
assert _cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
# ---------------------------------------------------------------------------
# retrieve
# ---------------------------------------------------------------------------
def _embedded(name: str, vec: list[float]) -> EmbeddedChunk:
return EmbeddedChunk(chunk=Chunk(name=name, content=f"content of {name}"), embedding=vec)
def test_retrieve_returns_top_k_by_similarity() -> None:
chunks = [
_embedded("close", [1.0, 0.0]), # most similar
_embedded("mid", [0.7, 0.7]),
_embedded("far", [0.0, 1.0]), # orthogonal to query
]
query = [1.0, 0.0]
result = retrieve(query, chunks, top_k=2)
assert len(result) == 2
assert result[0].name == "close"
assert result[1].name == "mid"
def test_retrieve_respects_top_k_larger_than_pool() -> None:
chunks = [_embedded("only", [1.0, 0.0])]
result = retrieve([1.0, 0.0], chunks, top_k=10)
assert len(result) == 1
def test_retrieve_empty_pool_returns_empty() -> None:
assert retrieve([1.0, 0.0], [], top_k=5) == []
def test_retrieve_top_k_zero_returns_empty() -> None:
chunks = [_embedded("x", [1.0, 0.0])]
assert retrieve([1.0, 0.0], chunks, top_k=0) == []