initialCommit #2
110
.gitea/workflows/release.yml
Normal file
110
.gitea/workflows/release.yml
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Ensure git is available
|
||||||
|
run: |
|
||||||
|
if command -v git >/dev/null 2>&1; then
|
||||||
|
git --version
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if command -v apt-get >/dev/null 2>&1; then
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y git
|
||||||
|
elif command -v dnf >/dev/null 2>&1; then
|
||||||
|
dnf install -y git
|
||||||
|
elif command -v yum >/dev/null 2>&1; then
|
||||||
|
yum install -y git
|
||||||
|
else
|
||||||
|
echo "No supported package manager found to install git."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Checkout source (native git)
|
||||||
|
env:
|
||||||
|
CI_GIT_TOKEN: ${{ secrets.CI_GIT_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "${CI_GIT_TOKEN:-}" ]; then
|
||||||
|
echo "Missing secret CI_GIT_TOKEN. Add it in repository Actions secrets."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
auth_server="${GITHUB_SERVER_URL#https://}"
|
||||||
|
auth_server="${auth_server#http://}"
|
||||||
|
remote_url="https://oauth2:${CI_GIT_TOKEN}@${auth_server}/${GITHUB_REPOSITORY}.git"
|
||||||
|
|
||||||
|
if [ -n "${GITHUB_WORKSPACE:-}" ]; then
|
||||||
|
cd "$GITHUB_WORKSPACE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d .git ]; then
|
||||||
|
git init
|
||||||
|
fi
|
||||||
|
|
||||||
|
git remote remove origin >/dev/null 2>&1 || true
|
||||||
|
git remote add origin "$remote_url"
|
||||||
|
|
||||||
|
# Fetch the tag by SHA so we get the exact tagged commit
|
||||||
|
git fetch --depth 1 origin "$GITHUB_SHA"
|
||||||
|
git checkout --force FETCH_HEAD
|
||||||
|
|
||||||
|
- name: Ensure Python and build dependencies are available
|
||||||
|
run: |
|
||||||
|
if ! command -v python3 >/dev/null 2>&1; then
|
||||||
|
if command -v apt-get >/dev/null 2>&1; then
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y python3 python3-pip python3-venv patchelf ccache
|
||||||
|
elif command -v dnf >/dev/null 2>&1; then
|
||||||
|
dnf install -y python3 python3-pip patchelf ccache
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# patchelf is required by Nuitka for standalone Linux binaries
|
||||||
|
command -v patchelf >/dev/null 2>&1 || {
|
||||||
|
apt-get update && apt-get install -y patchelf
|
||||||
|
}
|
||||||
|
|
||||||
|
python3 --version
|
||||||
|
|
||||||
|
- name: Set up venv and install package + build deps
|
||||||
|
run: |
|
||||||
|
python3 -m venv .venv
|
||||||
|
. .venv/bin/activate
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install -e ".[build]"
|
||||||
|
|
||||||
|
- name: Derive version from tag
|
||||||
|
id: version
|
||||||
|
run: echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Build standalone binary with Nuitka
|
||||||
|
run: |
|
||||||
|
. .venv/bin/activate
|
||||||
|
python -m nuitka \
|
||||||
|
--standalone \
|
||||||
|
--onefile \
|
||||||
|
--output-filename=tai \
|
||||||
|
--output-dir=dist \
|
||||||
|
--assume-yes-for-downloads \
|
||||||
|
--include-package=tai \
|
||||||
|
src/tai/cli.py
|
||||||
|
|
||||||
|
- name: Smoke-test the binary
|
||||||
|
run: dist/tai --help
|
||||||
|
|
||||||
|
- name: Upload binary artifact
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: tai-linux-amd64-${{ steps.version.outputs.tag }}
|
||||||
|
path: dist/tai
|
||||||
|
if-no-files-found: error
|
||||||
|
retention-days: 90
|
||||||
93
src/tai/ai_client.py
Normal file
93
src/tai/ai_client.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""AI backend client — OpenAI-compatible, works with Ollama, OpenAI, or any compatible endpoint."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
DEFAULT_AI_HOST = "http://localhost:11434/v1"
|
||||||
|
DEFAULT_MODEL = "gemma3:4b"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AIConfig:
|
||||||
|
"""Connection parameters for an OpenAI-compatible AI backend."""
|
||||||
|
|
||||||
|
host: str = DEFAULT_AI_HOST
|
||||||
|
model: str = DEFAULT_MODEL
|
||||||
|
api_key: str = "ollama" # Ollama ignores this; required by the openai client
|
||||||
|
timeout_seconds: float = 120.0
|
||||||
|
max_tokens: int = 4096
|
||||||
|
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AIResponse:
|
||||||
|
"""Structured response from an AI completion."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
content: str
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
return self.prompt_tokens + self.completion_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class AIClient:
|
||||||
|
"""Thin wrapper around the openai client targeting a configurable endpoint."""
|
||||||
|
|
||||||
|
def __init__(self, config: AIConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._client = OpenAI(
|
||||||
|
base_url=config.host,
|
||||||
|
api_key=config.api_key,
|
||||||
|
timeout=config.timeout_seconds,
|
||||||
|
default_headers=config.extra_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def complete(self, system_prompt: str, user_message: str) -> AIResponse:
|
||||||
|
"""Send a completion request and return the full response."""
|
||||||
|
response = self._client.chat.completions.create(
|
||||||
|
model=self._config.model,
|
||||||
|
max_tokens=self._config.max_tokens,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_message},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
choice = response.choices[0]
|
||||||
|
content = choice.message.content or ""
|
||||||
|
usage = response.usage
|
||||||
|
|
||||||
|
return AIResponse(
|
||||||
|
model=response.model,
|
||||||
|
content=content,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else 0,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(self, system_prompt: str, user_message: str) -> Iterator[str]:
|
||||||
|
"""Stream a completion, yielding text chunks as they arrive."""
|
||||||
|
stream = self._client.chat.completions.create(
|
||||||
|
model=self._config.model,
|
||||||
|
max_tokens=self._config.max_tokens,
|
||||||
|
stream=True,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_message},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta.content
|
||||||
|
if delta:
|
||||||
|
yield delta
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
"""Human-readable description of the AI config."""
|
||||||
|
return f"host={self._config.host} model={self._config.model}"
|
||||||
138
src/tai/cli.py
138
src/tai/cli.py
@@ -7,11 +7,14 @@ from typing import Annotated
|
|||||||
|
|
||||||
import typer
|
import typer
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
|
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||||
from tai.collectors import CollectionReport, collect_from_plan
|
from tai.collectors import CollectionReport, collect_from_plan
|
||||||
from tai.input_parser import InputValidationError, build_request
|
from tai.input_parser import InputValidationError, build_request
|
||||||
from tai.models import TroubleshootRequest
|
from tai.models import TroubleshootRequest
|
||||||
from tai.plan import plan_from_request
|
from tai.plan import plan_from_request
|
||||||
|
from tai.prompt_builder import build_system_prompt, build_user_message
|
||||||
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig
|
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig
|
||||||
|
|
||||||
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||||
@@ -56,6 +59,25 @@ def run(
|
|||||||
help="Collect baseline diagnostics after probe.",
|
help="Collect baseline diagnostics after probe.",
|
||||||
),
|
),
|
||||||
] = False,
|
] = False,
|
||||||
|
analyze: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--analyze/--no-analyze",
|
||||||
|
help="Send collected diagnostics to AI for analysis.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
ai_host: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--ai-host", help="OpenAI-compatible AI backend URL."),
|
||||||
|
] = DEFAULT_AI_HOST,
|
||||||
|
model: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--model", help="Model name to use for AI analysis."),
|
||||||
|
] = DEFAULT_MODEL,
|
||||||
|
ai_key: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--ai-key", help="API key for the AI backend (not needed for Ollama)."),
|
||||||
|
] = "ollama",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start an interactive troubleshooting session scaffold."""
|
"""Start an interactive troubleshooting session scaffold."""
|
||||||
try:
|
try:
|
||||||
@@ -81,50 +103,69 @@ def run(
|
|||||||
)
|
)
|
||||||
|
|
||||||
summary = SSHClient(config).summary()
|
summary = SSHClient(config).summary()
|
||||||
console.print("[bold green]tai scaffold ready[/bold green]")
|
console.print("[bold green]tai[/bold green]")
|
||||||
console.print(f"Issue: {req.issue}")
|
console.print(f"Issue: {req.issue}")
|
||||||
console.print(f"SSH: {summary}")
|
console.print(f"SSH: {summary}")
|
||||||
if req.target_paths:
|
if req.target_paths:
|
||||||
console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}")
|
console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}")
|
||||||
|
|
||||||
|
if not (probe or collect or analyze):
|
||||||
|
return # nothing SSH-related requested
|
||||||
|
|
||||||
|
ai_config = AIConfig(host=ai_host, model=model, api_key=ai_key)
|
||||||
|
if analyze:
|
||||||
|
console.print(f"[cyan]AI:[/cyan] {AIClient(ai_config).summary()}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(_async_main(config, req, probe=probe, collect=collect, analyze=analyze,
|
||||||
|
ai_config=ai_config))
|
||||||
|
except typer.Exit:
|
||||||
|
raise
|
||||||
|
except TimeoutError as exc:
|
||||||
|
console.print(f"[red]SSH timeout:[/red] {exc}")
|
||||||
|
raise typer.Exit(code=1) from exc
|
||||||
|
except OSError as exc:
|
||||||
|
console.print(f"[red]SSH error:[/red] unable to execute ssh: {exc}")
|
||||||
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_main(
|
||||||
|
config: SSHConnectionConfig,
|
||||||
|
req: TroubleshootRequest,
|
||||||
|
*,
|
||||||
|
probe: bool,
|
||||||
|
collect: bool,
|
||||||
|
analyze: bool,
|
||||||
|
ai_config: AIConfig,
|
||||||
|
) -> None:
|
||||||
|
"""Open a single SSH session and run probe / collection / analysis through it."""
|
||||||
client = SSHClient(config)
|
client = SSHClient(config)
|
||||||
|
async with client.connect() as session:
|
||||||
|
if probe:
|
||||||
|
result = await session.probe()
|
||||||
|
_handle_probe_result(result)
|
||||||
|
|
||||||
if probe:
|
report: CollectionReport | None = None
|
||||||
_run_probe(client)
|
if collect or analyze:
|
||||||
|
plan = plan_from_request(req)
|
||||||
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
|
report = await collect_from_plan(session, plan)
|
||||||
|
_handle_collection_report(report)
|
||||||
|
|
||||||
if collect:
|
if analyze and report is not None:
|
||||||
_run_collection(client, req)
|
_run_analysis(ai_config, req.issue, report)
|
||||||
|
|
||||||
|
|
||||||
def _run_probe(client: SSHClient) -> None:
|
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||||
"""Run a live SSH probe and exit non-zero on failure."""
|
"""Handle and render probe output for success or failure."""
|
||||||
console.print("[cyan]Running SSH probe:[/cyan] uname -a")
|
console.print("[cyan]Running SSH probe:[/cyan] uname -a")
|
||||||
try:
|
if result.exit_code != 0:
|
||||||
result = asyncio.run(client.probe())
|
details = result.stderr or result.stdout or "no error output from ssh"
|
||||||
except TimeoutError as exc:
|
console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}")
|
||||||
console.print(f"[red]Probe failed:[/red] {exc}")
|
raise typer.Exit(code=1)
|
||||||
raise typer.Exit(code=1) from exc
|
output = result.stdout or "(no output)"
|
||||||
except OSError as exc:
|
console.print("[bold green]Probe succeeded.[/bold green]")
|
||||||
console.print(f"[red]Probe failed:[/red] unable to execute ssh: {exc}")
|
console.print(f"Remote: {output}")
|
||||||
raise typer.Exit(code=1) from exc
|
|
||||||
|
|
||||||
_handle_probe_result(result)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_collection(client: SSHClient, request: TroubleshootRequest) -> None:
|
|
||||||
"""Run issue-aware collection and print a compact summary."""
|
|
||||||
plan = plan_from_request(request)
|
|
||||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
|
||||||
try:
|
|
||||||
report = asyncio.run(collect_from_plan(client, plan))
|
|
||||||
except TimeoutError as exc:
|
|
||||||
console.print(f"[red]Collection failed:[/red] {exc}")
|
|
||||||
raise typer.Exit(code=1) from exc
|
|
||||||
except OSError as exc:
|
|
||||||
console.print(f"[red]Collection failed:[/red] unable to execute ssh: {exc}")
|
|
||||||
raise typer.Exit(code=1) from exc
|
|
||||||
|
|
||||||
_handle_collection_report(report)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_collection_report(report: CollectionReport) -> None:
|
def _handle_collection_report(report: CollectionReport) -> None:
|
||||||
@@ -134,22 +175,25 @@ def _handle_collection_report(report: CollectionReport) -> None:
|
|||||||
)
|
)
|
||||||
for item in report.items:
|
for item in report.items:
|
||||||
status = "ok" if item.result.exit_code == 0 else f"exit {item.result.exit_code}"
|
status = "ok" if item.result.exit_code == 0 else f"exit {item.result.exit_code}"
|
||||||
trunc = ""
|
truncated = item.result.stdout_truncated or item.result.stderr_truncated
|
||||||
if item.result.stdout_truncated or item.result.stderr_truncated:
|
trunc = " (truncated)" if truncated else ""
|
||||||
trunc = " (truncated)"
|
|
||||||
console.print(f"- {item.name}: {status}{trunc}")
|
console.print(f"- {item.name}: {status}{trunc}")
|
||||||
|
|
||||||
|
|
||||||
def _handle_probe_result(result: SSHCommandResult) -> None:
|
def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) -> None:
|
||||||
"""Handle and render probe output for success or failure."""
|
"""Send collected data to the AI and stream the analysis to stdout."""
|
||||||
if result.exit_code != 0:
|
console.print("[cyan]Analyzing...[/cyan]\n")
|
||||||
details = result.stderr or result.stdout or "no error output from ssh"
|
ai = AIClient(ai_config)
|
||||||
console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}")
|
system_prompt = build_system_prompt()
|
||||||
raise typer.Exit(code=1)
|
user_message = build_user_message(issue, report)
|
||||||
|
try:
|
||||||
output = result.stdout or "(no output)"
|
chunks: list[str] = []
|
||||||
console.print("[bold green]Probe succeeded.[/bold green]")
|
for chunk in ai.stream(system_prompt, user_message):
|
||||||
console.print(f"Remote: {output}")
|
chunks.append(chunk)
|
||||||
|
console.print(Markdown("".join(chunks)))
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
console.print(f"[red]AI analysis failed:[/red] {exc}")
|
||||||
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from tai.plan import CollectionPlan
|
from tai.plan import CollectionPlan
|
||||||
from tai.ssh_client import SSHClient, SSHCommandResult
|
from tai.ssh_client import SSHCommandResult, SSHSession
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@@ -31,20 +31,20 @@ class CollectionReport:
|
|||||||
|
|
||||||
|
|
||||||
async def collect_from_plan(
|
async def collect_from_plan(
|
||||||
client: SSHClient,
|
session: SSHSession,
|
||||||
plan: CollectionPlan,
|
plan: CollectionPlan,
|
||||||
*,
|
*,
|
||||||
max_output_bytes: int = 32768,
|
max_output_bytes: int = 32768,
|
||||||
) -> CollectionReport:
|
) -> CollectionReport:
|
||||||
"""Execute all commands in *plan* and return a :class:`CollectionReport`."""
|
"""Execute all commands in *plan* over a shared SSH session."""
|
||||||
items: list[CollectedItem] = []
|
items: list[CollectedItem] = []
|
||||||
|
|
||||||
for name, command in plan.commands:
|
for name, command in plan.commands:
|
||||||
result = await client.run_read_only_command(
|
result = await session.run_read_only_command(
|
||||||
command,
|
command,
|
||||||
timeout_seconds=30.0,
|
timeout_seconds=30.0,
|
||||||
max_output_bytes=max_output_bytes,
|
max_output_bytes=max_output_bytes,
|
||||||
)
|
)
|
||||||
items.append(CollectedItem(name=name, result=result))
|
items.append(CollectedItem(name=name, result=result))
|
||||||
|
|
||||||
return CollectionReport(host=client.summary(), items=items)
|
return CollectionReport(host=session._client.summary(), items=items)
|
||||||
|
|||||||
74
src/tai/prompt_builder.py
Normal file
74
src/tai/prompt_builder.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Formats collected diagnostics into prompts for the AI backend."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from tai.collectors import CollectionReport
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = """\
|
||||||
|
You are an expert Linux systems administrator and troubleshooting assistant.
|
||||||
|
You are given diagnostic data collected read-only from a remote Linux host via SSH.
|
||||||
|
|
||||||
|
Your job:
|
||||||
|
1. Identify the root cause of the reported issue based only on the data provided.
|
||||||
|
2. Cite the specific output that supports your conclusion.
|
||||||
|
3. Give concise, actionable remediation steps.
|
||||||
|
|
||||||
|
Important rules:
|
||||||
|
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
|
||||||
|
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
|
||||||
|
rejected that specific command — it is not evidence about the service or system state.
|
||||||
|
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
|
||||||
|
additional commands or log files would be needed.
|
||||||
|
- Keep the response short. Skip sections that have nothing useful to say.
|
||||||
|
- Never suggest commands that modify the system unless explicitly asked.
|
||||||
|
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_system_prompt() -> str:
|
||||||
|
"""Return the static system prompt for the troubleshooting agent."""
|
||||||
|
return _SYSTEM_PROMPT.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def build_user_message(issue: str, report: CollectionReport) -> str:
|
||||||
|
"""Format *issue* and *report* into the user message sent to the AI."""
|
||||||
|
lines: list[str] = []
|
||||||
|
|
||||||
|
lines.append(f"## Issue reported\n\n{issue}\n")
|
||||||
|
lines.append(f"## Target host\n\n{report.host}\n")
|
||||||
|
lines.append("## Collected diagnostics\n")
|
||||||
|
|
||||||
|
skipped: list[str] = []
|
||||||
|
|
||||||
|
for item in report.items:
|
||||||
|
result = item.result
|
||||||
|
|
||||||
|
# Exit 255 with no output = SSH couldn't execute the command at all.
|
||||||
|
# Exclude these entirely to prevent the AI from speculating on them.
|
||||||
|
if result.exit_code == 255 and not result.stdout and not result.stderr:
|
||||||
|
skipped.append(item.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines.append(f"### {item.name}\n")
|
||||||
|
lines.append(f"**Command:** `{result.command}` ")
|
||||||
|
lines.append(f"**Exit code:** {result.exit_code}\n")
|
||||||
|
|
||||||
|
if result.stdout:
|
||||||
|
trunc = " *(output truncated)*" if result.stdout_truncated else ""
|
||||||
|
lines.append(f"**stdout:**{trunc}\n```\n{result.stdout.strip()}\n```\n")
|
||||||
|
|
||||||
|
if result.stderr:
|
||||||
|
trunc = " *(output truncated)*" if result.stderr_truncated else ""
|
||||||
|
lines.append(f"**stderr:**{trunc}\n```\n{result.stderr.strip()}\n```\n")
|
||||||
|
|
||||||
|
if not result.stdout and not result.stderr:
|
||||||
|
lines.append("*(no output)*\n")
|
||||||
|
|
||||||
|
if skipped:
|
||||||
|
lines.append(
|
||||||
|
f"**Note:** The following commands could not be executed on this host "
|
||||||
|
f"and produced no output: {', '.join(skipped)}. "
|
||||||
|
f"Do not draw any conclusions from their absence.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
"""SSH configuration and read-only command execution."""
|
"""SSH configuration and read-only command execution."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
import shlex
|
import shlex
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import TracebackType
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@@ -88,8 +91,8 @@ class SSHClient:
|
|||||||
f"key={key} jump={jump} mode={mode}"
|
f"key={key} jump={jump} mode={mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_ssh_argv(self, remote_command: str) -> list[str]:
|
def _build_base_argv(self) -> list[str]:
|
||||||
"""Build argv for a secure non-interactive SSH invocation."""
|
"""Build the common SSH argv flags (no host or command appended)."""
|
||||||
argv = [
|
argv = [
|
||||||
"ssh",
|
"ssh",
|
||||||
"-p",
|
"-p",
|
||||||
@@ -111,9 +114,16 @@ class SSHClient:
|
|||||||
if self._config.jump_host:
|
if self._config.jump_host:
|
||||||
argv += ["-J", self._config.jump_host]
|
argv += ["-J", self._config.jump_host]
|
||||||
|
|
||||||
argv += [self._config.host, remote_command]
|
|
||||||
return argv
|
return argv
|
||||||
|
|
||||||
|
def build_ssh_argv(self, remote_command: str) -> list[str]:
|
||||||
|
"""Build argv for a secure non-interactive SSH invocation."""
|
||||||
|
return self._build_base_argv() + [self._config.host, remote_command]
|
||||||
|
|
||||||
|
def connect(self, *, connect_timeout: float = 15.0) -> "SSHSession":
|
||||||
|
"""Return an :class:`SSHSession` async context manager for this host."""
|
||||||
|
return SSHSession(self, connect_timeout=connect_timeout)
|
||||||
|
|
||||||
def validate_read_only_command(self, command: str) -> None:
|
def validate_read_only_command(self, command: str) -> None:
|
||||||
"""Validate that a command appears read-only and non-destructive."""
|
"""Validate that a command appears read-only and non-destructive."""
|
||||||
normalized = command.strip()
|
normalized = command.strip()
|
||||||
@@ -168,8 +178,7 @@ class SSHClient:
|
|||||||
max_output_bytes=4096,
|
max_output_bytes=4096,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _run_ssh(
|
async def _run_ssh( self,
|
||||||
self,
|
|
||||||
command: str,
|
command: str,
|
||||||
*,
|
*,
|
||||||
timeout_seconds: float,
|
timeout_seconds: float,
|
||||||
@@ -231,3 +240,144 @@ class SSHClient:
|
|||||||
trimmed_bytes = encoded[:keep]
|
trimmed_bytes = encoded[:keep]
|
||||||
trimmed_text = trimmed_bytes.decode("utf-8", errors="ignore").rstrip()
|
trimmed_text = trimmed_bytes.decode("utf-8", errors="ignore").rstrip()
|
||||||
return f"{trimmed_text}{marker}", True
|
return f"{trimmed_text}{marker}", True
|
||||||
|
|
||||||
|
|
||||||
|
class SSHSession:
|
||||||
|
"""A persistent SSH connection using ControlMaster multiplexing.
|
||||||
|
|
||||||
|
All commands run over the same underlying TCP connection — no per-command
|
||||||
|
SSH handshake. Use as an async context manager::
|
||||||
|
|
||||||
|
async with client.connect() as session:
|
||||||
|
result = await session.run_read_only_command("df -h")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client: SSHClient, *, connect_timeout: float = 15.0) -> None:
|
||||||
|
self._client = client
|
||||||
|
self._connect_timeout = connect_timeout
|
||||||
|
self._socket_path: Path | None = None
|
||||||
|
self._master_proc: asyncio.subprocess.Process | None = None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "SSHSession":
|
||||||
|
fd, path = tempfile.mkstemp(prefix="tai-ssh-", suffix=".sock")
|
||||||
|
os.close(fd)
|
||||||
|
os.unlink(path) # SSH needs to create this itself as a socket
|
||||||
|
self._socket_path = Path(path)
|
||||||
|
|
||||||
|
master_argv = self._client._build_base_argv() + [
|
||||||
|
"-o", "ControlMaster=yes",
|
||||||
|
"-o", f"ControlPath={self._socket_path}",
|
||||||
|
"-o", "ControlPersist=no",
|
||||||
|
"-N",
|
||||||
|
self._client._config.host,
|
||||||
|
]
|
||||||
|
|
||||||
|
self._master_proc = await asyncio.create_subprocess_exec(
|
||||||
|
*master_argv,
|
||||||
|
stdout=asyncio.subprocess.DEVNULL,
|
||||||
|
stderr=asyncio.subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the control socket to appear (master is ready)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
deadline = loop.time() + self._connect_timeout
|
||||||
|
while not self._socket_path.exists():
|
||||||
|
if loop.time() > deadline:
|
||||||
|
await self._teardown()
|
||||||
|
raise TimeoutError(
|
||||||
|
f"SSH ControlMaster did not connect within "
|
||||||
|
f"{self._connect_timeout}s to {self._client._config.host}"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
await self._teardown()
|
||||||
|
|
||||||
|
async def _teardown(self) -> None:
|
||||||
|
if self._master_proc and self._master_proc.returncode is None:
|
||||||
|
self._master_proc.terminate()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._master_proc.wait(), timeout=5.0)
|
||||||
|
except TimeoutError:
|
||||||
|
self._master_proc.kill()
|
||||||
|
if self._socket_path and self._socket_path.exists():
|
||||||
|
self._socket_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def _command_argv(self, remote_command: str) -> list[str]:
|
||||||
|
return self._client._build_base_argv() + [
|
||||||
|
"-o", f"ControlPath={self._socket_path}",
|
||||||
|
"-o", "ControlMaster=no",
|
||||||
|
self._client._config.host,
|
||||||
|
remote_command,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def probe(self) -> SSHCommandResult:
|
||||||
|
"""Run uname -a to confirm connectivity."""
|
||||||
|
return await self._run("uname -a", timeout_seconds=15.0, max_output_bytes=4096)
|
||||||
|
|
||||||
|
async def run_read_only_command(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
timeout_seconds: float = 30.0,
|
||||||
|
max_output_bytes: int = 32768,
|
||||||
|
) -> SSHCommandResult:
|
||||||
|
"""Validate and run a read-only command over the shared connection."""
|
||||||
|
self._client.validate_read_only_command(command)
|
||||||
|
return await self._run(
|
||||||
|
command, timeout_seconds=timeout_seconds, max_output_bytes=max_output_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
timeout_seconds: float,
|
||||||
|
max_output_bytes: int,
|
||||||
|
) -> SSHCommandResult:
|
||||||
|
argv = self._command_argv(command)
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*argv,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||||
|
proc.communicate(), timeout=timeout_seconds
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
raise TimeoutError(
|
||||||
|
f"SSH command timed out after {timeout_seconds}s: {command}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if proc.returncode is None:
|
||||||
|
raise RuntimeError("SSH process did not provide an exit code.")
|
||||||
|
|
||||||
|
stdout_text, stdout_truncated = SSHClient._truncate_output(
|
||||||
|
stdout_bytes.decode("utf-8", errors="replace"),
|
||||||
|
max_output_bytes=max_output_bytes,
|
||||||
|
)
|
||||||
|
stderr_text, stderr_truncated = SSHClient._truncate_output(
|
||||||
|
stderr_bytes.decode("utf-8", errors="replace"),
|
||||||
|
max_output_bytes=max_output_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SSHCommandResult(
|
||||||
|
command=command,
|
||||||
|
exit_code=proc.returncode,
|
||||||
|
stdout=stdout_text,
|
||||||
|
stderr=stderr_text,
|
||||||
|
stdout_truncated=stdout_truncated,
|
||||||
|
stderr_truncated=stderr_truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
|||||||
192
tests/test_ai.py
Normal file
192
tests/test_ai.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""Tests for the AI client and prompt builder."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||||
|
from tai.collectors import CollectedItem, CollectionReport
|
||||||
|
from tai.prompt_builder import build_system_prompt, build_user_message
|
||||||
|
from tai.ssh_client import SSHCommandResult
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AIConfig defaults
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_ai_config_defaults() -> None:
|
||||||
|
config = AIConfig()
|
||||||
|
assert config.host == DEFAULT_AI_HOST
|
||||||
|
assert config.model == DEFAULT_MODEL
|
||||||
|
assert config.api_key == "ollama"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ai_config_custom_values() -> None:
|
||||||
|
config = AIConfig(host="https://api.openai.com/v1", model="gpt-4o", api_key="sk-test")
|
||||||
|
assert config.host == "https://api.openai.com/v1"
|
||||||
|
assert config.model == "gpt-4o"
|
||||||
|
assert config.api_key == "sk-test"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AIClient.summary
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_ai_client_summary_contains_host_and_model() -> None:
|
||||||
|
config = AIConfig(host="http://myserver:11434/v1", model="llama3.1:8b")
|
||||||
|
client = AIClient(config)
|
||||||
|
summary = client.summary()
|
||||||
|
assert "http://myserver:11434/v1" in summary
|
||||||
|
assert "llama3.1:8b" in summary
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AIClient.complete (mocked)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_response(content: str, model: str = "gemma3:4b") -> MagicMock:
|
||||||
|
usage = MagicMock()
|
||||||
|
usage.prompt_tokens = 10
|
||||||
|
usage.completion_tokens = 20
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.content = content
|
||||||
|
|
||||||
|
choice = MagicMock()
|
||||||
|
choice.message = message
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.choices = [choice]
|
||||||
|
response.model = model
|
||||||
|
response.usage = usage
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_returns_ai_response() -> None:
|
||||||
|
config = AIConfig()
|
||||||
|
client = AIClient(config)
|
||||||
|
mock_response = _make_mock_response("The root cause is X.")
|
||||||
|
|
||||||
|
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
|
||||||
|
result = client.complete("system prompt", "user message")
|
||||||
|
|
||||||
|
assert result.content == "The root cause is X."
|
||||||
|
assert result.prompt_tokens == 10
|
||||||
|
assert result.completion_tokens == 20
|
||||||
|
assert result.total_tokens == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_handles_empty_content() -> None:
|
||||||
|
config = AIConfig()
|
||||||
|
client = AIClient(config)
|
||||||
|
mock_response = _make_mock_response(None) # type: ignore[arg-type]
|
||||||
|
mock_response.choices[0].message.content = None
|
||||||
|
|
||||||
|
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
|
||||||
|
result = client.complete("system", "user")
|
||||||
|
|
||||||
|
assert result.content == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AIClient.stream (mocked)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_yields_chunks() -> None:
|
||||||
|
config = AIConfig()
|
||||||
|
client = AIClient(config)
|
||||||
|
|
||||||
|
def _make_chunk(text: str | None) -> MagicMock:
|
||||||
|
delta = MagicMock()
|
||||||
|
delta.content = text
|
||||||
|
choice = MagicMock()
|
||||||
|
choice.delta = delta
|
||||||
|
chunk = MagicMock()
|
||||||
|
chunk.choices = [choice]
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
mock_chunks = [
|
||||||
|
_make_chunk("Root "), _make_chunk("cause "), _make_chunk(None), _make_chunk("found."),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
|
||||||
|
result = list(client.stream("system", "user"))
|
||||||
|
|
||||||
|
assert result == ["Root ", "cause ", "found."]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# prompt_builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_report(items: list[tuple[str, str, int, str, str]]) -> CollectionReport:
|
||||||
|
"""Build a CollectionReport from (name, command, exit_code, stdout, stderr) tuples."""
|
||||||
|
return CollectionReport(
|
||||||
|
host="root@testhost",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name=name,
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command=command,
|
||||||
|
exit_code=exit_code,
|
||||||
|
stdout=stdout,
|
||||||
|
stderr=stderr,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for name, command, exit_code, stdout, stderr in items
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_prompt_contains_key_instructions() -> None:
|
||||||
|
prompt = build_system_prompt()
|
||||||
|
assert "Root Cause" in prompt
|
||||||
|
assert "Evidence" in prompt
|
||||||
|
assert "Recommended Actions" in prompt
|
||||||
|
assert "read-only" in prompt.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_contains_issue_and_host() -> None:
|
||||||
|
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
|
||||||
|
msg = build_user_message("nginx is failing", report)
|
||||||
|
assert "nginx is failing" in msg
|
||||||
|
assert "root@testhost" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_includes_command_output() -> None:
|
||||||
|
report = _make_report([("kernel", "uname -a", 0, "Linux web01 6.1.0", "")])
|
||||||
|
msg = build_user_message("test issue", report)
|
||||||
|
assert "uname -a" in msg
|
||||||
|
assert "Linux web01 6.1.0" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_shows_stderr() -> None:
|
||||||
|
report = _make_report(
|
||||||
|
[("svc", "systemctl status nginx", 3, "", "Unit nginx.service not found.")]
|
||||||
|
)
|
||||||
|
msg = build_user_message("nginx not found", report)
|
||||||
|
assert "Unit nginx.service not found." in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_notes_truncation() -> None:
|
||||||
|
result = SSHCommandResult(
|
||||||
|
command="journalctl -n 100 --no-pager",
|
||||||
|
exit_code=0,
|
||||||
|
stdout="...",
|
||||||
|
stderr="",
|
||||||
|
stdout_truncated=True,
|
||||||
|
)
|
||||||
|
report = CollectionReport(
|
||||||
|
host="root@testhost",
|
||||||
|
items=[CollectedItem(name="journal", result=result)],
|
||||||
|
)
|
||||||
|
msg = build_user_message("disk issue", report)
|
||||||
|
assert "truncated" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_handles_no_output() -> None:
|
||||||
|
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
|
||||||
|
msg = build_user_message("test", report)
|
||||||
|
assert "no output" in msg
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
from typer.testing import CliRunner
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
from tai.cli import app
|
from tai.cli import app
|
||||||
@@ -5,6 +7,24 @@ from tai.collectors import CollectedItem, CollectionReport
|
|||||||
from tai.ssh_client import SSHCommandResult
|
from tai.ssh_client import SSHCommandResult
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_session(
|
||||||
|
monkeypatch, # type: ignore[no-untyped-def]
|
||||||
|
*,
|
||||||
|
probe_result: SSHCommandResult | None = None,
|
||||||
|
probe_raises: Exception | None = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Patch SSHClient.connect to return a mock session."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.__aenter__ = AsyncMock(return_value=session)
|
||||||
|
session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
if probe_raises:
|
||||||
|
session.probe = AsyncMock(side_effect=probe_raises)
|
||||||
|
else:
|
||||||
|
session.probe = AsyncMock(return_value=probe_result)
|
||||||
|
monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
def test_run_command_prints_scaffold_summary() -> None:
|
def test_run_command_prints_scaffold_summary() -> None:
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
@@ -25,33 +45,23 @@ def test_run_command_prints_scaffold_summary() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "tai scaffold ready" in result.stdout
|
assert "tai" in result.stdout
|
||||||
assert "host=web01" in result.stdout
|
assert "host=web01" in result.stdout
|
||||||
assert "port=5566" in result.stdout
|
assert "port=5566" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||||
async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def]
|
_mock_session(
|
||||||
return SSHCommandResult(
|
monkeypatch,
|
||||||
command="uname -a",
|
probe_result=SSHCommandResult(
|
||||||
exit_code=0,
|
command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr=""
|
||||||
stdout="Linux ssh 6.12.0",
|
),
|
||||||
stderr="",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe)
|
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
app,
|
app,
|
||||||
[
|
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
|
||||||
"apache failed",
|
|
||||||
"--host",
|
|
||||||
"ssh.archflux.net",
|
|
||||||
"--port",
|
|
||||||
"5566",
|
|
||||||
"--probe",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
@@ -60,27 +70,20 @@ def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: #
|
|||||||
|
|
||||||
|
|
||||||
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||||
async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def]
|
_mock_session(
|
||||||
return SSHCommandResult(
|
monkeypatch,
|
||||||
|
probe_result=SSHCommandResult(
|
||||||
command="uname -a",
|
command="uname -a",
|
||||||
exit_code=255,
|
exit_code=255,
|
||||||
stdout="",
|
stdout="",
|
||||||
stderr="Permission denied (publickey,password).",
|
stderr="Permission denied (publickey,password).",
|
||||||
)
|
),
|
||||||
|
)
|
||||||
monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe)
|
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
app,
|
app,
|
||||||
[
|
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
|
||||||
"apache failed",
|
|
||||||
"--host",
|
|
||||||
"ssh.archflux.net",
|
|
||||||
"--port",
|
|
||||||
"5566",
|
|
||||||
"--probe",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.exit_code == 1
|
assert result.exit_code == 1
|
||||||
@@ -88,7 +91,9 @@ def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no
|
|||||||
|
|
||||||
|
|
||||||
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||||
async def fake_collect_from_plan(_client, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
_mock_session(monkeypatch)
|
||||||
|
|
||||||
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||||
return CollectionReport(
|
return CollectionReport(
|
||||||
host="ssh.archflux.net",
|
host="ssh.archflux.net",
|
||||||
items=[
|
items=[
|
||||||
|
|||||||
Reference in New Issue
Block a user