124
src/tai/cli.py
124
src/tai/cli.py
@@ -30,6 +30,7 @@ from tai.prompt_builder import (
|
|||||||
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored
|
from tai.rag_retriever import EmbeddedChunk, chunk_report, retrieve_scored
|
||||||
from tai.runbook_store import RunbookChunk, RunbookStore
|
from tai.runbook_store import RunbookChunk, RunbookStore
|
||||||
from tai.session_log import SessionLogger
|
from tai.session_log import SessionLogger
|
||||||
|
from tai.session_store import PastSession, SessionStore
|
||||||
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
||||||
|
|
||||||
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||||
@@ -151,6 +152,16 @@ def run(
|
|||||||
help="Path to a synced runbook ChromaDB store. Enables Tier 2 RAG.",
|
help="Path to a synced runbook ChromaDB store. Enables Tier 2 RAG.",
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
|
session_memory_path: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Option(
|
||||||
|
"--session-memory",
|
||||||
|
help=(
|
||||||
|
"Path to persistent session memory store for prior-session retrieval "
|
||||||
|
"(Tier 4). Omit to disable."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start an interactive troubleshooting session scaffold."""
|
"""Start an interactive troubleshooting session scaffold."""
|
||||||
try:
|
try:
|
||||||
@@ -207,6 +218,17 @@ def run(
|
|||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
console.print(f"[yellow]Runbook store unavailable:[/yellow] {exc}")
|
console.print(f"[yellow]Runbook store unavailable:[/yellow] {exc}")
|
||||||
|
|
||||||
|
session_store: SessionStore | None = None
|
||||||
|
if session_memory_path:
|
||||||
|
try:
|
||||||
|
session_store = SessionStore(session_memory_path)
|
||||||
|
mem_count = session_store.count()
|
||||||
|
console.print(
|
||||||
|
f"[dim]Session memory: {mem_count} indexed at {session_memory_path}[/dim]"
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
console.print(f"[yellow]Session memory unavailable:[/yellow] {exc}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
_async_main(
|
_async_main(
|
||||||
@@ -220,6 +242,7 @@ def run(
|
|||||||
no_rag=no_rag,
|
no_rag=no_rag,
|
||||||
rag_debug=rag_debug,
|
rag_debug=rag_debug,
|
||||||
runbook_store=runbook_store,
|
runbook_store=runbook_store,
|
||||||
|
session_store=session_store,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -245,6 +268,7 @@ async def _async_main(
|
|||||||
no_rag: bool,
|
no_rag: bool,
|
||||||
rag_debug: bool,
|
rag_debug: bool,
|
||||||
runbook_store: RunbookStore | None,
|
runbook_store: RunbookStore | None,
|
||||||
|
session_store: SessionStore | None,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Open a single SSH session and run probe / collection / analysis through it."""
|
"""Open a single SSH session and run probe / collection / analysis through it."""
|
||||||
@@ -291,19 +315,22 @@ async def _async_main(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
initial_response: str | None = None
|
||||||
if analyze and report is not None:
|
if analyze and report is not None:
|
||||||
_run_analysis(
|
initial_response = _run_analysis(
|
||||||
ai_config,
|
ai_config,
|
||||||
req.issue,
|
req.issue,
|
||||||
report,
|
report,
|
||||||
no_rag=no_rag,
|
no_rag=no_rag,
|
||||||
rag_debug=rag_debug,
|
rag_debug=rag_debug,
|
||||||
runbook_store=runbook_store,
|
runbook_store=runbook_store,
|
||||||
|
session_store=session_store,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
interactive_response: str | None = None
|
||||||
if interactive:
|
if interactive:
|
||||||
await _interactive_loop(
|
interactive_response = await _interactive_loop(
|
||||||
session,
|
session,
|
||||||
req,
|
req,
|
||||||
ai_config,
|
ai_config,
|
||||||
@@ -311,9 +338,14 @@ async def _async_main(
|
|||||||
no_rag=no_rag,
|
no_rag=no_rag,
|
||||||
rag_debug=rag_debug,
|
rag_debug=rag_debug,
|
||||||
runbook_store=runbook_store,
|
runbook_store=runbook_store,
|
||||||
|
session_store=session_store,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
final_response = interactive_response or initial_response
|
||||||
|
if session_store is not None and final_response:
|
||||||
|
_index_session_memory(session_store, ai_config, req, final_response, logger=logger)
|
||||||
|
|
||||||
|
|
||||||
async def _interactive_loop(
|
async def _interactive_loop(
|
||||||
session: SSHSession,
|
session: SSHSession,
|
||||||
@@ -324,8 +356,9 @@ async def _interactive_loop(
|
|||||||
no_rag: bool = False,
|
no_rag: bool = False,
|
||||||
rag_debug: bool = False,
|
rag_debug: bool = False,
|
||||||
runbook_store: RunbookStore | None = None,
|
runbook_store: RunbookStore | None = None,
|
||||||
|
session_store: SessionStore | None = None,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> None:
|
) -> str | None:
|
||||||
"""Run a follow-up loop for collecting and conversational analysis."""
|
"""Run a follow-up loop for collecting and conversational analysis."""
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
@@ -340,6 +373,7 @@ async def _interactive_loop(
|
|||||||
prior_questions: list[str] = []
|
prior_questions: list[str] = []
|
||||||
embedded_chunks: list[EmbeddedChunk] | None = None
|
embedded_chunks: list[EmbeddedChunk] | None = None
|
||||||
ai_embed = AIClient(ai_config)
|
ai_embed = AIClient(ai_config)
|
||||||
|
last_response: str | None = None
|
||||||
|
|
||||||
if not no_rag and report is not None:
|
if not no_rag and report is not None:
|
||||||
embedded_chunks, index_error, index_ms = await asyncio.to_thread(
|
embedded_chunks, index_error, index_ms = await asyncio.to_thread(
|
||||||
@@ -384,7 +418,7 @@ async def _interactive_loop(
|
|||||||
console.print("\n[yellow]Exiting interactive mode.[/yellow]")
|
console.print("\n[yellow]Exiting interactive mode.[/yellow]")
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event("interactive_exit", {"reason": "signal_or_eof"})
|
logger.log_event("interactive_exit", {"reason": "signal_or_eof"})
|
||||||
return
|
return last_response
|
||||||
|
|
||||||
if not command:
|
if not command:
|
||||||
continue
|
continue
|
||||||
@@ -393,7 +427,7 @@ async def _interactive_loop(
|
|||||||
console.print("[green]Bye.[/green]")
|
console.print("[green]Bye.[/green]")
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event("interactive_exit", {"reason": "user_quit"})
|
logger.log_event("interactive_exit", {"reason": "user_quit"})
|
||||||
return
|
return last_response
|
||||||
|
|
||||||
if command == "/help":
|
if command == "/help":
|
||||||
console.print(
|
console.print(
|
||||||
@@ -466,7 +500,7 @@ async def _interactive_loop(
|
|||||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_run_followup_analysis(
|
response = _run_followup_analysis(
|
||||||
ai_config,
|
ai_config,
|
||||||
req.issue,
|
req.issue,
|
||||||
report,
|
report,
|
||||||
@@ -475,11 +509,14 @@ async def _interactive_loop(
|
|||||||
embedded_chunks=embedded_chunks,
|
embedded_chunks=embedded_chunks,
|
||||||
rag_debug=rag_debug,
|
rag_debug=rag_debug,
|
||||||
runbook_store=runbook_store,
|
runbook_store=runbook_store,
|
||||||
|
session_store=session_store,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
prior_questions.append("/analyze")
|
prior_questions.append("/analyze")
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event("interactive_followup", {"question": "/analyze"})
|
logger.log_event("interactive_followup", {"question": "/analyze"})
|
||||||
|
last_response = response
|
||||||
|
continue
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if report is None:
|
if report is None:
|
||||||
@@ -523,7 +560,7 @@ async def _interactive_loop(
|
|||||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_run_followup_analysis(
|
response = _run_followup_analysis(
|
||||||
ai_config,
|
ai_config,
|
||||||
req.issue,
|
req.issue,
|
||||||
report,
|
report,
|
||||||
@@ -532,11 +569,13 @@ async def _interactive_loop(
|
|||||||
embedded_chunks=embedded_chunks,
|
embedded_chunks=embedded_chunks,
|
||||||
rag_debug=rag_debug,
|
rag_debug=rag_debug,
|
||||||
runbook_store=runbook_store,
|
runbook_store=runbook_store,
|
||||||
|
session_store=session_store,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
prior_questions.append(command)
|
prior_questions.append(command)
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event("interactive_followup", {"question": command})
|
logger.log_event("interactive_followup", {"question": command})
|
||||||
|
last_response = response
|
||||||
|
|
||||||
|
|
||||||
def _try_embed_report(
|
def _try_embed_report(
|
||||||
@@ -597,8 +636,9 @@ def _run_analysis(
|
|||||||
no_rag: bool = False,
|
no_rag: bool = False,
|
||||||
rag_debug: bool = False,
|
rag_debug: bool = False,
|
||||||
runbook_store: RunbookStore | None = None,
|
runbook_store: RunbookStore | None = None,
|
||||||
|
session_store: SessionStore | None = None,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> None:
|
) -> str:
|
||||||
"""Send collected data to the AI and stream the analysis to stdout."""
|
"""Send collected data to the AI and stream the analysis to stdout."""
|
||||||
console.print()
|
console.print()
|
||||||
console.print(Rule("[bold cyan]Analysis[/bold cyan]", style="cyan"))
|
console.print(Rule("[bold cyan]Analysis[/bold cyan]", style="cyan"))
|
||||||
@@ -606,10 +646,16 @@ def _run_analysis(
|
|||||||
ai = AIClient(ai_config)
|
ai = AIClient(ai_config)
|
||||||
system_prompt = build_system_prompt()
|
system_prompt = build_system_prompt()
|
||||||
runbook_chunks = _query_runbooks(runbook_store, issue, ai, top_k=1)
|
runbook_chunks = _query_runbooks(runbook_store, issue, ai, top_k=1)
|
||||||
|
past_sessions = _query_sessions(session_store, issue, report.host, ai, top_k=2)
|
||||||
|
|
||||||
user_message: str
|
user_message: str
|
||||||
if no_rag:
|
if no_rag:
|
||||||
user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None)
|
user_message = build_user_message(
|
||||||
|
issue,
|
||||||
|
report,
|
||||||
|
runbook_chunks=runbook_chunks or None,
|
||||||
|
past_sessions=past_sessions or None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
chunks = chunk_report(report)
|
chunks = chunk_report(report)
|
||||||
@@ -628,16 +674,28 @@ def _run_analysis(
|
|||||||
report.host,
|
report.host,
|
||||||
selected,
|
selected,
|
||||||
runbook_chunks=runbook_chunks or None,
|
runbook_chunks=runbook_chunks or None,
|
||||||
|
past_sessions=past_sessions or None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None)
|
user_message = build_user_message(
|
||||||
|
issue,
|
||||||
|
report,
|
||||||
|
runbook_chunks=runbook_chunks or None,
|
||||||
|
past_sessions=past_sessions or None,
|
||||||
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
console.print(
|
console.print(
|
||||||
"[yellow]RAG unavailable for initial analysis; using full-context fallback.[/yellow]"
|
"[yellow]RAG unavailable for initial analysis; "
|
||||||
|
"using full-context fallback.[/yellow]"
|
||||||
)
|
)
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event("rag_index", {"status": "fallback", "error": str(exc)})
|
logger.log_event("rag_index", {"status": "fallback", "error": str(exc)})
|
||||||
user_message = build_user_message(issue, report, runbook_chunks=runbook_chunks or None)
|
user_message = build_user_message(
|
||||||
|
issue,
|
||||||
|
report,
|
||||||
|
runbook_chunks=runbook_chunks or None,
|
||||||
|
past_sessions=past_sessions or None,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response = _complete_ai_response(
|
response = _complete_ai_response(
|
||||||
ai,
|
ai,
|
||||||
@@ -662,6 +720,7 @@ def _run_analysis(
|
|||||||
"guardrail_warnings": warnings,
|
"guardrail_warnings": warnings,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
return response
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
console.print(f"[red]AI analysis failed:[/red] {exc}")
|
console.print(f"[red]AI analysis failed:[/red] {exc}")
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
@@ -688,6 +747,7 @@ def _run_followup_analysis(
|
|||||||
embedded_chunks: list[EmbeddedChunk] | None = None,
|
embedded_chunks: list[EmbeddedChunk] | None = None,
|
||||||
rag_debug: bool = False,
|
rag_debug: bool = False,
|
||||||
runbook_store: RunbookStore | None = None,
|
runbook_store: RunbookStore | None = None,
|
||||||
|
session_store: SessionStore | None = None,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run grounded follow-up analysis re-anchored to current diagnostics.
|
"""Run grounded follow-up analysis re-anchored to current diagnostics.
|
||||||
@@ -702,6 +762,7 @@ def _run_followup_analysis(
|
|||||||
ai = AIClient(ai_config)
|
ai = AIClient(ai_config)
|
||||||
system_prompt = build_system_prompt()
|
system_prompt = build_system_prompt()
|
||||||
runbook_chunks = _query_runbooks(runbook_store, question, ai, top_k=1)
|
runbook_chunks = _query_runbooks(runbook_store, question, ai, top_k=1)
|
||||||
|
past_sessions = _query_sessions(session_store, question, report.host, ai, top_k=2)
|
||||||
|
|
||||||
user_message: str
|
user_message: str
|
||||||
retrieved_names: list[str] = []
|
retrieved_names: list[str] = []
|
||||||
@@ -724,6 +785,7 @@ def _run_followup_analysis(
|
|||||||
question,
|
question,
|
||||||
prior_questions,
|
prior_questions,
|
||||||
runbook_chunks=runbook_chunks or None,
|
runbook_chunks=runbook_chunks or None,
|
||||||
|
past_sessions=past_sessions or None,
|
||||||
)
|
)
|
||||||
if rag_debug:
|
if rag_debug:
|
||||||
pairs = ", ".join(
|
pairs = ", ".join(
|
||||||
@@ -741,12 +803,14 @@ def _run_followup_analysis(
|
|||||||
user_message = build_followup_message(
|
user_message = build_followup_message(
|
||||||
issue, report, question, prior_questions,
|
issue, report, question, prior_questions,
|
||||||
runbook_chunks=runbook_chunks or None,
|
runbook_chunks=runbook_chunks or None,
|
||||||
|
past_sessions=past_sessions or None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
fallback_reason = "rag not indexed"
|
fallback_reason = "rag not indexed"
|
||||||
user_message = build_followup_message(
|
user_message = build_followup_message(
|
||||||
issue, report, question, prior_questions,
|
issue, report, question, prior_questions,
|
||||||
runbook_chunks=runbook_chunks or None,
|
runbook_chunks=runbook_chunks or None,
|
||||||
|
past_sessions=past_sessions or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
@@ -826,6 +890,42 @@ def _query_runbooks(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _query_sessions(
|
||||||
|
store: SessionStore | None,
|
||||||
|
question: str,
|
||||||
|
host: str,
|
||||||
|
ai: AIClient,
|
||||||
|
*,
|
||||||
|
top_k: int = 2,
|
||||||
|
) -> list[PastSession]:
|
||||||
|
"""Query the session memory store silently; returns empty list on failures."""
|
||||||
|
if store is None:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
return store.query(question, host, ai, top_k=top_k)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _index_session_memory(
|
||||||
|
store: SessionStore,
|
||||||
|
ai_config: AIConfig,
|
||||||
|
req: TroubleshootRequest,
|
||||||
|
summary: str,
|
||||||
|
*,
|
||||||
|
logger: SessionLogger | None,
|
||||||
|
) -> None:
|
||||||
|
"""Persist final session summary for future retrieval; non-fatal on failure."""
|
||||||
|
try:
|
||||||
|
ai = AIClient(ai_config)
|
||||||
|
session_id = store.index_session(req.host, req.issue, summary, ai)
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event("session_memory_indexed", {"session_id": session_id})
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event("session_memory_error", {"error": str(exc)})
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# runbooks sub-app
|
# runbooks sub-app
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -149,6 +149,37 @@ _SERVICE_BINARIES: dict[str, list[str]] = {
|
|||||||
"apparmor": ["/usr/sbin/aa-status", "/sbin/apparmor_parser"],
|
"apparmor": ["/usr/sbin/aa-status", "/sbin/apparmor_parser"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_SERVICE_PACKAGES: dict[str, list[str]] = {
|
||||||
|
"docker": ["docker", "docker-ce"],
|
||||||
|
"sssd": ["sssd"],
|
||||||
|
"sshd": ["openssh-server", "openssh"],
|
||||||
|
"x2go": ["x2goserver", "x2goserver-xsession"],
|
||||||
|
"xorg": ["xorg-server", "xserver-xorg-core"],
|
||||||
|
"wayland": ["wayland", "xwayland"],
|
||||||
|
"selinux": ["selinux-policy", "selinux-policy-targeted"],
|
||||||
|
"apparmor": ["apparmor"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_GENERIC_SERVICE_STOPWORDS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"a",
|
||||||
|
"an",
|
||||||
|
"and",
|
||||||
|
"app",
|
||||||
|
"application",
|
||||||
|
"daemon",
|
||||||
|
"for",
|
||||||
|
"is",
|
||||||
|
"my",
|
||||||
|
"not",
|
||||||
|
"service",
|
||||||
|
"systemd",
|
||||||
|
"the",
|
||||||
|
"unit",
|
||||||
|
"working",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Command sets
|
# Command sets
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -225,6 +256,9 @@ def plan_from_request(request: TroubleshootRequest) -> CollectionPlan:
|
|||||||
)
|
)
|
||||||
for idx, binary_path in enumerate(_SERVICE_BINARIES.get(svc, []), start=1):
|
for idx, binary_path in enumerate(_SERVICE_BINARIES.get(svc, []), start=1):
|
||||||
plan.add(f"binary-{svc}-{idx}", f"ls -l {binary_path}")
|
plan.add(f"binary-{svc}-{idx}", f"ls -l {binary_path}")
|
||||||
|
for idx, package_name in enumerate(_service_package_candidates(svc), start=1):
|
||||||
|
plan.add(f"package-rpm-{svc}-{idx}", f"rpm -q {package_name}")
|
||||||
|
plan.add(f"package-dpkg-{svc}-{idx}", f"dpkg-query -W {package_name}")
|
||||||
plan.add(f"service-{svc}", f"systemctl status {svc}")
|
plan.add(f"service-{svc}", f"systemctl status {svc}")
|
||||||
plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager")
|
plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager")
|
||||||
for cfg_path in _SERVICE_CONFIGS.get(svc, []):
|
for cfg_path in _SERVICE_CONFIGS.get(svc, []):
|
||||||
@@ -258,7 +292,11 @@ def _issue_words(issue: str) -> set[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _extract_services(issue: str) -> list[str]:
|
def _extract_services(issue: str) -> list[str]:
|
||||||
"""Return known service names mentioned in *issue*."""
|
"""Return service candidates mentioned in *issue*.
|
||||||
|
|
||||||
|
Includes known services plus generic service-like tokens near words such
|
||||||
|
as "service", "daemon", and "unit".
|
||||||
|
"""
|
||||||
words = _issue_words(issue)
|
words = _issue_words(issue)
|
||||||
found: list[str] = []
|
found: list[str] = []
|
||||||
for svc in _KNOWN_SERVICES:
|
for svc in _KNOWN_SERVICES:
|
||||||
@@ -266,6 +304,58 @@ def _extract_services(issue: str) -> list[str]:
|
|||||||
svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")}
|
svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")}
|
||||||
if words & svc_words:
|
if words & svc_words:
|
||||||
found.append(svc)
|
found.append(svc)
|
||||||
|
for svc in _extract_generic_service_candidates(issue):
|
||||||
|
if svc not in found:
|
||||||
|
found.append(svc)
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_generic_service_candidates(issue: str) -> list[str]:
|
||||||
|
"""Extract likely service names from free text even when not pre-registered."""
|
||||||
|
tokens = [tok.lower() for tok in re.findall(r"[a-zA-Z0-9_.@-]+", issue)]
|
||||||
|
if not tokens:
|
||||||
|
return []
|
||||||
|
|
||||||
|
candidates: list[str] = []
|
||||||
|
for idx, token in enumerate(tokens):
|
||||||
|
normalized = token[:-8] if token.endswith(".service") else token
|
||||||
|
if _is_safe_service_name(normalized):
|
||||||
|
if token.endswith(".service") and normalized not in _GENERIC_SERVICE_STOPWORDS:
|
||||||
|
candidates.append(normalized)
|
||||||
|
|
||||||
|
if token in {"service", "daemon", "unit"}:
|
||||||
|
for neighbor in (idx - 1, idx + 1):
|
||||||
|
if neighbor < 0 or neighbor >= len(tokens):
|
||||||
|
continue
|
||||||
|
candidate = tokens[neighbor]
|
||||||
|
if candidate.endswith(".service"):
|
||||||
|
candidate = candidate[:-8]
|
||||||
|
if candidate in _GENERIC_SERVICE_STOPWORDS:
|
||||||
|
continue
|
||||||
|
if _is_safe_service_name(candidate):
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
deduped: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for item in candidates:
|
||||||
|
if item in seen:
|
||||||
|
continue
|
||||||
|
seen.add(item)
|
||||||
|
deduped.append(item)
|
||||||
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
|
def _is_safe_service_name(name: str) -> bool:
|
||||||
|
"""Return True when *name* is safe to interpolate into read-only commands."""
|
||||||
|
if len(name) < 2 or len(name) > 64:
|
||||||
|
return False
|
||||||
|
return re.fullmatch(r"[a-z0-9_.@-]+", name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _service_package_candidates(service: str) -> list[str]:
|
||||||
|
"""Return package names to probe for *service* presence."""
|
||||||
|
if service in _SERVICE_PACKAGES:
|
||||||
|
return _SERVICE_PACKAGES[service]
|
||||||
|
return [service]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
from tai.collectors import CollectionReport
|
from tai.collectors import CollectionReport
|
||||||
from tai.rag_retriever import Chunk
|
from tai.rag_retriever import Chunk
|
||||||
from tai.runbook_store import RunbookChunk
|
from tai.runbook_store import RunbookChunk
|
||||||
|
from tai.session_store import PastSession
|
||||||
|
|
||||||
_SYSTEM_PROMPT = """\
|
_SYSTEM_PROMPT = """\
|
||||||
You are an expert Linux systems administrator and troubleshooting assistant.
|
You are an expert Linux systems administrator and troubleshooting assistant.
|
||||||
@@ -21,7 +22,8 @@ Important rules:
|
|||||||
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
|
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
|
||||||
rejected that specific command — it is not evidence about the service or system state.
|
rejected that specific command — it is not evidence about the service or system state.
|
||||||
- If service presence checks show a unit, binary, package, or config is missing, treat that as
|
- If service presence checks show a unit, binary, package, or config is missing, treat that as
|
||||||
evidence the component may be absent or not installed, not as proof that the component is broken.
|
evidence the component may be absent or not installed, not as proof that the
|
||||||
|
component is broken.
|
||||||
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
|
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
|
||||||
additional commands or log files would be needed.
|
additional commands or log files would be needed.
|
||||||
- Keep the response short. Skip sections that have nothing useful to say.
|
- Keep the response short. Skip sections that have nothing useful to say.
|
||||||
@@ -33,6 +35,7 @@ Important rules:
|
|||||||
|
|
||||||
_MAX_RUNBOOK_CHARS = 500
|
_MAX_RUNBOOK_CHARS = 500
|
||||||
_MAX_DIAGNOSTIC_CHUNK_CHARS = 700
|
_MAX_DIAGNOSTIC_CHUNK_CHARS = 700
|
||||||
|
_MAX_SESSION_SUMMARY_CHARS = 500
|
||||||
|
|
||||||
|
|
||||||
def build_system_prompt() -> str:
|
def build_system_prompt() -> str:
|
||||||
@@ -66,11 +69,34 @@ def _format_diagnostic_chunk(content: str) -> str:
|
|||||||
return text[:_MAX_DIAGNOSTIC_CHUNK_CHARS].rstrip() + "\n...[truncated diagnostic context]"
|
return text[:_MAX_DIAGNOSTIC_CHUNK_CHARS].rstrip() + "\n...[truncated diagnostic context]"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_session_context(past_sessions: list[PastSession]) -> str:
|
||||||
|
"""Format similar prior sessions as compact grounding context."""
|
||||||
|
lines: list[str] = ["## Similar prior sessions\n"]
|
||||||
|
lines.append(
|
||||||
|
"The following completed sessions were semantically similar. "
|
||||||
|
"Use them as historical hints, but prioritize current diagnostics if they conflict.\n"
|
||||||
|
)
|
||||||
|
for sess in past_sessions:
|
||||||
|
summary = sess.summary.strip()
|
||||||
|
if len(summary) > _MAX_SESSION_SUMMARY_CHARS:
|
||||||
|
summary = (
|
||||||
|
summary[:_MAX_SESSION_SUMMARY_CHARS].rstrip()
|
||||||
|
+ "\n...[truncated session summary]"
|
||||||
|
)
|
||||||
|
lines.append(f"### Session: {sess.session_id} (host={sess.host})\n")
|
||||||
|
lines.append(f"**Issue:** {sess.issue}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(summary)
|
||||||
|
lines.append("")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def build_user_message(
|
def build_user_message(
|
||||||
issue: str,
|
issue: str,
|
||||||
report: CollectionReport,
|
report: CollectionReport,
|
||||||
*,
|
*,
|
||||||
runbook_chunks: list[RunbookChunk] | None = None,
|
runbook_chunks: list[RunbookChunk] | None = None,
|
||||||
|
past_sessions: list[PastSession] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Format *issue* and *report* into the user message sent to the AI."""
|
"""Format *issue* and *report* into the user message sent to the AI."""
|
||||||
lines: list[str] = []
|
lines: list[str] = []
|
||||||
@@ -81,6 +107,9 @@ def build_user_message(
|
|||||||
if runbook_chunks:
|
if runbook_chunks:
|
||||||
lines.append(_format_runbook_context(runbook_chunks))
|
lines.append(_format_runbook_context(runbook_chunks))
|
||||||
|
|
||||||
|
if past_sessions:
|
||||||
|
lines.append(_format_session_context(past_sessions))
|
||||||
|
|
||||||
lines.append("## Collected diagnostics\n")
|
lines.append("## Collected diagnostics\n")
|
||||||
|
|
||||||
skipped: list[str] = []
|
skipped: list[str] = []
|
||||||
@@ -126,9 +155,15 @@ def build_followup_message(
|
|||||||
prior_questions: list[str],
|
prior_questions: list[str],
|
||||||
*,
|
*,
|
||||||
runbook_chunks: list[RunbookChunk] | None = None,
|
runbook_chunks: list[RunbookChunk] | None = None,
|
||||||
|
past_sessions: list[PastSession] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a grounded follow-up message that re-anchors to diagnostics each turn."""
|
"""Build a grounded follow-up message that re-anchors to diagnostics each turn."""
|
||||||
base = build_user_message(issue, report, runbook_chunks=runbook_chunks)
|
base = build_user_message(
|
||||||
|
issue,
|
||||||
|
report,
|
||||||
|
runbook_chunks=runbook_chunks,
|
||||||
|
past_sessions=past_sessions,
|
||||||
|
)
|
||||||
lines: list[str] = [base, "## Follow-up"]
|
lines: list[str] = [base, "## Follow-up"]
|
||||||
|
|
||||||
if prior_questions:
|
if prior_questions:
|
||||||
@@ -157,6 +192,7 @@ def build_message_with_chunks(
|
|||||||
prior_questions: list[str],
|
prior_questions: list[str],
|
||||||
*,
|
*,
|
||||||
runbook_chunks: list[RunbookChunk] | None = None,
|
runbook_chunks: list[RunbookChunk] | None = None,
|
||||||
|
past_sessions: list[PastSession] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a follow-up message using only semantically retrieved diagnostic chunks.
|
"""Build a follow-up message using only semantically retrieved diagnostic chunks.
|
||||||
|
|
||||||
@@ -178,6 +214,9 @@ def build_message_with_chunks(
|
|||||||
if runbook_chunks:
|
if runbook_chunks:
|
||||||
lines.append(_format_runbook_context(runbook_chunks))
|
lines.append(_format_runbook_context(runbook_chunks))
|
||||||
|
|
||||||
|
if past_sessions:
|
||||||
|
lines.append(_format_session_context(past_sessions))
|
||||||
|
|
||||||
lines.append("## Follow-up")
|
lines.append("## Follow-up")
|
||||||
|
|
||||||
if prior_questions:
|
if prior_questions:
|
||||||
@@ -204,6 +243,7 @@ def build_analysis_message_with_chunks(
|
|||||||
chunks: list[Chunk],
|
chunks: list[Chunk],
|
||||||
*,
|
*,
|
||||||
runbook_chunks: list[RunbookChunk] | None = None,
|
runbook_chunks: list[RunbookChunk] | None = None,
|
||||||
|
past_sessions: list[PastSession] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build an initial analysis message from retrieved diagnostic chunks."""
|
"""Build an initial analysis message from retrieved diagnostic chunks."""
|
||||||
lines: list[str] = []
|
lines: list[str] = []
|
||||||
@@ -213,6 +253,9 @@ def build_analysis_message_with_chunks(
|
|||||||
if runbook_chunks:
|
if runbook_chunks:
|
||||||
lines.append(_format_runbook_context(runbook_chunks))
|
lines.append(_format_runbook_context(runbook_chunks))
|
||||||
|
|
||||||
|
if past_sessions:
|
||||||
|
lines.append(_format_session_context(past_sessions))
|
||||||
|
|
||||||
lines.append("## Most relevant diagnostics (retrieved by semantic similarity)\n")
|
lines.append("## Most relevant diagnostics (retrieved by semantic similarity)\n")
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
lines.append(f"### {chunk.name}\n")
|
lines.append(f"### {chunk.name}\n")
|
||||||
|
|||||||
105
src/tai/session_store.py
Normal file
105
src/tai/session_store.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Persistent session memory store (Tier 4) backed by ChromaDB."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tai.ai_client import AIClient
|
||||||
|
|
||||||
|
DEFAULT_SESSION_STORE_PATH = "~/.tai/sessions"
|
||||||
|
_COLLECTION_NAME = "tai_sessions"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class PastSession:
|
||||||
|
"""A retrieved prior session summary for prompt grounding."""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
host: str
|
||||||
|
issue: str
|
||||||
|
summary: str
|
||||||
|
|
||||||
|
|
||||||
|
class SessionStore:
|
||||||
|
"""ChromaDB-backed persistent memory for prior troubleshooting sessions."""
|
||||||
|
|
||||||
|
def __init__(self, store_path: str | Path = DEFAULT_SESSION_STORE_PATH) -> None:
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
path = Path(store_path).expanduser().resolve()
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
settings = None
|
||||||
|
try:
|
||||||
|
from chromadb.config import Settings
|
||||||
|
|
||||||
|
settings = Settings(
|
||||||
|
anonymized_telemetry=False,
|
||||||
|
chroma_product_telemetry_impl="tai.chroma_telemetry.NoOpProductTelemetryClient",
|
||||||
|
chroma_telemetry_impl="tai.chroma_telemetry.NoOpProductTelemetryClient",
|
||||||
|
)
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
settings = None
|
||||||
|
|
||||||
|
if settings is None:
|
||||||
|
self._client = chromadb.PersistentClient(path=str(path))
|
||||||
|
else:
|
||||||
|
self._client = chromadb.PersistentClient(path=str(path), settings=settings)
|
||||||
|
self._collection = self._client.get_or_create_collection(
|
||||||
|
name=_COLLECTION_NAME,
|
||||||
|
metadata={"hnsw:space": "cosine"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Return number of indexed session summaries."""
|
||||||
|
return self._collection.count()
|
||||||
|
|
||||||
|
def index_session(self, host: str, issue: str, summary: str, ai: AIClient) -> str:
|
||||||
|
"""Embed and upsert one session summary into persistent storage."""
|
||||||
|
session_id = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
|
||||||
|
embed_text = _build_embed_text(host=host, issue=issue, summary=summary)
|
||||||
|
embedding = ai.embed(embed_text)
|
||||||
|
self._collection.upsert(
|
||||||
|
ids=[session_id],
|
||||||
|
documents=[summary.strip()],
|
||||||
|
embeddings=[embedding],
|
||||||
|
metadatas=[{"host": host, "issue": issue}],
|
||||||
|
)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
def query(self, question: str, host: str, ai: AIClient, *, top_k: int = 2) -> list[PastSession]:
|
||||||
|
"""Return top-k semantically similar sessions for this host and question."""
|
||||||
|
if self._collection.count() == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
q_embedding = ai.embed(f"host: {host}\nquestion: {question}")
|
||||||
|
results = self._collection.query(
|
||||||
|
query_embeddings=[q_embedding],
|
||||||
|
n_results=min(top_k, self._collection.count()),
|
||||||
|
include=["documents", "metadatas"],
|
||||||
|
)
|
||||||
|
|
||||||
|
sessions: list[PastSession] = []
|
||||||
|
ids = results.get("ids") or []
|
||||||
|
docs = results.get("documents") or []
|
||||||
|
metas = results.get("metadatas") or []
|
||||||
|
for id_list, doc_list, meta_list in zip(ids, docs, metas, strict=False):
|
||||||
|
for sid, doc, meta in zip(id_list, doc_list, meta_list, strict=False):
|
||||||
|
sessions.append(
|
||||||
|
PastSession(
|
||||||
|
session_id=str(sid),
|
||||||
|
host=str(meta.get("host", "")),
|
||||||
|
issue=str(meta.get("issue", "")),
|
||||||
|
summary=str(doc),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
|
||||||
|
def _build_embed_text(*, host: str, issue: str, summary: str) -> str:
|
||||||
|
"""Build embedding text with host/issue context and summary excerpt."""
|
||||||
|
excerpt = summary.strip()[:1000]
|
||||||
|
return f"host: {host}\nissue: {issue}\nsummary:\n{excerpt}"
|
||||||
@@ -51,6 +51,7 @@ class SSHClient:
|
|||||||
_READ_ONLY_COMMANDS = {
|
_READ_ONLY_COMMANDS = {
|
||||||
"cat",
|
"cat",
|
||||||
"dmesg",
|
"dmesg",
|
||||||
|
"dpkg-query",
|
||||||
"df",
|
"df",
|
||||||
"du",
|
"du",
|
||||||
"find",
|
"find",
|
||||||
@@ -61,6 +62,7 @@ class SSHClient:
|
|||||||
"journalctl",
|
"journalctl",
|
||||||
"ls",
|
"ls",
|
||||||
"netstat",
|
"netstat",
|
||||||
|
"rpm",
|
||||||
"sed",
|
"sed",
|
||||||
"ss",
|
"ss",
|
||||||
"stat",
|
"stat",
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||||
from tai.collectors import CollectedItem, CollectionReport
|
from tai.collectors import CollectedItem, CollectionReport
|
||||||
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
|
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
|
||||||
|
from tai.session_store import PastSession
|
||||||
from tai.ssh_client import SSHCommandResult
|
from tai.ssh_client import SSHCommandResult
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -221,6 +222,24 @@ def test_build_user_message_handles_no_output() -> None:
|
|||||||
assert "no output" in msg
|
assert "no output" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_includes_prior_session_context() -> None:
|
||||||
|
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
|
||||||
|
msg = build_user_message(
|
||||||
|
"sssd broken",
|
||||||
|
report,
|
||||||
|
past_sessions=[
|
||||||
|
PastSession(
|
||||||
|
session_id="20260506T120000Z",
|
||||||
|
host="web01",
|
||||||
|
issue="sssd broken",
|
||||||
|
summary="Root cause was missing sssd package.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert "Similar prior sessions" in msg
|
||||||
|
assert "missing sssd package" in msg
|
||||||
|
|
||||||
|
|
||||||
def test_build_followup_message_includes_question_context() -> None:
|
def test_build_followup_message_includes_question_context() -> None:
|
||||||
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
|
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
|
||||||
msg = build_followup_message(
|
msg = build_followup_message(
|
||||||
|
|||||||
@@ -107,8 +107,12 @@ def test_sssd_in_issue_adds_presence_service_and_config_commands() -> None:
|
|||||||
assert "binary-sssd-1" in names
|
assert "binary-sssd-1" in names
|
||||||
assert "service-sssd" in names
|
assert "service-sssd" in names
|
||||||
assert "journal-sssd" in names
|
assert "journal-sssd" in names
|
||||||
|
assert "package-rpm-sssd-1" in names
|
||||||
|
assert "package-dpkg-sssd-1" in names
|
||||||
assert any("cat /etc/sssd/sssd.conf" in c for c in cmds)
|
assert any("cat /etc/sssd/sssd.conf" in c for c in cmds)
|
||||||
assert any("ls -l /usr/sbin/sssd" in c for c in cmds)
|
assert any("ls -l /usr/sbin/sssd" in c for c in cmds)
|
||||||
|
assert any("rpm -q sssd" in c for c in cmds)
|
||||||
|
assert any("dpkg-query -W sssd" in c for c in cmds)
|
||||||
assert any("list-unit-files sssd.service" in c for c in cmds)
|
assert any("list-unit-files sssd.service" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
@@ -119,8 +123,12 @@ def test_docker_presence_probe_checks_package_and_binary() -> None:
|
|||||||
assert "unit-file-docker" in names
|
assert "unit-file-docker" in names
|
||||||
assert "binary-docker-1" in names
|
assert "binary-docker-1" in names
|
||||||
assert "binary-docker-2" in names
|
assert "binary-docker-2" in names
|
||||||
|
assert "package-rpm-docker-1" in names
|
||||||
|
assert "package-dpkg-docker-1" in names
|
||||||
assert any("ls -l /usr/bin/docker" in c for c in cmds)
|
assert any("ls -l /usr/bin/docker" in c for c in cmds)
|
||||||
assert any("ls -l /usr/bin/dockerd" in c for c in cmds)
|
assert any("ls -l /usr/bin/dockerd" in c for c in cmds)
|
||||||
|
assert any("rpm -q docker" in c for c in cmds)
|
||||||
|
assert any("dpkg-query -W docker" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_service_name_no_config_cat() -> None:
|
def test_unknown_service_name_no_config_cat() -> None:
|
||||||
@@ -183,6 +191,16 @@ def test_extract_services_case_insensitive() -> None:
|
|||||||
assert "nginx" in _extract_services("NGINX failed")
|
assert "nginx" in _extract_services("NGINX failed")
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_services_detects_generic_service_name() -> None:
|
||||||
|
services = _extract_services("myweirdapp service keeps failing")
|
||||||
|
assert "myweirdapp" in services
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_services_detects_dot_service_pattern() -> None:
|
||||||
|
services = _extract_services("please check foobar.service on this host")
|
||||||
|
assert "foobar" in services
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Plan length sanity
|
# Plan length sanity
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
79
tests/test_session_store.py
Normal file
79
tests/test_session_store.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Tests for session_store with mocked ChromaDB."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from tai.session_store import PastSession, SessionStore, _build_embed_text
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chromadb_mock() -> MagicMock:
|
||||||
|
collection = MagicMock()
|
||||||
|
collection.count.return_value = 0
|
||||||
|
client = MagicMock()
|
||||||
|
client.get_or_create_collection.return_value = collection
|
||||||
|
chroma_mod = MagicMock()
|
||||||
|
chroma_mod.PersistentClient.return_value = client
|
||||||
|
return chroma_mod
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ai_mock(embedding: list[float] | None = None) -> MagicMock:
|
||||||
|
ai = MagicMock()
|
||||||
|
ai.embed.return_value = embedding or [0.1, 0.2, 0.3]
|
||||||
|
return ai
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_embed_text_contains_host_issue_and_summary() -> None:
|
||||||
|
text = _build_embed_text(host="web01", issue="sssd broken", summary="Unit missing")
|
||||||
|
assert "host: web01" in text
|
||||||
|
assert "issue: sssd broken" in text
|
||||||
|
assert "Unit missing" in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_index_session_upserts_with_metadata(tmp_path: Path) -> None:
|
||||||
|
chroma_mock = _make_chromadb_mock()
|
||||||
|
collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value
|
||||||
|
ai = _make_ai_mock()
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"chromadb": chroma_mock}):
|
||||||
|
store = SessionStore(tmp_path / "store")
|
||||||
|
session_id = store.index_session("web01", "sssd broken", "summary text", ai)
|
||||||
|
|
||||||
|
assert session_id
|
||||||
|
collection.upsert.assert_called_once()
|
||||||
|
args = collection.upsert.call_args.kwargs
|
||||||
|
assert args["metadatas"][0]["host"] == "web01"
|
||||||
|
assert args["metadatas"][0]["issue"] == "sssd broken"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_returns_empty_when_no_docs(tmp_path: Path) -> None:
|
||||||
|
chroma_mock = _make_chromadb_mock()
|
||||||
|
ai = _make_ai_mock()
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"chromadb": chroma_mock}):
|
||||||
|
store = SessionStore(tmp_path / "store")
|
||||||
|
results = store.query("why sssd", "web01", ai)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_returns_past_sessions(tmp_path: Path) -> None:
|
||||||
|
chroma_mock = _make_chromadb_mock()
|
||||||
|
collection = chroma_mock.PersistentClient.return_value.get_or_create_collection.return_value
|
||||||
|
collection.count.return_value = 1
|
||||||
|
collection.query.return_value = {
|
||||||
|
"ids": [["20260506T120000Z"]],
|
||||||
|
"documents": [["Root cause: package missing"]],
|
||||||
|
"metadatas": [[{"host": "web01", "issue": "sssd broken"}]],
|
||||||
|
}
|
||||||
|
ai = _make_ai_mock()
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"chromadb": chroma_mock}):
|
||||||
|
store = SessionStore(tmp_path / "store")
|
||||||
|
results = store.query("sssd issue", "web01", ai)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], PastSession)
|
||||||
|
assert results[0].host == "web01"
|
||||||
|
assert "package missing" in results[0].summary
|
||||||
@@ -82,6 +82,8 @@ def test_allows_expected_read_only_commands() -> None:
|
|||||||
"systemctl status apache2",
|
"systemctl status apache2",
|
||||||
"cat /etc/hosts",
|
"cat /etc/hosts",
|
||||||
"ss -lntp",
|
"ss -lntp",
|
||||||
|
"rpm -q sssd",
|
||||||
|
"dpkg-query -W sssd",
|
||||||
]:
|
]:
|
||||||
client.validate_read_only_command(command)
|
client.validate_read_only_command(command)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user