diff --git a/.gitea/workflows/release.yml b/.gitea/workflows/release.yml new file mode 100644 index 0000000..ad24b64 --- /dev/null +++ b/.gitea/workflows/release.yml @@ -0,0 +1,110 @@ +name: Release + +on: + push: + tags: + - "v*" + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Ensure git is available + run: | + if command -v git >/dev/null 2>&1; then + git --version + exit 0 + fi + + if command -v apt-get >/dev/null 2>&1; then + apt-get update + apt-get install -y git + elif command -v dnf >/dev/null 2>&1; then + dnf install -y git + elif command -v yum >/dev/null 2>&1; then + yum install -y git + else + echo "No supported package manager found to install git." + exit 1 + fi + + - name: Checkout source (native git) + env: + CI_GIT_TOKEN: ${{ secrets.CI_GIT_TOKEN }} + run: | + if [ -z "${CI_GIT_TOKEN:-}" ]; then + echo "Missing secret CI_GIT_TOKEN. Add it in repository Actions secrets." + exit 1 + fi + + auth_server="${GITHUB_SERVER_URL#https://}" + auth_server="${auth_server#http://}" + remote_url="https://oauth2:${CI_GIT_TOKEN}@${auth_server}/${GITHUB_REPOSITORY}.git" + + if [ -n "${GITHUB_WORKSPACE:-}" ]; then + cd "$GITHUB_WORKSPACE" + fi + + if [ ! -d .git ]; then + git init + fi + + git remote remove origin >/dev/null 2>&1 || true + git remote add origin "$remote_url" + + # Fetch the tag by SHA so we get the exact tagged commit + git fetch --depth 1 origin "$GITHUB_SHA" + git checkout --force FETCH_HEAD + + - name: Ensure Python and build dependencies are available + run: | + if ! command -v python3 >/dev/null 2>&1; then + if command -v apt-get >/dev/null 2>&1; then + apt-get update + apt-get install -y python3 python3-pip python3-venv patchelf ccache + elif command -v dnf >/dev/null 2>&1; then + dnf install -y python3 python3-pip patchelf ccache + fi + fi + + # patchelf is required by Nuitka for standalone Linux binaries + command -v patchelf >/dev/null 2>&1 || { + apt-get update && apt-get install -y patchelf + } + + python3 --version + + - name: Set up venv and install package + build deps + run: | + python3 -m venv .venv + . .venv/bin/activate + python -m pip install --upgrade pip + python -m pip install -e ".[build]" + + - name: Derive version from tag + id: version + run: echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT" + + - name: Build standalone binary with Nuitka + run: | + . .venv/bin/activate + python -m nuitka \ + --standalone \ + --onefile \ + --output-filename=tai \ + --output-dir=dist \ + --assume-yes-for-downloads \ + --include-package=tai \ + src/tai/cli.py + + - name: Smoke-test the binary + run: dist/tai --help + + - name: Upload binary artifact + uses: actions/upload-artifact@v3 + with: + name: tai-linux-amd64-${{ steps.version.outputs.tag }} + path: dist/tai + if-no-files-found: error + retention-days: 90 diff --git a/src/tai/ai_client.py b/src/tai/ai_client.py new file mode 100644 index 0000000..c80103e --- /dev/null +++ b/src/tai/ai_client.py @@ -0,0 +1,93 @@ +"""AI backend client — OpenAI-compatible, works with Ollama, OpenAI, or any compatible endpoint.""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass, field + +from openai import OpenAI + +DEFAULT_AI_HOST = "http://localhost:11434/v1" +DEFAULT_MODEL = "gemma3:4b" + + +@dataclass(slots=True) +class AIConfig: + """Connection parameters for an OpenAI-compatible AI backend.""" + + host: str = DEFAULT_AI_HOST + model: str = DEFAULT_MODEL + api_key: str = "ollama" # Ollama ignores this; required by the openai client + timeout_seconds: float = 120.0 + max_tokens: int = 4096 + extra_headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(slots=True) +class AIResponse: + """Structured response from an AI completion.""" + + model: str + content: str + prompt_tokens: int + completion_tokens: int + + @property + def total_tokens(self) -> int: + return self.prompt_tokens + self.completion_tokens + + +class AIClient: + """Thin wrapper around the openai client targeting a configurable endpoint.""" + + def __init__(self, config: AIConfig) -> None: + self._config = config + self._client = OpenAI( + base_url=config.host, + api_key=config.api_key, + timeout=config.timeout_seconds, + default_headers=config.extra_headers, + ) + + def complete(self, system_prompt: str, user_message: str) -> AIResponse: + """Send a completion request and return the full response.""" + response = self._client.chat.completions.create( + model=self._config.model, + max_tokens=self._config.max_tokens, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, + ], + ) + + choice = response.choices[0] + content = choice.message.content or "" + usage = response.usage + + return AIResponse( + model=response.model, + content=content, + prompt_tokens=usage.prompt_tokens if usage else 0, + completion_tokens=usage.completion_tokens if usage else 0, + ) + + def stream(self, system_prompt: str, user_message: str) -> Iterator[str]: + """Stream a completion, yielding text chunks as they arrive.""" + stream = self._client.chat.completions.create( + model=self._config.model, + max_tokens=self._config.max_tokens, + stream=True, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, + ], + ) + + for chunk in stream: + delta = chunk.choices[0].delta.content + if delta: + yield delta + + def summary(self) -> str: + """Human-readable description of the AI config.""" + return f"host={self._config.host} model={self._config.model}" diff --git a/src/tai/cli.py b/src/tai/cli.py index aa5baa3..4481be6 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -7,11 +7,14 @@ from typing import Annotated import typer from rich.console import Console +from rich.markdown import Markdown +from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.collectors import CollectionReport, collect_from_plan from tai.input_parser import InputValidationError, build_request from tai.models import TroubleshootRequest from tai.plan import plan_from_request +from tai.prompt_builder import build_system_prompt, build_user_message from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig app = typer.Typer(no_args_is_help=True, add_completion=False) @@ -56,6 +59,25 @@ def run( help="Collect baseline diagnostics after probe.", ), ] = False, + analyze: Annotated[ + bool, + typer.Option( + "--analyze/--no-analyze", + help="Send collected diagnostics to AI for analysis.", + ), + ] = False, + ai_host: Annotated[ + str, + typer.Option("--ai-host", help="OpenAI-compatible AI backend URL."), + ] = DEFAULT_AI_HOST, + model: Annotated[ + str, + typer.Option("--model", help="Model name to use for AI analysis."), + ] = DEFAULT_MODEL, + ai_key: Annotated[ + str, + typer.Option("--ai-key", help="API key for the AI backend (not needed for Ollama)."), + ] = "ollama", ) -> None: """Start an interactive troubleshooting session scaffold.""" try: @@ -81,50 +103,69 @@ def run( ) summary = SSHClient(config).summary() - console.print("[bold green]tai scaffold ready[/bold green]") + console.print("[bold green]tai[/bold green]") console.print(f"Issue: {req.issue}") - console.print(f"SSH: {summary}") + console.print(f"SSH: {summary}") if req.target_paths: console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}") + if not (probe or collect or analyze): + return # nothing SSH-related requested + + ai_config = AIConfig(host=ai_host, model=model, api_key=ai_key) + if analyze: + console.print(f"[cyan]AI:[/cyan] {AIClient(ai_config).summary()}") + + try: + asyncio.run(_async_main(config, req, probe=probe, collect=collect, analyze=analyze, + ai_config=ai_config)) + except typer.Exit: + raise + except TimeoutError as exc: + console.print(f"[red]SSH timeout:[/red] {exc}") + raise typer.Exit(code=1) from exc + except OSError as exc: + console.print(f"[red]SSH error:[/red] unable to execute ssh: {exc}") + raise typer.Exit(code=1) from exc + + +async def _async_main( + config: SSHConnectionConfig, + req: TroubleshootRequest, + *, + probe: bool, + collect: bool, + analyze: bool, + ai_config: AIConfig, +) -> None: + """Open a single SSH session and run probe / collection / analysis through it.""" client = SSHClient(config) + async with client.connect() as session: + if probe: + result = await session.probe() + _handle_probe_result(result) - if probe: - _run_probe(client) + report: CollectionReport | None = None + if collect or analyze: + plan = plan_from_request(req) + console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") + report = await collect_from_plan(session, plan) + _handle_collection_report(report) - if collect: - _run_collection(client, req) + if analyze and report is not None: + _run_analysis(ai_config, req.issue, report) -def _run_probe(client: SSHClient) -> None: - """Run a live SSH probe and exit non-zero on failure.""" +def _handle_probe_result(result: SSHCommandResult) -> None: + """Handle and render probe output for success or failure.""" console.print("[cyan]Running SSH probe:[/cyan] uname -a") - try: - result = asyncio.run(client.probe()) - except TimeoutError as exc: - console.print(f"[red]Probe failed:[/red] {exc}") - raise typer.Exit(code=1) from exc - except OSError as exc: - console.print(f"[red]Probe failed:[/red] unable to execute ssh: {exc}") - raise typer.Exit(code=1) from exc - - _handle_probe_result(result) - - -def _run_collection(client: SSHClient, request: TroubleshootRequest) -> None: - """Run issue-aware collection and print a compact summary.""" - plan = plan_from_request(request) - console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") - try: - report = asyncio.run(collect_from_plan(client, plan)) - except TimeoutError as exc: - console.print(f"[red]Collection failed:[/red] {exc}") - raise typer.Exit(code=1) from exc - except OSError as exc: - console.print(f"[red]Collection failed:[/red] unable to execute ssh: {exc}") - raise typer.Exit(code=1) from exc - - _handle_collection_report(report) + if result.exit_code != 0: + details = result.stderr or result.stdout or "no error output from ssh" + console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}") + raise typer.Exit(code=1) + output = result.stdout or "(no output)" + console.print("[bold green]Probe succeeded.[/bold green]") + console.print(f"Remote: {output}") def _handle_collection_report(report: CollectionReport) -> None: @@ -134,22 +175,25 @@ def _handle_collection_report(report: CollectionReport) -> None: ) for item in report.items: status = "ok" if item.result.exit_code == 0 else f"exit {item.result.exit_code}" - trunc = "" - if item.result.stdout_truncated or item.result.stderr_truncated: - trunc = " (truncated)" + truncated = item.result.stdout_truncated or item.result.stderr_truncated + trunc = " (truncated)" if truncated else "" console.print(f"- {item.name}: {status}{trunc}") -def _handle_probe_result(result: SSHCommandResult) -> None: - """Handle and render probe output for success or failure.""" - if result.exit_code != 0: - details = result.stderr or result.stdout or "no error output from ssh" - console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}") - raise typer.Exit(code=1) - - output = result.stdout or "(no output)" - console.print("[bold green]Probe succeeded.[/bold green]") - console.print(f"Remote: {output}") +def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) -> None: + """Send collected data to the AI and stream the analysis to stdout.""" + console.print("[cyan]Analyzing...[/cyan]\n") + ai = AIClient(ai_config) + system_prompt = build_system_prompt() + user_message = build_user_message(issue, report) + try: + chunks: list[str] = [] + for chunk in ai.stream(system_prompt, user_message): + chunks.append(chunk) + console.print(Markdown("".join(chunks))) + except Exception as exc: # noqa: BLE001 + console.print(f"[red]AI analysis failed:[/red] {exc}") + raise typer.Exit(code=1) from exc def main() -> None: diff --git a/src/tai/collectors.py b/src/tai/collectors.py index 9ad4754..72fbdd2 100644 --- a/src/tai/collectors.py +++ b/src/tai/collectors.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from tai.plan import CollectionPlan -from tai.ssh_client import SSHClient, SSHCommandResult +from tai.ssh_client import SSHCommandResult, SSHSession @dataclass(slots=True) @@ -31,20 +31,20 @@ class CollectionReport: async def collect_from_plan( - client: SSHClient, + session: SSHSession, plan: CollectionPlan, *, max_output_bytes: int = 32768, ) -> CollectionReport: - """Execute all commands in *plan* and return a :class:`CollectionReport`.""" + """Execute all commands in *plan* over a shared SSH session.""" items: list[CollectedItem] = [] for name, command in plan.commands: - result = await client.run_read_only_command( + result = await session.run_read_only_command( command, timeout_seconds=30.0, max_output_bytes=max_output_bytes, ) items.append(CollectedItem(name=name, result=result)) - return CollectionReport(host=client.summary(), items=items) + return CollectionReport(host=session._client.summary(), items=items) diff --git a/src/tai/prompt_builder.py b/src/tai/prompt_builder.py new file mode 100644 index 0000000..360cfb7 --- /dev/null +++ b/src/tai/prompt_builder.py @@ -0,0 +1,74 @@ +"""Formats collected diagnostics into prompts for the AI backend.""" + +from __future__ import annotations + +from tai.collectors import CollectionReport + +_SYSTEM_PROMPT = """\ +You are an expert Linux systems administrator and troubleshooting assistant. +You are given diagnostic data collected read-only from a remote Linux host via SSH. + +Your job: +1. Identify the root cause of the reported issue based only on the data provided. +2. Cite the specific output that supports your conclusion. +3. Give concise, actionable remediation steps. + +Important rules: +- Only draw conclusions from data that is actually present. Do not speculate or invent evidence. +- If a command shows "could not be executed (SSH error)" it means the remote host blocked or + rejected that specific command — it is not evidence about the service or system state. +- If there is not enough data to diagnose the issue, say so plainly and list exactly what + additional commands or log files would be needed. +- Keep the response short. Skip sections that have nothing useful to say. +- Never suggest commands that modify the system unless explicitly asked. +- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**. +""" + + +def build_system_prompt() -> str: + """Return the static system prompt for the troubleshooting agent.""" + return _SYSTEM_PROMPT.strip() + + +def build_user_message(issue: str, report: CollectionReport) -> str: + """Format *issue* and *report* into the user message sent to the AI.""" + lines: list[str] = [] + + lines.append(f"## Issue reported\n\n{issue}\n") + lines.append(f"## Target host\n\n{report.host}\n") + lines.append("## Collected diagnostics\n") + + skipped: list[str] = [] + + for item in report.items: + result = item.result + + # Exit 255 with no output = SSH couldn't execute the command at all. + # Exclude these entirely to prevent the AI from speculating on them. + if result.exit_code == 255 and not result.stdout and not result.stderr: + skipped.append(item.name) + continue + + lines.append(f"### {item.name}\n") + lines.append(f"**Command:** `{result.command}` ") + lines.append(f"**Exit code:** {result.exit_code}\n") + + if result.stdout: + trunc = " *(output truncated)*" if result.stdout_truncated else "" + lines.append(f"**stdout:**{trunc}\n```\n{result.stdout.strip()}\n```\n") + + if result.stderr: + trunc = " *(output truncated)*" if result.stderr_truncated else "" + lines.append(f"**stderr:**{trunc}\n```\n{result.stderr.strip()}\n```\n") + + if not result.stdout and not result.stderr: + lines.append("*(no output)*\n") + + if skipped: + lines.append( + f"**Note:** The following commands could not be executed on this host " + f"and produced no output: {', '.join(skipped)}. " + f"Do not draw any conclusions from their absence.\n" + ) + + return "\n".join(lines) diff --git a/src/tai/ssh_client.py b/src/tai/ssh_client.py index 4f7164a..6235010 100644 --- a/src/tai/ssh_client.py +++ b/src/tai/ssh_client.py @@ -1,9 +1,12 @@ """SSH configuration and read-only command execution.""" import asyncio +import os import shlex +import tempfile from dataclasses import dataclass from pathlib import Path +from types import TracebackType @dataclass(slots=True) @@ -88,8 +91,8 @@ class SSHClient: f"key={key} jump={jump} mode={mode}" ) - def build_ssh_argv(self, remote_command: str) -> list[str]: - """Build argv for a secure non-interactive SSH invocation.""" + def _build_base_argv(self) -> list[str]: + """Build the common SSH argv flags (no host or command appended).""" argv = [ "ssh", "-p", @@ -111,9 +114,16 @@ class SSHClient: if self._config.jump_host: argv += ["-J", self._config.jump_host] - argv += [self._config.host, remote_command] return argv + def build_ssh_argv(self, remote_command: str) -> list[str]: + """Build argv for a secure non-interactive SSH invocation.""" + return self._build_base_argv() + [self._config.host, remote_command] + + def connect(self, *, connect_timeout: float = 15.0) -> "SSHSession": + """Return an :class:`SSHSession` async context manager for this host.""" + return SSHSession(self, connect_timeout=connect_timeout) + def validate_read_only_command(self, command: str) -> None: """Validate that a command appears read-only and non-destructive.""" normalized = command.strip() @@ -168,8 +178,7 @@ class SSHClient: max_output_bytes=4096, ) - async def _run_ssh( - self, + async def _run_ssh( self, command: str, *, timeout_seconds: float, @@ -231,3 +240,144 @@ class SSHClient: trimmed_bytes = encoded[:keep] trimmed_text = trimmed_bytes.decode("utf-8", errors="ignore").rstrip() return f"{trimmed_text}{marker}", True + + +class SSHSession: + """A persistent SSH connection using ControlMaster multiplexing. + + All commands run over the same underlying TCP connection — no per-command + SSH handshake. Use as an async context manager:: + + async with client.connect() as session: + result = await session.run_read_only_command("df -h") + """ + + def __init__(self, client: SSHClient, *, connect_timeout: float = 15.0) -> None: + self._client = client + self._connect_timeout = connect_timeout + self._socket_path: Path | None = None + self._master_proc: asyncio.subprocess.Process | None = None + + async def __aenter__(self) -> "SSHSession": + fd, path = tempfile.mkstemp(prefix="tai-ssh-", suffix=".sock") + os.close(fd) + os.unlink(path) # SSH needs to create this itself as a socket + self._socket_path = Path(path) + + master_argv = self._client._build_base_argv() + [ + "-o", "ControlMaster=yes", + "-o", f"ControlPath={self._socket_path}", + "-o", "ControlPersist=no", + "-N", + self._client._config.host, + ] + + self._master_proc = await asyncio.create_subprocess_exec( + *master_argv, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + + # Wait for the control socket to appear (master is ready) + loop = asyncio.get_event_loop() + deadline = loop.time() + self._connect_timeout + while not self._socket_path.exists(): + if loop.time() > deadline: + await self._teardown() + raise TimeoutError( + f"SSH ControlMaster did not connect within " + f"{self._connect_timeout}s to {self._client._config.host}" + ) + await asyncio.sleep(0.05) + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._teardown() + + async def _teardown(self) -> None: + if self._master_proc and self._master_proc.returncode is None: + self._master_proc.terminate() + try: + await asyncio.wait_for(self._master_proc.wait(), timeout=5.0) + except TimeoutError: + self._master_proc.kill() + if self._socket_path and self._socket_path.exists(): + self._socket_path.unlink(missing_ok=True) + + def _command_argv(self, remote_command: str) -> list[str]: + return self._client._build_base_argv() + [ + "-o", f"ControlPath={self._socket_path}", + "-o", "ControlMaster=no", + self._client._config.host, + remote_command, + ] + + async def probe(self) -> SSHCommandResult: + """Run uname -a to confirm connectivity.""" + return await self._run("uname -a", timeout_seconds=15.0, max_output_bytes=4096) + + async def run_read_only_command( + self, + command: str, + *, + timeout_seconds: float = 30.0, + max_output_bytes: int = 32768, + ) -> SSHCommandResult: + """Validate and run a read-only command over the shared connection.""" + self._client.validate_read_only_command(command) + return await self._run( + command, timeout_seconds=timeout_seconds, max_output_bytes=max_output_bytes + ) + + async def _run( + self, + command: str, + *, + timeout_seconds: float, + max_output_bytes: int, + ) -> SSHCommandResult: + argv = self._command_argv(command) + proc = await asyncio.create_subprocess_exec( + *argv, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), timeout=timeout_seconds + ) + except TimeoutError as exc: + proc.kill() + await proc.wait() + raise TimeoutError( + f"SSH command timed out after {timeout_seconds}s: {command}" + ) from exc + + if proc.returncode is None: + raise RuntimeError("SSH process did not provide an exit code.") + + stdout_text, stdout_truncated = SSHClient._truncate_output( + stdout_bytes.decode("utf-8", errors="replace"), + max_output_bytes=max_output_bytes, + ) + stderr_text, stderr_truncated = SSHClient._truncate_output( + stderr_bytes.decode("utf-8", errors="replace"), + max_output_bytes=max_output_bytes, + ) + + return SSHCommandResult( + command=command, + exit_code=proc.returncode, + stdout=stdout_text, + stderr=stderr_text, + stdout_truncated=stdout_truncated, + stderr_truncated=stderr_truncated, + ) + diff --git a/tests/test_ai.py b/tests/test_ai.py new file mode 100644 index 0000000..08d1510 --- /dev/null +++ b/tests/test_ai.py @@ -0,0 +1,192 @@ +"""Tests for the AI client and prompt builder.""" + +from unittest.mock import MagicMock, patch + +from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig +from tai.collectors import CollectedItem, CollectionReport +from tai.prompt_builder import build_system_prompt, build_user_message +from tai.ssh_client import SSHCommandResult + +# --------------------------------------------------------------------------- +# AIConfig defaults +# --------------------------------------------------------------------------- + + +def test_ai_config_defaults() -> None: + config = AIConfig() + assert config.host == DEFAULT_AI_HOST + assert config.model == DEFAULT_MODEL + assert config.api_key == "ollama" + + +def test_ai_config_custom_values() -> None: + config = AIConfig(host="https://api.openai.com/v1", model="gpt-4o", api_key="sk-test") + assert config.host == "https://api.openai.com/v1" + assert config.model == "gpt-4o" + assert config.api_key == "sk-test" + + +# --------------------------------------------------------------------------- +# AIClient.summary +# --------------------------------------------------------------------------- + + +def test_ai_client_summary_contains_host_and_model() -> None: + config = AIConfig(host="http://myserver:11434/v1", model="llama3.1:8b") + client = AIClient(config) + summary = client.summary() + assert "http://myserver:11434/v1" in summary + assert "llama3.1:8b" in summary + + +# --------------------------------------------------------------------------- +# AIClient.complete (mocked) +# --------------------------------------------------------------------------- + + +def _make_mock_response(content: str, model: str = "gemma3:4b") -> MagicMock: + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 20 + + message = MagicMock() + message.content = content + + choice = MagicMock() + choice.message = message + + response = MagicMock() + response.choices = [choice] + response.model = model + response.usage = usage + return response + + +def test_complete_returns_ai_response() -> None: + config = AIConfig() + client = AIClient(config) + mock_response = _make_mock_response("The root cause is X.") + + with patch.object(client._client.chat.completions, "create", return_value=mock_response): + result = client.complete("system prompt", "user message") + + assert result.content == "The root cause is X." + assert result.prompt_tokens == 10 + assert result.completion_tokens == 20 + assert result.total_tokens == 30 + + +def test_complete_handles_empty_content() -> None: + config = AIConfig() + client = AIClient(config) + mock_response = _make_mock_response(None) # type: ignore[arg-type] + mock_response.choices[0].message.content = None + + with patch.object(client._client.chat.completions, "create", return_value=mock_response): + result = client.complete("system", "user") + + assert result.content == "" + + +# --------------------------------------------------------------------------- +# AIClient.stream (mocked) +# --------------------------------------------------------------------------- + + +def test_stream_yields_chunks() -> None: + config = AIConfig() + client = AIClient(config) + + def _make_chunk(text: str | None) -> MagicMock: + delta = MagicMock() + delta.content = text + choice = MagicMock() + choice.delta = delta + chunk = MagicMock() + chunk.choices = [choice] + return chunk + + mock_chunks = [ + _make_chunk("Root "), _make_chunk("cause "), _make_chunk(None), _make_chunk("found."), + ] + + with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)): + result = list(client.stream("system", "user")) + + assert result == ["Root ", "cause ", "found."] + + +# --------------------------------------------------------------------------- +# prompt_builder +# --------------------------------------------------------------------------- + + +def _make_report(items: list[tuple[str, str, int, str, str]]) -> CollectionReport: + """Build a CollectionReport from (name, command, exit_code, stdout, stderr) tuples.""" + return CollectionReport( + host="root@testhost", + items=[ + CollectedItem( + name=name, + result=SSHCommandResult( + command=command, + exit_code=exit_code, + stdout=stdout, + stderr=stderr, + ), + ) + for name, command, exit_code, stdout, stderr in items + ], + ) + + +def test_build_system_prompt_contains_key_instructions() -> None: + prompt = build_system_prompt() + assert "Root Cause" in prompt + assert "Evidence" in prompt + assert "Recommended Actions" in prompt + assert "read-only" in prompt.lower() + + +def test_build_user_message_contains_issue_and_host() -> None: + report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")]) + msg = build_user_message("nginx is failing", report) + assert "nginx is failing" in msg + assert "root@testhost" in msg + + +def test_build_user_message_includes_command_output() -> None: + report = _make_report([("kernel", "uname -a", 0, "Linux web01 6.1.0", "")]) + msg = build_user_message("test issue", report) + assert "uname -a" in msg + assert "Linux web01 6.1.0" in msg + + +def test_build_user_message_shows_stderr() -> None: + report = _make_report( + [("svc", "systemctl status nginx", 3, "", "Unit nginx.service not found.")] + ) + msg = build_user_message("nginx not found", report) + assert "Unit nginx.service not found." in msg + + +def test_build_user_message_notes_truncation() -> None: + result = SSHCommandResult( + command="journalctl -n 100 --no-pager", + exit_code=0, + stdout="...", + stderr="", + stdout_truncated=True, + ) + report = CollectionReport( + host="root@testhost", + items=[CollectedItem(name="journal", result=result)], + ) + msg = build_user_message("disk issue", report) + assert "truncated" in msg + + +def test_build_user_message_handles_no_output() -> None: + report = _make_report([("empty", "cat /nonexistent", 1, "", "")]) + msg = build_user_message("test", report) + assert "no output" in msg diff --git a/tests/test_cli.py b/tests/test_cli.py index 9bac274..c26716f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,3 +1,5 @@ +from unittest.mock import AsyncMock, MagicMock + from typer.testing import CliRunner from tai.cli import app @@ -5,6 +7,24 @@ from tai.collectors import CollectedItem, CollectionReport from tai.ssh_client import SSHCommandResult +def _mock_session( + monkeypatch, # type: ignore[no-untyped-def] + *, + probe_result: SSHCommandResult | None = None, + probe_raises: Exception | None = None, +) -> MagicMock: + """Patch SSHClient.connect to return a mock session.""" + session = MagicMock() + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock(return_value=None) + if probe_raises: + session.probe = AsyncMock(side_effect=probe_raises) + else: + session.probe = AsyncMock(return_value=probe_result) + monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session) + return session + + def test_run_command_prints_scaffold_summary() -> None: runner = CliRunner() result = runner.invoke( @@ -25,33 +45,23 @@ def test_run_command_prints_scaffold_summary() -> None: ) assert result.exit_code == 0 - assert "tai scaffold ready" in result.stdout + assert "tai" in result.stdout assert "host=web01" in result.stdout assert "port=5566" in result.stdout def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def] - async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def] - return SSHCommandResult( - command="uname -a", - exit_code=0, - stdout="Linux ssh 6.12.0", - stderr="", - ) - - monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe) + _mock_session( + monkeypatch, + probe_result=SSHCommandResult( + command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr="" + ), + ) runner = CliRunner() result = runner.invoke( app, - [ - "apache failed", - "--host", - "ssh.archflux.net", - "--port", - "5566", - "--probe", - ], + ["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"], ) assert result.exit_code == 0 @@ -60,27 +70,20 @@ def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def] - async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def] - return SSHCommandResult( + _mock_session( + monkeypatch, + probe_result=SSHCommandResult( command="uname -a", exit_code=255, stdout="", stderr="Permission denied (publickey,password).", - ) - - monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe) + ), + ) runner = CliRunner() result = runner.invoke( app, - [ - "apache failed", - "--host", - "ssh.archflux.net", - "--port", - "5566", - "--probe", - ], + ["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"], ) assert result.exit_code == 1 @@ -88,7 +91,9 @@ def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def] - async def fake_collect_from_plan(_client, _plan) -> CollectionReport: # type: ignore[no-untyped-def] + _mock_session(monkeypatch) + + async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] return CollectionReport( host="ssh.archflux.net", items=[