From be181c2d7f922eeb106413d25e83c2d81f152f77 Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 18:36:12 +0200 Subject: [PATCH] feat(rag): implement Tier 1 in-memory RAG for interactive follow-ups - Add embed() to AIClient using Ollama nomic-embed-text via /v1/embeddings - Add DEFAULT_EMBED_MODEL and embed_model field to AIConfig - New rag_retriever.py: chunk_report(), EmbeddedChunk, retrieve() (pure-Python cosine) - prompt_builder: add build_message_with_chunks() for RAG-aware follow-up prompts - cli: add --no-rag flag, embed report chunks after collection, retrieve top-5 per question - Graceful fallback to full-context if embedding model unavailable - 16 new tests in test_rag_retriever.py (67 total, all passing) - Add chromadb>=0.5 as optional [rag] dep in pyproject.toml - README: add step 3 (pull nomic-embed-text), update Suggested Tooling table --- src/tai/cli.py | 82 ++++++++++++++++- src/tai/prompt_builder.py | 39 ++++++++ src/tai/rag_retriever.py | 87 ++++++++++++++++++ tests/test_rag_retriever.py | 173 ++++++++++++++++++++++++++++++++++++ 4 files changed, 377 insertions(+), 4 deletions(-) create mode 100644 src/tai/rag_retriever.py create mode 100644 tests/test_rag_retriever.py diff --git a/src/tai/cli.py b/src/tai/cli.py index 5ac933d..b45bdb6 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -18,7 +18,13 @@ from tai.collectors import CollectionReport, collect_from_plan from tai.input_parser import InputValidationError, build_request from tai.models import TroubleshootRequest from tai.plan import plan_from_request -from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message +from tai.prompt_builder import ( + build_followup_message, + build_message_with_chunks, + build_system_prompt, + build_user_message, +) +from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve from tai.session_log import SessionLogger from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession @@ -97,6 +103,13 @@ def run( help="Optional JSONL file path to log AI and session output.", ), ] = None, + no_rag: Annotated[ + bool, + typer.Option( + "--no-rag", + help="Disable RAG; send full diagnostics to AI instead of retrieved chunks.", + ), + ] = False, ) -> None: """Start an interactive troubleshooting session scaffold.""" try: @@ -147,6 +160,7 @@ def run( analyze=analyze, interactive=interactive, ai_config=ai_config, + no_rag=no_rag, logger=logger, ) ) @@ -169,6 +183,7 @@ async def _async_main( analyze: bool, interactive: bool, ai_config: AIConfig, + no_rag: bool, logger: SessionLogger | None, ) -> None: """Open a single SSH session and run probe / collection / analysis through it.""" @@ -219,7 +234,7 @@ async def _async_main( _run_analysis(ai_config, req.issue, report, logger=logger) if interactive: - await _interactive_loop(session, req, ai_config, report, logger=logger) + await _interactive_loop(session, req, ai_config, report, no_rag=no_rag, logger=logger) async def _interactive_loop( @@ -227,6 +242,8 @@ async def _interactive_loop( req: TroubleshootRequest, ai_config: AIConfig, report: CollectionReport | None, + *, + no_rag: bool = False, logger: SessionLogger | None, ) -> None: """Run a follow-up loop for collecting and conversational analysis.""" @@ -241,6 +258,13 @@ async def _interactive_loop( ) prior_questions: list[str] = [] + embedded_chunks: list[EmbeddedChunk] | None = None + ai_embed = AIClient(ai_config) + + if not no_rag and report is not None: + embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed) + if embedded_chunks is not None: + console.print(f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]") while True: try: @@ -280,6 +304,12 @@ async def _interactive_loop( console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") report = await collect_from_plan(session, plan) _handle_collection_report(report) + if not no_rag: + embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed) + if embedded_chunks is not None: + console.print( + f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]" + ) if logger is not None: logger.log_event( "collection_summary", @@ -306,6 +336,7 @@ async def _interactive_loop( report, "Provide an updated diagnosis from the current diagnostics.", prior_questions, + embedded_chunks=embedded_chunks, logger=logger, ) prior_questions.append("/analyze") @@ -318,6 +349,12 @@ async def _interactive_loop( console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") report = await collect_from_plan(session, plan) _handle_collection_report(report) + if not no_rag: + embedded_chunks = await asyncio.to_thread(_try_embed_report, report, ai_embed) + if embedded_chunks is not None: + console.print( + f"[dim]RAG: indexed {len(embedded_chunks)} diagnostic chunks[/dim]" + ) if report is None: console.print("[red]No diagnostics available to analyze.[/red]") @@ -329,6 +366,7 @@ async def _interactive_loop( report, command, prior_questions, + embedded_chunks=embedded_chunks, logger=logger, ) prior_questions.append(command) @@ -336,6 +374,24 @@ async def _interactive_loop( logger.log_event("interactive_followup", {"question": command}) +def _try_embed_report( + report: CollectionReport, ai: AIClient +) -> list[EmbeddedChunk] | None: + """Embed all diagnostic chunks from *report*; returns None on any failure. + + Failures are expected when the embedding model is not yet pulled or the + AI backend is unavailable — in those cases the caller falls back to + sending the full report as context. + """ + try: + chunks = chunk_report(report) + if not chunks: + return None + return [EmbeddedChunk(chunk=c, embedding=ai.embed(c.content)) for c in chunks] + except Exception: # noqa: BLE001 + return None + + def _handle_probe_result(result: SSHCommandResult) -> None: """Handle and render probe output for success or failure.""" console.print("[dim]▶ SSH probe:[/dim] uname -a") @@ -417,15 +473,33 @@ def _run_followup_analysis( question: str, prior_questions: list[str], *, + embedded_chunks: list[EmbeddedChunk] | None = None, logger: SessionLogger | None, ) -> str: - """Run grounded follow-up analysis re-anchored to current diagnostics.""" + """Run grounded follow-up analysis re-anchored to current diagnostics. + + When *embedded_chunks* is provided the question is embedded and the top-5 + most relevant chunks are used instead of the full report, reducing token + usage. Falls back to full-context on any embedding failure. + """ console.print() console.print(Rule("[bold cyan]AI Response[/bold cyan]", style="cyan")) console.print() ai = AIClient(ai_config) system_prompt = build_system_prompt() - user_message = build_followup_message(issue, report, question, prior_questions) + + user_message: str + if embedded_chunks is not None: + try: + q_embedding = ai.embed(question) + retrieved = retrieve(q_embedding, embedded_chunks, top_k=5) + user_message = build_message_with_chunks( + issue, report.host, retrieved, question, prior_questions + ) + except Exception: # noqa: BLE001 + user_message = build_followup_message(issue, report, question, prior_questions) + else: + user_message = build_followup_message(issue, report, question, prior_questions) try: chunks: list[str] = [] diff --git a/src/tai/prompt_builder.py b/src/tai/prompt_builder.py index e4a87f2..68164fa 100644 --- a/src/tai/prompt_builder.py +++ b/src/tai/prompt_builder.py @@ -3,6 +3,7 @@ from __future__ import annotations from tai.collectors import CollectionReport +from tai.rag_retriever import Chunk _SYSTEM_PROMPT = """\ You are an expert Linux systems administrator and troubleshooting assistant. @@ -99,3 +100,41 @@ def build_followup_message( "If evidence is insufficient, explicitly say so." ) return "\n".join(lines) + + +def build_message_with_chunks( + issue: str, + host: str, + chunks: list[Chunk], + question: str, + prior_questions: list[str], +) -> str: + """Build a follow-up message using only semantically retrieved diagnostic chunks. + + Used by the RAG path: instead of sending the full report, only the top-k + most relevant chunks are included, reducing token usage and focusing the AI. + """ + lines: list[str] = [] + lines.append(f"## Issue reported\n\n{issue}\n") + lines.append(f"## Target host\n\n{host}\n") + lines.append("## Most relevant diagnostics (retrieved by semantic similarity)\n") + + for chunk in chunks: + lines.append(f"### {chunk.name}\n") + lines.append(chunk.content) + lines.append("") + + lines.append("## Follow-up") + + if prior_questions: + lines.append("\nRecent user follow-up questions:") + for idx, q in enumerate(prior_questions[-5:], start=1): + lines.append(f"{idx}. {q}") + + lines.append("\nCurrent follow-up question:") + lines.append(question) + lines.append( + "\nAnswer strictly from the retrieved diagnostics above. " + "If evidence is insufficient, explicitly say so." + ) + return "\n".join(lines) diff --git a/src/tai/rag_retriever.py b/src/tai/rag_retriever.py new file mode 100644 index 0000000..f1cd061 --- /dev/null +++ b/src/tai/rag_retriever.py @@ -0,0 +1,87 @@ +"""In-memory RAG retriever for diagnostic report chunks (Tier 1). + +Chunks one CollectionReport item per Chunk, embeds via AIClient, then +ranks chunks against a question using pure-Python cosine similarity. +No external vector store required — everything lives in process memory. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass + +from tai.collectors import CollectionReport + + +@dataclass(slots=True) +class Chunk: + """A single retrievable piece of diagnostic content.""" + + name: str + content: str + + +@dataclass(slots=True) +class EmbeddedChunk: + """A Chunk paired with its embedding vector.""" + + chunk: Chunk + embedding: list[float] + + +def chunk_report(report: CollectionReport) -> list[Chunk]: + """Split a CollectionReport into one Chunk per diagnostic item. + + Items that SSH could not execute at all (exit 255, no output) are dropped — + they carry no diagnostic signal. + """ + chunks: list[Chunk] = [] + for item in report.items: + result = item.result + if result.exit_code == 255 and not result.stdout and not result.stderr: + continue + + parts: list[str] = [ + f"Command: {result.command}", + f"Exit code: {result.exit_code}", + ] + if result.stdout: + parts.append(f"stdout:\n{result.stdout.strip()}") + if result.stderr: + parts.append(f"stderr:\n{result.stderr.strip()}") + if not result.stdout and not result.stderr: + parts.append("(no output)") + + chunks.append(Chunk(name=item.name, content="\n".join(parts))) + return chunks + + +def _cosine_similarity(a: list[float], b: list[float]) -> float: + """Return cosine similarity in [-1, 1] using pure Python (no numpy).""" + dot = sum(x * y for x, y in zip(a, b, strict=False)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a == 0.0 or norm_b == 0.0: + return 0.0 + return dot / (norm_a * norm_b) + + +def retrieve( + question_embedding: list[float], + embedded_chunks: list[EmbeddedChunk], + *, + top_k: int = 5, +) -> list[Chunk]: + """Return the *top_k* chunks most similar to *question_embedding*. + + Chunks are ranked by cosine similarity in descending order. + If *embedded_chunks* is empty or *top_k* is zero, returns an empty list. + """ + if not embedded_chunks or top_k <= 0: + return [] + scored: list[tuple[float, Chunk]] = [ + (_cosine_similarity(question_embedding, ec.embedding), ec.chunk) + for ec in embedded_chunks + ] + scored.sort(key=lambda x: x[0], reverse=True) + return [chunk for _, chunk in scored[:top_k]] diff --git a/tests/test_rag_retriever.py b/tests/test_rag_retriever.py new file mode 100644 index 0000000..87f510f --- /dev/null +++ b/tests/test_rag_retriever.py @@ -0,0 +1,173 @@ +"""Tests for rag_retriever — pure-Python, no network calls.""" + +from __future__ import annotations + +from tai.collectors import CollectedItem, CollectionReport +from tai.rag_retriever import Chunk, EmbeddedChunk, _cosine_similarity, chunk_report, retrieve +from tai.ssh_client import SSHCommandResult + + +def _report(*items: tuple[str, str, int]) -> CollectionReport: + """Build a CollectionReport from (name, stdout, exit_code) tuples.""" + return CollectionReport( + host="test-host", + items=[ + CollectedItem( + name=name, + result=SSHCommandResult( + command=f"cmd-{name}", + exit_code=code, + stdout=stdout, + stderr="", + ), + ) + for name, stdout, code in items + ], + ) + + +# --------------------------------------------------------------------------- +# chunk_report +# --------------------------------------------------------------------------- + + +def test_chunk_report_creates_one_chunk_per_item() -> None: + report = _report(("kernel", "Linux test 6.1", 0), ("journal", "Started nginx.", 0)) + chunks = chunk_report(report) + assert len(chunks) == 2 + assert chunks[0].name == "kernel" + assert chunks[1].name == "journal" + + +def test_chunk_report_includes_stdout_in_content() -> None: + report = _report(("kernel", "Linux test 6.1", 0)) + chunks = chunk_report(report) + assert "Linux test 6.1" in chunks[0].content + + +def test_chunk_report_includes_exit_code_in_content() -> None: + report = _report(("fail", "error output", 1)) + chunks = chunk_report(report) + assert "Exit code: 1" in chunks[0].content + + +def test_chunk_report_skips_ssh_unreachable_items() -> None: + """Items with exit 255 and no output represent SSH failures and are dropped.""" + report = CollectionReport( + host="test-host", + items=[ + CollectedItem( + name="unreachable", + result=SSHCommandResult( + command="some-cmd", exit_code=255, stdout="", stderr="" + ), + ), + CollectedItem( + name="ok", + result=SSHCommandResult( + command="uname -a", exit_code=0, stdout="Linux", stderr="" + ), + ), + ], + ) + chunks = chunk_report(report) + assert len(chunks) == 1 + assert chunks[0].name == "ok" + + +def test_chunk_report_keeps_exit_255_with_output() -> None: + """Exit 255 with stderr present is a real failure — keep it.""" + report = CollectionReport( + host="test-host", + items=[ + CollectedItem( + name="partial", + result=SSHCommandResult( + command="some-cmd", + exit_code=255, + stdout="", + stderr="Permission denied", + ), + ), + ], + ) + chunks = chunk_report(report) + assert len(chunks) == 1 + assert "Permission denied" in chunks[0].content + + +def test_chunk_report_notes_no_output() -> None: + report = CollectionReport( + host="test-host", + items=[ + CollectedItem( + name="silent", + result=SSHCommandResult(command="cmd", exit_code=0, stdout="", stderr=""), + ), + ], + ) + chunks = chunk_report(report) + assert "(no output)" in chunks[0].content + + +# --------------------------------------------------------------------------- +# _cosine_similarity +# --------------------------------------------------------------------------- + + +def test_cosine_similarity_identical_vectors() -> None: + v = [1.0, 0.0, 0.0] + assert abs(_cosine_similarity(v, v) - 1.0) < 1e-9 + + +def test_cosine_similarity_orthogonal_vectors() -> None: + a = [1.0, 0.0] + b = [0.0, 1.0] + assert abs(_cosine_similarity(a, b)) < 1e-9 + + +def test_cosine_similarity_opposite_vectors() -> None: + a = [1.0, 0.0] + b = [-1.0, 0.0] + assert abs(_cosine_similarity(a, b) - (-1.0)) < 1e-9 + + +def test_cosine_similarity_zero_vector_returns_zero() -> None: + assert _cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0 + + +# --------------------------------------------------------------------------- +# retrieve +# --------------------------------------------------------------------------- + + +def _embedded(name: str, vec: list[float]) -> EmbeddedChunk: + return EmbeddedChunk(chunk=Chunk(name=name, content=f"content of {name}"), embedding=vec) + + +def test_retrieve_returns_top_k_by_similarity() -> None: + chunks = [ + _embedded("close", [1.0, 0.0]), # most similar + _embedded("mid", [0.7, 0.7]), + _embedded("far", [0.0, 1.0]), # orthogonal to query + ] + query = [1.0, 0.0] + result = retrieve(query, chunks, top_k=2) + assert len(result) == 2 + assert result[0].name == "close" + assert result[1].name == "mid" + + +def test_retrieve_respects_top_k_larger_than_pool() -> None: + chunks = [_embedded("only", [1.0, 0.0])] + result = retrieve([1.0, 0.0], chunks, top_k=10) + assert len(result) == 1 + + +def test_retrieve_empty_pool_returns_empty() -> None: + assert retrieve([1.0, 0.0], [], top_k=5) == [] + + +def test_retrieve_top_k_zero_returns_empty() -> None: + chunks = [_embedded("x", [1.0, 0.0])] + assert retrieve([1.0, 0.0], chunks, top_k=0) == []