diff --git a/CHANGELOG.md b/CHANGELOG.md index bbb0180..be6f743 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,20 @@ ______________________________________________________________________ ### Added -- Nothing yet. +- Tier 3 core session memory implementation: + - new `src/tai/session_store.py` persistent ChromaDB store + - `--session-memory` option on `tai run` + - prior-session retrieval injected into analysis/follow-up prompts + - final response indexing at session end +- Planner enhancements for broader service detection: + - generic service candidate extraction from free text + - package presence probes in plans (`rpm -q` and `dpkg-query -W`) +- SSH read-only allowlist expanded to permit package presence commands (`rpm`, `dpkg-query`) +- Session memory tests in `tests/test_session_store.py` + +### Changed + +- Documentation alignment updates in README and ROADMAP to reflect implemented session memory and package-presence capabilities. ______________________________________________________________________ @@ -22,20 +35,20 @@ ______________________________________________________________________ - Runbook knowledge store module `src/tai/runbook_store.py` (persistent ChromaDB-backed index and query) - Chroma telemetry no-op client `src/tai/chroma_telemetry.py` to suppress noisy local telemetry errors - `tai runbooks` command group with: - - `sync` for indexing all Markdown runbooks - - `list` for listing indexed metadata - - `add` for indexing a single runbook file + - `sync` for indexing all Markdown runbooks + - `list` for listing indexed metadata + - `add` for indexing a single runbook file - `--runbooks` option on `tai run` to enable Tier 2 runbook retrieval - Initial analysis RAG path using retrieved diagnostic chunks (`build_analysis_message_with_chunks`) - Follow-up RAG path updates with tighter `top_k` and runbook context injection - AI runtime controls: - - `--ai-timeout-seconds` - - `--ai-max-tokens` + - `--ai-timeout-seconds` + - `--ai-max-tokens` - Non-streaming AI completion path for improved local backend reliability - Service/subsystem presence probes in collection plans: - - unit-file checks - - expected binary path checks - - status/journal/config probes for recognized services including `sssd` + - unit-file checks + - expected binary path checks + - status/journal/config probes for recognized services including `sssd` - Prompt instruction for "component absent or not installed" interpretation when presence signals are missing - Runbook store unit tests in `tests/test_runbook_store.py` - CLI tests updated for `tai run` subcommand and non-streaming completion mocks diff --git a/README.md b/README.md index 5571b0c..81e74c7 100644 --- a/README.md +++ b/README.md @@ -191,9 +191,8 @@ pytest tests/test_plan.py tests/test_ai.py tests/test_cli.py ## Known Limits -- Service-specific presence checks currently apply to recognized service/subsystem names. -- Package-manager-level presence checks are not yet in the default read-only command allowlist. -- Tier 3 persistent session memory is not implemented yet. +- Deep service-specific probes (known binary/config/package aliases) are richer for recognized services than generic service names. +- Session memory is available via `--session-memory`, but dedicated history UX commands (`tai history`, `/history`) are not implemented yet. ## Changelog and Roadmap diff --git a/ROADMAP.md b/ROADMAP.md index 6a1e8ef..0144ba1 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -130,7 +130,7 @@ model weights alone. Three tiers of increasing capability, each buildable indepe - Build compounding institutional memory from past troubleshooting sessions - Keep all data local — no embeddings or session content leaves the network ---- +______________________________________________________________________ ### Technology Decisions Required @@ -143,9 +143,9 @@ model weights alone. Three tiers of increasing capability, each buildable indepe | Hybrid retrieval | Semantic only, BM25 only, hybrid | Hybrid (BM25 keyword + cosine semantic) for best recall | ⬜ Pending | | Reranking | None, cross-encoder (`ms-marco-MiniLM`), LLM-as-judge | Cross-encoder rerank pass before prompt injection | ⬜ Pending | | Runbook format | Markdown, YAML, JSON | Markdown (human-editable, version-controllable) | ✅ Implemented | -| Session index storage | Local `~/.tai/`, configurable path | `~/.tai/sessions/` with ChromaDB collection | ⬜ Pending | +| Session index storage | Local `~/.tai/`, configurable path | `~/.tai/sessions/` with ChromaDB collection | ✅ Implemented (core) | ---- +______________________________________________________________________ ### Tier 1 — Diagnostic Chunk Retrieval (in-memory, per-session) @@ -155,31 +155,36 @@ Status: ✅ Implemented On busy hosts this floods the context window with irrelevant output, degrading quality. **Approach:** + - After collection, split each command's output into overlapping token chunks (e.g. 512 tokens, 64 overlap) - Embed all chunks using `nomic-embed-text` via Ollama embeddings API - On each question (initial + follow-up), embed the question and retrieve top-k chunks by cosine similarity - Inject only retrieved chunks into the prompt, not the full dump **New module:** `src/tai/rag_retriever.py` + - `chunk_report(report) -> list[Chunk]` - `embed_chunks(chunks) -> list[EmbeddedChunk]` - `retrieve(question, embedded_chunks, top_k) -> list[Chunk]` **Changes to existing code:** + - `prompt_builder.py`: accept `retrieved_chunks` instead of full `CollectionReport` for RAG-mode prompts - `cli.py`: embed report after collection, pass retriever to `_run_analysis` and `_run_followup_analysis` - `ai_client.py`: add `embed(text) -> list[float]` method using Ollama `/api/embeddings` **Companion features buildable at same time:** + - `--no-rag` flag to bypass retrieval and use full dump (backwards compat) - Token budget display: show user how many tokens are being sent vs. saved - Per-chunk source attribution in AI response (which command produced the evidence) **Tests:** + - `tests/test_rag_retriever.py`: chunk splitting, cosine similarity ranking, top-k retrieval - `tests/test_ai.py`: add `test_embed_returns_float_list()` ---- +______________________________________________________________________ ### Tier 2 — Runbook Knowledge Base (persistent, ChromaDB) @@ -189,63 +194,74 @@ Status: ✅ Implemented specific environments, distros, or internal conventions. **Approach:** + - Maintain a version-controlled corpus of Markdown runbooks in `runbooks/` directory - On first run (or `tai runbooks --sync`), embed all runbooks and persist to ChromaDB collection - On each analysis, retrieve top-3 relevant runbook chunks alongside diagnostic chunks - Inject as a separate `## Runbook Context` section in the prompt **New module:** `src/tai/runbook_store.py` + - `RunbookStore`: wraps ChromaDB collection - `sync(runbooks_dir) -> int` — embed and upsert all runbooks - `query(question, top_k) -> list[RunbookChunk]` **New directory:** `runbooks/` + - `ssh.md`, `nginx.md`, `postgres.md`, `disk.md`, `kernel.md`, etc. - Each runbook: YAML frontmatter (`service`, `symptoms`, `tags`) + Markdown body **New CLI command:** `tai runbooks --sync [--path ./runbooks]` **Changes to existing code:** + - `prompt_builder.py`: add `build_message_with_runbooks(retrieved_chunks, runbook_chunks)` - `cli.py`: optionally load `RunbookStore`, query it per analysis turn **Companion features buildable at same time:** + - `tai runbooks --list` — show indexed runbooks and last sync time - `tai runbooks --add ` — index a single runbook - `/runbooks` slash command in interactive mode — show which runbooks were retrieved - Runbook citation in AI output: "Based on runbook: `ssh.md#AuthenticationFailures`" ---- +______________________________________________________________________ ### Tier 3 — Session Memory Index (institutional learning) -Status: ⬜ Pending +Status: ✅ Implemented (core retrieval/indexing) / ⬜ UX commands pending **Problem:** Every session starts from zero. Repeat incidents on the same host or same issue type get no benefit from past work. -**Approach:** +**Implemented now:** + - On session end, embed the session summary (issue + root cause + actions) and upsert into a persistent ChromaDB collection (`~/.tai/sessions/`) - On session start, query for similar past sessions by issue text + hostname - Inject top-2 past sessions as `## Prior Sessions` context -- Optionally: `/history` command in interactive mode to surface past sessions explicitly + +**Pending UX layer:** + +- `/history` command in interactive mode to surface past sessions explicitly **New module:** `src/tai/session_store.py` + - `SessionStore`: wraps ChromaDB collection at `~/.tai/sessions/` -- `index_session(session_log_path)` — embed and store completed session -- `query_similar(issue, host, top_k) -> list[PastSession]` +- `index_session(host, issue, summary, ai)` — embed and store completed session +- `query(question, host, ai, top_k) -> list[PastSession]` **Changes to existing code:** -- `session_log.py`: add `summarise() -> str` method (issue + final AI response) -- `cli.py`: query `SessionStore` at session start, index at session end + +- `cli.py`: query `SessionStore` during analysis turns and index final responses at session end **Companion features buildable at same time:** + - `tai history` CLI subcommand — search past sessions by keyword - `tai history --host ` — all sessions for a host - `tai history --export ` — export session summaries as Markdown report - Auto-suggest: "Similar issue found from 2 weeks ago — load context? [y/N]" ---- +______________________________________________________________________ ### Implementation Order @@ -258,6 +274,7 @@ Tier 3 (session memory) ← Builds on Tier 2 infrastructure. Minimal extr ``` **Estimated effort:** + - Tier 1: 2–3 days (new module + prompt builder changes + tests) - Tier 2: 3–4 days (ChromaDB + runbook authoring + CLI command + tests) - Tier 3: 1–2 days (reuses Tier 2 infrastructure) @@ -293,14 +310,14 @@ ______________________________________________________________________ | Date | Decision | Outcome | |------|----------|---------| | 2026-05-04 | Implementation language | Python — with single distributable binary via Nuitka | -| — | AI inference backend | vLLM (provisional) | -| — | Default model | `gemma4:a4b` (provisional) | +| 2026-05-04 | AI backend API | OpenAI-compatible API endpoint (local Ollama by default) | +| 2026-05-04 | Default model | `gemma3:4b` | | 2026-05-04 | SSH auth methods | Keypair only (ed25519/RSA); auto-accept new hosts; reject on key change (MITM) | | 2026-05-04 | Bastion host support | `--jump-host` flag via SSH native ProxyJump | | 2026-05-04 | SSH config behavior | Use `~/.ssh/config` by default; allow override via `--ignore-ssh-config` | | 2026-05-04 | CLI vs interactive mode | Interactive: REPL for v0.1, `textual` TUI for v0.2+ | -| 2026-05-04 | RAG embedding model | `nomic-embed-text` via Ollama (local, air-gapped safe) — ⬜ pending confirmation | +| 2026-05-04 | RAG embedding model | `nomic-embed-text` via Ollama (local, air-gapped safe) | | 2026-05-04 | RAG vector store (Tier 1) | In-memory numpy cosine similarity — zero deps, session-scoped | -| 2026-05-04 | RAG vector store (Tier 2/3) | `chromadb` embedded mode (default) or `qdrant` self-hosted — ⬜ pending confirmation | +| 2026-05-04 | RAG vector store (Tier 2/3) | `chromadb` embedded mode (default) or `qdrant` self-hosted | | 2026-05-04 | RAG chunking unit | Command-boundary splitting — each collected command = one or more chunks | | 2026-05-04 | Runbook format | Markdown with YAML frontmatter, version-controlled in `runbooks/` directory | diff --git a/pyproject.toml b/pyproject.toml index 9bcd8d5..5b59216 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,3 +54,11 @@ select = ["E", "F", "I", "UP", "B"] python_version = "3.11" strict = true warn_unused_configs = true + +[[tool.mypy.overrides]] +module = ["chromadb", "chromadb.*"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["tai.chroma_telemetry"] +disable_error_code = ["misc"] diff --git a/src/tai/chroma_telemetry.py b/src/tai/chroma_telemetry.py index 310c65e..609e686 100644 --- a/src/tai/chroma_telemetry.py +++ b/src/tai/chroma_telemetry.py @@ -7,9 +7,10 @@ disabled, so tai wires ChromaDB to this no-op client instead. from __future__ import annotations +from typing import override + from chromadb.config import System from chromadb.telemetry.product import ProductTelemetryClient, ProductTelemetryEvent -from overrides import override class NoOpProductTelemetryClient(ProductTelemetryClient): diff --git a/src/tai/cli.py b/src/tai/cli.py index b6eb5a3..22b048a 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -30,6 +30,7 @@ from tai.prompt_builder import ( from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored from tai.runbook_store import RunbookChunk, RunbookStore from tai.session_log import SessionLogger +from tai.session_store import PastSession, SessionStore from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession app = typer.Typer(no_args_is_help=True, add_completion=False) @@ -151,6 +152,16 @@ def run( help="Path to a synced runbook ChromaDB store. Enables Tier 2 RAG.", ), ] = None, + session_memory_path: Annotated[ + str | None, + typer.Option( + "--session-memory", + help=( + "Path to persistent session memory store for prior-session retrieval " + "(Tier 4). Omit to disable." + ), + ), + ] = None, ) -> None: """Start an interactive troubleshooting session scaffold.""" try: @@ -207,6 +218,17 @@ def run( except Exception as exc: # noqa: BLE001 console.print(f"[yellow]Runbook store unavailable:[/yellow] {exc}") + session_store: SessionStore | None = None + if session_memory_path: + try: + session_store = SessionStore(session_memory_path) + mem_count = session_store.count() + console.print( + f"[dim]Session memory: {mem_count} indexed at {session_memory_path}[/dim]" + ) + except Exception as exc: # noqa: BLE001 + console.print(f"[yellow]Session memory unavailable:[/yellow] {exc}") + try: asyncio.run( _async_main( @@ -220,6 +242,7 @@ def run( no_rag=no_rag, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) ) @@ -245,6 +268,7 @@ async def _async_main( no_rag: bool, rag_debug: bool, runbook_store: RunbookStore | None, + session_store: SessionStore | None, logger: SessionLogger | None, ) -> None: """Open a single SSH session and run probe / collection / analysis through it.""" @@ -291,19 +315,22 @@ async def _async_main( }, ) + initial_response: str | None = None if analyze and report is not None: - _run_analysis( + initial_response = _run_analysis( ai_config, req.issue, report, no_rag=no_rag, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) + interactive_response: str | None = None if interactive: - await _interactive_loop( + interactive_response = await _interactive_loop( session, req, ai_config, @@ -311,9 +338,14 @@ async def _async_main( no_rag=no_rag, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) + final_response = interactive_response or initial_response + if session_store is not None and final_response: + _index_session_memory(session_store, ai_config, req, final_response, logger=logger) + async def _interactive_loop( session: SSHSession, @@ -324,8 +356,9 @@ async def _interactive_loop( no_rag: bool = False, rag_debug: bool = False, runbook_store: RunbookStore | None = None, + session_store: SessionStore | None = None, logger: SessionLogger | None, -) -> None: +) -> str | None: """Run a follow-up loop for collecting and conversational analysis.""" console.print( Panel( @@ -340,6 +373,7 @@ async def _interactive_loop( prior_questions: list[str] = [] embedded_chunks: list[EmbeddedChunk] | None = None ai_embed = AIClient(ai_config) + last_response: str | None = None if not no_rag and report is not None: embedded_chunks, index_error, index_ms = await asyncio.to_thread( @@ -377,14 +411,14 @@ async def _interactive_loop( else: line = sys.stdin.readline() # non-TTY / piped mode if not line: - return + return last_response command = line.strip() console.print(f"\n[bold cyan]tai[/bold cyan][dim] >[/dim] {command}") except (EOFError, KeyboardInterrupt): console.print("\n[yellow]Exiting interactive mode.[/yellow]") if logger is not None: logger.log_event("interactive_exit", {"reason": "signal_or_eof"}) - return + return last_response if not command: continue @@ -393,7 +427,7 @@ async def _interactive_loop( console.print("[green]Bye.[/green]") if logger is not None: logger.log_event("interactive_exit", {"reason": "user_quit"}) - return + return last_response if command == "/help": console.print( @@ -466,7 +500,7 @@ async def _interactive_loop( console.print("[red]No diagnostics available to analyze.[/red]") continue - _run_followup_analysis( + response = _run_followup_analysis( ai_config, req.issue, report, @@ -475,11 +509,14 @@ async def _interactive_loop( embedded_chunks=embedded_chunks, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) prior_questions.append("/analyze") if logger is not None: logger.log_event("interactive_followup", {"question": "/analyze"}) + last_response = response + continue continue if report is None: @@ -523,7 +560,7 @@ async def _interactive_loop( console.print("[red]No diagnostics available to analyze.[/red]") continue - _run_followup_analysis( + response = _run_followup_analysis( ai_config, req.issue, report, @@ -532,11 +569,13 @@ async def _interactive_loop( embedded_chunks=embedded_chunks, rag_debug=rag_debug, runbook_store=runbook_store, + session_store=session_store, logger=logger, ) prior_questions.append(command) if logger is not None: logger.log_event("interactive_followup", {"question": command}) + last_response = response def _try_embed_report( @@ -597,8 +636,9 @@ def _run_analysis( no_rag: bool = False, rag_debug: bool = False, runbook_store: RunbookStore | None = None, + session_store: SessionStore | None = None, logger: SessionLogger | None, -) -> None: +) -> str: """Send collected data to the AI and stream the analysis to stdout.""" console.print() console.print(Rule("[bold cyan]Analysis[/bold cyan]", style="cyan")) @@ -606,10 +646,16 @@ def _run_analysis( ai = AIClient(ai_config) system_prompt = build_system_prompt() runbook_chunks = _query_runbooks(runbook_store, issue, ai, top_k=1) + past_sessions = _query_sessions(session_store, issue, report.host, ai, top_k=2) user_message: str if no_rag: - user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None) + user_message = build_user_message( + issue, + report, + runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, + ) else: try: chunks = chunk_report(report) @@ -628,16 +674,28 @@ def _run_analysis( report.host, selected, runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, ) else: - user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None) + user_message = build_user_message( + issue, + report, + runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, + ) except Exception as exc: # noqa: BLE001 console.print( - "[yellow]RAG unavailable for initial analysis; using full-context fallback.[/yellow]" + "[yellow]RAG unavailable for initial analysis; " + "using full-context fallback.[/yellow]" ) if logger is not None: logger.log_event("rag_index", {"status": "fallback", "error": str(exc)}) - user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None) + user_message = build_user_message( + issue, + report, + runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, + ) try: response = _complete_ai_response( ai, @@ -662,6 +720,7 @@ def _run_analysis( "guardrail_warnings": warnings, }, ) + return response except Exception as exc: # noqa: BLE001 console.print(f"[red]AI analysis failed:[/red] {exc}") if logger is not None: @@ -688,6 +747,7 @@ def _run_followup_analysis( embedded_chunks: list[EmbeddedChunk] | None = None, rag_debug: bool = False, runbook_store: RunbookStore | None = None, + session_store: SessionStore | None = None, logger: SessionLogger | None, ) -> str: """Run grounded follow-up analysis re-anchored to current diagnostics. @@ -702,6 +762,7 @@ def _run_followup_analysis( ai = AIClient(ai_config) system_prompt = build_system_prompt() runbook_chunks = _query_runbooks(runbook_store, question, ai, top_k=1) + past_sessions = _query_sessions(session_store, question, report.host, ai, top_k=2) user_message: str retrieved_names: list[str] = [] @@ -724,6 +785,7 @@ def _run_followup_analysis( question, prior_questions, runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, ) if rag_debug: pairs = ", ".join( @@ -741,12 +803,14 @@ def _run_followup_analysis( user_message = build_followup_message( issue, report, question, prior_questions, runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, ) else: fallback_reason = "rag not indexed" user_message = build_followup_message( issue, report, question, prior_questions, runbook_chunks=runbook_chunks or None, + past_sessions=past_sessions or None, ) if logger is not None: @@ -826,6 +890,42 @@ def _query_runbooks( return [] +def _query_sessions( + store: SessionStore | None, + question: str, + host: str, + ai: AIClient, + *, + top_k: int = 2, +) -> list[PastSession]: + """Query the session memory store silently; returns empty list on failures.""" + if store is None: + return [] + try: + return store.query(question, host, ai, top_k=top_k) + except Exception: # noqa: BLE001 + return [] + + +def _index_session_memory( + store: SessionStore, + ai_config: AIConfig, + req: TroubleshootRequest, + summary: str, + *, + logger: SessionLogger | None, +) -> None: + """Persist final session summary for future retrieval; non-fatal on failure.""" + try: + ai = AIClient(ai_config) + session_id = store.index_session(req.host, req.issue, summary, ai) + if logger is not None: + logger.log_event("session_memory_indexed", {"session_id": session_id}) + except Exception as exc: # noqa: BLE001 + if logger is not None: + logger.log_event("session_memory_error", {"error": str(exc)}) + + # --------------------------------------------------------------------------- # runbooks sub-app # --------------------------------------------------------------------------- diff --git a/src/tai/plan.py b/src/tai/plan.py index c6d6701..f75bfc7 100644 --- a/src/tai/plan.py +++ b/src/tai/plan.py @@ -149,6 +149,37 @@ _SERVICE_BINARIES: dict[str, list[str]] = { "apparmor": ["/usr/sbin/aa-status", "/sbin/apparmor_parser"], } +_SERVICE_PACKAGES: dict[str, list[str]] = { + "docker": ["docker", "docker-ce"], + "sssd": ["sssd"], + "sshd": ["openssh-server", "openssh"], + "x2go": ["x2goserver", "x2goserver-xsession"], + "xorg": ["xorg-server", "xserver-xorg-core"], + "wayland": ["wayland", "xwayland"], + "selinux": ["selinux-policy", "selinux-policy-targeted"], + "apparmor": ["apparmor"], +} + +_GENERIC_SERVICE_STOPWORDS: frozenset[str] = frozenset( + { + "a", + "an", + "and", + "app", + "application", + "daemon", + "for", + "is", + "my", + "not", + "service", + "systemd", + "the", + "unit", + "working", + } +) + # --------------------------------------------------------------------------- # Command sets # --------------------------------------------------------------------------- @@ -225,6 +256,9 @@ def plan_from_request(request: TroubleshootRequest) -> CollectionPlan: ) for idx, binary_path in enumerate(_SERVICE_BINARIES.get(svc, []), start=1): plan.add(f"binary-{svc}-{idx}", f"ls -l {binary_path}") + for idx, package_name in enumerate(_service_package_candidates(svc), start=1): + plan.add(f"package-rpm-{svc}-{idx}", f"rpm -q {package_name}") + plan.add(f"package-dpkg-{svc}-{idx}", f"dpkg-query -W {package_name}") plan.add(f"service-{svc}", f"systemctl status {svc}") plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager") for cfg_path in _SERVICE_CONFIGS.get(svc, []): @@ -258,7 +292,11 @@ def _issue_words(issue: str) -> set[str]: def _extract_services(issue: str) -> list[str]: - """Return known service names mentioned in *issue*.""" + """Return service candidates mentioned in *issue*. + + Includes known services plus generic service-like tokens near words such + as "service", "daemon", and "unit". + """ words = _issue_words(issue) found: list[str] = [] for svc in _KNOWN_SERVICES: @@ -266,6 +304,58 @@ def _extract_services(issue: str) -> list[str]: svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")} if words & svc_words: found.append(svc) + for svc in _extract_generic_service_candidates(issue): + if svc not in found: + found.append(svc) return found +def _extract_generic_service_candidates(issue: str) -> list[str]: + """Extract likely service names from free text even when not pre-registered.""" + tokens = [tok.lower() for tok in re.findall(r"[a-zA-Z0-9_.@-]+", issue)] + if not tokens: + return [] + + candidates: list[str] = [] + for idx, token in enumerate(tokens): + normalized = token[:-8] if token.endswith(".service") else token + if _is_safe_service_name(normalized): + if token.endswith(".service") and normalized not in _GENERIC_SERVICE_STOPWORDS: + candidates.append(normalized) + + if token in {"service", "daemon", "unit"}: + for neighbor in (idx - 1, idx + 1): + if neighbor < 0 or neighbor >= len(tokens): + continue + candidate = tokens[neighbor] + if candidate.endswith(".service"): + candidate = candidate[:-8] + if candidate in _GENERIC_SERVICE_STOPWORDS: + continue + if _is_safe_service_name(candidate): + candidates.append(candidate) + + deduped: list[str] = [] + seen: set[str] = set() + for item in candidates: + if item in seen: + continue + seen.add(item) + deduped.append(item) + return deduped + + +def _is_safe_service_name(name: str) -> bool: + """Return True when *name* is safe to interpolate into read-only commands.""" + if len(name) < 2 or len(name) > 64: + return False + return re.fullmatch(r"[a-z0-9_.@-]+", name) is not None + + +def _service_package_candidates(service: str) -> list[str]: + """Return package names to probe for *service* presence.""" + if service in _SERVICE_PACKAGES: + return _SERVICE_PACKAGES[service] + return [service] + + diff --git a/src/tai/prompt_builder.py b/src/tai/prompt_builder.py index ede0607..35ffcd7 100644 --- a/src/tai/prompt_builder.py +++ b/src/tai/prompt_builder.py @@ -5,6 +5,7 @@ from __future__ import annotations from tai.collectors import CollectionReport from tai.rag_retriever import Chunk from tai.runbook_store import RunbookChunk +from tai.session_store import PastSession _SYSTEM_PROMPT = """\ You are an expert Linux systems administrator and troubleshooting assistant. @@ -21,7 +22,8 @@ Important rules: - If a command shows "could not be executed (SSH error)" it means the remote host blocked or rejected that specific command — it is not evidence about the service or system state. - If service presence checks show a unit, binary, package, or config is missing, treat that as - evidence the component may be absent or not installed, not as proof that the component is broken. + evidence the component may be absent or not installed, not as proof that the + component is broken. - If there is not enough data to diagnose the issue, say so plainly and list exactly what additional commands or log files would be needed. - Keep the response short. Skip sections that have nothing useful to say. @@ -33,6 +35,7 @@ Important rules: _MAX_RUNBOOK_CHARS = 500 _MAX_DIAGNOSTIC_CHUNK_CHARS = 700 +_MAX_SESSION_SUMMARY_CHARS = 500 def build_system_prompt() -> str: @@ -66,11 +69,34 @@ def _format_diagnostic_chunk(content: str) -> str: return text[:_MAX_DIAGNOSTIC_CHUNK_CHARS].rstrip() + "\n...[truncated diagnostic context]" +def _format_session_context(past_sessions: list[PastSession]) -> str: + """Format similar prior sessions as compact grounding context.""" + lines: list[str] = ["## Similar prior sessions\n"] + lines.append( + "The following completed sessions were semantically similar. " + "Use them as historical hints, but prioritize current diagnostics if they conflict.\n" + ) + for sess in past_sessions: + summary = sess.summary.strip() + if len(summary) > _MAX_SESSION_SUMMARY_CHARS: + summary = ( + summary[:_MAX_SESSION_SUMMARY_CHARS].rstrip() + + "\n...[truncated session summary]" + ) + lines.append(f"### Session: {sess.session_id} (host={sess.host})\n") + lines.append(f"**Issue:** {sess.issue}") + lines.append("") + lines.append(summary) + lines.append("") + return "\n".join(lines) + + def build_user_message( issue: str, report: CollectionReport, *, runbook_chunks: list[RunbookChunk] | None = None, + past_sessions: list[PastSession] | None = None, ) -> str: """Format *issue* and *report* into the user message sent to the AI.""" lines: list[str] = [] @@ -81,6 +107,9 @@ def build_user_message( if runbook_chunks: lines.append(_format_runbook_context(runbook_chunks)) + if past_sessions: + lines.append(_format_session_context(past_sessions)) + lines.append("## Collected diagnostics\n") skipped: list[str] = [] @@ -126,9 +155,15 @@ def build_followup_message( prior_questions: list[str], *, runbook_chunks: list[RunbookChunk] | None = None, + past_sessions: list[PastSession] | None = None, ) -> str: """Build a grounded follow-up message that re-anchors to diagnostics each turn.""" - base = build_user_message(issue, report, runbook_chunks=runbook_chunks) + base = build_user_message( + issue, + report, + runbook_chunks=runbook_chunks, + past_sessions=past_sessions, + ) lines: list[str] = [base, "## Follow-up"] if prior_questions: @@ -157,6 +192,7 @@ def build_message_with_chunks( prior_questions: list[str], *, runbook_chunks: list[RunbookChunk] | None = None, + past_sessions: list[PastSession] | None = None, ) -> str: """Build a follow-up message using only semantically retrieved diagnostic chunks. @@ -178,6 +214,9 @@ def build_message_with_chunks( if runbook_chunks: lines.append(_format_runbook_context(runbook_chunks)) + if past_sessions: + lines.append(_format_session_context(past_sessions)) + lines.append("## Follow-up") if prior_questions: @@ -204,6 +243,7 @@ def build_analysis_message_with_chunks( chunks: list[Chunk], *, runbook_chunks: list[RunbookChunk] | None = None, + past_sessions: list[PastSession] | None = None, ) -> str: """Build an initial analysis message from retrieved diagnostic chunks.""" lines: list[str] = [] @@ -213,6 +253,9 @@ def build_analysis_message_with_chunks( if runbook_chunks: lines.append(_format_runbook_context(runbook_chunks)) + if past_sessions: + lines.append(_format_session_context(past_sessions)) + lines.append("## Most relevant diagnostics (retrieved by semantic similarity)\n") for chunk in chunks: lines.append(f"### {chunk.name}\n") diff --git a/src/tai/runbook_store.py b/src/tai/runbook_store.py index 42778e6..4f34969 100644 --- a/src/tai/runbook_store.py +++ b/src/tai/runbook_store.py @@ -17,7 +17,7 @@ from __future__ import annotations import re from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from tai.ai_client import AIClient @@ -123,7 +123,7 @@ class RunbookStore: self._client = chromadb.PersistentClient(path=str(path)) else: self._client = chromadb.PersistentClient(path=str(path), settings=settings) - self._collection = self._client.get_or_create_collection( + self._collection: Any = self._client.get_or_create_collection( name=_COLLECTION_NAME, metadata={"hnsw:space": "cosine"}, ) @@ -241,11 +241,14 @@ class RunbookStore: return [] results = self._collection.get(include=["metadatas"]) metas = results.get("metadatas") or [] - return [dict(m) for m in metas] + entries: list[dict[str, str]] = [] + for meta in metas: + entries.append({str(k): str(v) for k, v in dict(meta).items()}) + return entries def count(self) -> int: """Return the number of indexed runbook documents.""" - return self._collection.count() + return int(self._collection.count()) # --------------------------------------------------------------------------- diff --git a/src/tai/session_store.py b/src/tai/session_store.py new file mode 100644 index 0000000..63ef633 --- /dev/null +++ b/src/tai/session_store.py @@ -0,0 +1,105 @@ +"""Persistent session memory store (Tier 4) backed by ChromaDB.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from tai.ai_client import AIClient + +DEFAULT_SESSION_STORE_PATH = "~/.tai/sessions" +_COLLECTION_NAME = "tai_sessions" + + +@dataclass(slots=True) +class PastSession: + """A retrieved prior session summary for prompt grounding.""" + + session_id: str + host: str + issue: str + summary: str + + +class SessionStore: + """ChromaDB-backed persistent memory for prior troubleshooting sessions.""" + + def __init__(self, store_path: str | Path = DEFAULT_SESSION_STORE_PATH) -> None: + import chromadb + + path = Path(store_path).expanduser().resolve() + path.mkdir(parents=True, exist_ok=True) + settings = None + try: + from chromadb.config import Settings + + settings = Settings( + anonymized_telemetry=False, + chroma_product_telemetry_impl="tai.chroma_telemetry.NoOpProductTelemetryClient", + chroma_telemetry_impl="tai.chroma_telemetry.NoOpProductTelemetryClient", + ) + except (ImportError, ModuleNotFoundError): + settings = None + + if settings is None: + self._client = chromadb.PersistentClient(path=str(path)) + else: + self._client = chromadb.PersistentClient(path=str(path), settings=settings) + self._collection: Any = self._client.get_or_create_collection( + name=_COLLECTION_NAME, + metadata={"hnsw:space": "cosine"}, + ) + + def count(self) -> int: + """Return number of indexed session summaries.""" + return int(self._collection.count()) + + def index_session(self, host: str, issue: str, summary: str, ai: AIClient) -> str: + """Embed and upsert one session summary into persistent storage.""" + session_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + embed_text = _build_embed_text(host=host, issue=issue, summary=summary) + embedding = ai.embed(embed_text) + self._collection.upsert( + ids=[session_id], + documents=[summary.strip()], + embeddings=[embedding], + metadatas=[{"host": host, "issue": issue}], + ) + return session_id + + def query(self, question: str, host: str, ai: AIClient, *, top_k: int = 2) -> list[PastSession]: + """Return top-k semantically similar sessions for this host and question.""" + if self._collection.count() == 0: + return [] + + q_embedding = ai.embed(f"host: {host}\nquestion: {question}") + results = self._collection.query( + query_embeddings=[q_embedding], + n_results=min(top_k, self._collection.count()), + include=["documents", "metadatas"], + ) + + sessions: list[PastSession] = [] + ids = results.get("ids") or [] + docs = results.get("documents") or [] + metas = results.get("metadatas") or [] + for id_list, doc_list, meta_list in zip(ids, docs, metas, strict=False): + for sid, doc, meta in zip(id_list, doc_list, meta_list, strict=False): + sessions.append( + PastSession( + session_id=str(sid), + host=str(meta.get("host", "")), + issue=str(meta.get("issue", "")), + summary=str(doc), + ) + ) + return sessions + + +def _build_embed_text(*, host: str, issue: str, summary: str) -> str: + """Build embedding text with host/issue context and summary excerpt.""" + excerpt = summary.strip()[:1000] + return f"host: {host}\nissue: {issue}\nsummary:\n{excerpt}" diff --git a/src/tai/ssh_client.py b/src/tai/ssh_client.py index 6235010..0afc974 100644 --- a/src/tai/ssh_client.py +++ b/src/tai/ssh_client.py @@ -51,6 +51,7 @@ class SSHClient: _READ_ONLY_COMMANDS = { "cat", "dmesg", + "dpkg-query", "df", "du", "find", @@ -61,6 +62,7 @@ class SSHClient: "journalctl", "ls", "netstat", + "rpm", "sed", "ss", "stat", diff --git a/tests/test_ai.py b/tests/test_ai.py index f37de7d..b904e57 100644 --- a/tests/test_ai.py +++ b/tests/test_ai.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.collectors import CollectedItem, CollectionReport from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message +from tai.session_store import PastSession from tai.ssh_client import SSHCommandResult # --------------------------------------------------------------------------- @@ -221,6 +222,24 @@ def test_build_user_message_handles_no_output() -> None: assert "no output" in msg +def test_build_user_message_includes_prior_session_context() -> None: + report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")]) + msg = build_user_message( + "sssd broken", + report, + past_sessions=[ + PastSession( + session_id="20260506T120000Z", + host="web01", + issue="sssd broken", + summary="Root cause was missing sssd package.", + ) + ], + ) + assert "Similar prior sessions" in msg + assert "missing sssd package" in msg + + def test_build_followup_message_includes_question_context() -> None: report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")]) msg = build_followup_message( diff --git a/tests/test_plan.py b/tests/test_plan.py index 4fdf556..cc2658b 100644 --- a/tests/test_plan.py +++ b/tests/test_plan.py @@ -107,8 +107,12 @@ def test_sssd_in_issue_adds_presence_service_and_config_commands() -> None: assert "binary-sssd-1" in names assert "service-sssd" in names assert "journal-sssd" in names + assert "package-rpm-sssd-1" in names + assert "package-dpkg-sssd-1" in names assert any("cat /etc/sssd/sssd.conf" in c for c in cmds) assert any("ls -l /usr/sbin/sssd" in c for c in cmds) + assert any("rpm -q sssd" in c for c in cmds) + assert any("dpkg-query -W sssd" in c for c in cmds) assert any("list-unit-files sssd.service" in c for c in cmds) @@ -119,8 +123,12 @@ def test_docker_presence_probe_checks_package_and_binary() -> None: assert "unit-file-docker" in names assert "binary-docker-1" in names assert "binary-docker-2" in names + assert "package-rpm-docker-1" in names + assert "package-dpkg-docker-1" in names assert any("ls -l /usr/bin/docker" in c for c in cmds) assert any("ls -l /usr/bin/dockerd" in c for c in cmds) + assert any("rpm -q docker" in c for c in cmds) + assert any("dpkg-query -W docker" in c for c in cmds) def test_unknown_service_name_no_config_cat() -> None: @@ -183,6 +191,16 @@ def test_extract_services_case_insensitive() -> None: assert "nginx" in _extract_services("NGINX failed") +def test_extract_services_detects_generic_service_name() -> None: + services = _extract_services("myweirdapp service keeps failing") + assert "myweirdapp" in services + + +def test_extract_services_detects_dot_service_pattern() -> None: + services = _extract_services("please check foobar.service on this host") + assert "foobar" in services + + # --------------------------------------------------------------------------- # Plan length sanity # --------------------------------------------------------------------------- diff --git a/tests/test_session_store.py b/tests/test_session_store.py new file mode 100644 index 0000000..e66ed41 --- /dev/null +++ b/tests/test_session_store.py @@ -0,0 +1,79 @@ +"""Tests for session_store with mocked ChromaDB.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from tai.session_store import PastSession, SessionStore, _build_embed_text + + +def _make_chromadb_mock() -> MagicMock: + collection = MagicMock() + collection.count.return_value = 0 + client = MagicMock() + client.get_or_create_collection.return_value = collection + chroma_mod = MagicMock() + chroma_mod.PersistentClient.return_value = client + return chroma_mod + + +def _make_ai_mock(embedding: list[float] | None = None) -> MagicMock: + ai = MagicMock() + ai.embed.return_value = embedding or [0.1, 0.2, 0.3] + return ai + + +def test_build_embed_text_contains_host_issue_and_summary() -> None: + text = _build_embed_text(host="web01", issue="sssd broken", summary="Unit missing") + assert "host: web01" in text + assert "issue: sssd broken" in text + assert "Unit missing" in text + + +def test_index_session_upserts_with_metadata(tmp_path: Path) -> None: + chroma_mock = _make_chromadb_mock() + collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value + ai = _make_ai_mock() + + with patch.dict("sys.modules", {"chromadb": chroma_mock}): + store = SessionStore(tmp_path / "store") + session_id = store.index_session("web01", "sssd broken", "summary text", ai) + + assert session_id + collection.upsert.assert_called_once() + args = collection.upsert.call_args.kwargs + assert args["metadatas"][0]["host"] == "web01" + assert args["metadatas"][0]["issue"] == "sssd broken" + + +def test_query_returns_empty_when_no_docs(tmp_path: Path) -> None: + chroma_mock = _make_chromadb_mock() + ai = _make_ai_mock() + + with patch.dict("sys.modules", {"chromadb": chroma_mock}): + store = SessionStore(tmp_path / "store") + results = store.query("why sssd", "web01", ai) + + assert results == [] + + +def test_query_returns_past_sessions(tmp_path: Path) -> None: + chroma_mock = _make_chromadb_mock() + collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value + collection.count.return_value = 1 + collection.query.return_value = { + "ids": [["20260506T120000Z"]], + "documents": [["Root cause: package missing"]], + "metadatas": [[{"host": "web01", "issue": "sssd broken"}]], + } + ai = _make_ai_mock() + + with patch.dict("sys.modules", {"chromadb": chroma_mock}): + store = SessionStore(tmp_path / "store") + results = store.query("sssd issue", "web01", ai) + + assert len(results) == 1 + assert isinstance(results[0], PastSession) + assert results[0].host == "web01" + assert "package missing" in results[0].summary diff --git a/tests/test_ssh_client.py b/tests/test_ssh_client.py index fcad417..6b68db0 100644 --- a/tests/test_ssh_client.py +++ b/tests/test_ssh_client.py @@ -82,6 +82,8 @@ def test_allows_expected_read_only_commands() -> None: "systemctl status apache2", "cat /etc/hosts", "ss -lntp", + "rpm -q sssd", + "dpkg-query -W sssd", ]: client.validate_read_only_command(command)