5 Commits

10 changed files with 527 additions and 11 deletions

View File

@@ -85,7 +85,7 @@ jobs:
id: version id: version
run: | run: |
tag="${GITHUB_REF_NAME}" tag="${GITHUB_REF_NAME}"
deb_version="${tag}" deb_version="${tag#v}" # Remove leading 'v' if present
echo "tag=${tag}" >> "$GITHUB_OUTPUT" echo "tag=${tag}" >> "$GITHUB_OUTPUT"
echo "deb_version=${deb_version}" >> "$GITHUB_OUTPUT" echo "deb_version=${deb_version}" >> "$GITHUB_OUTPUT"

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, cast
from openai import OpenAI from openai import OpenAI
@@ -88,6 +89,20 @@ class AIClient:
if delta: if delta:
yield delta yield delta
def stream_messages(self, messages: list[dict[str, str]]) -> Iterator[str]:
"""Stream a completion from an explicit chat history."""
stream = self._client.chat.completions.create(
model=self._config.model,
max_tokens=self._config.max_tokens,
stream=True,
messages=cast(Any, messages),
)
for chunk in cast(Iterator[Any], stream):
delta = chunk.choices[0].delta.content
if delta:
yield delta
def summary(self) -> str: def summary(self) -> str:
"""Human-readable description of the AI config.""" """Human-readable description of the AI config."""
return f"host={self._config.host} model={self._config.model}" return f"host={self._config.host} model={self._config.model}"

36
src/tai/ai_guardrails.py Normal file
View File

@@ -0,0 +1,36 @@
"""Heuristic checks for AI response quality and safety."""
from __future__ import annotations
import re
_RISKY_ACTION_PATTERNS = [
r"\bsystemctl\s+(restart|stop|start)\b",
r"\b(edit|modify|change)\s+/etc/",
r"\bpasswd\b",
r"\bapt\s+install\b",
r"\bdnf\s+install\b",
r"\byum\s+install\b",
]
def validate_ai_response(response: str) -> list[str]:
"""Return warning messages for potentially unsafe or weakly grounded output."""
warnings: list[str] = []
if "Evidence" not in response:
warnings.append("Response is missing an Evidence section.")
if "`" not in response:
warnings.append("Response does not include quoted evidence snippets.")
lower_response = response.lower()
for pattern in _RISKY_ACTION_PATTERNS:
if re.search(pattern, lower_response):
warnings.append(
"Response suggests potentially modifying actions; "
"prefer read-only verification unless remediation was explicitly requested."
)
break
return warnings

View File

@@ -10,12 +10,14 @@ from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.ai_guardrails import validate_ai_response
from tai.collectors import CollectionReport, collect_from_plan from tai.collectors import CollectionReport, collect_from_plan
from tai.input_parser import InputValidationError, build_request from tai.input_parser import InputValidationError, build_request
from tai.models import TroubleshootRequest from tai.models import TroubleshootRequest
from tai.plan import plan_from_request from tai.plan import plan_from_request
from tai.prompt_builder import build_system_prompt, build_user_message from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig from tai.session_log import SessionLogger
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
app = typer.Typer(no_args_is_help=True, add_completion=False) app = typer.Typer(no_args_is_help=True, add_completion=False)
console = Console() console = Console()
@@ -66,6 +68,13 @@ def run(
help="Send collected diagnostics to AI for analysis.", help="Send collected diagnostics to AI for analysis.",
), ),
] = False, ] = False,
interactive: Annotated[
bool,
typer.Option(
"--interactive/--no-interactive",
help="Start interactive follow-up mode (/collect, /analyze, /quit).",
),
] = False,
ai_host: Annotated[ ai_host: Annotated[
str, str,
typer.Option("--ai-host", help="OpenAI-compatible AI backend URL."), typer.Option("--ai-host", help="OpenAI-compatible AI backend URL."),
@@ -78,6 +87,13 @@ def run(
str, str,
typer.Option("--ai-key", help="API key for the AI backend (not needed for Ollama)."), typer.Option("--ai-key", help="API key for the AI backend (not needed for Ollama)."),
] = "ollama", ] = "ollama",
log_file: Annotated[
str | None,
typer.Option(
"--log-file",
help="Optional JSONL file path to log AI and session output.",
),
] = None,
) -> None: ) -> None:
"""Start an interactive troubleshooting session scaffold.""" """Start an interactive troubleshooting session scaffold."""
try: try:
@@ -109,16 +125,27 @@ def run(
if req.target_paths: if req.target_paths:
console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}") console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}")
if not (probe or collect or analyze): if not (probe or collect or analyze or interactive):
return # nothing SSH-related requested return # nothing SSH-related requested
ai_config = AIConfig(host=ai_host, model=model, api_key=ai_key) ai_config = AIConfig(host=ai_host, model=model, api_key=ai_key)
if analyze: logger = SessionLogger.create(log_file) if log_file else None
if analyze or interactive:
console.print(f"[cyan]AI:[/cyan] {AIClient(ai_config).summary()}") console.print(f"[cyan]AI:[/cyan] {AIClient(ai_config).summary()}")
try: try:
asyncio.run(_async_main(config, req, probe=probe, collect=collect, analyze=analyze, asyncio.run(
ai_config=ai_config)) _async_main(
config,
req,
probe=probe,
collect=collect,
analyze=analyze,
interactive=interactive,
ai_config=ai_config,
logger=logger,
)
)
except typer.Exit: except typer.Exit:
raise raise
except TimeoutError as exc: except TimeoutError as exc:
@@ -136,14 +163,38 @@ async def _async_main(
probe: bool, probe: bool,
collect: bool, collect: bool,
analyze: bool, analyze: bool,
interactive: bool,
ai_config: AIConfig, ai_config: AIConfig,
logger: SessionLogger | None,
) -> None: ) -> None:
"""Open a single SSH session and run probe / collection / analysis through it.""" """Open a single SSH session and run probe / collection / analysis through it."""
client = SSHClient(config) client = SSHClient(config)
if logger is not None:
logger.log_event(
"session_start",
{
"host": req.host,
"port": req.port,
"issue": req.issue,
"probe": probe,
"collect": collect,
"analyze": analyze,
"interactive": interactive,
},
)
async with client.connect() as session: async with client.connect() as session:
if probe: if probe:
result = await session.probe() result = await session.probe()
_handle_probe_result(result) _handle_probe_result(result)
if logger is not None:
logger.log_event(
"probe_result",
{
"exit_code": result.exit_code,
"stdout": result.stdout,
"stderr": result.stderr,
},
)
report: CollectionReport | None = None report: CollectionReport | None = None
if collect or analyze: if collect or analyze:
@@ -151,9 +202,119 @@ async def _async_main(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
if logger is not None:
logger.log_event(
"collection_summary",
{
"total": report.total,
"failed": report.failed,
},
)
if analyze and report is not None: if analyze and report is not None:
_run_analysis(ai_config, req.issue, report) _run_analysis(ai_config, req.issue, report, logger=logger)
if interactive:
await _interactive_loop(session, req, ai_config, report, logger=logger)
async def _interactive_loop(
session: SSHSession,
req: TroubleshootRequest,
ai_config: AIConfig,
report: CollectionReport | None,
logger: SessionLogger | None,
) -> None:
"""Run a follow-up loop for collecting and conversational analysis."""
console.print(
"[cyan]Interactive mode:[/cyan] "
"ask questions directly, or use /collect, /analyze, /help, /quit"
)
prior_questions: list[str] = []
while True:
try:
command = input("tai> ").strip()
except (EOFError, KeyboardInterrupt):
console.print("\n[yellow]Exiting interactive mode.[/yellow]")
if logger is not None:
logger.log_event("interactive_exit", {"reason": "signal_or_eof"})
return
if not command:
continue
if command in {"/quit", "/exit"}:
console.print("[green]Bye.[/green]")
if logger is not None:
logger.log_event("interactive_exit", {"reason": "user_quit"})
return
if command == "/help":
console.print("Commands: /collect, /analyze, /help, /quit")
console.print("Tip: any non-slash text is treated as a follow-up AI question.")
continue
if command == "/collect":
plan = plan_from_request(req)
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
if logger is not None:
logger.log_event(
"collection_summary",
{
"total": report.total,
"failed": report.failed,
},
)
continue
if command == "/analyze":
if report is None:
plan = plan_from_request(req)
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
if report is None:
console.print("[red]No diagnostics available to analyze.[/red]")
continue
_run_followup_analysis(
ai_config,
req.issue,
report,
"Provide an updated diagnosis from the current diagnostics.",
prior_questions,
logger=logger,
)
prior_questions.append("/analyze")
if logger is not None:
logger.log_event("interactive_followup", {"question": "/analyze"})
continue
if report is None:
plan = plan_from_request(req)
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
if report is None:
console.print("[red]No diagnostics available to analyze.[/red]")
continue
_run_followup_analysis(
ai_config,
req.issue,
report,
command,
prior_questions,
logger=logger,
)
prior_questions.append(command)
if logger is not None:
logger.log_event("interactive_followup", {"question": command})
def _handle_probe_result(result: SSHCommandResult) -> None: def _handle_probe_result(result: SSHCommandResult) -> None:
@@ -180,7 +341,13 @@ def _handle_collection_report(report: CollectionReport) -> None:
console.print(f"- {item.name}: {status}{trunc}") console.print(f"- {item.name}: {status}{trunc}")
def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) -> None: def _run_analysis(
ai_config: AIConfig,
issue: str,
report: CollectionReport,
*,
logger: SessionLogger | None,
) -> None:
"""Send collected data to the AI and stream the analysis to stdout.""" """Send collected data to the AI and stream the analysis to stdout."""
console.print("[cyan]Analyzing...[/cyan]\n") console.print("[cyan]Analyzing...[/cyan]\n")
ai = AIClient(ai_config) ai = AIClient(ai_config)
@@ -190,9 +357,69 @@ def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) ->
chunks: list[str] = [] chunks: list[str] = []
for chunk in ai.stream(system_prompt, user_message): for chunk in ai.stream(system_prompt, user_message):
chunks.append(chunk) chunks.append(chunk)
console.print(Markdown("".join(chunks))) response = "".join(chunks)
console.print(Markdown(response))
warnings = validate_ai_response(response)
for item in warnings:
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
if logger is not None:
logger.log_event(
"analysis_response",
{
"issue": issue,
"response": response,
"guardrail_warnings": warnings,
},
)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
console.print(f"[red]AI analysis failed:[/red] {exc}") console.print(f"[red]AI analysis failed:[/red] {exc}")
if logger is not None:
logger.log_event("analysis_error", {"error": str(exc)})
raise typer.Exit(code=1) from exc
def _run_followup_analysis(
ai_config: AIConfig,
issue: str,
report: CollectionReport,
question: str,
prior_questions: list[str],
*,
logger: SessionLogger | None,
) -> str:
"""Run grounded follow-up analysis re-anchored to current diagnostics."""
console.print("[cyan]Analyzing...[/cyan]\n")
ai = AIClient(ai_config)
system_prompt = build_system_prompt()
user_message = build_followup_message(issue, report, question, prior_questions)
try:
chunks: list[str] = []
for chunk in ai.stream(system_prompt, user_message):
chunks.append(chunk)
response = "".join(chunks)
console.print(Markdown(response))
warnings = validate_ai_response(response)
for item in warnings:
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
if logger is not None:
logger.log_event(
"analysis_response",
{
"last_user_message": question,
"response": response,
"guardrail_warnings": warnings,
},
)
return response
except Exception as exc: # noqa: BLE001
console.print(f"[red]AI analysis failed:[/red] {exc}")
if logger is not None:
logger.log_event("analysis_error", {"error": str(exc), "question": question})
raise typer.Exit(code=1) from exc raise typer.Exit(code=1) from exc

View File

@@ -15,12 +15,15 @@ Your job:
Important rules: Important rules:
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence. - Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
- For every root-cause claim, quote at least one exact snippet from collected output in backticks.
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or - If a command shows "could not be executed (SSH error)" it means the remote host blocked or
rejected that specific command — it is not evidence about the service or system state. rejected that specific command — it is not evidence about the service or system state.
- If there is not enough data to diagnose the issue, say so plainly and list exactly what - If there is not enough data to diagnose the issue, say so plainly and list exactly what
additional commands or log files would be needed. additional commands or log files would be needed.
- Keep the response short. Skip sections that have nothing useful to say. - Keep the response short. Skip sections that have nothing useful to say.
- Never suggest commands that modify the system unless explicitly asked. - Never suggest commands that modify the system unless explicitly asked.
- Default to read-only verification steps. Do not suggest restarting services or editing configs
unless the user explicitly asks for remediation actions.
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**. - Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
""" """
@@ -72,3 +75,27 @@ def build_user_message(issue: str, report: CollectionReport) -> str:
) )
return "\n".join(lines) return "\n".join(lines)
def build_followup_message(
issue: str,
report: CollectionReport,
question: str,
prior_questions: list[str],
) -> str:
"""Build a grounded follow-up message that re-anchors to diagnostics each turn."""
base = build_user_message(issue, report)
lines: list[str] = [base, "## Follow-up"]
if prior_questions:
lines.append("\nRecent user follow-up questions:")
for idx, item in enumerate(prior_questions[-5:], start=1):
lines.append(f"{idx}. {item}")
lines.append("\nCurrent follow-up question:")
lines.append(question)
lines.append(
"\nAnswer strictly from the collected diagnostics above. "
"If evidence is insufficient, explicitly say so."
)
return "\n".join(lines)

34
src/tai/session_log.py Normal file
View File

@@ -0,0 +1,34 @@
"""Structured session logging helpers for troubleshooting runs."""
from __future__ import annotations
import json
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
@dataclass(slots=True)
class SessionLogger:
"""Append JSONL events to a log file for post-run analysis."""
path: Path
@classmethod
def create(cls, file_path: str) -> SessionLogger:
"""Create a logger for *file_path*, ensuring parent directories exist."""
path = Path(file_path).expanduser()
path.parent.mkdir(parents=True, exist_ok=True)
return cls(path=path)
def log_event(self, event: str, payload: dict[str, Any]) -> None:
"""Write one timestamped event row to the JSONL log."""
row = {
"ts": datetime.now(UTC).isoformat(),
"event": event,
"payload": payload,
}
with self.path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(row, ensure_ascii=True))
handle.write("\n")

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.collectors import CollectedItem, CollectionReport from tai.collectors import CollectedItem, CollectionReport
from tai.prompt_builder import build_system_prompt, build_user_message from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
from tai.ssh_client import SSHCommandResult from tai.ssh_client import SSHCommandResult
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -116,6 +116,34 @@ def test_stream_yields_chunks() -> None:
assert result == ["Root ", "cause ", "found."] assert result == ["Root ", "cause ", "found."]
def test_stream_messages_yields_chunks() -> None:
config = AIConfig()
client = AIClient(config)
def _make_chunk(text: str | None) -> MagicMock:
delta = MagicMock()
delta.content = text
choice = MagicMock()
choice.delta = delta
chunk = MagicMock()
chunk.choices = [choice]
return chunk
mock_chunks = [_make_chunk("A"), _make_chunk(None), _make_chunk("B")]
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
result = list(
client.stream_messages(
[
{"role": "system", "content": "sys"},
{"role": "user", "content": "question"},
]
)
)
assert result == ["A", "B"]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# prompt_builder # prompt_builder
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -190,3 +218,16 @@ def test_build_user_message_handles_no_output() -> None:
report = _make_report([("empty", "cat /nonexistent", 1, "", "")]) report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
msg = build_user_message("test", report) msg = build_user_message("test", report)
assert "no output" in msg assert "no output" in msg
def test_build_followup_message_includes_question_context() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
msg = build_followup_message(
"nginx is failing",
report,
"what should I check next?",
["is nginx running?", "show me logs"],
)
assert "Current follow-up question" in msg
assert "what should I check next?" in msg
assert "Recent user follow-up questions" in msg

View File

@@ -0,0 +1,24 @@
"""Tests for AI response guardrails."""
from tai.ai_guardrails import validate_ai_response
def test_validate_ai_response_flags_missing_evidence_and_quotes() -> None:
warnings = validate_ai_response("Root cause only, no structure.")
assert any("Evidence section" in item for item in warnings)
assert any("quoted evidence" in item for item in warnings)
def test_validate_ai_response_flags_risky_actions() -> None:
text = "Evidence: `PasswordAuthentication no`\nRun systemctl restart sshd now."
warnings = validate_ai_response(text)
assert any("modifying actions" in item for item in warnings)
def test_validate_ai_response_allows_grounded_read_only_answer() -> None:
text = (
"Evidence: `PasswordAuthentication no`\n"
"Recommended Actions: run `journalctl -u sshd -n 200 --no-pager`"
)
warnings = validate_ai_response(text)
assert not warnings

View File

@@ -139,3 +139,93 @@ def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no
assert "Collection complete" in result.stdout assert "Collection complete" in result.stdout
assert "kernel: ok" in result.stdout assert "kernel: ok" in result.stdout
assert "journal: ok (truncated)" in result.stdout assert "journal: ok (truncated)" in result.stdout
def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["/collect", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "Interactive mode" in result.stdout
assert "Collection complete" in result.stdout
assert "Bye." in result.stdout
def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli.AIClient.stream",
lambda *_args, **_kwargs: iter(["Check logs."]),
)
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "Analyzing..." in result.stdout
assert "Check logs." in result.stdout

22
tests/test_session_log.py Normal file
View File

@@ -0,0 +1,22 @@
"""Tests for structured session logging."""
from __future__ import annotations
import json
from tai.session_log import SessionLogger
def test_session_logger_writes_jsonl_row(tmp_path) -> None: # type: ignore[no-untyped-def]
log_path = tmp_path / "logs" / "session.jsonl"
logger = SessionLogger.create(str(log_path))
logger.log_event("analysis_response", {"response": "Root cause is X"})
lines = log_path.read_text(encoding="utf-8").splitlines()
assert len(lines) == 1
row = json.loads(lines[0])
assert row["event"] == "analysis_response"
assert row["payload"]["response"] == "Root cause is X"
assert "ts" in row