From 17fd96680bea0038dab3ad8297abc9adfcf8e4ed Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 03:43:41 +0200 Subject: [PATCH 01/11] push Co-authored-by: Copilot --- .gitea/workflows/ci.yml | 32 ++++++ .github/workflows/ci.yml | 32 ++++++ .gitignore | 26 +++++ CHANGELOG.md | 34 +++++++ README.md | 36 ++++++- ROADMAP.md | 126 ++++++++++++++++++++++++ pyproject.toml | 50 ++++++++++ src/tai/__init__.py | 5 + src/tai/cli.py | 117 ++++++++++++++++++++++ src/tai/input_parser.py | 46 +++++++++ src/tai/models.py | 17 ++++ src/tai/ssh_client.py | 193 +++++++++++++++++++++++++++++++++++++ tests/test_cli.py | 86 +++++++++++++++++ tests/test_input_parser.py | 65 +++++++++++++ tests/test_ssh_client.py | 93 ++++++++++++++++++ 15 files changed, 956 insertions(+), 2 deletions(-) create mode 100644 .gitea/workflows/ci.yml create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 CHANGELOG.md create mode 100644 ROADMAP.md create mode 100644 pyproject.toml create mode 100644 src/tai/__init__.py create mode 100644 src/tai/cli.py create mode 100644 src/tai/input_parser.py create mode 100644 src/tai/models.py create mode 100644 src/tai/ssh_client.py create mode 100644 tests/test_cli.py create mode 100644 tests/test_input_parser.py create mode 100644 tests/test_ssh_client.py diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml new file mode 100644 index 0000000..8465b2a --- /dev/null +++ b/.gitea/workflows/ci.yml @@ -0,0 +1,32 @@ +name: CI + +on: + push: + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install package and dev dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[dev] + + - name: Lint + run: ruff check . + + - name: Type-check + run: mypy src + + - name: Test + run: pytest diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..8465b2a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +name: CI + +on: + push: + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install package and dev dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[dev] + + - name: Lint + run: ruff check . + + - name: Type-check + run: mypy src + + - name: Test + run: pytest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ccb7ac0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +# Python cache and bytecode +__pycache__/ +*.py[cod] +*.pyo + +# Virtual environments +.venv/ +venv/ + +# Tool caches +.pytest_cache/ +.ruff_cache/ +.mypy_cache/ + +# Build artifacts +build/ +dist/ +*.egg-info/ +*.spec + +# Coverage +.coverage +htmlcov/ + +# IDE +.vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3442cdc --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,34 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). + +--- + +## [Unreleased] + +### Added +- `README.md` — project overview, description, example workflow, supported distributions, and suggested tooling +- `ROADMAP.md` — phased development plan covering decisions, data collection, AI integration, CLI design, and hardening +- `CHANGELOG.md` — this file; established changelog tracking for the project +- `.gitea/workflows/ci.yml` — Gitea Actions CI workflow for push and pull request events +- Python package scaffold with `src` layout and project metadata in `pyproject.toml` +- Initial CLI entrypoint with agreed SSH flags: `--identity-file`, `--jump-host`, and `--ignore-ssh-config` +- Input parsing/validation module and core request model +- SSH configuration scaffold module for upcoming connection/read-only execution work +- Implemented SSH module with real key-based command execution via system `ssh` +- Added explicit SSH port support across CLI, input parsing, request model, and SSH client (`--port`, e.g. 5566) +- Added live SSH connectivity probe (`uname -a`) enabled by default, with `--no-probe` opt-out and non-zero exit on failure +- Read-only command policy enforcement (allowlist + blocked shell operators) +- Test scaffold (`pytest`) with initial parser and CLI coverage +- SSH test coverage for policy checks, SSH argument construction, and config summary behavior +- CI workflow for lint (`ruff`), type-check (`mypy`), and tests (`pytest`) + +### Decided +- Implementation language: **Python** +- Distribution strategy: single distributable binary via **Nuitka** (PyInstaller as fallback) +- SSH authentication: **keypair only** (ed25519/RSA); auto-accept new hosts; hard reject on host key change with MITM warning +- SSH bastion support: `--jump-host` flag using SSH native ProxyJump +- SSH config behavior: use `~/.ssh/config` by default; allow override via `--ignore-ssh-config` +- Interface: **interactive REPL** for v0.1; `textual`-based TUI (split-pane) for v0.2+ diff --git a/README.md b/README.md index ca9a26f..296c068 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,35 @@ -# tai +# tai — Linux AI Troubleshooting Agent -Linux AI driven troubleshooting agent. \ No newline at end of file +`tai` is an agentic AI-driven troubleshooting tool for Linux systems. It autonomously investigates issues on remote hosts via SSH, analyzes relevant logs and configuration files, and provides a clear diagnosis along with suggested remediation steps — all without making any changes to the target system. + +## Overview + +Given a problem description and a target hostname, `tai` connects to the remote system over SSH, gathers relevant data (logs, configuration files, service status, etc.), and uses a locally-hosted AI model to reason about the root cause and recommend solutions. + +The agent operates in **read-only mode at all times**. It will never modify the target system under any circumstances — all suggestions are presented to the human troubleshooter for review and action. + +## Supported Distributions + +- Ubuntu +- Debian +- RHEL +- Rocky Linux + +## Example Workflow + +A troubleshooter receives a ticket reporting that the Apache service on a remote server has failed to start. They provide `tai` with: + +1. The ticket description or error message +2. The hostname of the affected system +3. Any relevant directories to focus on + +`tai` then connects to the host, reads through system logs, service configurations, and any other related files, and returns a structured analysis of the likely cause along with recommended next steps. + +## Suggested Tooling + +| Component | Tool | +|-----------|------| +| AI inference backend | [vLLM](https://github.com/vllm-project/vllm) | +| Model | `gemma4:a4b` | + +> **Note:** A suitable implementation language for this project is yet to be determined. \ No newline at end of file diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..235c6a4 --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,126 @@ +# Roadmap + +This document outlines the major decisions, milestones, and development phases required to bring `tai` from concept to a working tool. + +--- + +## Phase 0 — Decisions & Prerequisites + +These must be resolved before meaningful development can begin. + +### Language Selection +- [x] **Decision: Python** +- Key factors: native vLLM integration, mature SSH libraries (`paramiko` / `asyncssh`), strong text/log parsing, rapid development +- Single binary distribution will be achieved via **Nuitka** (preferred for true compilation) or **PyInstaller** as a fallback +- [ ] Evaluate Nuitka vs PyInstaller for binary output quality and CI reproducibility +- [ ] Add binary build step to CI pipeline + +### AI Backend & Model +- [ ] Confirm use of [vLLM](https://github.com/vllm-project/vllm) as the inference backend +- [ ] Confirm `gemma4:a4b` as the default model (or select an alternative) +- [ ] Define minimum hardware requirements for running the model locally +- [ ] Decide whether the AI backend is bundled, self-hosted externally, or user-supplied + +### SSH Strategy +- [x] **Decision: keypair authentication only** — no password auth; eliminates credential storage risk + - Default key resolution: `~/.ssh/id_ed25519`, `~/.ssh/id_rsa` (in order of preference) + - CLI override via `--identity-file ` + - No SSH agent forwarding needed — a shared key is distributed to all managed hosts via Puppet +- [x] **Known hosts: auto-accept new hosts; reject on key mismatch** — a changed host key triggers a hard stop with a MITM warning; unknown/new hosts are accepted silently on first connect +- [x] **Bastion/jump host: `--jump-host ` flag** — delegates to SSH's native ProxyJump functionality +- [x] **SSH config behavior: respect existing `~/.ssh/config` by default; allow CLI override** + - Default: follow host settings from `~/.ssh/config` (for `User`, `Port`, `ProxyJump`, etc.) + - Override switch: `--ignore-ssh-config` to bypass local SSH config when required + +### Scope & Constraints +- [ ] Define the supported scope of issues (services, network, disk, kernel, etc.) +- [ ] Confirm read-only guarantee — document exactly what "read-only" means in practice +- [x] **Decision: interactive REPL mode for v0.1, full TUI for v0.2+** + - v0.1: chat-loop REPL launched from CLI; human can follow up, correct, and redirect the agent + - v0.2+: `textual`-based TUI with split panes (collected data | AI output | input bar) + - Built-in slash commands: `/collect`, `/show logs`, `/clear`, `/host `, `/help`, `/quit` + +--- + +## Phase 1 — Project Foundation + +Basic project scaffolding and connectivity. + +- [x] Finalise repository structure and language toolchain +- [x] Set up CI pipeline (linting, tests) +- [ ] Implement SSH connection module + - [x] Define SSH config model and probe interface scaffold + - [x] Connect to remote host + - [x] Execute read-only commands (e.g. `journalctl`, `systemctl status`, `cat`) + - [ ] Stream or collect command output safely +- [x] Implement basic input parsing (ticket text, hostname, target directories) +- [x] Write unit tests for SSH and input modules + - [x] Input parser and CLI tests added + - [x] SSH module tests added for command policy and SSH argv behavior + +--- + +## Phase 2 — Data Collection Layer + +Define what information the agent gathers and how. + +- [ ] Identify the canonical set of data sources per issue type: + - Service failures: `journalctl`, `systemctl`, service config files + - Network issues: `ip`, `ss`, `netstat`, firewall rules + - Disk issues: `df`, `du`, `dmesg`, `smartctl` + - General: `/var/log/syslog`, `/var/log/messages`, `dmesg` +- [ ] Implement pluggable "collector" modules per data source +- [ ] Implement directory traversal for user-specified paths (read-only) +- [ ] Add support for per-distro variations (Ubuntu vs RHEL path differences, etc.) +- [ ] Write tests with mocked SSH output + +--- + +## Phase 3 — AI Integration + +Wire collected data into the local AI model. + +- [ ] Implement vLLM client module +- [ ] Design prompt template: system context, collected data, issue description → diagnosis +- [ ] Implement response parsing and structured output (root cause + suggested steps) +- [ ] Tune context window usage — handle truncation for large log outputs +- [ ] Add streaming support for long AI responses +- [ ] Evaluate and test model output quality on common issue types + +--- + +## Phase 4 — CLI & User Experience + +Polish the interface for real-world use. + +- [ ] Design CLI interface (flags, subcommands, interactive prompts) +- [ ] Implement structured output: diagnosis, confidence, recommended actions +- [ ] Add `--verbose` / `--debug` mode showing raw collected data +- [ ] Support output to file or clipboard +- [ ] Write man page / `--help` documentation + +--- + +## Phase 5 — Hardening & Distribution + +Prepare for broader use. + +- [ ] Security review of SSH handling and credential storage +- [ ] Ensure no data is written to the remote system under any path +- [ ] Package for distribution (binary release, container image, or distro packages) +- [ ] Write installation and quickstart documentation +- [ ] End-to-end integration tests against a test VM + +--- + +## Decisions Log + +| Date | Decision | Outcome | +|------|----------|---------| +| 2026-05-04 | Implementation language | Python — with single distributable binary via Nuitka | +| — | AI inference backend | vLLM (provisional) | +| — | Default model | `gemma4:a4b` (provisional) | +| 2026-05-04 | SSH auth methods | Keypair only (ed25519/RSA); auto-accept new hosts; reject on key change (MITM) | +| 2026-05-04 | Bastion host support | `--jump-host` flag via SSH native ProxyJump | +| 2026-05-04 | SSH config behavior | Use `~/.ssh/config` by default; allow override via `--ignore-ssh-config` | +| 2026-05-04 | CLI vs interactive mode | Interactive: REPL for v0.1, `textual` TUI for v0.2+ | diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2e8d855 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +requires = ["hatchling>=1.25"] +build-backend = "hatchling.build" + +[project] +name = "tai" +version = "0.1.0" +description = "Linux AI-driven troubleshooting agent" +readme = "README.md" +requires-python = ">=3.11" +authors = [ + { name = "tai contributors" } +] +dependencies = [ + "typer>=0.12,<1.0", + "rich>=13.7,<14.0", + "asyncssh>=2.14,<3.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.2,<9.0", + "ruff>=0.5,<1.0", + "mypy>=1.10,<2.0", +] +build = [ + "nuitka>=2.4,<3.0", +] + +[project.scripts] +tai = "tai.cli:main" + +[tool.hatch.build.targets.wheel] +packages = ["src/tai"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-q" + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "I", "UP", "B"] + +[tool.mypy] +python_version = "3.11" +strict = true +warn_unused_configs = true diff --git a/src/tai/__init__.py b/src/tai/__init__.py new file mode 100644 index 0000000..42d8357 --- /dev/null +++ b/src/tai/__init__.py @@ -0,0 +1,5 @@ +"""tai package.""" + +__all__ = ["__version__"] + +__version__ = "0.1.0" diff --git a/src/tai/cli.py b/src/tai/cli.py new file mode 100644 index 0000000..5a3e269 --- /dev/null +++ b/src/tai/cli.py @@ -0,0 +1,117 @@ +"""CLI entrypoint for tai.""" + +from __future__ import annotations + +import asyncio +from typing import Annotated + +import typer +from rich.console import Console + +from tai.input_parser import InputValidationError, build_request +from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig + +app = typer.Typer(no_args_is_help=True, add_completion=False) +console = Console() + + +@app.command() +def run( + issue: Annotated[str, typer.Argument(help="Ticket text or issue summary.")], + host: Annotated[str, typer.Option("--host", help="Target host to troubleshoot.")], + port: Annotated[int, typer.Option("--port", help="SSH port for the target host.")] = 22, + path: Annotated[ + list[str] | None, + typer.Option("--path", help="Path to inspect. Repeatable."), + ] = None, + identity_file: Annotated[ + str | None, + typer.Option("--identity-file", help="SSH private key path."), + ] = None, + jump_host: Annotated[ + str | None, + typer.Option("--jump-host", help="SSH bastion/jump host."), + ] = None, + ignore_ssh_config: Annotated[ + bool, + typer.Option( + "--ignore-ssh-config", + help="Ignore ~/.ssh/config and rely only on CLI options.", + ), + ] = False, + probe: Annotated[ + bool, + typer.Option( + "--probe/--no-probe", + help="Enable or disable live SSH connectivity probe (uname -a).", + ), + ] = True, +) -> None: + """Start an interactive troubleshooting session scaffold.""" + try: + req = build_request( + issue=issue, + host=host, + port=port, + target_paths=path or [], + identity_file=identity_file, + jump_host=jump_host, + ignore_ssh_config=ignore_ssh_config, + ) + except InputValidationError as exc: + console.print(f"[red]Input error:[/red] {exc}") + raise typer.Exit(code=2) from exc + + config = SSHConnectionConfig( + host=req.host, + port=req.port, + identity_file=req.identity_file, + jump_host=req.jump_host, + ignore_ssh_config=req.ignore_ssh_config, + ) + + summary = SSHClient(config).summary() + console.print("[bold green]tai scaffold ready[/bold green]") + console.print(f"Issue: {req.issue}") + console.print(f"SSH: {summary}") + if req.target_paths: + console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}") + + if probe: + _run_probe(SSHClient(config)) + + +def _run_probe(client: SSHClient) -> None: + """Run a live SSH probe and exit non-zero on failure.""" + console.print("[cyan]Running SSH probe:[/cyan] uname -a") + try: + result = asyncio.run(client.probe()) + except TimeoutError as exc: + console.print(f"[red]Probe failed:[/red] {exc}") + raise typer.Exit(code=1) from exc + except OSError as exc: + console.print(f"[red]Probe failed:[/red] unable to execute ssh: {exc}") + raise typer.Exit(code=1) from exc + + _handle_probe_result(result) + + +def _handle_probe_result(result: SSHCommandResult) -> None: + """Handle and render probe output for success or failure.""" + if result.exit_code != 0: + details = result.stderr or result.stdout or "no error output from ssh" + console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}") + raise typer.Exit(code=1) + + output = result.stdout or "(no output)" + console.print("[bold green]Probe succeeded.[/bold green]") + console.print(f"Remote: {output}") + + +def main() -> None: + """Console script entrypoint.""" + app() + + +if __name__ == "__main__": + main() diff --git a/src/tai/input_parser.py b/src/tai/input_parser.py new file mode 100644 index 0000000..a474d23 --- /dev/null +++ b/src/tai/input_parser.py @@ -0,0 +1,46 @@ +"""Helpers to normalize and validate CLI input.""" + +from pathlib import Path + +from tai.models import TroubleshootRequest + + +class InputValidationError(ValueError): + """Raised when required user input is missing or invalid.""" + + +def build_request( + *, + issue: str, + host: str, + port: int, + target_paths: list[str], + identity_file: str | None, + jump_host: str | None, + ignore_ssh_config: bool, +) -> TroubleshootRequest: + """Create a normalized request object from raw CLI values.""" + normalized_issue = issue.strip() + normalized_host = host.strip() + + if not normalized_issue: + raise InputValidationError("Issue description cannot be empty.") + + if not normalized_host: + raise InputValidationError("Host cannot be empty.") + + if port < 1 or port > 65535: + raise InputValidationError("Port must be between 1 and 65535.") + + paths = [Path(p).expanduser() for p in target_paths] + identity = Path(identity_file).expanduser() if identity_file else None + + return TroubleshootRequest( + issue=normalized_issue, + host=normalized_host, + port=port, + target_paths=paths, + identity_file=identity, + jump_host=jump_host.strip() if jump_host else None, + ignore_ssh_config=ignore_ssh_config, + ) diff --git a/src/tai/models.py b/src/tai/models.py new file mode 100644 index 0000000..e3a855e --- /dev/null +++ b/src/tai/models.py @@ -0,0 +1,17 @@ +"""Core domain models for tai.""" + +from dataclasses import dataclass, field +from pathlib import Path + + +@dataclass(slots=True) +class TroubleshootRequest: + """User-provided troubleshooting input for a single run.""" + + issue: str + host: str + port: int = 22 + target_paths: list[Path] = field(default_factory=list) + identity_file: Path | None = None + jump_host: str | None = None + ignore_ssh_config: bool = False diff --git a/src/tai/ssh_client.py b/src/tai/ssh_client.py new file mode 100644 index 0000000..c690b6e --- /dev/null +++ b/src/tai/ssh_client.py @@ -0,0 +1,193 @@ +"""SSH configuration and read-only command execution.""" + +import asyncio +import shlex +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(slots=True) +class SSHConnectionConfig: + """Connection parameters for the target host.""" + + host: str + port: int = 22 + identity_file: Path | None = None + jump_host: str | None = None + ignore_ssh_config: bool = False + + +@dataclass(slots=True) +class SSHCommandResult: + """Result of a remote SSH command execution.""" + + command: str + exit_code: int + stdout: str + stderr: str + + +class SSHCommandRejectedError(ValueError): + """Raised when a command violates read-only policy.""" + + +class SSHClient: + """Wrapper around SSH operations with read-only safeguards.""" + + _BLOCKED_TOKENS = { + ">", + ">>", + "<", + "|", + "&&", + "||", + ";", + } + _READ_ONLY_COMMANDS = { + "cat", + "dmesg", + "df", + "du", + "find", + "grep", + "head", + "hostnamectl", + "ip", + "journalctl", + "ls", + "netstat", + "sed", + "ss", + "stat", + "systemctl", + "tail", + "uname", + } + _READ_ONLY_SYSTEMCTL_SUBCOMMANDS = { + "cat", + "is-active", + "is-failed", + "list-unit-files", + "list-units", + "show", + "status", + } + + def __init__(self, config: SSHConnectionConfig) -> None: + self._config = config + + def summary(self) -> str: + """Return a short summary of connection settings.""" + mode = "ignore ssh config" if self._config.ignore_ssh_config else "use ssh config" + jump = self._config.jump_host or "none" + key = str(self._config.identity_file) if self._config.identity_file else "auto" + return ( + f"host={self._config.host} port={self._config.port} " + f"key={key} jump={jump} mode={mode}" + ) + + def build_ssh_argv(self, remote_command: str) -> list[str]: + """Build argv for a secure non-interactive SSH invocation.""" + argv = [ + "ssh", + "-p", + str(self._config.port), + "-o", + "BatchMode=yes", + "-o", + "ConnectTimeout=15", + "-o", + "StrictHostKeyChecking=accept-new", + ] + + if self._config.ignore_ssh_config: + argv += ["-F", "/dev/null"] + + if self._config.identity_file: + argv += ["-i", str(self._config.identity_file)] + + if self._config.jump_host: + argv += ["-J", self._config.jump_host] + + argv += [self._config.host, remote_command] + return argv + + def validate_read_only_command(self, command: str) -> None: + """Validate that a command appears read-only and non-destructive.""" + normalized = command.strip() + if not normalized: + raise SSHCommandRejectedError("Command cannot be empty.") + + for token in self._BLOCKED_TOKENS: + if token in normalized: + raise SSHCommandRejectedError( + f"Command contains blocked shell operator: {token}" + ) + + parts = shlex.split(normalized) + if not parts: + raise SSHCommandRejectedError("Command cannot be empty.") + + base = parts[0] + if base not in self._READ_ONLY_COMMANDS: + raise SSHCommandRejectedError( + f"Command '{base}' is not allowed by read-only policy." + ) + + if base == "systemctl": + if len(parts) < 2: + raise SSHCommandRejectedError("systemctl requires a subcommand.") + subcommand = parts[1] + if subcommand not in self._READ_ONLY_SYSTEMCTL_SUBCOMMANDS: + raise SSHCommandRejectedError( + f"systemctl subcommand '{subcommand}' is not read-only." + ) + + async def run_read_only_command( + self, + command: str, + *, + timeout_seconds: float = 30.0, + ) -> SSHCommandResult: + """Run a validated read-only command over SSH.""" + self.validate_read_only_command(command) + return await self._run_ssh(command, timeout_seconds=timeout_seconds) + + async def probe(self) -> SSHCommandResult: + """Probe connectivity using a harmless remote command.""" + return await self._run_ssh("uname -a", timeout_seconds=15.0) + + async def _run_ssh( + self, + command: str, + *, + timeout_seconds: float, + ) -> SSHCommandResult: + argv = self.build_ssh_argv(command) + proc = await asyncio.create_subprocess_exec( + *argv, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), + timeout=timeout_seconds, + ) + except TimeoutError as exc: + proc.kill() + await proc.wait() + raise TimeoutError( + f"SSH command timed out after {timeout_seconds} seconds: {command}" + ) from exc + + if proc.returncode is None: + raise RuntimeError("SSH process did not provide an exit code.") + + return SSHCommandResult( + command=command, + exit_code=proc.returncode, + stdout=stdout_bytes.decode("utf-8", errors="replace").strip(), + stderr=stderr_bytes.decode("utf-8", errors="replace").strip(), + ) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..68f013b --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,86 @@ +from typer.testing import CliRunner + +from tai.cli import app +from tai.ssh_client import SSHCommandResult + + +def test_run_command_prints_scaffold_summary() -> None: + runner = CliRunner() + result = runner.invoke( + app, + [ + "apache failed", + "--host", + "web01", + "--port", + "5566", + "--no-probe", + "--path", + "/etc/apache2", + "--jump-host", + "bastion01", + "--ignore-ssh-config", + ], + ) + + assert result.exit_code == 0 + assert "tai scaffold ready" in result.stdout + assert "host=web01" in result.stdout + assert "port=5566" in result.stdout + + +def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def] + async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def] + return SSHCommandResult( + command="uname -a", + exit_code=0, + stdout="Linux ssh 6.12.0", + stderr="", + ) + + monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "apache failed", + "--host", + "ssh.archflux.net", + "--port", + "5566", + "--probe", + ], + ) + + assert result.exit_code == 0 + assert "Probe succeeded" in result.stdout + assert "Linux ssh 6.12.0" in result.stdout + + +def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def] + async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def] + return SSHCommandResult( + command="uname -a", + exit_code=255, + stdout="", + stderr="Permission denied (publickey,password).", + ) + + monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "apache failed", + "--host", + "ssh.archflux.net", + "--port", + "5566", + "--probe", + ], + ) + + assert result.exit_code == 1 + assert "Probe failed" in result.stdout diff --git a/tests/test_input_parser.py b/tests/test_input_parser.py new file mode 100644 index 0000000..44a20d0 --- /dev/null +++ b/tests/test_input_parser.py @@ -0,0 +1,65 @@ +from pathlib import Path + +import pytest + +from tai.input_parser import InputValidationError, build_request + + +def test_build_request_normalizes_values() -> None: + req = build_request( + issue=" apache fails to start ", + host=" web01 ", + port=5566, + target_paths=["/etc/apache2", "~/logs"], + identity_file="~/.ssh/id_ed25519", + jump_host=" bastion01 ", + ignore_ssh_config=True, + ) + + assert req.issue == "apache fails to start" + assert req.host == "web01" + assert req.port == 5566 + assert req.target_paths[0] == Path("/etc/apache2") + assert req.target_paths[1] == Path("~/logs").expanduser() + assert req.identity_file == Path("~/.ssh/id_ed25519").expanduser() + assert req.jump_host == "bastion01" + assert req.ignore_ssh_config is True + + +def test_build_request_rejects_empty_issue() -> None: + with pytest.raises(InputValidationError): + build_request( + issue=" ", + host="web01", + port=22, + target_paths=[], + identity_file=None, + jump_host=None, + ignore_ssh_config=False, + ) + + +def test_build_request_rejects_empty_host() -> None: + with pytest.raises(InputValidationError): + build_request( + issue="apache down", + host=" ", + port=22, + target_paths=[], + identity_file=None, + jump_host=None, + ignore_ssh_config=False, + ) + + +def test_build_request_rejects_invalid_port() -> None: + with pytest.raises(InputValidationError): + build_request( + issue="apache down", + host="web01", + port=70000, + target_paths=[], + identity_file=None, + jump_host=None, + ignore_ssh_config=False, + ) diff --git a/tests/test_ssh_client.py b/tests/test_ssh_client.py new file mode 100644 index 0000000..37425ee --- /dev/null +++ b/tests/test_ssh_client.py @@ -0,0 +1,93 @@ +from pathlib import Path + +import pytest + +from tai.ssh_client import SSHClient, SSHCommandRejectedError, SSHConnectionConfig + + +def _client(**kwargs: object) -> SSHClient: + host = str(kwargs.get("host", "root@ssh.archflux.net")) + port_value = kwargs.get("port", 22) + if not isinstance(port_value, int): + raise TypeError("port must be an int") + port = port_value + identity_file = kwargs.get("identity_file") + jump_host = kwargs.get("jump_host") + ignore_ssh_config = bool(kwargs.get("ignore_ssh_config", False)) + + if identity_file is not None and not isinstance(identity_file, Path): + raise TypeError("identity_file must be a Path or None") + + if jump_host is not None and not isinstance(jump_host, str): + raise TypeError("jump_host must be a string or None") + + return SSHClient( + SSHConnectionConfig( + host=host, + port=port, + identity_file=identity_file, + jump_host=jump_host, + ignore_ssh_config=ignore_ssh_config, + ) + ) + + +def test_summary_includes_expected_defaults() -> None: + client = _client() + text = client.summary() + + assert "host=root@ssh.archflux.net" in text + assert "port=22" in text + assert "key=auto" in text + assert "jump=none" in text + assert "mode=use ssh config" in text + + +def test_build_ssh_argv_respects_flags() -> None: + client = _client( + identity_file=Path("/root/.ssh/id_ed25519"), + jump_host="bastion.archflux.net", + ignore_ssh_config=True, + ) + + argv = client.build_ssh_argv("uname -a") + + assert argv[0] == "ssh" + assert "-p" in argv + assert "22" in argv + assert "-F" in argv + assert "/dev/null" in argv + assert "-i" in argv + assert "/root/.ssh/id_ed25519" in argv + assert "-J" in argv + assert "bastion.archflux.net" in argv + assert argv[-2] == "root@ssh.archflux.net" + assert argv[-1] == "uname -a" + + +def test_rejects_destructive_or_shell_operator_commands() -> None: + client = _client() + + for command in ["rm -rf /tmp/x", "cat /etc/hosts | grep localhost", "uname -a; id"]: + with pytest.raises(SSHCommandRejectedError): + client.validate_read_only_command(command) + + +def test_allows_expected_read_only_commands() -> None: + client = _client() + + for command in [ + "uname -a", + "journalctl -n 100", + "systemctl status apache2", + "cat /etc/hosts", + "ss -lntp", + ]: + client.validate_read_only_command(command) + + +def test_rejects_non_read_only_systemctl_subcommand() -> None: + client = _client() + + with pytest.raises(SSHCommandRejectedError): + client.validate_read_only_command("systemctl restart apache2") From 6bcf839102e14ca4de9a0252bf3ea04edc5b79c7 Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 03:47:32 +0200 Subject: [PATCH 02/11] update pipeline Co-authored-by: Copilot --- .gitea/workflows/ci.yml | 70 ++++++++++++++++++++++++++++++++++------- CHANGELOG.md | 1 + 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index 8465b2a..7bf8d19 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -9,24 +9,72 @@ jobs: runs-on: ubuntu-latest steps: - - name: Checkout - uses: actions/checkout@v4 + - name: Ensure git is available + run: | + if command -v git >/dev/null 2>&1; then + git --version + exit 0 + fi - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' + if command -v apt-get >/dev/null 2>&1; then + apt-get update + apt-get install -y git + elif command -v dnf >/dev/null 2>&1; then + dnf install -y git + elif command -v yum >/dev/null 2>&1; then + yum install -y git + else + echo "No supported package manager found to install git." + exit 1 + fi + + git --version + + - name: Checkout source (native git) + run: | + if [ -n "${GITHUB_WORKSPACE:-}" ]; then + cd "$GITHUB_WORKSPACE" + fi + + if [ ! -d .git ]; then + git init + git remote add origin "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" + fi + + git fetch --depth 1 origin "$GITHUB_SHA" + git checkout --force FETCH_HEAD + + - name: Ensure Python and pip are available + run: | + if command -v python3 >/dev/null 2>&1 && python3 -m pip --version >/dev/null 2>&1; then + python3 --version + exit 0 + fi + + if command -v apt-get >/dev/null 2>&1; then + apt-get update + apt-get install -y python3 python3-pip python3-venv + elif command -v dnf >/dev/null 2>&1; then + dnf install -y python3 python3-pip + elif command -v yum >/dev/null 2>&1; then + yum install -y python3 python3-pip + else + echo "No supported package manager found to install Python." + exit 1 + fi + + python3 --version - name: Install package and dev dependencies run: | - python -m pip install --upgrade pip - pip install -e .[dev] + python3 -m pip install --upgrade pip + python3 -m pip install -e .[dev] - name: Lint - run: ruff check . + run: python3 -m ruff check . - name: Type-check - run: mypy src + run: python3 -m mypy src - name: Test - run: pytest + run: python3 -m pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index 3442cdc..bc49818 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - `ROADMAP.md` — phased development plan covering decisions, data collection, AI integration, CLI design, and hardening - `CHANGELOG.md` — this file; established changelog tracking for the project - `.gitea/workflows/ci.yml` — Gitea Actions CI workflow for push and pull request events +- Gitea CI now uses native `git` checkout and system Python setup to avoid host-executor JavaScript action path issues - Python package scaffold with `src` layout and project metadata in `pyproject.toml` - Initial CLI entrypoint with agreed SSH flags: `--identity-file`, `--jump-host`, and `--ignore-ssh-config` - Input parsing/validation module and core request model From 1bb7084d97d10227b22e1a2f832895466e233646 Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 03:59:22 +0200 Subject: [PATCH 05/11] update Co-authored-by: Copilot --- .gitea/workflows/ci.yml | 15 ++++++++++++++- CHANGELOG.md | 1 + 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index 7bf8d19..9903253 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -31,16 +31,29 @@ jobs: git --version - name: Checkout source (native git) + env: + CI_GIT_TOKEN: ${{ secrets.CI_GIT_TOKEN }} run: | + if [ -z "${CI_GIT_TOKEN:-}" ]; then + echo "Missing secret CI_GIT_TOKEN. Add it in repository Actions secrets." + exit 1 + fi + + auth_server="${GITHUB_SERVER_URL#https://}" + auth_server="${auth_server#http://}" + remote_url="https://oauth2:${CI_GIT_TOKEN}@${auth_server}/${GITHUB_REPOSITORY}.git" + if [ -n "${GITHUB_WORKSPACE:-}" ]; then cd "$GITHUB_WORKSPACE" fi if [ ! -d .git ]; then git init - git remote add origin "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" fi + git remote remove origin >/dev/null 2>&1 || true + git remote add origin "$remote_url" + git fetch --depth 1 origin "$GITHUB_SHA" git checkout --force FETCH_HEAD diff --git a/CHANGELOG.md b/CHANGELOG.md index bc49818..2035ff4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - `CHANGELOG.md` — this file; established changelog tracking for the project - `.gitea/workflows/ci.yml` — Gitea Actions CI workflow for push and pull request events - Gitea CI now uses native `git` checkout and system Python setup to avoid host-executor JavaScript action path issues +- Gitea native checkout now uses `CI_GIT_TOKEN` repository secret for authenticated fetch from private repos - Python package scaffold with `src` layout and project metadata in `pyproject.toml` - Initial CLI entrypoint with agreed SSH flags: `--identity-file`, `--jump-host`, and `--ignore-ssh-config` - Input parsing/validation module and core request model From 1f0286015b0a1b493fc45a2cbcbb13ae409c3fe5 Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 04:05:45 +0200 Subject: [PATCH 06/11] test again Co-authored-by: Copilot --- .gitea/workflows/ci.yml | 12 +++++++----- CHANGELOG.md | 1 + 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index 9903253..d37e702 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -80,14 +80,16 @@ jobs: - name: Install package and dev dependencies run: | - python3 -m pip install --upgrade pip - python3 -m pip install -e .[dev] + python3 -m venv .venv + . .venv/bin/activate + python -m pip install --upgrade pip + python -m pip install -e .[dev] - name: Lint - run: python3 -m ruff check . + run: .venv/bin/python -m ruff check . - name: Type-check - run: python3 -m mypy src + run: .venv/bin/python -m mypy src - name: Test - run: python3 -m pytest + run: .venv/bin/python -m pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index 2035ff4..068452c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - `.gitea/workflows/ci.yml` — Gitea Actions CI workflow for push and pull request events - Gitea CI now uses native `git` checkout and system Python setup to avoid host-executor JavaScript action path issues - Gitea native checkout now uses `CI_GIT_TOKEN` repository secret for authenticated fetch from private repos +- Gitea CI now installs dependencies in a local `.venv` to avoid Debian/PEP 668 externally-managed pip errors - Python package scaffold with `src` layout and project metadata in `pyproject.toml` - Initial CLI entrypoint with agreed SSH flags: `--identity-file`, `--jump-host`, and `--ignore-ssh-config` - Input parsing/validation module and core request model From 65c74dde5a555e44cba02f13519103954500a2cd Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 04:08:50 +0200 Subject: [PATCH 07/11] update Co-authored-by: Copilot --- .gitea/workflows/ci.yml | 6 ++++++ .github/workflows/ci.yml | 6 ++++++ .yamllint.yml | 16 ++++++++++++++++ CHANGELOG.md | 5 ++++- README.md | 6 +++--- ROADMAP.md | 18 +++++++++++------- pyproject.toml | 2 ++ 7 files changed, 48 insertions(+), 11 deletions(-) create mode 100644 .yamllint.yml diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index d37e702..5011c6c 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -88,6 +88,12 @@ jobs: - name: Lint run: .venv/bin/python -m ruff check . + - name: Lint Markdown + run: .venv/bin/mdformat --check README.md ROADMAP.md CHANGELOG.md + + - name: Lint YAML + run: .venv/bin/yamllint . + - name: Type-check run: .venv/bin/python -m mypy src diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8465b2a..08ca86d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,6 +25,12 @@ jobs: - name: Lint run: ruff check . + - name: Lint Markdown + run: mdformat --check README.md ROADMAP.md CHANGELOG.md + + - name: Lint YAML + run: yamllint . + - name: Type-check run: mypy src diff --git a/.yamllint.yml b/.yamllint.yml new file mode 100644 index 0000000..cb10dc2 --- /dev/null +++ b/.yamllint.yml @@ -0,0 +1,16 @@ +extends: default + +ignore: | + .git/ + .venv/ + .mypy_cache/ + .pytest_cache/ + .ruff_cache/ + +rules: + document-start: disable + line-length: + max: 120 + truthy: + allowed-values: ["true", "false"] + check-keys: false diff --git a/CHANGELOG.md b/CHANGELOG.md index 068452c..4b57691 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ---- +______________________________________________________________________ ## [Unreleased] ### Added + - `README.md` — project overview, description, example workflow, supported distributions, and suggested tooling - `ROADMAP.md` — phased development plan covering decisions, data collection, AI integration, CLI design, and hardening - `CHANGELOG.md` — this file; established changelog tracking for the project @@ -27,8 +28,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Test scaffold (`pytest`) with initial parser and CLI coverage - SSH test coverage for policy checks, SSH argument construction, and config summary behavior - CI workflow for lint (`ruff`), type-check (`mypy`), and tests (`pytest`) +- CI coverage expanded with Markdown formatting checks (`mdformat --check`) and YAML linting (`yamllint`) ### Decided + - Implementation language: **Python** - Distribution strategy: single distributable binary via **Nuitka** (PyInstaller as fallback) - SSH authentication: **keypair only** (ed25519/RSA); auto-accept new hosts; hard reject on host key change with MITM warning diff --git a/README.md b/README.md index 296c068..74ed425 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,8 @@ The agent operates in **read-only mode at all times**. It will never modify the A troubleshooter receives a ticket reporting that the Apache service on a remote server has failed to start. They provide `tai` with: 1. The ticket description or error message -2. The hostname of the affected system -3. Any relevant directories to focus on +1. The hostname of the affected system +1. Any relevant directories to focus on `tai` then connects to the host, reads through system logs, service configurations, and any other related files, and returns a structured analysis of the likely cause along with recommended next steps. @@ -32,4 +32,4 @@ A troubleshooter receives a ticket reporting that the Apache service on a remote | AI inference backend | [vLLM](https://github.com/vllm-project/vllm) | | Model | `gemma4:a4b` | -> **Note:** A suitable implementation language for this project is yet to be determined. \ No newline at end of file +> **Note:** A suitable implementation language for this project is yet to be determined. diff --git a/ROADMAP.md b/ROADMAP.md index 235c6a4..45dbc1d 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -2,13 +2,14 @@ This document outlines the major decisions, milestones, and development phases required to bring `tai` from concept to a working tool. ---- +______________________________________________________________________ ## Phase 0 — Decisions & Prerequisites These must be resolved before meaningful development can begin. ### Language Selection + - [x] **Decision: Python** - Key factors: native vLLM integration, mature SSH libraries (`paramiko` / `asyncssh`), strong text/log parsing, rapid development - Single binary distribution will be achieved via **Nuitka** (preferred for true compilation) or **PyInstaller** as a fallback @@ -16,12 +17,14 @@ These must be resolved before meaningful development can begin. - [ ] Add binary build step to CI pipeline ### AI Backend & Model + - [ ] Confirm use of [vLLM](https://github.com/vllm-project/vllm) as the inference backend - [ ] Confirm `gemma4:a4b` as the default model (or select an alternative) - [ ] Define minimum hardware requirements for running the model locally - [ ] Decide whether the AI backend is bundled, self-hosted externally, or user-supplied ### SSH Strategy + - [x] **Decision: keypair authentication only** — no password auth; eliminates credential storage risk - Default key resolution: `~/.ssh/id_ed25519`, `~/.ssh/id_rsa` (in order of preference) - CLI override via `--identity-file ` @@ -33,6 +36,7 @@ These must be resolved before meaningful development can begin. - Override switch: `--ignore-ssh-config` to bypass local SSH config when required ### Scope & Constraints + - [ ] Define the supported scope of issues (services, network, disk, kernel, etc.) - [ ] Confirm read-only guarantee — document exactly what "read-only" means in practice - [x] **Decision: interactive REPL mode for v0.1, full TUI for v0.2+** @@ -40,7 +44,7 @@ These must be resolved before meaningful development can begin. - v0.2+: `textual`-based TUI with split panes (collected data | AI output | input bar) - Built-in slash commands: `/collect`, `/show logs`, `/clear`, `/host `, `/help`, `/quit` ---- +______________________________________________________________________ ## Phase 1 — Project Foundation @@ -58,7 +62,7 @@ Basic project scaffolding and connectivity. - [x] Input parser and CLI tests added - [x] SSH module tests added for command policy and SSH argv behavior ---- +______________________________________________________________________ ## Phase 2 — Data Collection Layer @@ -74,7 +78,7 @@ Define what information the agent gathers and how. - [ ] Add support for per-distro variations (Ubuntu vs RHEL path differences, etc.) - [ ] Write tests with mocked SSH output ---- +______________________________________________________________________ ## Phase 3 — AI Integration @@ -87,7 +91,7 @@ Wire collected data into the local AI model. - [ ] Add streaming support for long AI responses - [ ] Evaluate and test model output quality on common issue types ---- +______________________________________________________________________ ## Phase 4 — CLI & User Experience @@ -99,7 +103,7 @@ Polish the interface for real-world use. - [ ] Support output to file or clipboard - [ ] Write man page / `--help` documentation ---- +______________________________________________________________________ ## Phase 5 — Hardening & Distribution @@ -111,7 +115,7 @@ Prepare for broader use. - [ ] Write installation and quickstart documentation - [ ] End-to-end integration tests against a test VM ---- +______________________________________________________________________ ## Decisions Log diff --git a/pyproject.toml b/pyproject.toml index 2e8d855..3c80449 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ dev = [ "pytest>=8.2,<9.0", "ruff>=0.5,<1.0", "mypy>=1.10,<2.0", + "mdformat>=0.7,<1.0", + "yamllint>=1.35,<2.0", ] build = [ "nuitka>=2.4,<3.0", From e589240c67561397fdb8eecafd5c5605018e35f6 Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 04:22:58 +0200 Subject: [PATCH 08/11] update Co-authored-by: Copilot --- .github/workflows/ci.yml | 38 ------ CHANGELOG.md | 6 + ROADMAP.md | 2 +- src/tai/cli.py | 46 +++++++- src/tai/collectors.py | 50 ++++++++ src/tai/plan.py | 244 +++++++++++++++++++++++++++++++++++++++ src/tai/ssh_client.py | 48 +++++++- tests/test_cli.py | 50 ++++++++ tests/test_plan.py | 169 +++++++++++++++++++++++++++ tests/test_ssh_client.py | 15 +++ 10 files changed, 624 insertions(+), 44 deletions(-) delete mode 100644 .github/workflows/ci.yml create mode 100644 src/tai/collectors.py create mode 100644 src/tai/plan.py create mode 100644 tests/test_plan.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 08ca86d..0000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: CI - -on: - push: - pull_request: - -jobs: - test: - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - - name: Install package and dev dependencies - run: | - python -m pip install --upgrade pip - pip install -e .[dev] - - - name: Lint - run: ruff check . - - - name: Lint Markdown - run: mdformat --check README.md ROADMAP.md CHANGELOG.md - - - name: Lint YAML - run: yamllint . - - - name: Type-check - run: mypy src - - - name: Test - run: pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b57691..cfd146d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,12 +24,18 @@ ______________________________________________________________________ - Implemented SSH module with real key-based command execution via system `ssh` - Added explicit SSH port support across CLI, input parsing, request model, and SSH client (`--port`, e.g. 5566) - Added live SSH connectivity probe (`uname -a`) enabled by default, with `--no-probe` opt-out and non-zero exit on failure +- Added baseline diagnostics collection via `--collect`, including service, journal, disk, and network checks - Read-only command policy enforcement (allowlist + blocked shell operators) +- Added byte-limited SSH output capture with truncation markers for large command output - Test scaffold (`pytest`) with initial parser and CLI coverage - SSH test coverage for policy checks, SSH argument construction, and config summary behavior - CI workflow for lint (`ruff`), type-check (`mypy`), and tests (`pytest`) - CI coverage expanded with Markdown formatting checks (`mdformat --check`) and YAML linting (`yamllint`) +### Removed + +- `.github/workflows/ci.yml` — GitHub Actions workflow removed; CI is now Gitea-only + ### Decided - Implementation language: **Python** diff --git a/ROADMAP.md b/ROADMAP.md index 45dbc1d..c20e242 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -56,7 +56,7 @@ Basic project scaffolding and connectivity. - [x] Define SSH config model and probe interface scaffold - [x] Connect to remote host - [x] Execute read-only commands (e.g. `journalctl`, `systemctl status`, `cat`) - - [ ] Stream or collect command output safely + - [x] Stream or collect command output safely (byte-limited output with truncation marker) - [x] Implement basic input parsing (ticket text, hostname, target directories) - [x] Write unit tests for SSH and input modules - [x] Input parser and CLI tests added diff --git a/src/tai/cli.py b/src/tai/cli.py index 5a3e269..aa5baa3 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -8,7 +8,10 @@ from typing import Annotated import typer from rich.console import Console +from tai.collectors import CollectionReport, collect_from_plan from tai.input_parser import InputValidationError, build_request +from tai.models import TroubleshootRequest +from tai.plan import plan_from_request from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig app = typer.Typer(no_args_is_help=True, add_completion=False) @@ -46,6 +49,13 @@ def run( help="Enable or disable live SSH connectivity probe (uname -a).", ), ] = True, + collect: Annotated[ + bool, + typer.Option( + "--collect/--no-collect", + help="Collect baseline diagnostics after probe.", + ), + ] = False, ) -> None: """Start an interactive troubleshooting session scaffold.""" try: @@ -77,8 +87,13 @@ def run( if req.target_paths: console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}") + client = SSHClient(config) + if probe: - _run_probe(SSHClient(config)) + _run_probe(client) + + if collect: + _run_collection(client, req) def _run_probe(client: SSHClient) -> None: @@ -96,6 +111,35 @@ def _run_probe(client: SSHClient) -> None: _handle_probe_result(result) +def _run_collection(client: SSHClient, request: TroubleshootRequest) -> None: + """Run issue-aware collection and print a compact summary.""" + plan = plan_from_request(request) + console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") + try: + report = asyncio.run(collect_from_plan(client, plan)) + except TimeoutError as exc: + console.print(f"[red]Collection failed:[/red] {exc}") + raise typer.Exit(code=1) from exc + except OSError as exc: + console.print(f"[red]Collection failed:[/red] unable to execute ssh: {exc}") + raise typer.Exit(code=1) from exc + + _handle_collection_report(report) + + +def _handle_collection_report(report: CollectionReport) -> None: + """Render collected command status and truncation hints.""" + console.print( + f"[bold]Collection complete:[/bold] {report.total} commands, {report.failed} failed" + ) + for item in report.items: + status = "ok" if item.result.exit_code == 0 else f"exit {item.result.exit_code}" + trunc = "" + if item.result.stdout_truncated or item.result.stderr_truncated: + trunc = " (truncated)" + console.print(f"- {item.name}: {status}{trunc}") + + def _handle_probe_result(result: SSHCommandResult) -> None: """Handle and render probe output for success or failure.""" if result.exit_code != 0: diff --git a/src/tai/collectors.py b/src/tai/collectors.py new file mode 100644 index 0000000..9ad4754 --- /dev/null +++ b/src/tai/collectors.py @@ -0,0 +1,50 @@ +"""Data collection routines built on top of the SSH client.""" + +from dataclasses import dataclass + +from tai.plan import CollectionPlan +from tai.ssh_client import SSHClient, SSHCommandResult + + +@dataclass(slots=True) +class CollectedItem: + """Single collected diagnostic command result.""" + + name: str + result: SSHCommandResult + + +@dataclass(slots=True) +class CollectionReport: + """Collection summary for a batch of diagnostics.""" + + host: str + items: list[CollectedItem] + + @property + def total(self) -> int: + return len(self.items) + + @property + def failed(self) -> int: + return sum(1 for item in self.items if item.result.exit_code != 0) + + +async def collect_from_plan( + client: SSHClient, + plan: CollectionPlan, + *, + max_output_bytes: int = 32768, +) -> CollectionReport: + """Execute all commands in *plan* and return a :class:`CollectionReport`.""" + items: list[CollectedItem] = [] + + for name, command in plan.commands: + result = await client.run_read_only_command( + command, + timeout_seconds=30.0, + max_output_bytes=max_output_bytes, + ) + items.append(CollectedItem(name=name, result=result)) + + return CollectionReport(host=client.summary(), items=items) diff --git a/src/tai/plan.py b/src/tai/plan.py new file mode 100644 index 0000000..e3c76a8 --- /dev/null +++ b/src/tai/plan.py @@ -0,0 +1,244 @@ +"""Collection plan builder — decides what to collect based on the issue.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field + +from tai.models import TroubleshootRequest + +# --------------------------------------------------------------------------- +# Keyword sets for issue classification +# --------------------------------------------------------------------------- + +_SERVICE_KEYWORDS: frozenset[str] = frozenset( + { + "service", + "unit", + "daemon", + "failed", + "dead", + "inactive", + "crash", + "crashed", + "start", + "stop", + "restart", + "status", + "systemd", + "systemctl", + } +) + +_NETWORK_KEYWORDS: frozenset[str] = frozenset( + { + "network", + "port", + "connect", + "connection", + "listen", + "firewall", + "route", + "routing", + "interface", + "dns", + "http", + "https", + "tcp", + "udp", + "socket", + "unreachable", + "refused", + "timeout", + "latency", + "bandwidth", + "packet", + } +) + +_DISK_KEYWORDS: frozenset[str] = frozenset( + { + "disk", + "space", + "storage", + "inode", + "full", + "mount", + "filesystem", + "partition", + "quota", + "usage", + "capacity", + } +) + +# --------------------------------------------------------------------------- +# Known service names and their candidate config paths +# --------------------------------------------------------------------------- + +_KNOWN_SERVICES: list[str] = [ + "apache2", + "httpd", + "nginx", + "mysql", + "mysqld", + "mariadb", + "postgresql", + "redis", + "redis-server", + "mongodb", + "mongod", + "docker", + "containerd", + "kubelet", + "sshd", + "postfix", + "dovecot", + "sendmail", + "php-fpm", + "elasticsearch", + "rabbitmq", + "rabbitmq-server", + "celery", + "gunicorn", + "ufw", + "fail2ban", + "cron", + "crond", + "rsyslog", + "auditd", + "firewalld", + "haproxy", + "varnish", + "memcached", +] + +_SERVICE_CONFIGS: dict[str, list[str]] = { + "apache2": ["/etc/apache2/apache2.conf"], + "httpd": ["/etc/httpd/conf/httpd.conf"], + "nginx": ["/etc/nginx/nginx.conf"], + "mysql": ["/etc/mysql/mysql.conf.d/mysqld.cnf"], + "mysqld": ["/etc/my.cnf"], + "mariadb": ["/etc/mysql/mariadb.conf.d/50-server.cnf"], + "postgresql": ["/etc/postgresql"], + "sshd": ["/etc/ssh/sshd_config"], + "postfix": ["/etc/postfix/main.cf"], + "haproxy": ["/etc/haproxy/haproxy.cfg"], + "redis": ["/etc/redis/redis.conf"], + "redis-server": ["/etc/redis/redis.conf"], + "fail2ban": ["/etc/fail2ban/jail.conf"], + "ufw": ["/etc/ufw/ufw.conf"], +} + +# --------------------------------------------------------------------------- +# Command sets +# --------------------------------------------------------------------------- + +_ALWAYS: list[tuple[str, str]] = [ + ("kernel", "uname -a"), + ("uptime", "cat /proc/uptime"), + ("disk-usage", "df -h"), + ("memory", "cat /proc/meminfo"), + ("running-services", "systemctl list-units --type=service --state=running --no-pager"), +] + +_SERVICE_EXTRA: list[tuple[str, str]] = [ + ("failed-services", "systemctl list-units --type=service --state=failed --no-pager"), + ("journal-errors", "journalctl -p err -n 100 --no-pager"), +] + +_NETWORK_EXTRA: list[tuple[str, str]] = [ + ("listening-ports", "ss -lntp"), + ("ip-addresses", "ip addr show"), + ("ip-routes", "ip route show"), + ("ip-stats", "ip -s link show"), +] + +_DISK_EXTRA: list[tuple[str, str]] = [ + ("disk-inodes", "df -i"), + ("dmesg-disk", "dmesg -T --level=err,warn"), + ("large-dirs", "du -sh /var /tmp /home /opt"), +] + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +@dataclass(slots=True) +class CollectionPlan: + """Ordered list of (name, command) pairs to execute on a remote host.""" + + commands: list[tuple[str, str]] = field(default_factory=list) + + def add(self, name: str, command: str) -> None: + self.commands.append((name, command)) + + def __len__(self) -> int: + return len(self.commands) + + +def plan_from_request(request: TroubleshootRequest) -> CollectionPlan: + """Build a :class:`CollectionPlan` tailored to *request*.""" + plan = CollectionPlan(commands=list(_ALWAYS)) + keywords = _issue_words(request.issue) + + # --- category expansions ------------------------------------------- + if keywords & _SERVICE_KEYWORDS: + plan.commands.extend(_SERVICE_EXTRA) + + if keywords & _NETWORK_KEYWORDS: + plan.commands.extend(_NETWORK_EXTRA) + + if keywords & _DISK_KEYWORDS: + plan.commands.extend(_DISK_EXTRA) + + # --- named service detection --------------------------------------- + services = _extract_services(request.issue) + seen: set[str] = set() + for svc in services: + if svc in seen: + continue + seen.add(svc) + plan.add(f"service-{svc}", f"systemctl status {svc}") + plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager") + for cfg_path in _SERVICE_CONFIGS.get(svc, []): + plan.add(f"config-{svc}", f"cat {cfg_path}") + + # --- user-specified paths ----------------------------------------- + for path in request.target_paths: + plan.add(f"ls-{path.name}", f"ls -la {path}") + if "log" in str(path).lower(): + plan.add( + f"find-logs-{path.name}", + f"find {path} -maxdepth 2 -type f -name '*.log'", + ) + else: + plan.add( + f"find-files-{path.name}", + f"find {path} -maxdepth 2 -type f", + ) + + return plan + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _issue_words(issue: str) -> set[str]: + """Return the set of lowercase words in *issue*.""" + return set(re.findall(r"\b\w+\b", issue.lower())) + + +def _extract_services(issue: str) -> list[str]: + """Return known service names mentioned in *issue*.""" + words = _issue_words(issue) + found: list[str] = [] + for svc in _KNOWN_SERVICES: + # Match the service name or its stem (strip trailing 'd', e.g. 'apache' → 'apache2') + svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")} + if words & svc_words: + found.append(svc) + return found diff --git a/src/tai/ssh_client.py b/src/tai/ssh_client.py index c690b6e..4f7164a 100644 --- a/src/tai/ssh_client.py +++ b/src/tai/ssh_client.py @@ -25,6 +25,8 @@ class SSHCommandResult: exit_code: int stdout: str stderr: str + stdout_truncated: bool = False + stderr_truncated: bool = False class SSHCommandRejectedError(ValueError): @@ -148,20 +150,30 @@ class SSHClient: command: str, *, timeout_seconds: float = 30.0, + max_output_bytes: int = 32768, ) -> SSHCommandResult: """Run a validated read-only command over SSH.""" self.validate_read_only_command(command) - return await self._run_ssh(command, timeout_seconds=timeout_seconds) + return await self._run_ssh( + command, + timeout_seconds=timeout_seconds, + max_output_bytes=max_output_bytes, + ) async def probe(self) -> SSHCommandResult: """Probe connectivity using a harmless remote command.""" - return await self._run_ssh("uname -a", timeout_seconds=15.0) + return await self._run_ssh( + "uname -a", + timeout_seconds=15.0, + max_output_bytes=4096, + ) async def _run_ssh( self, command: str, *, timeout_seconds: float, + max_output_bytes: int, ) -> SSHCommandResult: argv = self.build_ssh_argv(command) proc = await asyncio.create_subprocess_exec( @@ -185,9 +197,37 @@ class SSHClient: if proc.returncode is None: raise RuntimeError("SSH process did not provide an exit code.") + stdout_text, stdout_truncated = self._truncate_output( + stdout_bytes.decode("utf-8", errors="replace"), + max_output_bytes=max_output_bytes, + ) + stderr_text, stderr_truncated = self._truncate_output( + stderr_bytes.decode("utf-8", errors="replace"), + max_output_bytes=max_output_bytes, + ) + return SSHCommandResult( command=command, exit_code=proc.returncode, - stdout=stdout_bytes.decode("utf-8", errors="replace").strip(), - stderr=stderr_bytes.decode("utf-8", errors="replace").strip(), + stdout=stdout_text, + stderr=stderr_text, + stdout_truncated=stdout_truncated, + stderr_truncated=stderr_truncated, ) + + @staticmethod + def _truncate_output(text: str, *, max_output_bytes: int) -> tuple[str, bool]: + """Trim output to a maximum byte length while preserving UTF-8 validity.""" + if max_output_bytes < 256: + raise ValueError("max_output_bytes must be at least 256.") + + encoded = text.encode("utf-8", errors="replace") + if len(encoded) <= max_output_bytes: + return text.strip(), False + + marker = "\n...[truncated]" + marker_bytes = marker.encode("utf-8") + keep = max_output_bytes - len(marker_bytes) + trimmed_bytes = encoded[:keep] + trimmed_text = trimmed_bytes.decode("utf-8", errors="ignore").rstrip() + return f"{trimmed_text}{marker}", True diff --git a/tests/test_cli.py b/tests/test_cli.py index 68f013b..9bac274 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,7 @@ from typer.testing import CliRunner from tai.cli import app +from tai.collectors import CollectedItem, CollectionReport from tai.ssh_client import SSHCommandResult @@ -84,3 +85,52 @@ def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no assert result.exit_code == 1 assert "Probe failed" in result.stdout + + +def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def] + async def fake_collect_from_plan(_client, _plan) -> CollectionReport: # type: ignore[no-untyped-def] + return CollectionReport( + host="ssh.archflux.net", + items=[ + CollectedItem( + name="kernel", + result=SSHCommandResult( + command="uname -a", + exit_code=0, + stdout="Linux test", + stderr="", + ), + ), + CollectedItem( + name="journal", + result=SSHCommandResult( + command="journalctl -n 200", + exit_code=0, + stdout="...", + stderr="", + stdout_truncated=True, + ), + ), + ], + ) + + monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "apache failed", + "--host", + "ssh.archflux.net", + "--port", + "5566", + "--no-probe", + "--collect", + ], + ) + + assert result.exit_code == 0 + assert "Collection complete" in result.stdout + assert "kernel: ok" in result.stdout + assert "journal: ok (truncated)" in result.stdout diff --git a/tests/test_plan.py b/tests/test_plan.py new file mode 100644 index 0000000..94016e2 --- /dev/null +++ b/tests/test_plan.py @@ -0,0 +1,169 @@ +"""Tests for the collection plan builder.""" + +from pathlib import Path + +from tai.models import TroubleshootRequest +from tai.plan import CollectionPlan, _extract_services, _issue_words, plan_from_request + + +def _req(issue: str, paths: list[str] | None = None) -> TroubleshootRequest: + return TroubleshootRequest( + issue=issue, + host="root@testhost", + target_paths=[Path(p) for p in (paths or [])], + ) + + +def _commands(plan: CollectionPlan) -> list[str]: + """Return flat list of command strings from plan.""" + return [cmd for _, cmd in plan.commands] + + +def _names(plan: CollectionPlan) -> list[str]: + return [name for name, _ in plan.commands] + + +# --------------------------------------------------------------------------- +# Always-present commands +# --------------------------------------------------------------------------- + + +def test_plan_always_has_baseline_commands() -> None: + plan = plan_from_request(_req("some generic issue")) + cmds = _commands(plan) + assert any("uname -a" in c for c in cmds) + assert any("df -h" in c for c in cmds) + assert any("proc/meminfo" in c for c in cmds) + assert any("systemctl list-units" in c for c in cmds) + + +# --------------------------------------------------------------------------- +# Keyword-based category expansion +# --------------------------------------------------------------------------- + + +def test_service_keywords_add_failed_services_check() -> None: + plan = plan_from_request(_req("service failed to start")) + cmds = _commands(plan) + assert any("--state=failed" in c for c in cmds) + assert any("journalctl -p err" in c for c in cmds) + + +def test_network_keywords_add_network_commands() -> None: + plan = plan_from_request(_req("connection refused on port 80")) + cmds = _commands(plan) + assert any("ss -lntp" in c for c in cmds) + assert any("ip addr show" in c for c in cmds) + assert any("ip route show" in c for c in cmds) + + +def test_disk_keywords_add_disk_commands() -> None: + plan = plan_from_request(_req("disk full filesystem usage critical")) + cmds = _commands(plan) + assert any("df -i" in c for c in cmds) + assert any("dmesg" in c for c in cmds) + assert any("du -sh" in c for c in cmds) + + +def test_unrelated_issue_does_not_add_network_commands() -> None: + plan = plan_from_request(_req("apache service crashed")) + cmds = _commands(plan) + assert not any("ip route show" in c for c in cmds) + + +# --------------------------------------------------------------------------- +# Named service detection +# --------------------------------------------------------------------------- + + +def test_nginx_in_issue_adds_nginx_service_commands() -> None: + plan = plan_from_request(_req("nginx is failing to start")) + names = _names(plan) + cmds = _commands(plan) + assert "service-nginx" in names + assert "journal-nginx" in names + assert any("systemctl status nginx" in c for c in cmds) + assert any("journalctl -u nginx" in c for c in cmds) + + +def test_apache2_adds_config_cat() -> None: + plan = plan_from_request(_req("apache2 service check")) + cmds = _commands(plan) + assert any("cat /etc/apache2/apache2.conf" in c for c in cmds) + + +def test_sshd_adds_config_cat() -> None: + plan = plan_from_request(_req("sshd connection problems")) + cmds = _commands(plan) + assert any("cat /etc/ssh/sshd_config" in c for c in cmds) + + +def test_unknown_service_name_no_config_cat() -> None: + plan = plan_from_request(_req("myweirdapp service crashed")) + cmds = _commands(plan) + assert not any("cat /etc" in c for c in cmds) + + +def test_duplicate_service_name_not_repeated() -> None: + plan = plan_from_request(_req("nginx nginx nginx")) + names = _names(plan) + assert names.count("service-nginx") == 1 + + +# --------------------------------------------------------------------------- +# Target path handling +# --------------------------------------------------------------------------- + + +def test_target_path_adds_ls_and_find() -> None: + plan = plan_from_request(_req("app crash", paths=["/opt/myapp"])) + cmds = _commands(plan) + assert any("ls -la /opt/myapp" in c for c in cmds) + assert any("find /opt/myapp" in c for c in cmds) + + +def test_log_path_uses_log_find_pattern() -> None: + plan = plan_from_request(_req("app errors", paths=["/var/log/myapp"])) + cmds = _commands(plan) + assert any("*.log" in c for c in cmds) + + +def test_non_log_path_uses_generic_find() -> None: + plan = plan_from_request(_req("config issue", paths=["/etc/myapp"])) + cmds = _commands(plan) + assert any("find /etc/myapp" in c and "*.log" not in c for c in cmds) + + +# --------------------------------------------------------------------------- +# Helper unit tests +# --------------------------------------------------------------------------- + + +def test_issue_words_lowercases_and_splits() -> None: + words = _issue_words("Apache Service FAILED") + assert "apache" in words + assert "service" in words + assert "failed" in words + + +def test_extract_services_finds_nginx() -> None: + assert "nginx" in _extract_services("nginx is down") + + +def test_extract_services_finds_nothing_for_unknown() -> None: + assert _extract_services("the widget is broken") == [] + + +def test_extract_services_case_insensitive() -> None: + assert "nginx" in _extract_services("NGINX failed") + + +# --------------------------------------------------------------------------- +# Plan length sanity +# --------------------------------------------------------------------------- + + +def test_plain_issue_has_only_always_commands() -> None: + plan = plan_from_request(_req("something went wrong")) + # Only _ALWAYS (5 commands), no category expansion, no service, no paths + assert len(plan) == 5 diff --git a/tests/test_ssh_client.py b/tests/test_ssh_client.py index 37425ee..fcad417 100644 --- a/tests/test_ssh_client.py +++ b/tests/test_ssh_client.py @@ -91,3 +91,18 @@ def test_rejects_non_read_only_systemctl_subcommand() -> None: with pytest.raises(SSHCommandRejectedError): client.validate_read_only_command("systemctl restart apache2") + + +def test_truncate_output_marks_and_limits_content() -> None: + text = "a" * 400 + rendered, truncated = SSHClient._truncate_output(text, max_output_bytes=256) + + assert truncated is True + assert rendered.endswith("...[truncated]") + + +def test_truncate_output_keeps_short_content() -> None: + rendered, truncated = SSHClient._truncate_output("short output", max_output_bytes=256) + + assert truncated is False + assert rendered == "short output" From e6233f237be75789f184463aadbbad541005775c Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 04:30:14 +0200 Subject: [PATCH 09/11] update Co-authored-by: Copilot --- README.md | 64 +++++++++++++++++++++++++++++++++++++++++++++++--- pyproject.toml | 1 + 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 74ed425..53e77c5 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,65 @@ A troubleshooter receives a ticket reporting that the Apache service on a remote | Component | Tool | |-----------|------| -| AI inference backend | [vLLM](https://github.com/vllm-project/vllm) | -| Model | `gemma4:a4b` | +| AI inference backend | [Ollama](https://ollama.com) | +| Model | `gemma3:4b`, `llama3.1:8b`, or `qwen2.5:7b` | +| Language | Python 3.11+ | -> **Note:** A suitable implementation language for this project is yet to be determined. +--- + +## How-To: Setting Up the AI Backend (Arch Linux + RTX 3080) + +`tai` uses [Ollama](https://ollama.com) as its local AI backend. It exposes an OpenAI-compatible HTTP API that `tai` talks to — no cloud services, no data leaving your machine. + +An RTX 3080 (10 GB VRAM) comfortably runs 7–8B parameter models at 4-bit quantisation. + +### 1. Install CUDA and Ollama + +```bash +# CUDA runtime (skip if already installed) +sudo pacman -S cuda + +# Ollama with CUDA support from the AUR +yay -S ollama-cuda +# or: paru -S ollama-cuda + +# Enable and start the service +sudo systemctl enable --now ollama +``` + +### 2. Pull a model + +```bash +ollama pull gemma3:4b # ~3 GB — fast, good for sysadmin tasks +ollama pull llama3.1:8b # ~5 GB — stronger reasoning +ollama pull qwen2.5:7b # ~4.5 GB — strong structured output +``` + +### 3. Verify the model works + +```bash +ollama run gemma3:4b "what causes a systemd service to enter failed state?" +``` + +### 4. Verify the HTTP API is running + +`tai` communicates with Ollama over its OpenAI-compatible REST API: + +```bash +curl http://localhost:11434/api/generate \ + -d '{"model":"gemma3:4b","prompt":"hello","stream":false}' +``` + +A JSON response with a `response` field confirms everything is working. + +### 5. Point tai at your Ollama instance + +Once `tai` AI integration is complete, use these flags: + +```bash +tai "nginx failing to start" --host web01 \ + --ai-host http://localhost:11434 \ + --model gemma3:4b +``` + +The default values for `--ai-host` and `--model` will be `http://localhost:11434` and `gemma3:4b` respectively, so for local use you won't need to specify them explicitly. diff --git a/pyproject.toml b/pyproject.toml index 3c80449..348fc5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "typer>=0.12,<1.0", "rich>=13.7,<14.0", "asyncssh>=2.14,<3.0", + "openai>=1.30,<2.0", ] [project.optional-dependencies] From 61d3e2c4e6a1330e3c72413bc5de1b8ba0a2d442 Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 04:51:48 +0200 Subject: [PATCH 10/11] update Co-authored-by: Copilot --- .gitea/workflows/release.yml | 110 ++++++++++++++++++++ src/tai/ai_client.py | 93 +++++++++++++++++ src/tai/cli.py | 138 ++++++++++++++++--------- src/tai/collectors.py | 10 +- src/tai/prompt_builder.py | 74 ++++++++++++++ src/tai/ssh_client.py | 160 ++++++++++++++++++++++++++++- tests/test_ai.py | 192 +++++++++++++++++++++++++++++++++++ tests/test_cli.py | 69 +++++++------ 8 files changed, 757 insertions(+), 89 deletions(-) create mode 100644 .gitea/workflows/release.yml create mode 100644 src/tai/ai_client.py create mode 100644 src/tai/prompt_builder.py create mode 100644 tests/test_ai.py diff --git a/.gitea/workflows/release.yml b/.gitea/workflows/release.yml new file mode 100644 index 0000000..ad24b64 --- /dev/null +++ b/.gitea/workflows/release.yml @@ -0,0 +1,110 @@ +name: Release + +on: + push: + tags: + - "v*" + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Ensure git is available + run: | + if command -v git >/dev/null 2>&1; then + git --version + exit 0 + fi + + if command -v apt-get >/dev/null 2>&1; then + apt-get update + apt-get install -y git + elif command -v dnf >/dev/null 2>&1; then + dnf install -y git + elif command -v yum >/dev/null 2>&1; then + yum install -y git + else + echo "No supported package manager found to install git." + exit 1 + fi + + - name: Checkout source (native git) + env: + CI_GIT_TOKEN: ${{ secrets.CI_GIT_TOKEN }} + run: | + if [ -z "${CI_GIT_TOKEN:-}" ]; then + echo "Missing secret CI_GIT_TOKEN. Add it in repository Actions secrets." + exit 1 + fi + + auth_server="${GITHUB_SERVER_URL#https://}" + auth_server="${auth_server#http://}" + remote_url="https://oauth2:${CI_GIT_TOKEN}@${auth_server}/${GITHUB_REPOSITORY}.git" + + if [ -n "${GITHUB_WORKSPACE:-}" ]; then + cd "$GITHUB_WORKSPACE" + fi + + if [ ! -d .git ]; then + git init + fi + + git remote remove origin >/dev/null 2>&1 || true + git remote add origin "$remote_url" + + # Fetch the tag by SHA so we get the exact tagged commit + git fetch --depth 1 origin "$GITHUB_SHA" + git checkout --force FETCH_HEAD + + - name: Ensure Python and build dependencies are available + run: | + if ! command -v python3 >/dev/null 2>&1; then + if command -v apt-get >/dev/null 2>&1; then + apt-get update + apt-get install -y python3 python3-pip python3-venv patchelf ccache + elif command -v dnf >/dev/null 2>&1; then + dnf install -y python3 python3-pip patchelf ccache + fi + fi + + # patchelf is required by Nuitka for standalone Linux binaries + command -v patchelf >/dev/null 2>&1 || { + apt-get update && apt-get install -y patchelf + } + + python3 --version + + - name: Set up venv and install package + build deps + run: | + python3 -m venv .venv + . .venv/bin/activate + python -m pip install --upgrade pip + python -m pip install -e ".[build]" + + - name: Derive version from tag + id: version + run: echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT" + + - name: Build standalone binary with Nuitka + run: | + . .venv/bin/activate + python -m nuitka \ + --standalone \ + --onefile \ + --output-filename=tai \ + --output-dir=dist \ + --assume-yes-for-downloads \ + --include-package=tai \ + src/tai/cli.py + + - name: Smoke-test the binary + run: dist/tai --help + + - name: Upload binary artifact + uses: actions/upload-artifact@v3 + with: + name: tai-linux-amd64-${{ steps.version.outputs.tag }} + path: dist/tai + if-no-files-found: error + retention-days: 90 diff --git a/src/tai/ai_client.py b/src/tai/ai_client.py new file mode 100644 index 0000000..c80103e --- /dev/null +++ b/src/tai/ai_client.py @@ -0,0 +1,93 @@ +"""AI backend client — OpenAI-compatible, works with Ollama, OpenAI, or any compatible endpoint.""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass, field + +from openai import OpenAI + +DEFAULT_AI_HOST = "http://localhost:11434/v1" +DEFAULT_MODEL = "gemma3:4b" + + +@dataclass(slots=True) +class AIConfig: + """Connection parameters for an OpenAI-compatible AI backend.""" + + host: str = DEFAULT_AI_HOST + model: str = DEFAULT_MODEL + api_key: str = "ollama" # Ollama ignores this; required by the openai client + timeout_seconds: float = 120.0 + max_tokens: int = 4096 + extra_headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(slots=True) +class AIResponse: + """Structured response from an AI completion.""" + + model: str + content: str + prompt_tokens: int + completion_tokens: int + + @property + def total_tokens(self) -> int: + return self.prompt_tokens + self.completion_tokens + + +class AIClient: + """Thin wrapper around the openai client targeting a configurable endpoint.""" + + def __init__(self, config: AIConfig) -> None: + self._config = config + self._client = OpenAI( + base_url=config.host, + api_key=config.api_key, + timeout=config.timeout_seconds, + default_headers=config.extra_headers, + ) + + def complete(self, system_prompt: str, user_message: str) -> AIResponse: + """Send a completion request and return the full response.""" + response = self._client.chat.completions.create( + model=self._config.model, + max_tokens=self._config.max_tokens, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, + ], + ) + + choice = response.choices[0] + content = choice.message.content or "" + usage = response.usage + + return AIResponse( + model=response.model, + content=content, + prompt_tokens=usage.prompt_tokens if usage else 0, + completion_tokens=usage.completion_tokens if usage else 0, + ) + + def stream(self, system_prompt: str, user_message: str) -> Iterator[str]: + """Stream a completion, yielding text chunks as they arrive.""" + stream = self._client.chat.completions.create( + model=self._config.model, + max_tokens=self._config.max_tokens, + stream=True, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, + ], + ) + + for chunk in stream: + delta = chunk.choices[0].delta.content + if delta: + yield delta + + def summary(self) -> str: + """Human-readable description of the AI config.""" + return f"host={self._config.host} model={self._config.model}" diff --git a/src/tai/cli.py b/src/tai/cli.py index aa5baa3..4481be6 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -7,11 +7,14 @@ from typing import Annotated import typer from rich.console import Console +from rich.markdown import Markdown +from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.collectors import CollectionReport, collect_from_plan from tai.input_parser import InputValidationError, build_request from tai.models import TroubleshootRequest from tai.plan import plan_from_request +from tai.prompt_builder import build_system_prompt, build_user_message from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig app = typer.Typer(no_args_is_help=True, add_completion=False) @@ -56,6 +59,25 @@ def run( help="Collect baseline diagnostics after probe.", ), ] = False, + analyze: Annotated[ + bool, + typer.Option( + "--analyze/--no-analyze", + help="Send collected diagnostics to AI for analysis.", + ), + ] = False, + ai_host: Annotated[ + str, + typer.Option("--ai-host", help="OpenAI-compatible AI backend URL."), + ] = DEFAULT_AI_HOST, + model: Annotated[ + str, + typer.Option("--model", help="Model name to use for AI analysis."), + ] = DEFAULT_MODEL, + ai_key: Annotated[ + str, + typer.Option("--ai-key", help="API key for the AI backend (not needed for Ollama)."), + ] = "ollama", ) -> None: """Start an interactive troubleshooting session scaffold.""" try: @@ -81,50 +103,69 @@ def run( ) summary = SSHClient(config).summary() - console.print("[bold green]tai scaffold ready[/bold green]") + console.print("[bold green]tai[/bold green]") console.print(f"Issue: {req.issue}") - console.print(f"SSH: {summary}") + console.print(f"SSH: {summary}") if req.target_paths: console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}") + if not (probe or collect or analyze): + return # nothing SSH-related requested + + ai_config = AIConfig(host=ai_host, model=model, api_key=ai_key) + if analyze: + console.print(f"[cyan]AI:[/cyan] {AIClient(ai_config).summary()}") + + try: + asyncio.run(_async_main(config, req, probe=probe, collect=collect, analyze=analyze, + ai_config=ai_config)) + except typer.Exit: + raise + except TimeoutError as exc: + console.print(f"[red]SSH timeout:[/red] {exc}") + raise typer.Exit(code=1) from exc + except OSError as exc: + console.print(f"[red]SSH error:[/red] unable to execute ssh: {exc}") + raise typer.Exit(code=1) from exc + + +async def _async_main( + config: SSHConnectionConfig, + req: TroubleshootRequest, + *, + probe: bool, + collect: bool, + analyze: bool, + ai_config: AIConfig, +) -> None: + """Open a single SSH session and run probe / collection / analysis through it.""" client = SSHClient(config) + async with client.connect() as session: + if probe: + result = await session.probe() + _handle_probe_result(result) - if probe: - _run_probe(client) + report: CollectionReport | None = None + if collect or analyze: + plan = plan_from_request(req) + console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") + report = await collect_from_plan(session, plan) + _handle_collection_report(report) - if collect: - _run_collection(client, req) + if analyze and report is not None: + _run_analysis(ai_config, req.issue, report) -def _run_probe(client: SSHClient) -> None: - """Run a live SSH probe and exit non-zero on failure.""" +def _handle_probe_result(result: SSHCommandResult) -> None: + """Handle and render probe output for success or failure.""" console.print("[cyan]Running SSH probe:[/cyan] uname -a") - try: - result = asyncio.run(client.probe()) - except TimeoutError as exc: - console.print(f"[red]Probe failed:[/red] {exc}") - raise typer.Exit(code=1) from exc - except OSError as exc: - console.print(f"[red]Probe failed:[/red] unable to execute ssh: {exc}") - raise typer.Exit(code=1) from exc - - _handle_probe_result(result) - - -def _run_collection(client: SSHClient, request: TroubleshootRequest) -> None: - """Run issue-aware collection and print a compact summary.""" - plan = plan_from_request(request) - console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") - try: - report = asyncio.run(collect_from_plan(client, plan)) - except TimeoutError as exc: - console.print(f"[red]Collection failed:[/red] {exc}") - raise typer.Exit(code=1) from exc - except OSError as exc: - console.print(f"[red]Collection failed:[/red] unable to execute ssh: {exc}") - raise typer.Exit(code=1) from exc - - _handle_collection_report(report) + if result.exit_code != 0: + details = result.stderr or result.stdout or "no error output from ssh" + console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}") + raise typer.Exit(code=1) + output = result.stdout or "(no output)" + console.print("[bold green]Probe succeeded.[/bold green]") + console.print(f"Remote: {output}") def _handle_collection_report(report: CollectionReport) -> None: @@ -134,22 +175,25 @@ def _handle_collection_report(report: CollectionReport) -> None: ) for item in report.items: status = "ok" if item.result.exit_code == 0 else f"exit {item.result.exit_code}" - trunc = "" - if item.result.stdout_truncated or item.result.stderr_truncated: - trunc = " (truncated)" + truncated = item.result.stdout_truncated or item.result.stderr_truncated + trunc = " (truncated)" if truncated else "" console.print(f"- {item.name}: {status}{trunc}") -def _handle_probe_result(result: SSHCommandResult) -> None: - """Handle and render probe output for success or failure.""" - if result.exit_code != 0: - details = result.stderr or result.stdout or "no error output from ssh" - console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}") - raise typer.Exit(code=1) - - output = result.stdout or "(no output)" - console.print("[bold green]Probe succeeded.[/bold green]") - console.print(f"Remote: {output}") +def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) -> None: + """Send collected data to the AI and stream the analysis to stdout.""" + console.print("[cyan]Analyzing...[/cyan]\n") + ai = AIClient(ai_config) + system_prompt = build_system_prompt() + user_message = build_user_message(issue, report) + try: + chunks: list[str] = [] + for chunk in ai.stream(system_prompt, user_message): + chunks.append(chunk) + console.print(Markdown("".join(chunks))) + except Exception as exc: # noqa: BLE001 + console.print(f"[red]AI analysis failed:[/red] {exc}") + raise typer.Exit(code=1) from exc def main() -> None: diff --git a/src/tai/collectors.py b/src/tai/collectors.py index 9ad4754..72fbdd2 100644 --- a/src/tai/collectors.py +++ b/src/tai/collectors.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from tai.plan import CollectionPlan -from tai.ssh_client import SSHClient, SSHCommandResult +from tai.ssh_client import SSHCommandResult, SSHSession @dataclass(slots=True) @@ -31,20 +31,20 @@ class CollectionReport: async def collect_from_plan( - client: SSHClient, + session: SSHSession, plan: CollectionPlan, *, max_output_bytes: int = 32768, ) -> CollectionReport: - """Execute all commands in *plan* and return a :class:`CollectionReport`.""" + """Execute all commands in *plan* over a shared SSH session.""" items: list[CollectedItem] = [] for name, command in plan.commands: - result = await client.run_read_only_command( + result = await session.run_read_only_command( command, timeout_seconds=30.0, max_output_bytes=max_output_bytes, ) items.append(CollectedItem(name=name, result=result)) - return CollectionReport(host=client.summary(), items=items) + return CollectionReport(host=session._client.summary(), items=items) diff --git a/src/tai/prompt_builder.py b/src/tai/prompt_builder.py new file mode 100644 index 0000000..360cfb7 --- /dev/null +++ b/src/tai/prompt_builder.py @@ -0,0 +1,74 @@ +"""Formats collected diagnostics into prompts for the AI backend.""" + +from __future__ import annotations + +from tai.collectors import CollectionReport + +_SYSTEM_PROMPT = """\ +You are an expert Linux systems administrator and troubleshooting assistant. +You are given diagnostic data collected read-only from a remote Linux host via SSH. + +Your job: +1. Identify the root cause of the reported issue based only on the data provided. +2. Cite the specific output that supports your conclusion. +3. Give concise, actionable remediation steps. + +Important rules: +- Only draw conclusions from data that is actually present. Do not speculate or invent evidence. +- If a command shows "could not be executed (SSH error)" it means the remote host blocked or + rejected that specific command — it is not evidence about the service or system state. +- If there is not enough data to diagnose the issue, say so plainly and list exactly what + additional commands or log files would be needed. +- Keep the response short. Skip sections that have nothing useful to say. +- Never suggest commands that modify the system unless explicitly asked. +- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**. +""" + + +def build_system_prompt() -> str: + """Return the static system prompt for the troubleshooting agent.""" + return _SYSTEM_PROMPT.strip() + + +def build_user_message(issue: str, report: CollectionReport) -> str: + """Format *issue* and *report* into the user message sent to the AI.""" + lines: list[str] = [] + + lines.append(f"## Issue reported\n\n{issue}\n") + lines.append(f"## Target host\n\n{report.host}\n") + lines.append("## Collected diagnostics\n") + + skipped: list[str] = [] + + for item in report.items: + result = item.result + + # Exit 255 with no output = SSH couldn't execute the command at all. + # Exclude these entirely to prevent the AI from speculating on them. + if result.exit_code == 255 and not result.stdout and not result.stderr: + skipped.append(item.name) + continue + + lines.append(f"### {item.name}\n") + lines.append(f"**Command:** `{result.command}` ") + lines.append(f"**Exit code:** {result.exit_code}\n") + + if result.stdout: + trunc = " *(output truncated)*" if result.stdout_truncated else "" + lines.append(f"**stdout:**{trunc}\n```\n{result.stdout.strip()}\n```\n") + + if result.stderr: + trunc = " *(output truncated)*" if result.stderr_truncated else "" + lines.append(f"**stderr:**{trunc}\n```\n{result.stderr.strip()}\n```\n") + + if not result.stdout and not result.stderr: + lines.append("*(no output)*\n") + + if skipped: + lines.append( + f"**Note:** The following commands could not be executed on this host " + f"and produced no output: {', '.join(skipped)}. " + f"Do not draw any conclusions from their absence.\n" + ) + + return "\n".join(lines) diff --git a/src/tai/ssh_client.py b/src/tai/ssh_client.py index 4f7164a..6235010 100644 --- a/src/tai/ssh_client.py +++ b/src/tai/ssh_client.py @@ -1,9 +1,12 @@ """SSH configuration and read-only command execution.""" import asyncio +import os import shlex +import tempfile from dataclasses import dataclass from pathlib import Path +from types import TracebackType @dataclass(slots=True) @@ -88,8 +91,8 @@ class SSHClient: f"key={key} jump={jump} mode={mode}" ) - def build_ssh_argv(self, remote_command: str) -> list[str]: - """Build argv for a secure non-interactive SSH invocation.""" + def _build_base_argv(self) -> list[str]: + """Build the common SSH argv flags (no host or command appended).""" argv = [ "ssh", "-p", @@ -111,9 +114,16 @@ class SSHClient: if self._config.jump_host: argv += ["-J", self._config.jump_host] - argv += [self._config.host, remote_command] return argv + def build_ssh_argv(self, remote_command: str) -> list[str]: + """Build argv for a secure non-interactive SSH invocation.""" + return self._build_base_argv() + [self._config.host, remote_command] + + def connect(self, *, connect_timeout: float = 15.0) -> "SSHSession": + """Return an :class:`SSHSession` async context manager for this host.""" + return SSHSession(self, connect_timeout=connect_timeout) + def validate_read_only_command(self, command: str) -> None: """Validate that a command appears read-only and non-destructive.""" normalized = command.strip() @@ -168,8 +178,7 @@ class SSHClient: max_output_bytes=4096, ) - async def _run_ssh( - self, + async def _run_ssh( self, command: str, *, timeout_seconds: float, @@ -231,3 +240,144 @@ class SSHClient: trimmed_bytes = encoded[:keep] trimmed_text = trimmed_bytes.decode("utf-8", errors="ignore").rstrip() return f"{trimmed_text}{marker}", True + + +class SSHSession: + """A persistent SSH connection using ControlMaster multiplexing. + + All commands run over the same underlying TCP connection — no per-command + SSH handshake. Use as an async context manager:: + + async with client.connect() as session: + result = await session.run_read_only_command("df -h") + """ + + def __init__(self, client: SSHClient, *, connect_timeout: float = 15.0) -> None: + self._client = client + self._connect_timeout = connect_timeout + self._socket_path: Path | None = None + self._master_proc: asyncio.subprocess.Process | None = None + + async def __aenter__(self) -> "SSHSession": + fd, path = tempfile.mkstemp(prefix="tai-ssh-", suffix=".sock") + os.close(fd) + os.unlink(path) # SSH needs to create this itself as a socket + self._socket_path = Path(path) + + master_argv = self._client._build_base_argv() + [ + "-o", "ControlMaster=yes", + "-o", f"ControlPath={self._socket_path}", + "-o", "ControlPersist=no", + "-N", + self._client._config.host, + ] + + self._master_proc = await asyncio.create_subprocess_exec( + *master_argv, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + + # Wait for the control socket to appear (master is ready) + loop = asyncio.get_event_loop() + deadline = loop.time() + self._connect_timeout + while not self._socket_path.exists(): + if loop.time() > deadline: + await self._teardown() + raise TimeoutError( + f"SSH ControlMaster did not connect within " + f"{self._connect_timeout}s to {self._client._config.host}" + ) + await asyncio.sleep(0.05) + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._teardown() + + async def _teardown(self) -> None: + if self._master_proc and self._master_proc.returncode is None: + self._master_proc.terminate() + try: + await asyncio.wait_for(self._master_proc.wait(), timeout=5.0) + except TimeoutError: + self._master_proc.kill() + if self._socket_path and self._socket_path.exists(): + self._socket_path.unlink(missing_ok=True) + + def _command_argv(self, remote_command: str) -> list[str]: + return self._client._build_base_argv() + [ + "-o", f"ControlPath={self._socket_path}", + "-o", "ControlMaster=no", + self._client._config.host, + remote_command, + ] + + async def probe(self) -> SSHCommandResult: + """Run uname -a to confirm connectivity.""" + return await self._run("uname -a", timeout_seconds=15.0, max_output_bytes=4096) + + async def run_read_only_command( + self, + command: str, + *, + timeout_seconds: float = 30.0, + max_output_bytes: int = 32768, + ) -> SSHCommandResult: + """Validate and run a read-only command over the shared connection.""" + self._client.validate_read_only_command(command) + return await self._run( + command, timeout_seconds=timeout_seconds, max_output_bytes=max_output_bytes + ) + + async def _run( + self, + command: str, + *, + timeout_seconds: float, + max_output_bytes: int, + ) -> SSHCommandResult: + argv = self._command_argv(command) + proc = await asyncio.create_subprocess_exec( + *argv, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + proc.communicate(), timeout=timeout_seconds + ) + except TimeoutError as exc: + proc.kill() + await proc.wait() + raise TimeoutError( + f"SSH command timed out after {timeout_seconds}s: {command}" + ) from exc + + if proc.returncode is None: + raise RuntimeError("SSH process did not provide an exit code.") + + stdout_text, stdout_truncated = SSHClient._truncate_output( + stdout_bytes.decode("utf-8", errors="replace"), + max_output_bytes=max_output_bytes, + ) + stderr_text, stderr_truncated = SSHClient._truncate_output( + stderr_bytes.decode("utf-8", errors="replace"), + max_output_bytes=max_output_bytes, + ) + + return SSHCommandResult( + command=command, + exit_code=proc.returncode, + stdout=stdout_text, + stderr=stderr_text, + stdout_truncated=stdout_truncated, + stderr_truncated=stderr_truncated, + ) + diff --git a/tests/test_ai.py b/tests/test_ai.py new file mode 100644 index 0000000..08d1510 --- /dev/null +++ b/tests/test_ai.py @@ -0,0 +1,192 @@ +"""Tests for the AI client and prompt builder.""" + +from unittest.mock import MagicMock, patch + +from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig +from tai.collectors import CollectedItem, CollectionReport +from tai.prompt_builder import build_system_prompt, build_user_message +from tai.ssh_client import SSHCommandResult + +# --------------------------------------------------------------------------- +# AIConfig defaults +# --------------------------------------------------------------------------- + + +def test_ai_config_defaults() -> None: + config = AIConfig() + assert config.host == DEFAULT_AI_HOST + assert config.model == DEFAULT_MODEL + assert config.api_key == "ollama" + + +def test_ai_config_custom_values() -> None: + config = AIConfig(host="https://api.openai.com/v1", model="gpt-4o", api_key="sk-test") + assert config.host == "https://api.openai.com/v1" + assert config.model == "gpt-4o" + assert config.api_key == "sk-test" + + +# --------------------------------------------------------------------------- +# AIClient.summary +# --------------------------------------------------------------------------- + + +def test_ai_client_summary_contains_host_and_model() -> None: + config = AIConfig(host="http://myserver:11434/v1", model="llama3.1:8b") + client = AIClient(config) + summary = client.summary() + assert "http://myserver:11434/v1" in summary + assert "llama3.1:8b" in summary + + +# --------------------------------------------------------------------------- +# AIClient.complete (mocked) +# --------------------------------------------------------------------------- + + +def _make_mock_response(content: str, model: str = "gemma3:4b") -> MagicMock: + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 20 + + message = MagicMock() + message.content = content + + choice = MagicMock() + choice.message = message + + response = MagicMock() + response.choices = [choice] + response.model = model + response.usage = usage + return response + + +def test_complete_returns_ai_response() -> None: + config = AIConfig() + client = AIClient(config) + mock_response = _make_mock_response("The root cause is X.") + + with patch.object(client._client.chat.completions, "create", return_value=mock_response): + result = client.complete("system prompt", "user message") + + assert result.content == "The root cause is X." + assert result.prompt_tokens == 10 + assert result.completion_tokens == 20 + assert result.total_tokens == 30 + + +def test_complete_handles_empty_content() -> None: + config = AIConfig() + client = AIClient(config) + mock_response = _make_mock_response(None) # type: ignore[arg-type] + mock_response.choices[0].message.content = None + + with patch.object(client._client.chat.completions, "create", return_value=mock_response): + result = client.complete("system", "user") + + assert result.content == "" + + +# --------------------------------------------------------------------------- +# AIClient.stream (mocked) +# --------------------------------------------------------------------------- + + +def test_stream_yields_chunks() -> None: + config = AIConfig() + client = AIClient(config) + + def _make_chunk(text: str | None) -> MagicMock: + delta = MagicMock() + delta.content = text + choice = MagicMock() + choice.delta = delta + chunk = MagicMock() + chunk.choices = [choice] + return chunk + + mock_chunks = [ + _make_chunk("Root "), _make_chunk("cause "), _make_chunk(None), _make_chunk("found."), + ] + + with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)): + result = list(client.stream("system", "user")) + + assert result == ["Root ", "cause ", "found."] + + +# --------------------------------------------------------------------------- +# prompt_builder +# --------------------------------------------------------------------------- + + +def _make_report(items: list[tuple[str, str, int, str, str]]) -> CollectionReport: + """Build a CollectionReport from (name, command, exit_code, stdout, stderr) tuples.""" + return CollectionReport( + host="root@testhost", + items=[ + CollectedItem( + name=name, + result=SSHCommandResult( + command=command, + exit_code=exit_code, + stdout=stdout, + stderr=stderr, + ), + ) + for name, command, exit_code, stdout, stderr in items + ], + ) + + +def test_build_system_prompt_contains_key_instructions() -> None: + prompt = build_system_prompt() + assert "Root Cause" in prompt + assert "Evidence" in prompt + assert "Recommended Actions" in prompt + assert "read-only" in prompt.lower() + + +def test_build_user_message_contains_issue_and_host() -> None: + report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")]) + msg = build_user_message("nginx is failing", report) + assert "nginx is failing" in msg + assert "root@testhost" in msg + + +def test_build_user_message_includes_command_output() -> None: + report = _make_report([("kernel", "uname -a", 0, "Linux web01 6.1.0", "")]) + msg = build_user_message("test issue", report) + assert "uname -a" in msg + assert "Linux web01 6.1.0" in msg + + +def test_build_user_message_shows_stderr() -> None: + report = _make_report( + [("svc", "systemctl status nginx", 3, "", "Unit nginx.service not found.")] + ) + msg = build_user_message("nginx not found", report) + assert "Unit nginx.service not found." in msg + + +def test_build_user_message_notes_truncation() -> None: + result = SSHCommandResult( + command="journalctl -n 100 --no-pager", + exit_code=0, + stdout="...", + stderr="", + stdout_truncated=True, + ) + report = CollectionReport( + host="root@testhost", + items=[CollectedItem(name="journal", result=result)], + ) + msg = build_user_message("disk issue", report) + assert "truncated" in msg + + +def test_build_user_message_handles_no_output() -> None: + report = _make_report([("empty", "cat /nonexistent", 1, "", "")]) + msg = build_user_message("test", report) + assert "no output" in msg diff --git a/tests/test_cli.py b/tests/test_cli.py index 9bac274..c26716f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,3 +1,5 @@ +from unittest.mock import AsyncMock, MagicMock + from typer.testing import CliRunner from tai.cli import app @@ -5,6 +7,24 @@ from tai.collectors import CollectedItem, CollectionReport from tai.ssh_client import SSHCommandResult +def _mock_session( + monkeypatch, # type: ignore[no-untyped-def] + *, + probe_result: SSHCommandResult | None = None, + probe_raises: Exception | None = None, +) -> MagicMock: + """Patch SSHClient.connect to return a mock session.""" + session = MagicMock() + session.__aenter__ = AsyncMock(return_value=session) + session.__aexit__ = AsyncMock(return_value=None) + if probe_raises: + session.probe = AsyncMock(side_effect=probe_raises) + else: + session.probe = AsyncMock(return_value=probe_result) + monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session) + return session + + def test_run_command_prints_scaffold_summary() -> None: runner = CliRunner() result = runner.invoke( @@ -25,33 +45,23 @@ def test_run_command_prints_scaffold_summary() -> None: ) assert result.exit_code == 0 - assert "tai scaffold ready" in result.stdout + assert "tai" in result.stdout assert "host=web01" in result.stdout assert "port=5566" in result.stdout def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def] - async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def] - return SSHCommandResult( - command="uname -a", - exit_code=0, - stdout="Linux ssh 6.12.0", - stderr="", - ) - - monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe) + _mock_session( + monkeypatch, + probe_result=SSHCommandResult( + command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr="" + ), + ) runner = CliRunner() result = runner.invoke( app, - [ - "apache failed", - "--host", - "ssh.archflux.net", - "--port", - "5566", - "--probe", - ], + ["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"], ) assert result.exit_code == 0 @@ -60,27 +70,20 @@ def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def] - async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def] - return SSHCommandResult( + _mock_session( + monkeypatch, + probe_result=SSHCommandResult( command="uname -a", exit_code=255, stdout="", stderr="Permission denied (publickey,password).", - ) - - monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe) + ), + ) runner = CliRunner() result = runner.invoke( app, - [ - "apache failed", - "--host", - "ssh.archflux.net", - "--port", - "5566", - "--probe", - ], + ["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"], ) assert result.exit_code == 1 @@ -88,7 +91,9 @@ def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def] - async def fake_collect_from_plan(_client, _plan) -> CollectionReport: # type: ignore[no-untyped-def] + _mock_session(monkeypatch) + + async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] return CollectionReport( host="ssh.archflux.net", items=[ From 6e693d0c8332c5e7c1f7867f39c2b7e82161072d Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 04:52:47 +0200 Subject: [PATCH 11/11] update --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 53e77c5..5ef7c20 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ A troubleshooter receives a ticket reporting that the Apache service on a remote | Model | `gemma3:4b`, `llama3.1:8b`, or `qwen2.5:7b` | | Language | Python 3.11+ | ---- +______________________________________________________________________ ## How-To: Setting Up the AI Backend (Arch Linux + RTX 3080)