update
Some checks failed
CI / test (push) Failing after 15s

This commit is contained in:
2026-05-06 05:02:38 +02:00
parent 74a56e3113
commit d5e1822644
9 changed files with 473 additions and 15 deletions

View File

@@ -30,6 +30,7 @@ from tai.prompt_builder import (
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored
from tai.runbook_store import RunbookChunk, RunbookStore from tai.runbook_store import RunbookChunk, RunbookStore
from tai.session_log import SessionLogger from tai.session_log import SessionLogger
from tai.session_store import PastSession, SessionStore
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
app = typer.Typer(no_args_is_help=True, add_completion=False) app = typer.Typer(no_args_is_help=True, add_completion=False)
@@ -151,6 +152,16 @@ def run(
help="Path to a synced runbook ChromaDB store. Enables Tier 2 RAG.", help="Path to a synced runbook ChromaDB store. Enables Tier 2 RAG.",
), ),
] = None, ] = None,
session_memory_path: Annotated[
str | None,
typer.Option(
"--session-memory",
help=(
"Path to persistent session memory store for prior-session retrieval "
"(Tier 4). Omit to disable."
),
),
] = None,
) -> None: ) -> None:
"""Start an interactive troubleshooting session scaffold.""" """Start an interactive troubleshooting session scaffold."""
try: try:
@@ -207,6 +218,17 @@ def run(
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
console.print(f"[yellow]Runbook store unavailable:[/yellow] {exc}") console.print(f"[yellow]Runbook store unavailable:[/yellow] {exc}")
session_store: SessionStore | None = None
if session_memory_path:
try:
session_store = SessionStore(session_memory_path)
mem_count = session_store.count()
console.print(
f"[dim]Session memory: {mem_count} indexed at {session_memory_path}[/dim]"
)
except Exception as exc: # noqa: BLE001
console.print(f"[yellow]Session memory unavailable:[/yellow] {exc}")
try: try:
asyncio.run( asyncio.run(
_async_main( _async_main(
@@ -220,6 +242,7 @@ def run(
no_rag=no_rag, no_rag=no_rag,
rag_debug=rag_debug, rag_debug=rag_debug,
runbook_store=runbook_store, runbook_store=runbook_store,
session_store=session_store,
logger=logger, logger=logger,
) )
) )
@@ -245,6 +268,7 @@ async def _async_main(
no_rag: bool, no_rag: bool,
rag_debug: bool, rag_debug: bool,
runbook_store: RunbookStore | None, runbook_store: RunbookStore | None,
session_store: SessionStore | None,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> None: ) -> None:
"""Open a single SSH session and run probe / collection / analysis through it.""" """Open a single SSH session and run probe / collection / analysis through it."""
@@ -291,19 +315,22 @@ async def _async_main(
}, },
) )
initial_response: str | None = None
if analyze and report is not None: if analyze and report is not None:
_run_analysis( initial_response = _run_analysis(
ai_config, ai_config,
req.issue, req.issue,
report, report,
no_rag=no_rag, no_rag=no_rag,
rag_debug=rag_debug, rag_debug=rag_debug,
runbook_store=runbook_store, runbook_store=runbook_store,
session_store=session_store,
logger=logger, logger=logger,
) )
interactive_response: str | None = None
if interactive: if interactive:
await _interactive_loop( interactive_response = await _interactive_loop(
session, session,
req, req,
ai_config, ai_config,
@@ -311,9 +338,14 @@ async def _async_main(
no_rag=no_rag, no_rag=no_rag,
rag_debug=rag_debug, rag_debug=rag_debug,
runbook_store=runbook_store, runbook_store=runbook_store,
session_store=session_store,
logger=logger, logger=logger,
) )
final_response = interactive_response or initial_response
if session_store is not None and final_response:
_index_session_memory(session_store, ai_config, req, final_response, logger=logger)
async def _interactive_loop( async def _interactive_loop(
session: SSHSession, session: SSHSession,
@@ -324,8 +356,9 @@ async def _interactive_loop(
no_rag: bool = False, no_rag: bool = False,
rag_debug: bool = False, rag_debug: bool = False,
runbook_store: RunbookStore | None = None, runbook_store: RunbookStore | None = None,
session_store: SessionStore | None = None,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> None: ) -> str | None:
"""Run a follow-up loop for collecting and conversational analysis.""" """Run a follow-up loop for collecting and conversational analysis."""
console.print( console.print(
Panel( Panel(
@@ -340,6 +373,7 @@ async def _interactive_loop(
prior_questions: list[str] = [] prior_questions: list[str] = []
embedded_chunks: list[EmbeddedChunk] | None = None embedded_chunks: list[EmbeddedChunk] | None = None
ai_embed = AIClient(ai_config) ai_embed = AIClient(ai_config)
last_response: str | None = None
if not no_rag and report is not None: if not no_rag and report is not None:
embedded_chunks, index_error, index_ms = await asyncio.to_thread( embedded_chunks, index_error, index_ms = await asyncio.to_thread(
@@ -384,7 +418,7 @@ async def _interactive_loop(
console.print("\n[yellow]Exiting interactive mode.[/yellow]") console.print("\n[yellow]Exiting interactive mode.[/yellow]")
if logger is not None: if logger is not None:
logger.log_event("interactive_exit", {"reason": "signal_or_eof"}) logger.log_event("interactive_exit", {"reason": "signal_or_eof"})
return return last_response
if not command: if not command:
continue continue
@@ -393,7 +427,7 @@ async def _interactive_loop(
console.print("[green]Bye.[/green]") console.print("[green]Bye.[/green]")
if logger is not None: if logger is not None:
logger.log_event("interactive_exit", {"reason": "user_quit"}) logger.log_event("interactive_exit", {"reason": "user_quit"})
return return last_response
if command == "/help": if command == "/help":
console.print( console.print(
@@ -466,7 +500,7 @@ async def _interactive_loop(
console.print("[red]No diagnostics available to analyze.[/red]") console.print("[red]No diagnostics available to analyze.[/red]")
continue continue
_run_followup_analysis( response = _run_followup_analysis(
ai_config, ai_config,
req.issue, req.issue,
report, report,
@@ -475,11 +509,14 @@ async def _interactive_loop(
embedded_chunks=embedded_chunks, embedded_chunks=embedded_chunks,
rag_debug=rag_debug, rag_debug=rag_debug,
runbook_store=runbook_store, runbook_store=runbook_store,
session_store=session_store,
logger=logger, logger=logger,
) )
prior_questions.append("/analyze") prior_questions.append("/analyze")
if logger is not None: if logger is not None:
logger.log_event("interactive_followup", {"question": "/analyze"}) logger.log_event("interactive_followup", {"question": "/analyze"})
last_response = response
continue
continue continue
if report is None: if report is None:
@@ -523,7 +560,7 @@ async def _interactive_loop(
console.print("[red]No diagnostics available to analyze.[/red]") console.print("[red]No diagnostics available to analyze.[/red]")
continue continue
_run_followup_analysis( response = _run_followup_analysis(
ai_config, ai_config,
req.issue, req.issue,
report, report,
@@ -532,11 +569,13 @@ async def _interactive_loop(
embedded_chunks=embedded_chunks, embedded_chunks=embedded_chunks,
rag_debug=rag_debug, rag_debug=rag_debug,
runbook_store=runbook_store, runbook_store=runbook_store,
session_store=session_store,
logger=logger, logger=logger,
) )
prior_questions.append(command) prior_questions.append(command)
if logger is not None: if logger is not None:
logger.log_event("interactive_followup", {"question": command}) logger.log_event("interactive_followup", {"question": command})
last_response = response
def _try_embed_report( def _try_embed_report(
@@ -597,8 +636,9 @@ def _run_analysis(
no_rag: bool = False, no_rag: bool = False,
rag_debug: bool = False, rag_debug: bool = False,
runbook_store: RunbookStore | None = None, runbook_store: RunbookStore | None = None,
session_store: SessionStore | None = None,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> None: ) -> str:
"""Send collected data to the AI and stream the analysis to stdout.""" """Send collected data to the AI and stream the analysis to stdout."""
console.print() console.print()
console.print(Rule("[bold cyan]Analysis[/bold cyan]", style="cyan")) console.print(Rule("[bold cyan]Analysis[/bold cyan]", style="cyan"))
@@ -606,10 +646,16 @@ def _run_analysis(
ai = AIClient(ai_config) ai = AIClient(ai_config)
system_prompt = build_system_prompt() system_prompt = build_system_prompt()
runbook_chunks = _query_runbooks(runbook_store, issue, ai, top_k=1) runbook_chunks = _query_runbooks(runbook_store, issue, ai, top_k=1)
past_sessions = _query_sessions(session_store, issue, report.host, ai, top_k=2)
user_message: str user_message: str
if no_rag: if no_rag:
user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None) user_message = build_user_message(
issue,
report,
runbook_chunks=runbook_chunks or None,
past_sessions=past_sessions or None,
)
else: else:
try: try:
chunks = chunk_report(report) chunks = chunk_report(report)
@@ -628,16 +674,28 @@ def _run_analysis(
report.host, report.host,
selected, selected,
runbook_chunks=runbook_chunks or None, runbook_chunks=runbook_chunks or None,
past_sessions=past_sessions or None,
) )
else: else:
user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None) user_message = build_user_message(
issue,
report,
runbook_chunks=runbook_chunks or None,
past_sessions=past_sessions or None,
)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
console.print( console.print(
"[yellow]RAG unavailable for initial analysis; using full-context fallback.[/yellow]" "[yellow]RAG unavailable for initial analysis; "
"using full-context fallback.[/yellow]"
) )
if logger is not None: if logger is not None:
logger.log_event("rag_index", {"status": "fallback", "error": str(exc)}) logger.log_event("rag_index", {"status": "fallback", "error": str(exc)})
user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None) user_message = build_user_message(
issue,
report,
runbook_chunks=runbook_chunks or None,
past_sessions=past_sessions or None,
)
try: try:
response = _complete_ai_response( response = _complete_ai_response(
ai, ai,
@@ -662,6 +720,7 @@ def _run_analysis(
"guardrail_warnings": warnings, "guardrail_warnings": warnings,
}, },
) )
return response
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
console.print(f"[red]AI analysis failed:[/red] {exc}") console.print(f"[red]AI analysis failed:[/red] {exc}")
if logger is not None: if logger is not None:
@@ -688,6 +747,7 @@ def _run_followup_analysis(
embedded_chunks: list[EmbeddedChunk] | None = None, embedded_chunks: list[EmbeddedChunk] | None = None,
rag_debug: bool = False, rag_debug: bool = False,
runbook_store: RunbookStore | None = None, runbook_store: RunbookStore | None = None,
session_store: SessionStore | None = None,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> str: ) -> str:
"""Run grounded follow-up analysis re-anchored to current diagnostics. """Run grounded follow-up analysis re-anchored to current diagnostics.
@@ -702,6 +762,7 @@ def _run_followup_analysis(
ai = AIClient(ai_config) ai = AIClient(ai_config)
system_prompt = build_system_prompt() system_prompt = build_system_prompt()
runbook_chunks = _query_runbooks(runbook_store, question, ai, top_k=1) runbook_chunks = _query_runbooks(runbook_store, question, ai, top_k=1)
past_sessions = _query_sessions(session_store, question, report.host, ai, top_k=2)
user_message: str user_message: str
retrieved_names: list[str] = [] retrieved_names: list[str] = []
@@ -724,6 +785,7 @@ def _run_followup_analysis(
question, question,
prior_questions, prior_questions,
runbook_chunks=runbook_chunks or None, runbook_chunks=runbook_chunks or None,
past_sessions=past_sessions or None,
) )
if rag_debug: if rag_debug:
pairs = ", ".join( pairs = ", ".join(
@@ -741,12 +803,14 @@ def _run_followup_analysis(
user_message = build_followup_message( user_message = build_followup_message(
issue, report, question, prior_questions, issue, report, question, prior_questions,
runbook_chunks=runbook_chunks or None, runbook_chunks=runbook_chunks or None,
past_sessions=past_sessions or None,
) )
else: else:
fallback_reason = "rag not indexed" fallback_reason = "rag not indexed"
user_message = build_followup_message( user_message = build_followup_message(
issue, report, question, prior_questions, issue, report, question, prior_questions,
runbook_chunks=runbook_chunks or None, runbook_chunks=runbook_chunks or None,
past_sessions=past_sessions or None,
) )
if logger is not None: if logger is not None:
@@ -826,6 +890,42 @@ def _query_runbooks(
return [] return []
def _query_sessions(
store: SessionStore | None,
question: str,
host: str,
ai: AIClient,
*,
top_k: int = 2,
) -> list[PastSession]:
"""Query the session memory store silently; returns empty list on failures."""
if store is None:
return []
try:
return store.query(question, host, ai, top_k=top_k)
except Exception: # noqa: BLE001
return []
def _index_session_memory(
store: SessionStore,
ai_config: AIConfig,
req: TroubleshootRequest,
summary: str,
*,
logger: SessionLogger | None,
) -> None:
"""Persist final session summary for future retrieval; non-fatal on failure."""
try:
ai = AIClient(ai_config)
session_id = store.index_session(req.host, req.issue, summary, ai)
if logger is not None:
logger.log_event("session_memory_indexed", {"session_id": session_id})
except Exception as exc: # noqa: BLE001
if logger is not None:
logger.log_event("session_memory_error", {"error": str(exc)})
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# runbooks sub-app # runbooks sub-app
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -149,6 +149,37 @@ _SERVICE_BINARIES: dict[str, list[str]] = {
"apparmor": ["/usr/sbin/aa-status", "/sbin/apparmor_parser"], "apparmor": ["/usr/sbin/aa-status", "/sbin/apparmor_parser"],
} }
_SERVICE_PACKAGES: dict[str, list[str]] = {
"docker": ["docker", "docker-ce"],
"sssd": ["sssd"],
"sshd": ["openssh-server", "openssh"],
"x2go": ["x2goserver", "x2goserver-xsession"],
"xorg": ["xorg-server", "xserver-xorg-core"],
"wayland": ["wayland", "xwayland"],
"selinux": ["selinux-policy", "selinux-policy-targeted"],
"apparmor": ["apparmor"],
}
_GENERIC_SERVICE_STOPWORDS: frozenset[str] = frozenset(
{
"a",
"an",
"and",
"app",
"application",
"daemon",
"for",
"is",
"my",
"not",
"service",
"systemd",
"the",
"unit",
"working",
}
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Command sets # Command sets
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -225,6 +256,9 @@ def plan_from_request(request: TroubleshootRequest) -> CollectionPlan:
) )
for idx, binary_path in enumerate(_SERVICE_BINARIES.get(svc, []), start=1): for idx, binary_path in enumerate(_SERVICE_BINARIES.get(svc, []), start=1):
plan.add(f"binary-{svc}-{idx}", f"ls -l {binary_path}") plan.add(f"binary-{svc}-{idx}", f"ls -l {binary_path}")
for idx, package_name in enumerate(_service_package_candidates(svc), start=1):
plan.add(f"package-rpm-{svc}-{idx}", f"rpm -q {package_name}")
plan.add(f"package-dpkg-{svc}-{idx}", f"dpkg-query -W {package_name}")
plan.add(f"service-{svc}", f"systemctl status {svc}") plan.add(f"service-{svc}", f"systemctl status {svc}")
plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager") plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager")
for cfg_path in _SERVICE_CONFIGS.get(svc, []): for cfg_path in _SERVICE_CONFIGS.get(svc, []):
@@ -258,7 +292,11 @@ def _issue_words(issue: str) -> set[str]:
def _extract_services(issue: str) -> list[str]: def _extract_services(issue: str) -> list[str]:
"""Return known service names mentioned in *issue*.""" """Return service candidates mentioned in *issue*.
Includes known services plus generic service-like tokens near words such
as "service", "daemon", and "unit".
"""
words = _issue_words(issue) words = _issue_words(issue)
found: list[str] = [] found: list[str] = []
for svc in _KNOWN_SERVICES: for svc in _KNOWN_SERVICES:
@@ -266,6 +304,58 @@ def _extract_services(issue: str) -> list[str]:
svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")} svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")}
if words & svc_words: if words & svc_words:
found.append(svc) found.append(svc)
for svc in _extract_generic_service_candidates(issue):
if svc not in found:
found.append(svc)
return found return found
def _extract_generic_service_candidates(issue: str) -> list[str]:
"""Extract likely service names from free text even when not pre-registered."""
tokens = [tok.lower() for tok in re.findall(r"[a-zA-Z0-9_.@-]+", issue)]
if not tokens:
return []
candidates: list[str] = []
for idx, token in enumerate(tokens):
normalized = token[:-8] if token.endswith(".service") else token
if _is_safe_service_name(normalized):
if token.endswith(".service") and normalized not in _GENERIC_SERVICE_STOPWORDS:
candidates.append(normalized)
if token in {"service", "daemon", "unit"}:
for neighbor in (idx - 1, idx + 1):
if neighbor < 0 or neighbor >= len(tokens):
continue
candidate = tokens[neighbor]
if candidate.endswith(".service"):
candidate = candidate[:-8]
if candidate in _GENERIC_SERVICE_STOPWORDS:
continue
if _is_safe_service_name(candidate):
candidates.append(candidate)
deduped: list[str] = []
seen: set[str] = set()
for item in candidates:
if item in seen:
continue
seen.add(item)
deduped.append(item)
return deduped
def _is_safe_service_name(name: str) -> bool:
"""Return True when *name* is safe to interpolate into read-only commands."""
if len(name) < 2 or len(name) > 64:
return False
return re.fullmatch(r"[a-z0-9_.@-]+", name) is not None
def _service_package_candidates(service: str) -> list[str]:
"""Return package names to probe for *service* presence."""
if service in _SERVICE_PACKAGES:
return _SERVICE_PACKAGES[service]
return [service]

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from tai.collectors import CollectionReport from tai.collectors import CollectionReport
from tai.rag_retriever import Chunk from tai.rag_retriever import Chunk
from tai.runbook_store import RunbookChunk from tai.runbook_store import RunbookChunk
from tai.session_store import PastSession
_SYSTEM_PROMPT = """\ _SYSTEM_PROMPT = """\
You are an expert Linux systems administrator and troubleshooting assistant. You are an expert Linux systems administrator and troubleshooting assistant.
@@ -21,7 +22,8 @@ Important rules:
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or - If a command shows "could not be executed (SSH error)" it means the remote host blocked or
rejected that specific command — it is not evidence about the service or system state. rejected that specific command — it is not evidence about the service or system state.
- If service presence checks show a unit, binary, package, or config is missing, treat that as - If service presence checks show a unit, binary, package, or config is missing, treat that as
evidence the component may be absent or not installed, not as proof that the component is broken. evidence the component may be absent or not installed, not as proof that the
component is broken.
- If there is not enough data to diagnose the issue, say so plainly and list exactly what - If there is not enough data to diagnose the issue, say so plainly and list exactly what
additional commands or log files would be needed. additional commands or log files would be needed.
- Keep the response short. Skip sections that have nothing useful to say. - Keep the response short. Skip sections that have nothing useful to say.
@@ -33,6 +35,7 @@ Important rules:
_MAX_RUNBOOK_CHARS = 500 _MAX_RUNBOOK_CHARS = 500
_MAX_DIAGNOSTIC_CHUNK_CHARS = 700 _MAX_DIAGNOSTIC_CHUNK_CHARS = 700
_MAX_SESSION_SUMMARY_CHARS = 500
def build_system_prompt() -> str: def build_system_prompt() -> str:
@@ -66,11 +69,34 @@ def _format_diagnostic_chunk(content: str) -> str:
return text[:_MAX_DIAGNOSTIC_CHUNK_CHARS].rstrip() + "\n...[truncated diagnostic context]" return text[:_MAX_DIAGNOSTIC_CHUNK_CHARS].rstrip() + "\n...[truncated diagnostic context]"
def _format_session_context(past_sessions: list[PastSession]) -> str:
"""Format similar prior sessions as compact grounding context."""
lines: list[str] = ["## Similar prior sessions\n"]
lines.append(
"The following completed sessions were semantically similar. "
"Use them as historical hints, but prioritize current diagnostics if they conflict.\n"
)
for sess in past_sessions:
summary = sess.summary.strip()
if len(summary) > _MAX_SESSION_SUMMARY_CHARS:
summary = (
summary[:_MAX_SESSION_SUMMARY_CHARS].rstrip()
+ "\n...[truncated session summary]"
)
lines.append(f"### Session: {sess.session_id} (host={sess.host})\n")
lines.append(f"**Issue:** {sess.issue}")
lines.append("")
lines.append(summary)
lines.append("")
return "\n".join(lines)
def build_user_message( def build_user_message(
issue: str, issue: str,
report: CollectionReport, report: CollectionReport,
*, *,
runbook_chunks: list[RunbookChunk] | None = None, runbook_chunks: list[RunbookChunk] | None = None,
past_sessions: list[PastSession] | None = None,
) -> str: ) -> str:
"""Format *issue* and *report* into the user message sent to the AI.""" """Format *issue* and *report* into the user message sent to the AI."""
lines: list[str] = [] lines: list[str] = []
@@ -81,6 +107,9 @@ def build_user_message(
if runbook_chunks: if runbook_chunks:
lines.append(_format_runbook_context(runbook_chunks)) lines.append(_format_runbook_context(runbook_chunks))
if past_sessions:
lines.append(_format_session_context(past_sessions))
lines.append("## Collected diagnostics\n") lines.append("## Collected diagnostics\n")
skipped: list[str] = [] skipped: list[str] = []
@@ -126,9 +155,15 @@ def build_followup_message(
prior_questions: list[str], prior_questions: list[str],
*, *,
runbook_chunks: list[RunbookChunk] | None = None, runbook_chunks: list[RunbookChunk] | None = None,
past_sessions: list[PastSession] | None = None,
) -> str: ) -> str:
"""Build a grounded follow-up message that re-anchors to diagnostics each turn.""" """Build a grounded follow-up message that re-anchors to diagnostics each turn."""
base = build_user_message(issue, report, runbook_chunks=runbook_chunks) base = build_user_message(
issue,
report,
runbook_chunks=runbook_chunks,
past_sessions=past_sessions,
)
lines: list[str] = [base, "## Follow-up"] lines: list[str] = [base, "## Follow-up"]
if prior_questions: if prior_questions:
@@ -157,6 +192,7 @@ def build_message_with_chunks(
prior_questions: list[str], prior_questions: list[str],
*, *,
runbook_chunks: list[RunbookChunk] | None = None, runbook_chunks: list[RunbookChunk] | None = None,
past_sessions: list[PastSession] | None = None,
) -> str: ) -> str:
"""Build a follow-up message using only semantically retrieved diagnostic chunks. """Build a follow-up message using only semantically retrieved diagnostic chunks.
@@ -178,6 +214,9 @@ def build_message_with_chunks(
if runbook_chunks: if runbook_chunks:
lines.append(_format_runbook_context(runbook_chunks)) lines.append(_format_runbook_context(runbook_chunks))
if past_sessions:
lines.append(_format_session_context(past_sessions))
lines.append("## Follow-up") lines.append("## Follow-up")
if prior_questions: if prior_questions:
@@ -204,6 +243,7 @@ def build_analysis_message_with_chunks(
chunks: list[Chunk], chunks: list[Chunk],
*, *,
runbook_chunks: list[RunbookChunk] | None = None, runbook_chunks: list[RunbookChunk] | None = None,
past_sessions: list[PastSession] | None = None,
) -> str: ) -> str:
"""Build an initial analysis message from retrieved diagnostic chunks.""" """Build an initial analysis message from retrieved diagnostic chunks."""
lines: list[str] = [] lines: list[str] = []
@@ -213,6 +253,9 @@ def build_analysis_message_with_chunks(
if runbook_chunks: if runbook_chunks:
lines.append(_format_runbook_context(runbook_chunks)) lines.append(_format_runbook_context(runbook_chunks))
if past_sessions:
lines.append(_format_session_context(past_sessions))
lines.append("## Most relevant diagnostics (retrieved by semantic similarity)\n") lines.append("## Most relevant diagnostics (retrieved by semantic similarity)\n")
for chunk in chunks: for chunk in chunks:
lines.append(f"### {chunk.name}\n") lines.append(f"### {chunk.name}\n")

105
src/tai/session_store.py Normal file
View File

@@ -0,0 +1,105 @@
"""Persistent session memory store (Tier 4) backed by ChromaDB."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from tai.ai_client import AIClient
DEFAULT_SESSION_STORE_PATH = "~/.tai/sessions"
_COLLECTION_NAME = "tai_sessions"
@dataclass(slots=True)
class PastSession:
"""A retrieved prior session summary for prompt grounding."""
session_id: str
host: str
issue: str
summary: str
class SessionStore:
"""ChromaDB-backed persistent memory for prior troubleshooting sessions."""
def __init__(self, store_path: str | Path = DEFAULT_SESSION_STORE_PATH) -> None:
import chromadb
path = Path(store_path).expanduser().resolve()
path.mkdir(parents=True, exist_ok=True)
settings = None
try:
from chromadb.config import Settings
settings = Settings(
anonymized_telemetry=False,
chroma_product_telemetry_impl="tai.chroma_telemetry.NoOpProductTelemetryClient",
chroma_telemetry_impl="tai.chroma_telemetry.NoOpProductTelemetryClient",
)
except (ImportError, ModuleNotFoundError):
settings = None
if settings is None:
self._client = chromadb.PersistentClient(path=str(path))
else:
self._client = chromadb.PersistentClient(path=str(path), settings=settings)
self._collection = self._client.get_or_create_collection(
name=_COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
def count(self) -> int:
"""Return number of indexed session summaries."""
return self._collection.count()
def index_session(self, host: str, issue: str, summary: str, ai: AIClient) -> str:
"""Embed and upsert one session summary into persistent storage."""
session_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
embed_text = _build_embed_text(host=host, issue=issue, summary=summary)
embedding = ai.embed(embed_text)
self._collection.upsert(
ids=[session_id],
documents=[summary.strip()],
embeddings=[embedding],
metadatas=[{"host": host, "issue": issue}],
)
return session_id
def query(self, question: str, host: str, ai: AIClient, *, top_k: int = 2) -> list[PastSession]:
"""Return top-k semantically similar sessions for this host and question."""
if self._collection.count() == 0:
return []
q_embedding = ai.embed(f"host: {host}\nquestion: {question}")
results = self._collection.query(
query_embeddings=[q_embedding],
n_results=min(top_k, self._collection.count()),
include=["documents", "metadatas"],
)
sessions: list[PastSession] = []
ids = results.get("ids") or []
docs = results.get("documents") or []
metas = results.get("metadatas") or []
for id_list, doc_list, meta_list in zip(ids, docs, metas, strict=False):
for sid, doc, meta in zip(id_list, doc_list, meta_list, strict=False):
sessions.append(
PastSession(
session_id=str(sid),
host=str(meta.get("host", "")),
issue=str(meta.get("issue", "")),
summary=str(doc),
)
)
return sessions
def _build_embed_text(*, host: str, issue: str, summary: str) -> str:
"""Build embedding text with host/issue context and summary excerpt."""
excerpt = summary.strip()[:1000]
return f"host: {host}\nissue: {issue}\nsummary:\n{excerpt}"

View File

@@ -51,6 +51,7 @@ class SSHClient:
_READ_ONLY_COMMANDS = { _READ_ONLY_COMMANDS = {
"cat", "cat",
"dmesg", "dmesg",
"dpkg-query",
"df", "df",
"du", "du",
"find", "find",
@@ -61,6 +62,7 @@ class SSHClient:
"journalctl", "journalctl",
"ls", "ls",
"netstat", "netstat",
"rpm",
"sed", "sed",
"ss", "ss",
"stat", "stat",

View File

@@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.collectors import CollectedItem, CollectionReport from tai.collectors import CollectedItem, CollectionReport
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
from tai.session_store import PastSession
from tai.ssh_client import SSHCommandResult from tai.ssh_client import SSHCommandResult
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -221,6 +222,24 @@ def test_build_user_message_handles_no_output() -> None:
assert "no output" in msg assert "no output" in msg
def test_build_user_message_includes_prior_session_context() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
msg = build_user_message(
"sssd broken",
report,
past_sessions=[
PastSession(
session_id="20260506T120000Z",
host="web01",
issue="sssd broken",
summary="Root cause was missing sssd package.",
)
],
)
assert "Similar prior sessions" in msg
assert "missing sssd package" in msg
def test_build_followup_message_includes_question_context() -> None: def test_build_followup_message_includes_question_context() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")]) report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
msg = build_followup_message( msg = build_followup_message(

View File

@@ -107,8 +107,12 @@ def test_sssd_in_issue_adds_presence_service_and_config_commands() -> None:
assert "binary-sssd-1" in names assert "binary-sssd-1" in names
assert "service-sssd" in names assert "service-sssd" in names
assert "journal-sssd" in names assert "journal-sssd" in names
assert "package-rpm-sssd-1" in names
assert "package-dpkg-sssd-1" in names
assert any("cat /etc/sssd/sssd.conf" in c for c in cmds) assert any("cat /etc/sssd/sssd.conf" in c for c in cmds)
assert any("ls -l /usr/sbin/sssd" in c for c in cmds) assert any("ls -l /usr/sbin/sssd" in c for c in cmds)
assert any("rpm -q sssd" in c for c in cmds)
assert any("dpkg-query -W sssd" in c for c in cmds)
assert any("list-unit-files sssd.service" in c for c in cmds) assert any("list-unit-files sssd.service" in c for c in cmds)
@@ -119,8 +123,12 @@ def test_docker_presence_probe_checks_package_and_binary() -> None:
assert "unit-file-docker" in names assert "unit-file-docker" in names
assert "binary-docker-1" in names assert "binary-docker-1" in names
assert "binary-docker-2" in names assert "binary-docker-2" in names
assert "package-rpm-docker-1" in names
assert "package-dpkg-docker-1" in names
assert any("ls -l /usr/bin/docker" in c for c in cmds) assert any("ls -l /usr/bin/docker" in c for c in cmds)
assert any("ls -l /usr/bin/dockerd" in c for c in cmds) assert any("ls -l /usr/bin/dockerd" in c for c in cmds)
assert any("rpm -q docker" in c for c in cmds)
assert any("dpkg-query -W docker" in c for c in cmds)
def test_unknown_service_name_no_config_cat() -> None: def test_unknown_service_name_no_config_cat() -> None:
@@ -183,6 +191,16 @@ def test_extract_services_case_insensitive() -> None:
assert "nginx" in _extract_services("NGINX failed") assert "nginx" in _extract_services("NGINX failed")
def test_extract_services_detects_generic_service_name() -> None:
services = _extract_services("myweirdapp service keeps failing")
assert "myweirdapp" in services
def test_extract_services_detects_dot_service_pattern() -> None:
services = _extract_services("please check foobar.service on this host")
assert "foobar" in services
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Plan length sanity # Plan length sanity
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -0,0 +1,79 @@
"""Tests for session_store with mocked ChromaDB."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
from tai.session_store import PastSession, SessionStore, _build_embed_text
def _make_chromadb_mock() -> MagicMock:
collection = MagicMock()
collection.count.return_value = 0
client = MagicMock()
client.get_or_create_collection.return_value = collection
chroma_mod = MagicMock()
chroma_mod.PersistentClient.return_value = client
return chroma_mod
def _make_ai_mock(embedding: list[float] | None = None) -> MagicMock:
ai = MagicMock()
ai.embed.return_value = embedding or [0.1, 0.2, 0.3]
return ai
def test_build_embed_text_contains_host_issue_and_summary() -> None:
text = _build_embed_text(host="web01", issue="sssd broken", summary="Unit missing")
assert "host: web01" in text
assert "issue: sssd broken" in text
assert "Unit missing" in text
def test_index_session_upserts_with_metadata(tmp_path: Path) -> None:
chroma_mock = _make_chromadb_mock()
collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value
ai = _make_ai_mock()
with patch.dict("sys.modules", {"chromadb": chroma_mock}):
store = SessionStore(tmp_path / "store")
session_id = store.index_session("web01", "sssd broken", "summary text", ai)
assert session_id
collection.upsert.assert_called_once()
args = collection.upsert.call_args.kwargs
assert args["metadatas"][0]["host"] == "web01"
assert args["metadatas"][0]["issue"] == "sssd broken"
def test_query_returns_empty_when_no_docs(tmp_path: Path) -> None:
chroma_mock = _make_chromadb_mock()
ai = _make_ai_mock()
with patch.dict("sys.modules", {"chromadb": chroma_mock}):
store = SessionStore(tmp_path / "store")
results = store.query("why sssd", "web01", ai)
assert results == []
def test_query_returns_past_sessions(tmp_path: Path) -> None:
chroma_mock = _make_chromadb_mock()
collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value
collection.count.return_value = 1
collection.query.return_value = {
"ids": [["20260506T120000Z"]],
"documents": [["Root cause: package missing"]],
"metadatas": [[{"host": "web01", "issue": "sssd broken"}]],
}
ai = _make_ai_mock()
with patch.dict("sys.modules", {"chromadb": chroma_mock}):
store = SessionStore(tmp_path / "store")
results = store.query("sssd issue", "web01", ai)
assert len(results) == 1
assert isinstance(results[0], PastSession)
assert results[0].host == "web01"
assert "package missing" in results[0].summary

View File

@@ -82,6 +82,8 @@ def test_allows_expected_read_only_commands() -> None:
"systemctl status apache2", "systemctl status apache2",
"cat /etc/hosts", "cat /etc/hosts",
"ss -lntp", "ss -lntp",
"rpm -q sssd",
"dpkg-query -W sssd",
]: ]:
client.validate_read_only_command(command) client.validate_read_only_command(command)