From d5e1822644d375812af4d42c7944a7740fb21c42 Mon Sep 17 00:00:00 2001 From: zphinx Date: Wed, 6 May 2026 05:02:38 +0200 Subject: [PATCH] update --- src/tai/cli.py | 124 ++++++++++++++++++++++++++++++++---- src/tai/plan.py | 92 +++++++++++++++++++++++++- src/tai/prompt_builder.py | 47 +++++++++++++- src/tai/session_store.py | 105 ++++++++++++++++++++++++++++++ src/tai/ssh_client.py | 2 + tests/test_ai.py | 19 ++++++ tests/test_plan.py | 18 ++++++ tests/test_session_store.py | 79 +++++++++++++++++++++++ tests/test_ssh_client.py | 2 + 9 files changed, 473 insertions(+), 15 deletions(-) create mode 100644 src/tai/session_store.py create mode 100644 tests/test_session_store.py diff --git a/src/tai/cli.py b/src/tai/cli.py index b6eb5a3..8dbd50e 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -30,6 +30,7 @@ from tai.prompt_builder import ( from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored from tai.runbook_store import RunbookChunk, RunbookStore from tai.session_log import SessionLogger +from tai.session_store import PastSession, SessionStore from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession app = typer.Typer(no_args_is_help=True, add_completion=False) @@ -151,6 +152,16 @@ def run( help="Path to a synced runbook ChromaDB store. Enables Tier 2 RAG.", ), ] = None, + session_memory_path: Annotated[ + str | None, + typer.Option( + "--session-memory", + help=( + "Path to persistent session memory store for prior-session retrieval " + "(Tier 4). Omit to disable." + ), + ), + ] = None, ) -> None: """Start an interactive troubleshooting session scaffold.""" try: @@ -207,6 +218,17 @@ def run( except Exception as exc: # noqa: BLE001 console.print(f"[yellow]Runbook store unavailable:[/yellow] {exc}") + session_store: SessionStore | None = None + if session_memory_path: + try: + session_store = SessionStore(session_memory_path) + mem_count = session_store.count() + console.print( + f"[dim]Session memory: {mem_count} indexed at {session_memory_path}[/dim]" + ) + except Exception as exc: # noqa: BLE001 + console.print(f"[yellow]Session memory unavailable:[/yellow] {exc}") + try: asyncio.run( _async_main( @@ -220,6 +242,7 @@ def run( no_rag=no_rag, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) ) @@ -245,6 +268,7 @@ async def _async_main( no_rag: bool, rag_debug: bool, runbook_store: RunbookStore | None, + session_store: SessionStore | None, logger: SessionLogger | None, ) -> None: """Open a single SSH session and run probe / collection / analysis through it.""" @@ -291,19 +315,22 @@ async def _async_main( }, ) + initial_response: str | None = None if analyze and report is not None: - _run_analysis( + initial_response = _run_analysis( ai_config, req.issue, report, no_rag=no_rag, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) + interactive_response: str | None = None if interactive: - await _interactive_loop( + interactive_response = await _interactive_loop( session, req, ai_config, @@ -311,9 +338,14 @@ async def _async_main( no_rag=no_rag, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) + final_response = interactive_response or initial_response + if session_store is not None and final_response: + _index_session_memory(session_store, ai_config, req, final_response, logger=logger) + async def _interactive_loop( session: SSHSession, @@ -324,8 +356,9 @@ async def _interactive_loop( no_rag: bool = False, rag_debug: bool = False, runbook_store: RunbookStore | None = None, + session_store: SessionStore | None = None, logger: SessionLogger | None, -) -> None: +) -> str | None: """Run a follow-up loop for collecting and conversational analysis.""" console.print( Panel( @@ -340,6 +373,7 @@ async def _interactive_loop( prior_questions: list[str] = [] embedded_chunks: list[EmbeddedChunk] | None = None ai_embed = AIClient(ai_config) + last_response: str | None = None if not no_rag and report is not None: embedded_chunks, index_error, index_ms = await asyncio.to_thread( @@ -384,7 +418,7 @@ async def _interactive_loop( console.print("\n[yellow]Exiting interactive mode.[/yellow]") if logger is not None: logger.log_event("interactive_exit", {"reason": "signal_or_eof"}) - return + return last_response if not command: continue @@ -393,7 +427,7 @@ async def _interactive_loop( console.print("[green]Bye.[/green]") if logger is not None: logger.log_event("interactive_exit", {"reason": "user_quit"}) - return + return last_response if command == "/help": console.print( @@ -466,7 +500,7 @@ async def _interactive_loop( console.print("[red]No diagnostics available to analyze.[/red]") continue - _run_followup_analysis( + response = _run_followup_analysis( ai_config, req.issue, report, @@ -475,11 +509,14 @@ async def _interactive_loop( embedded_chunks=embedded_chunks, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) prior_questions.append("/analyze") if logger is not None: logger.log_event("interactive_followup", {"question": "/analyze"}) + last_response = response + continue continue if report is None: @@ -523,7 +560,7 @@ async def _interactive_loop( console.print("[red]No diagnostics available to analyze.[/red]") continue - _run_followup_analysis( + response = _run_followup_analysis( ai_config, req.issue, report, @@ -532,11 +569,13 @@ async def _interactive_loop( embedded_chunks=embedded_chunks, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) prior_questions.append(command) if logger is not None: logger.log_event("interactive_followup", {"question": command}) + last_response = response def _try_embed_report( @@ -597,8 +636,9 @@ def _run_analysis( no_rag: bool = False, rag_debug: bool = False, runbook_store: RunbookStore | None = None, + session_store: SessionStore | None = None, logger: SessionLogger | None, -) -> None: +) -> str: """Send collected data to the AI and stream the analysis to stdout.""" console.print() console.print(Rule("[bold cyan]Analysis[/bold cyan]", style="cyan")) @@ -606,10 +646,16 @@ def _run_analysis( ai = AIClient(ai_config) system_prompt = build_system_prompt() runbook_chunks = _query_runbooks(runbook_store, issue, ai, top_k=1) + past_sessions = _query_sessions(session_store, issue, report.host, ai, top_k=2) user_message: str if no_rag: - user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None) + user_message = build_user_message( + issue, + report, + runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, + ) else: try: chunks = chunk_report(report) @@ -628,16 +674,28 @@ def _run_analysis( report.host, selected, runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, ) else: - user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None) + user_message = build_user_message( + issue, + report, + runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, + ) except Exception as exc: # noqa: BLE001 console.print( - "[yellow]RAG unavailable for initial analysis; using full-context fallback.[/yellow]" + "[yellow]RAG unavailable for initial analysis; " + "using full-context fallback.[/yellow]" ) if logger is not None: logger.log_event("rag_index", {"status": "fallback", "error": str(exc)}) - user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None) + user_message = build_user_message( + issue, + report, + runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, + ) try: response = _complete_ai_response( ai, @@ -662,6 +720,7 @@ def _run_analysis( "guardrail_warnings": warnings, }, ) + return response except Exception as exc: # noqa: BLE001 console.print(f"[red]AI analysis failed:[/red] {exc}") if logger is not None: @@ -688,6 +747,7 @@ def _run_followup_analysis( embedded_chunks: list[EmbeddedChunk] | None = None, rag_debug: bool = False, runbook_store: RunbookStore | None = None, + session_store: SessionStore | None = None, logger: SessionLogger | None, ) -> str: """Run grounded follow-up analysis re-anchored to current diagnostics. @@ -702,6 +762,7 @@ def _run_followup_analysis( ai = AIClient(ai_config) system_prompt = build_system_prompt() runbook_chunks = _query_runbooks(runbook_store, question, ai, top_k=1) + past_sessions = _query_sessions(session_store, question, report.host, ai, top_k=2) user_message: str retrieved_names: list[str] = [] @@ -724,6 +785,7 @@ def _run_followup_analysis( question, prior_questions, runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, ) if rag_debug: pairs = ", ".join( @@ -741,12 +803,14 @@ def _run_followup_analysis( user_message = build_followup_message( issue, report, question, prior_questions, runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, ) else: fallback_reason = "rag not indexed" user_message = build_followup_message( issue, report, question, prior_questions, runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, ) if logger is not None: @@ -826,6 +890,42 @@ def _query_runbooks( return [] +def _query_sessions( + store: SessionStore | None, + question: str, + host: str, + ai: AIClient, + *, + top_k: int = 2, +) -> list[PastSession]: + """Query the session memory store silently; returns empty list on failures.""" + if store is None: + return [] + try: + return store.query(question, host, ai, top_k=top_k) + except Exception: # noqa: BLE001 + return [] + + +def _index_session_memory( + store: SessionStore, + ai_config: AIConfig, + req: TroubleshootRequest, + summary: str, + *, + logger: SessionLogger | None, +) -> None: + """Persist final session summary for future retrieval; non-fatal on failure.""" + try: + ai = AIClient(ai_config) + session_id = store.index_session(req.host, req.issue, summary, ai) + if logger is not None: + logger.log_event("session_memory_indexed", {"session_id": session_id}) + except Exception as exc: # noqa: BLE001 + if logger is not None: + logger.log_event("session_memory_error", {"error": str(exc)}) + + # --------------------------------------------------------------------------- # runbooks sub-app # --------------------------------------------------------------------------- diff --git a/src/tai/plan.py b/src/tai/plan.py index c6d6701..f75bfc7 100644 --- a/src/tai/plan.py +++ b/src/tai/plan.py @@ -149,6 +149,37 @@ _SERVICE_BINARIES: dict[str, list[str]] = { "apparmor": ["/usr/sbin/aa-status", "/sbin/apparmor_parser"], } +_SERVICE_PACKAGES: dict[str, list[str]] = { + "docker": ["docker", "docker-ce"], + "sssd": ["sssd"], + "sshd": ["openssh-server", "openssh"], + "x2go": ["x2goserver", "x2goserver-xsession"], + "xorg": ["xorg-server", "xserver-xorg-core"], + "wayland": ["wayland", "xwayland"], + "selinux": ["selinux-policy", "selinux-policy-targeted"], + "apparmor": ["apparmor"], +} + +_GENERIC_SERVICE_STOPWORDS: frozenset[str] = frozenset( + { + "a", + "an", + "and", + "app", + "application", + "daemon", + "for", + "is", + "my", + "not", + "service", + "systemd", + "the", + "unit", + "working", + } +) + # --------------------------------------------------------------------------- # Command sets # --------------------------------------------------------------------------- @@ -225,6 +256,9 @@ def plan_from_request(request: TroubleshootRequest) -> CollectionPlan: ) for idx, binary_path in enumerate(_SERVICE_BINARIES.get(svc, []), start=1): plan.add(f"binary-{svc}-{idx}", f"ls -l {binary_path}") + for idx, package_name in enumerate(_service_package_candidates(svc), start=1): + plan.add(f"package-rpm-{svc}-{idx}", f"rpm -q {package_name}") + plan.add(f"package-dpkg-{svc}-{idx}", f"dpkg-query -W {package_name}") plan.add(f"service-{svc}", f"systemctl status {svc}") plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager") for cfg_path in _SERVICE_CONFIGS.get(svc, []): @@ -258,7 +292,11 @@ def _issue_words(issue: str) -> set[str]: def _extract_services(issue: str) -> list[str]: - """Return known service names mentioned in *issue*.""" + """Return service candidates mentioned in *issue*. + + Includes known services plus generic service-like tokens near words such + as "service", "daemon", and "unit". + """ words = _issue_words(issue) found: list[str] = [] for svc in _KNOWN_SERVICES: @@ -266,6 +304,58 @@ def _extract_services(issue: str) -> list[str]: svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")} if words & svc_words: found.append(svc) + for svc in _extract_generic_service_candidates(issue): + if svc not in found: + found.append(svc) return found +def _extract_generic_service_candidates(issue: str) -> list[str]: + """Extract likely service names from free text even when not pre-registered.""" + tokens = [tok.lower() for tok in re.findall(r"[a-zA-Z0-9_.@-]+", issue)] + if not tokens: + return [] + + candidates: list[str] = [] + for idx, token in enumerate(tokens): + normalized = token[:-8] if token.endswith(".service") else token + if _is_safe_service_name(normalized): + if token.endswith(".service") and normalized not in _GENERIC_SERVICE_STOPWORDS: + candidates.append(normalized) + + if token in {"service", "daemon", "unit"}: + for neighbor in (idx - 1, idx + 1): + if neighbor < 0 or neighbor >= len(tokens): + continue + candidate = tokens[neighbor] + if candidate.endswith(".service"): + candidate = candidate[:-8] + if candidate in _GENERIC_SERVICE_STOPWORDS: + continue + if _is_safe_service_name(candidate): + candidates.append(candidate) + + deduped: list[str] = [] + seen: set[str] = set() + for item in candidates: + if item in seen: + continue + seen.add(item) + deduped.append(item) + return deduped + + +def _is_safe_service_name(name: str) -> bool: + """Return True when *name* is safe to interpolate into read-only commands.""" + if len(name) < 2 or len(name) > 64: + return False + return re.fullmatch(r"[a-z0-9_.@-]+", name) is not None + + +def _service_package_candidates(service: str) -> list[str]: + """Return package names to probe for *service* presence.""" + if service in _SERVICE_PACKAGES: + return _SERVICE_PACKAGES[service] + return [service] + + diff --git a/src/tai/prompt_builder.py b/src/tai/prompt_builder.py index ede0607..35ffcd7 100644 --- a/src/tai/prompt_builder.py +++ b/src/tai/prompt_builder.py @@ -5,6 +5,7 @@ from __future__ import annotations from tai.collectors import CollectionReport from tai.rag_retriever import Chunk from tai.runbook_store import RunbookChunk +from tai.session_store import PastSession _SYSTEM_PROMPT = """\ You are an expert Linux systems administrator and troubleshooting assistant. @@ -21,7 +22,8 @@ Important rules: - If a command shows "could not be executed (SSH error)" it means the remote host blocked or rejected that specific command — it is not evidence about the service or system state. - If service presence checks show a unit, binary, package, or config is missing, treat that as - evidence the component may be absent or not installed, not as proof that the component is broken. + evidence the component may be absent or not installed, not as proof that the + component is broken. - If there is not enough data to diagnose the issue, say so plainly and list exactly what additional commands or log files would be needed. - Keep the response short. Skip sections that have nothing useful to say. @@ -33,6 +35,7 @@ Important rules: _MAX_RUNBOOK_CHARS = 500 _MAX_DIAGNOSTIC_CHUNK_CHARS = 700 +_MAX_SESSION_SUMMARY_CHARS = 500 def build_system_prompt() -> str: @@ -66,11 +69,34 @@ def _format_diagnostic_chunk(content: str) -> str: return text[:_MAX_DIAGNOSTIC_CHUNK_CHARS].rstrip() + "\n...[truncated diagnostic context]" +def _format_session_context(past_sessions: list[PastSession]) -> str: + """Format similar prior sessions as compact grounding context.""" + lines: list[str] = ["## Similar prior sessions\n"] + lines.append( + "The following completed sessions were semantically similar. " + "Use them as historical hints, but prioritize current diagnostics if they conflict.\n" + ) + for sess in past_sessions: + summary = sess.summary.strip() + if len(summary) > _MAX_SESSION_SUMMARY_CHARS: + summary = ( + summary[:_MAX_SESSION_SUMMARY_CHARS].rstrip() + + "\n...[truncated session summary]" + ) + lines.append(f"### Session: {sess.session_id} (host={sess.host})\n") + lines.append(f"**Issue:** {sess.issue}") + lines.append("") + lines.append(summary) + lines.append("") + return "\n".join(lines) + + def build_user_message( issue: str, report: CollectionReport, *, runbook_chunks: list[RunbookChunk] | None = None, + past_sessions: list[PastSession] | None = None, ) -> str: """Format *issue* and *report* into the user message sent to the AI.""" lines: list[str] = [] @@ -81,6 +107,9 @@ def build_user_message( if runbook_chunks: lines.append(_format_runbook_context(runbook_chunks)) + if past_sessions: + lines.append(_format_session_context(past_sessions)) + lines.append("## Collected diagnostics\n") skipped: list[str] = [] @@ -126,9 +155,15 @@ def build_followup_message( prior_questions: list[str], *, runbook_chunks: list[RunbookChunk] | None = None, + past_sessions: list[PastSession] | None = None, ) -> str: """Build a grounded follow-up message that re-anchors to diagnostics each turn.""" - base = build_user_message(issue, report, runbook_chunks=runbook_chunks) + base = build_user_message( + issue, + report, + runbook_chunks=runbook_chunks, + past_sessions=past_sessions, + ) lines: list[str] = [base, "## Follow-up"] if prior_questions: @@ -157,6 +192,7 @@ def build_message_with_chunks( prior_questions: list[str], *, runbook_chunks: list[RunbookChunk] | None = None, + past_sessions: list[PastSession] | None = None, ) -> str: """Build a follow-up message using only semantically retrieved diagnostic chunks. @@ -178,6 +214,9 @@ def build_message_with_chunks( if runbook_chunks: lines.append(_format_runbook_context(runbook_chunks)) + if past_sessions: + lines.append(_format_session_context(past_sessions)) + lines.append("## Follow-up") if prior_questions: @@ -204,6 +243,7 @@ def build_analysis_message_with_chunks( chunks: list[Chunk], *, runbook_chunks: list[RunbookChunk] | None = None, + past_sessions: list[PastSession] | None = None, ) -> str: """Build an initial analysis message from retrieved diagnostic chunks.""" lines: list[str] = [] @@ -213,6 +253,9 @@ def build_analysis_message_with_chunks( if runbook_chunks: lines.append(_format_runbook_context(runbook_chunks)) + if past_sessions: + lines.append(_format_session_context(past_sessions)) + lines.append("## Most relevant diagnostics (retrieved by semantic similarity)\n") for chunk in chunks: lines.append(f"### {chunk.name}\n") diff --git a/src/tai/session_store.py b/src/tai/session_store.py new file mode 100644 index 0000000..a2d3564 --- /dev/null +++ b/src/tai/session_store.py @@ -0,0 +1,105 @@ +"""Persistent session memory store (Tier 4) backed by ChromaDB.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tai.ai_client import AIClient + +DEFAULT_SESSION_STORE_PATH = "~/.tai/sessions" +_COLLECTION_NAME = "tai_sessions" + + +@dataclass(slots=True) +class PastSession: + """A retrieved prior session summary for prompt grounding.""" + + session_id: str + host: str + issue: str + summary: str + + +class SessionStore: + """ChromaDB-backed persistent memory for prior troubleshooting sessions.""" + + def __init__(self, store_path: str | Path = DEFAULT_SESSION_STORE_PATH) -> None: + import chromadb + + path = Path(store_path).expanduser().resolve() + path.mkdir(parents=True, exist_ok=True) + settings = None + try: + from chromadb.config import Settings + + settings = Settings( + anonymized_telemetry=False, + chroma_product_telemetry_impl="tai.chroma_telemetry.NoOpProductTelemetryClient", + chroma_telemetry_impl="tai.chroma_telemetry.NoOpProductTelemetryClient", + ) + except (ImportError, ModuleNotFoundError): + settings = None + + if settings is None: + self._client = chromadb.PersistentClient(path=str(path)) + else: + self._client = chromadb.PersistentClient(path=str(path), settings=settings) + self._collection = self._client.get_or_create_collection( + name=_COLLECTION_NAME, + metadata={"hnsw:space": "cosine"}, + ) + + def count(self) -> int: + """Return number of indexed session summaries.""" + return self._collection.count() + + def index_session(self, host: str, issue: str, summary: str, ai: AIClient) -> str: + """Embed and upsert one session summary into persistent storage.""" + session_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + embed_text = _build_embed_text(host=host, issue=issue, summary=summary) + embedding = ai.embed(embed_text) + self._collection.upsert( + ids=[session_id], + documents=[summary.strip()], + embeddings=[embedding], + metadatas=[{"host": host, "issue": issue}], + ) + return session_id + + def query(self, question: str, host: str, ai: AIClient, *, top_k: int = 2) -> list[PastSession]: + """Return top-k semantically similar sessions for this host and question.""" + if self._collection.count() == 0: + return [] + + q_embedding = ai.embed(f"host: {host}\nquestion: {question}") + results = self._collection.query( + query_embeddings=[q_embedding], + n_results=min(top_k, self._collection.count()), + include=["documents", "metadatas"], + ) + + sessions: list[PastSession] = [] + ids = results.get("ids") or [] + docs = results.get("documents") or [] + metas = results.get("metadatas") or [] + for id_list, doc_list, meta_list in zip(ids, docs, metas, strict=False): + for sid, doc, meta in zip(id_list, doc_list, meta_list, strict=False): + sessions.append( + PastSession( + session_id=str(sid), + host=str(meta.get("host", "")), + issue=str(meta.get("issue", "")), + summary=str(doc), + ) + ) + return sessions + + +def _build_embed_text(*, host: str, issue: str, summary: str) -> str: + """Build embedding text with host/issue context and summary excerpt.""" + excerpt = summary.strip()[:1000] + return f"host: {host}\nissue: {issue}\nsummary:\n{excerpt}" diff --git a/src/tai/ssh_client.py b/src/tai/ssh_client.py index 6235010..0afc974 100644 --- a/src/tai/ssh_client.py +++ b/src/tai/ssh_client.py @@ -51,6 +51,7 @@ class SSHClient: _READ_ONLY_COMMANDS = { "cat", "dmesg", + "dpkg-query", "df", "du", "find", @@ -61,6 +62,7 @@ class SSHClient: "journalctl", "ls", "netstat", + "rpm", "sed", "ss", "stat", diff --git a/tests/test_ai.py b/tests/test_ai.py index f37de7d..b904e57 100644 --- a/tests/test_ai.py +++ b/tests/test_ai.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.collectors import CollectedItem, CollectionReport from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message +from tai.session_store import PastSession from tai.ssh_client import SSHCommandResult # --------------------------------------------------------------------------- @@ -221,6 +222,24 @@ def test_build_user_message_handles_no_output() -> None: assert "no output" in msg +def test_build_user_message_includes_prior_session_context() -> None: + report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")]) + msg = build_user_message( + "sssd broken", + report, + past_sessions=[ + PastSession( + session_id="20260506T120000Z", + host="web01", + issue="sssd broken", + summary="Root cause was missing sssd package.", + ) + ], + ) + assert "Similar prior sessions" in msg + assert "missing sssd package" in msg + + def test_build_followup_message_includes_question_context() -> None: report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")]) msg = build_followup_message( diff --git a/tests/test_plan.py b/tests/test_plan.py index 4fdf556..cc2658b 100644 --- a/tests/test_plan.py +++ b/tests/test_plan.py @@ -107,8 +107,12 @@ def test_sssd_in_issue_adds_presence_service_and_config_commands() -> None: assert "binary-sssd-1" in names assert "service-sssd" in names assert "journal-sssd" in names + assert "package-rpm-sssd-1" in names + assert "package-dpkg-sssd-1" in names assert any("cat /etc/sssd/sssd.conf" in c for c in cmds) assert any("ls -l /usr/sbin/sssd" in c for c in cmds) + assert any("rpm -q sssd" in c for c in cmds) + assert any("dpkg-query -W sssd" in c for c in cmds) assert any("list-unit-files sssd.service" in c for c in cmds) @@ -119,8 +123,12 @@ def test_docker_presence_probe_checks_package_and_binary() -> None: assert "unit-file-docker" in names assert "binary-docker-1" in names assert "binary-docker-2" in names + assert "package-rpm-docker-1" in names + assert "package-dpkg-docker-1" in names assert any("ls -l /usr/bin/docker" in c for c in cmds) assert any("ls -l /usr/bin/dockerd" in c for c in cmds) + assert any("rpm -q docker" in c for c in cmds) + assert any("dpkg-query -W docker" in c for c in cmds) def test_unknown_service_name_no_config_cat() -> None: @@ -183,6 +191,16 @@ def test_extract_services_case_insensitive() -> None: assert "nginx" in _extract_services("NGINX failed") +def test_extract_services_detects_generic_service_name() -> None: + services = _extract_services("myweirdapp service keeps failing") + assert "myweirdapp" in services + + +def test_extract_services_detects_dot_service_pattern() -> None: + services = _extract_services("please check foobar.service on this host") + assert "foobar" in services + + # --------------------------------------------------------------------------- # Plan length sanity # --------------------------------------------------------------------------- diff --git a/tests/test_session_store.py b/tests/test_session_store.py new file mode 100644 index 0000000..e66ed41 --- /dev/null +++ b/tests/test_session_store.py @@ -0,0 +1,79 @@ +"""Tests for session_store with mocked ChromaDB.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from tai.session_store import PastSession, SessionStore, _build_embed_text + + +def _make_chromadb_mock() -> MagicMock: + collection = MagicMock() + collection.count.return_value = 0 + client = MagicMock() + client.get_or_create_collection.return_value = collection + chroma_mod = MagicMock() + chroma_mod.PersistentClient.return_value = client + return chroma_mod + + +def _make_ai_mock(embedding: list[float] | None = None) -> MagicMock: + ai = MagicMock() + ai.embed.return_value = embedding or [0.1, 0.2, 0.3] + return ai + + +def test_build_embed_text_contains_host_issue_and_summary() -> None: + text = _build_embed_text(host="web01", issue="sssd broken", summary="Unit missing") + assert "host: web01" in text + assert "issue: sssd broken" in text + assert "Unit missing" in text + + +def test_index_session_upserts_with_metadata(tmp_path: Path) -> None: + chroma_mock = _make_chromadb_mock() + collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value + ai = _make_ai_mock() + + with patch.dict("sys.modules", {"chromadb": chroma_mock}): + store = SessionStore(tmp_path / "store") + session_id = store.index_session("web01", "sssd broken", "summary text", ai) + + assert session_id + collection.upsert.assert_called_once() + args = collection.upsert.call_args.kwargs + assert args["metadatas"][0]["host"] == "web01" + assert args["metadatas"][0]["issue"] == "sssd broken" + + +def test_query_returns_empty_when_no_docs(tmp_path: Path) -> None: + chroma_mock = _make_chromadb_mock() + ai = _make_ai_mock() + + with patch.dict("sys.modules", {"chromadb": chroma_mock}): + store = SessionStore(tmp_path / "store") + results = store.query("why sssd", "web01", ai) + + assert results == [] + + +def test_query_returns_past_sessions(tmp_path: Path) -> None: + chroma_mock = _make_chromadb_mock() + collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value + collection.count.return_value = 1 + collection.query.return_value = { + "ids": [["20260506T120000Z"]], + "documents": [["Root cause: package missing"]], + "metadatas": [[{"host": "web01", "issue": "sssd broken"}]], + } + ai = _make_ai_mock() + + with patch.dict("sys.modules", {"chromadb": chroma_mock}): + store = SessionStore(tmp_path / "store") + results = store.query("sssd issue", "web01", ai) + + assert len(results) == 1 + assert isinstance(results[0], PastSession) + assert results[0].host == "web01" + assert "package missing" in results[0].summary diff --git a/tests/test_ssh_client.py b/tests/test_ssh_client.py index fcad417..6b68db0 100644 --- a/tests/test_ssh_client.py +++ b/tests/test_ssh_client.py @@ -82,6 +82,8 @@ def test_allows_expected_read_only_commands() -> None: "systemctl status apache2", "cat /etc/hosts", "ss -lntp", + "rpm -q sssd", + "dpkg-query -W sssd", ]: client.validate_read_only_command(command)