update
Some checks failed
CI / test (push) Failing after 15s

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-05-04 04:51:48 +02:00
parent e6233f237b
commit 61d3e2c4e6
8 changed files with 757 additions and 89 deletions

View File

@@ -0,0 +1,110 @@
name: Release
on:
push:
tags:
- "v*"
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Ensure git is available
run: |
if command -v git >/dev/null 2>&1; then
git --version
exit 0
fi
if command -v apt-get >/dev/null 2>&1; then
apt-get update
apt-get install -y git
elif command -v dnf >/dev/null 2>&1; then
dnf install -y git
elif command -v yum >/dev/null 2>&1; then
yum install -y git
else
echo "No supported package manager found to install git."
exit 1
fi
- name: Checkout source (native git)
env:
CI_GIT_TOKEN: ${{ secrets.CI_GIT_TOKEN }}
run: |
if [ -z "${CI_GIT_TOKEN:-}" ]; then
echo "Missing secret CI_GIT_TOKEN. Add it in repository Actions secrets."
exit 1
fi
auth_server="${GITHUB_SERVER_URL#https://}"
auth_server="${auth_server#http://}"
remote_url="https://oauth2:${CI_GIT_TOKEN}@${auth_server}/${GITHUB_REPOSITORY}.git"
if [ -n "${GITHUB_WORKSPACE:-}" ]; then
cd "$GITHUB_WORKSPACE"
fi
if [ ! -d .git ]; then
git init
fi
git remote remove origin >/dev/null 2>&1 || true
git remote add origin "$remote_url"
# Fetch the tag by SHA so we get the exact tagged commit
git fetch --depth 1 origin "$GITHUB_SHA"
git checkout --force FETCH_HEAD
- name: Ensure Python and build dependencies are available
run: |
if ! command -v python3 >/dev/null 2>&1; then
if command -v apt-get >/dev/null 2>&1; then
apt-get update
apt-get install -y python3 python3-pip python3-venv patchelf ccache
elif command -v dnf >/dev/null 2>&1; then
dnf install -y python3 python3-pip patchelf ccache
fi
fi
# patchelf is required by Nuitka for standalone Linux binaries
command -v patchelf >/dev/null 2>&1 || {
apt-get update && apt-get install -y patchelf
}
python3 --version
- name: Set up venv and install package + build deps
run: |
python3 -m venv .venv
. .venv/bin/activate
python -m pip install --upgrade pip
python -m pip install -e ".[build]"
- name: Derive version from tag
id: version
run: echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
- name: Build standalone binary with Nuitka
run: |
. .venv/bin/activate
python -m nuitka \
--standalone \
--onefile \
--output-filename=tai \
--output-dir=dist \
--assume-yes-for-downloads \
--include-package=tai \
src/tai/cli.py
- name: Smoke-test the binary
run: dist/tai --help
- name: Upload binary artifact
uses: actions/upload-artifact@v3
with:
name: tai-linux-amd64-${{ steps.version.outputs.tag }}
path: dist/tai
if-no-files-found: error
retention-days: 90

93
src/tai/ai_client.py Normal file
View File

@@ -0,0 +1,93 @@
"""AI backend client — OpenAI-compatible, works with Ollama, OpenAI, or any compatible endpoint."""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass, field
from openai import OpenAI
DEFAULT_AI_HOST = "http://localhost:11434/v1"
DEFAULT_MODEL = "gemma3:4b"
@dataclass(slots=True)
class AIConfig:
"""Connection parameters for an OpenAI-compatible AI backend."""
host: str = DEFAULT_AI_HOST
model: str = DEFAULT_MODEL
api_key: str = "ollama" # Ollama ignores this; required by the openai client
timeout_seconds: float = 120.0
max_tokens: int = 4096
extra_headers: dict[str, str] = field(default_factory=dict)
@dataclass(slots=True)
class AIResponse:
"""Structured response from an AI completion."""
model: str
content: str
prompt_tokens: int
completion_tokens: int
@property
def total_tokens(self) -> int:
return self.prompt_tokens + self.completion_tokens
class AIClient:
"""Thin wrapper around the openai client targeting a configurable endpoint."""
def __init__(self, config: AIConfig) -> None:
self._config = config
self._client = OpenAI(
base_url=config.host,
api_key=config.api_key,
timeout=config.timeout_seconds,
default_headers=config.extra_headers,
)
def complete(self, system_prompt: str, user_message: str) -> AIResponse:
"""Send a completion request and return the full response."""
response = self._client.chat.completions.create(
model=self._config.model,
max_tokens=self._config.max_tokens,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
)
choice = response.choices[0]
content = choice.message.content or ""
usage = response.usage
return AIResponse(
model=response.model,
content=content,
prompt_tokens=usage.prompt_tokens if usage else 0,
completion_tokens=usage.completion_tokens if usage else 0,
)
def stream(self, system_prompt: str, user_message: str) -> Iterator[str]:
"""Stream a completion, yielding text chunks as they arrive."""
stream = self._client.chat.completions.create(
model=self._config.model,
max_tokens=self._config.max_tokens,
stream=True,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
)
for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
yield delta
def summary(self) -> str:
"""Human-readable description of the AI config."""
return f"host={self._config.host} model={self._config.model}"

View File

@@ -7,11 +7,14 @@ from typing import Annotated
import typer
from rich.console import Console
from rich.markdown import Markdown
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.collectors import CollectionReport, collect_from_plan
from tai.input_parser import InputValidationError, build_request
from tai.models import TroubleshootRequest
from tai.plan import plan_from_request
from tai.prompt_builder import build_system_prompt, build_user_message
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig
app = typer.Typer(no_args_is_help=True, add_completion=False)
@@ -56,6 +59,25 @@ def run(
help="Collect baseline diagnostics after probe.",
),
] = False,
analyze: Annotated[
bool,
typer.Option(
"--analyze/--no-analyze",
help="Send collected diagnostics to AI for analysis.",
),
] = False,
ai_host: Annotated[
str,
typer.Option("--ai-host", help="OpenAI-compatible AI backend URL."),
] = DEFAULT_AI_HOST,
model: Annotated[
str,
typer.Option("--model", help="Model name to use for AI analysis."),
] = DEFAULT_MODEL,
ai_key: Annotated[
str,
typer.Option("--ai-key", help="API key for the AI backend (not needed for Ollama)."),
] = "ollama",
) -> None:
"""Start an interactive troubleshooting session scaffold."""
try:
@@ -81,50 +103,69 @@ def run(
)
summary = SSHClient(config).summary()
console.print("[bold green]tai scaffold ready[/bold green]")
console.print("[bold green]tai[/bold green]")
console.print(f"Issue: {req.issue}")
console.print(f"SSH: {summary}")
console.print(f"SSH: {summary}")
if req.target_paths:
console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}")
if not (probe or collect or analyze):
return # nothing SSH-related requested
ai_config = AIConfig(host=ai_host, model=model, api_key=ai_key)
if analyze:
console.print(f"[cyan]AI:[/cyan] {AIClient(ai_config).summary()}")
try:
asyncio.run(_async_main(config, req, probe=probe, collect=collect, analyze=analyze,
ai_config=ai_config))
except typer.Exit:
raise
except TimeoutError as exc:
console.print(f"[red]SSH timeout:[/red] {exc}")
raise typer.Exit(code=1) from exc
except OSError as exc:
console.print(f"[red]SSH error:[/red] unable to execute ssh: {exc}")
raise typer.Exit(code=1) from exc
async def _async_main(
config: SSHConnectionConfig,
req: TroubleshootRequest,
*,
probe: bool,
collect: bool,
analyze: bool,
ai_config: AIConfig,
) -> None:
"""Open a single SSH session and run probe / collection / analysis through it."""
client = SSHClient(config)
async with client.connect() as session:
if probe:
result = await session.probe()
_handle_probe_result(result)
if probe:
_run_probe(client)
report: CollectionReport | None = None
if collect or analyze:
plan = plan_from_request(req)
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
if collect:
_run_collection(client, req)
if analyze and report is not None:
_run_analysis(ai_config, req.issue, report)
def _run_probe(client: SSHClient) -> None:
"""Run a live SSH probe and exit non-zero on failure."""
def _handle_probe_result(result: SSHCommandResult) -> None:
"""Handle and render probe output for success or failure."""
console.print("[cyan]Running SSH probe:[/cyan] uname -a")
try:
result = asyncio.run(client.probe())
except TimeoutError as exc:
console.print(f"[red]Probe failed:[/red] {exc}")
raise typer.Exit(code=1) from exc
except OSError as exc:
console.print(f"[red]Probe failed:[/red] unable to execute ssh: {exc}")
raise typer.Exit(code=1) from exc
_handle_probe_result(result)
def _run_collection(client: SSHClient, request: TroubleshootRequest) -> None:
"""Run issue-aware collection and print a compact summary."""
plan = plan_from_request(request)
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
try:
report = asyncio.run(collect_from_plan(client, plan))
except TimeoutError as exc:
console.print(f"[red]Collection failed:[/red] {exc}")
raise typer.Exit(code=1) from exc
except OSError as exc:
console.print(f"[red]Collection failed:[/red] unable to execute ssh: {exc}")
raise typer.Exit(code=1) from exc
_handle_collection_report(report)
if result.exit_code != 0:
details = result.stderr or result.stdout or "no error output from ssh"
console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}")
raise typer.Exit(code=1)
output = result.stdout or "(no output)"
console.print("[bold green]Probe succeeded.[/bold green]")
console.print(f"Remote: {output}")
def _handle_collection_report(report: CollectionReport) -> None:
@@ -134,22 +175,25 @@ def _handle_collection_report(report: CollectionReport) -> None:
)
for item in report.items:
status = "ok" if item.result.exit_code == 0 else f"exit {item.result.exit_code}"
trunc = ""
if item.result.stdout_truncated or item.result.stderr_truncated:
trunc = " (truncated)"
truncated = item.result.stdout_truncated or item.result.stderr_truncated
trunc = " (truncated)" if truncated else ""
console.print(f"- {item.name}: {status}{trunc}")
def _handle_probe_result(result: SSHCommandResult) -> None:
"""Handle and render probe output for success or failure."""
if result.exit_code != 0:
details = result.stderr or result.stdout or "no error output from ssh"
console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}")
raise typer.Exit(code=1)
output = result.stdout or "(no output)"
console.print("[bold green]Probe succeeded.[/bold green]")
console.print(f"Remote: {output}")
def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) -> None:
"""Send collected data to the AI and stream the analysis to stdout."""
console.print("[cyan]Analyzing...[/cyan]\n")
ai = AIClient(ai_config)
system_prompt = build_system_prompt()
user_message = build_user_message(issue, report)
try:
chunks: list[str] = []
for chunk in ai.stream(system_prompt, user_message):
chunks.append(chunk)
console.print(Markdown("".join(chunks)))
except Exception as exc: # noqa: BLE001
console.print(f"[red]AI analysis failed:[/red] {exc}")
raise typer.Exit(code=1) from exc
def main() -> None:

View File

@@ -3,7 +3,7 @@
from dataclasses import dataclass
from tai.plan import CollectionPlan
from tai.ssh_client import SSHClient, SSHCommandResult
from tai.ssh_client import SSHCommandResult, SSHSession
@dataclass(slots=True)
@@ -31,20 +31,20 @@ class CollectionReport:
async def collect_from_plan(
client: SSHClient,
session: SSHSession,
plan: CollectionPlan,
*,
max_output_bytes: int = 32768,
) -> CollectionReport:
"""Execute all commands in *plan* and return a :class:`CollectionReport`."""
"""Execute all commands in *plan* over a shared SSH session."""
items: list[CollectedItem] = []
for name, command in plan.commands:
result = await client.run_read_only_command(
result = await session.run_read_only_command(
command,
timeout_seconds=30.0,
max_output_bytes=max_output_bytes,
)
items.append(CollectedItem(name=name, result=result))
return CollectionReport(host=client.summary(), items=items)
return CollectionReport(host=session._client.summary(), items=items)

74
src/tai/prompt_builder.py Normal file
View File

@@ -0,0 +1,74 @@
"""Formats collected diagnostics into prompts for the AI backend."""
from __future__ import annotations
from tai.collectors import CollectionReport
_SYSTEM_PROMPT = """\
You are an expert Linux systems administrator and troubleshooting assistant.
You are given diagnostic data collected read-only from a remote Linux host via SSH.
Your job:
1. Identify the root cause of the reported issue based only on the data provided.
2. Cite the specific output that supports your conclusion.
3. Give concise, actionable remediation steps.
Important rules:
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
rejected that specific command — it is not evidence about the service or system state.
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
additional commands or log files would be needed.
- Keep the response short. Skip sections that have nothing useful to say.
- Never suggest commands that modify the system unless explicitly asked.
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
"""
def build_system_prompt() -> str:
"""Return the static system prompt for the troubleshooting agent."""
return _SYSTEM_PROMPT.strip()
def build_user_message(issue: str, report: CollectionReport) -> str:
"""Format *issue* and *report* into the user message sent to the AI."""
lines: list[str] = []
lines.append(f"## Issue reported\n\n{issue}\n")
lines.append(f"## Target host\n\n{report.host}\n")
lines.append("## Collected diagnostics\n")
skipped: list[str] = []
for item in report.items:
result = item.result
# Exit 255 with no output = SSH couldn't execute the command at all.
# Exclude these entirely to prevent the AI from speculating on them.
if result.exit_code == 255 and not result.stdout and not result.stderr:
skipped.append(item.name)
continue
lines.append(f"### {item.name}\n")
lines.append(f"**Command:** `{result.command}` ")
lines.append(f"**Exit code:** {result.exit_code}\n")
if result.stdout:
trunc = " *(output truncated)*" if result.stdout_truncated else ""
lines.append(f"**stdout:**{trunc}\n```\n{result.stdout.strip()}\n```\n")
if result.stderr:
trunc = " *(output truncated)*" if result.stderr_truncated else ""
lines.append(f"**stderr:**{trunc}\n```\n{result.stderr.strip()}\n```\n")
if not result.stdout and not result.stderr:
lines.append("*(no output)*\n")
if skipped:
lines.append(
f"**Note:** The following commands could not be executed on this host "
f"and produced no output: {', '.join(skipped)}. "
f"Do not draw any conclusions from their absence.\n"
)
return "\n".join(lines)

View File

@@ -1,9 +1,12 @@
"""SSH configuration and read-only command execution."""
import asyncio
import os
import shlex
import tempfile
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType
@dataclass(slots=True)
@@ -88,8 +91,8 @@ class SSHClient:
f"key={key} jump={jump} mode={mode}"
)
def build_ssh_argv(self, remote_command: str) -> list[str]:
"""Build argv for a secure non-interactive SSH invocation."""
def _build_base_argv(self) -> list[str]:
"""Build the common SSH argv flags (no host or command appended)."""
argv = [
"ssh",
"-p",
@@ -111,9 +114,16 @@ class SSHClient:
if self._config.jump_host:
argv += ["-J", self._config.jump_host]
argv += [self._config.host, remote_command]
return argv
def build_ssh_argv(self, remote_command: str) -> list[str]:
"""Build argv for a secure non-interactive SSH invocation."""
return self._build_base_argv() + [self._config.host, remote_command]
def connect(self, *, connect_timeout: float = 15.0) -> "SSHSession":
"""Return an :class:`SSHSession` async context manager for this host."""
return SSHSession(self, connect_timeout=connect_timeout)
def validate_read_only_command(self, command: str) -> None:
"""Validate that a command appears read-only and non-destructive."""
normalized = command.strip()
@@ -168,8 +178,7 @@ class SSHClient:
max_output_bytes=4096,
)
async def _run_ssh(
self,
async def _run_ssh( self,
command: str,
*,
timeout_seconds: float,
@@ -231,3 +240,144 @@ class SSHClient:
trimmed_bytes = encoded[:keep]
trimmed_text = trimmed_bytes.decode("utf-8", errors="ignore").rstrip()
return f"{trimmed_text}{marker}", True
class SSHSession:
"""A persistent SSH connection using ControlMaster multiplexing.
All commands run over the same underlying TCP connection — no per-command
SSH handshake. Use as an async context manager::
async with client.connect() as session:
result = await session.run_read_only_command("df -h")
"""
def __init__(self, client: SSHClient, *, connect_timeout: float = 15.0) -> None:
self._client = client
self._connect_timeout = connect_timeout
self._socket_path: Path | None = None
self._master_proc: asyncio.subprocess.Process | None = None
async def __aenter__(self) -> "SSHSession":
fd, path = tempfile.mkstemp(prefix="tai-ssh-", suffix=".sock")
os.close(fd)
os.unlink(path) # SSH needs to create this itself as a socket
self._socket_path = Path(path)
master_argv = self._client._build_base_argv() + [
"-o", "ControlMaster=yes",
"-o", f"ControlPath={self._socket_path}",
"-o", "ControlPersist=no",
"-N",
self._client._config.host,
]
self._master_proc = await asyncio.create_subprocess_exec(
*master_argv,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
# Wait for the control socket to appear (master is ready)
loop = asyncio.get_event_loop()
deadline = loop.time() + self._connect_timeout
while not self._socket_path.exists():
if loop.time() > deadline:
await self._teardown()
raise TimeoutError(
f"SSH ControlMaster did not connect within "
f"{self._connect_timeout}s to {self._client._config.host}"
)
await asyncio.sleep(0.05)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self._teardown()
async def _teardown(self) -> None:
if self._master_proc and self._master_proc.returncode is None:
self._master_proc.terminate()
try:
await asyncio.wait_for(self._master_proc.wait(), timeout=5.0)
except TimeoutError:
self._master_proc.kill()
if self._socket_path and self._socket_path.exists():
self._socket_path.unlink(missing_ok=True)
def _command_argv(self, remote_command: str) -> list[str]:
return self._client._build_base_argv() + [
"-o", f"ControlPath={self._socket_path}",
"-o", "ControlMaster=no",
self._client._config.host,
remote_command,
]
async def probe(self) -> SSHCommandResult:
"""Run uname -a to confirm connectivity."""
return await self._run("uname -a", timeout_seconds=15.0, max_output_bytes=4096)
async def run_read_only_command(
self,
command: str,
*,
timeout_seconds: float = 30.0,
max_output_bytes: int = 32768,
) -> SSHCommandResult:
"""Validate and run a read-only command over the shared connection."""
self._client.validate_read_only_command(command)
return await self._run(
command, timeout_seconds=timeout_seconds, max_output_bytes=max_output_bytes
)
async def _run(
self,
command: str,
*,
timeout_seconds: float,
max_output_bytes: int,
) -> SSHCommandResult:
argv = self._command_argv(command)
proc = await asyncio.create_subprocess_exec(
*argv,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
proc.communicate(), timeout=timeout_seconds
)
except TimeoutError as exc:
proc.kill()
await proc.wait()
raise TimeoutError(
f"SSH command timed out after {timeout_seconds}s: {command}"
) from exc
if proc.returncode is None:
raise RuntimeError("SSH process did not provide an exit code.")
stdout_text, stdout_truncated = SSHClient._truncate_output(
stdout_bytes.decode("utf-8", errors="replace"),
max_output_bytes=max_output_bytes,
)
stderr_text, stderr_truncated = SSHClient._truncate_output(
stderr_bytes.decode("utf-8", errors="replace"),
max_output_bytes=max_output_bytes,
)
return SSHCommandResult(
command=command,
exit_code=proc.returncode,
stdout=stdout_text,
stderr=stderr_text,
stdout_truncated=stdout_truncated,
stderr_truncated=stderr_truncated,
)

192
tests/test_ai.py Normal file
View File

@@ -0,0 +1,192 @@
"""Tests for the AI client and prompt builder."""
from unittest.mock import MagicMock, patch
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.collectors import CollectedItem, CollectionReport
from tai.prompt_builder import build_system_prompt, build_user_message
from tai.ssh_client import SSHCommandResult
# ---------------------------------------------------------------------------
# AIConfig defaults
# ---------------------------------------------------------------------------
def test_ai_config_defaults() -> None:
config = AIConfig()
assert config.host == DEFAULT_AI_HOST
assert config.model == DEFAULT_MODEL
assert config.api_key == "ollama"
def test_ai_config_custom_values() -> None:
config = AIConfig(host="https://api.openai.com/v1", model="gpt-4o", api_key="sk-test")
assert config.host == "https://api.openai.com/v1"
assert config.model == "gpt-4o"
assert config.api_key == "sk-test"
# ---------------------------------------------------------------------------
# AIClient.summary
# ---------------------------------------------------------------------------
def test_ai_client_summary_contains_host_and_model() -> None:
config = AIConfig(host="http://myserver:11434/v1", model="llama3.1:8b")
client = AIClient(config)
summary = client.summary()
assert "http://myserver:11434/v1" in summary
assert "llama3.1:8b" in summary
# ---------------------------------------------------------------------------
# AIClient.complete (mocked)
# ---------------------------------------------------------------------------
def _make_mock_response(content: str, model: str = "gemma3:4b") -> MagicMock:
usage = MagicMock()
usage.prompt_tokens = 10
usage.completion_tokens = 20
message = MagicMock()
message.content = content
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
response.model = model
response.usage = usage
return response
def test_complete_returns_ai_response() -> None:
config = AIConfig()
client = AIClient(config)
mock_response = _make_mock_response("The root cause is X.")
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
result = client.complete("system prompt", "user message")
assert result.content == "The root cause is X."
assert result.prompt_tokens == 10
assert result.completion_tokens == 20
assert result.total_tokens == 30
def test_complete_handles_empty_content() -> None:
config = AIConfig()
client = AIClient(config)
mock_response = _make_mock_response(None) # type: ignore[arg-type]
mock_response.choices[0].message.content = None
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
result = client.complete("system", "user")
assert result.content == ""
# ---------------------------------------------------------------------------
# AIClient.stream (mocked)
# ---------------------------------------------------------------------------
def test_stream_yields_chunks() -> None:
config = AIConfig()
client = AIClient(config)
def _make_chunk(text: str | None) -> MagicMock:
delta = MagicMock()
delta.content = text
choice = MagicMock()
choice.delta = delta
chunk = MagicMock()
chunk.choices = [choice]
return chunk
mock_chunks = [
_make_chunk("Root "), _make_chunk("cause "), _make_chunk(None), _make_chunk("found."),
]
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
result = list(client.stream("system", "user"))
assert result == ["Root ", "cause ", "found."]
# ---------------------------------------------------------------------------
# prompt_builder
# ---------------------------------------------------------------------------
def _make_report(items: list[tuple[str, str, int, str, str]]) -> CollectionReport:
"""Build a CollectionReport from (name, command, exit_code, stdout, stderr) tuples."""
return CollectionReport(
host="root@testhost",
items=[
CollectedItem(
name=name,
result=SSHCommandResult(
command=command,
exit_code=exit_code,
stdout=stdout,
stderr=stderr,
),
)
for name, command, exit_code, stdout, stderr in items
],
)
def test_build_system_prompt_contains_key_instructions() -> None:
prompt = build_system_prompt()
assert "Root Cause" in prompt
assert "Evidence" in prompt
assert "Recommended Actions" in prompt
assert "read-only" in prompt.lower()
def test_build_user_message_contains_issue_and_host() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
msg = build_user_message("nginx is failing", report)
assert "nginx is failing" in msg
assert "root@testhost" in msg
def test_build_user_message_includes_command_output() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01 6.1.0", "")])
msg = build_user_message("test issue", report)
assert "uname -a" in msg
assert "Linux web01 6.1.0" in msg
def test_build_user_message_shows_stderr() -> None:
report = _make_report(
[("svc", "systemctl status nginx", 3, "", "Unit nginx.service not found.")]
)
msg = build_user_message("nginx not found", report)
assert "Unit nginx.service not found." in msg
def test_build_user_message_notes_truncation() -> None:
result = SSHCommandResult(
command="journalctl -n 100 --no-pager",
exit_code=0,
stdout="...",
stderr="",
stdout_truncated=True,
)
report = CollectionReport(
host="root@testhost",
items=[CollectedItem(name="journal", result=result)],
)
msg = build_user_message("disk issue", report)
assert "truncated" in msg
def test_build_user_message_handles_no_output() -> None:
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
msg = build_user_message("test", report)
assert "no output" in msg

View File

@@ -1,3 +1,5 @@
from unittest.mock import AsyncMock, MagicMock
from typer.testing import CliRunner
from tai.cli import app
@@ -5,6 +7,24 @@ from tai.collectors import CollectedItem, CollectionReport
from tai.ssh_client import SSHCommandResult
def _mock_session(
monkeypatch, # type: ignore[no-untyped-def]
*,
probe_result: SSHCommandResult | None = None,
probe_raises: Exception | None = None,
) -> MagicMock:
"""Patch SSHClient.connect to return a mock session."""
session = MagicMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)
if probe_raises:
session.probe = AsyncMock(side_effect=probe_raises)
else:
session.probe = AsyncMock(return_value=probe_result)
monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session)
return session
def test_run_command_prints_scaffold_summary() -> None:
runner = CliRunner()
result = runner.invoke(
@@ -25,33 +45,23 @@ def test_run_command_prints_scaffold_summary() -> None:
)
assert result.exit_code == 0
assert "tai scaffold ready" in result.stdout
assert "tai" in result.stdout
assert "host=web01" in result.stdout
assert "port=5566" in result.stdout
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def]
return SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux ssh 6.12.0",
stderr="",
)
monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe)
_mock_session(
monkeypatch,
probe_result=SSHCommandResult(
command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr=""
),
)
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--probe",
],
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
)
assert result.exit_code == 0
@@ -60,27 +70,20 @@ def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: #
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def]
return SSHCommandResult(
_mock_session(
monkeypatch,
probe_result=SSHCommandResult(
command="uname -a",
exit_code=255,
stdout="",
stderr="Permission denied (publickey,password).",
)
monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe)
),
)
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--probe",
],
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
)
assert result.exit_code == 1
@@ -88,7 +91,9 @@ def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
async def fake_collect_from_plan(_client, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[