Merge pull request 'initialCommit' (#2) from initialCommit into main

Reviewed-on: #2
This commit is contained in:
2026-05-04 04:54:49 +02:00
22 changed files with 2366 additions and 2 deletions

101
.gitea/workflows/ci.yml Normal file
View File

@@ -0,0 +1,101 @@
name: CI
on:
push:
pull_request:
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Ensure git is available
run: |
if command -v git >/dev/null 2>&1; then
git --version
exit 0
fi
if command -v apt-get >/dev/null 2>&1; then
apt-get update
apt-get install -y git
elif command -v dnf >/dev/null 2>&1; then
dnf install -y git
elif command -v yum >/dev/null 2>&1; then
yum install -y git
else
echo "No supported package manager found to install git."
exit 1
fi
git --version
- name: Checkout source (native git)
env:
CI_GIT_TOKEN: ${{ secrets.CI_GIT_TOKEN }}
run: |
if [ -z "${CI_GIT_TOKEN:-}" ]; then
echo "Missing secret CI_GIT_TOKEN. Add it in repository Actions secrets."
exit 1
fi
auth_server="${GITHUB_SERVER_URL#https://}"
auth_server="${auth_server#http://}"
remote_url="https://oauth2:${CI_GIT_TOKEN}@${auth_server}/${GITHUB_REPOSITORY}.git"
if [ -n "${GITHUB_WORKSPACE:-}" ]; then
cd "$GITHUB_WORKSPACE"
fi
if [ ! -d .git ]; then
git init
fi
git remote remove origin >/dev/null 2>&1 || true
git remote add origin "$remote_url"
git fetch --depth 1 origin "$GITHUB_SHA"
git checkout --force FETCH_HEAD
- name: Ensure Python and pip are available
run: |
if command -v python3 >/dev/null 2>&1 && python3 -m pip --version >/dev/null 2>&1; then
python3 --version
exit 0
fi
if command -v apt-get >/dev/null 2>&1; then
apt-get update
apt-get install -y python3 python3-pip python3-venv
elif command -v dnf >/dev/null 2>&1; then
dnf install -y python3 python3-pip
elif command -v yum >/dev/null 2>&1; then
yum install -y python3 python3-pip
else
echo "No supported package manager found to install Python."
exit 1
fi
python3 --version
- name: Install package and dev dependencies
run: |
python3 -m venv .venv
. .venv/bin/activate
python -m pip install --upgrade pip
python -m pip install -e .[dev]
- name: Lint
run: .venv/bin/python -m ruff check .
- name: Lint Markdown
run: .venv/bin/mdformat --check README.md ROADMAP.md CHANGELOG.md
- name: Lint YAML
run: .venv/bin/yamllint .
- name: Type-check
run: .venv/bin/python -m mypy src
- name: Test
run: .venv/bin/python -m pytest

View File

@@ -0,0 +1,110 @@
name: Release
on:
push:
tags:
- "v*"
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Ensure git is available
run: |
if command -v git >/dev/null 2>&1; then
git --version
exit 0
fi
if command -v apt-get >/dev/null 2>&1; then
apt-get update
apt-get install -y git
elif command -v dnf >/dev/null 2>&1; then
dnf install -y git
elif command -v yum >/dev/null 2>&1; then
yum install -y git
else
echo "No supported package manager found to install git."
exit 1
fi
- name: Checkout source (native git)
env:
CI_GIT_TOKEN: ${{ secrets.CI_GIT_TOKEN }}
run: |
if [ -z "${CI_GIT_TOKEN:-}" ]; then
echo "Missing secret CI_GIT_TOKEN. Add it in repository Actions secrets."
exit 1
fi
auth_server="${GITHUB_SERVER_URL#https://}"
auth_server="${auth_server#http://}"
remote_url="https://oauth2:${CI_GIT_TOKEN}@${auth_server}/${GITHUB_REPOSITORY}.git"
if [ -n "${GITHUB_WORKSPACE:-}" ]; then
cd "$GITHUB_WORKSPACE"
fi
if [ ! -d .git ]; then
git init
fi
git remote remove origin >/dev/null 2>&1 || true
git remote add origin "$remote_url"
# Fetch the tag by SHA so we get the exact tagged commit
git fetch --depth 1 origin "$GITHUB_SHA"
git checkout --force FETCH_HEAD
- name: Ensure Python and build dependencies are available
run: |
if ! command -v python3 >/dev/null 2>&1; then
if command -v apt-get >/dev/null 2>&1; then
apt-get update
apt-get install -y python3 python3-pip python3-venv patchelf ccache
elif command -v dnf >/dev/null 2>&1; then
dnf install -y python3 python3-pip patchelf ccache
fi
fi
# patchelf is required by Nuitka for standalone Linux binaries
command -v patchelf >/dev/null 2>&1 || {
apt-get update && apt-get install -y patchelf
}
python3 --version
- name: Set up venv and install package + build deps
run: |
python3 -m venv .venv
. .venv/bin/activate
python -m pip install --upgrade pip
python -m pip install -e ".[build]"
- name: Derive version from tag
id: version
run: echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
- name: Build standalone binary with Nuitka
run: |
. .venv/bin/activate
python -m nuitka \
--standalone \
--onefile \
--output-filename=tai \
--output-dir=dist \
--assume-yes-for-downloads \
--include-package=tai \
src/tai/cli.py
- name: Smoke-test the binary
run: dist/tai --help
- name: Upload binary artifact
uses: actions/upload-artifact@v3
with:
name: tai-linux-amd64-${{ steps.version.outputs.tag }}
path: dist/tai
if-no-files-found: error
retention-days: 90

26
.gitignore vendored Normal file
View File

@@ -0,0 +1,26 @@
# Python cache and bytecode
__pycache__/
*.py[cod]
*.pyo
# Virtual environments
.venv/
venv/
# Tool caches
.pytest_cache/
.ruff_cache/
.mypy_cache/
# Build artifacts
build/
dist/
*.egg-info/
*.spec
# Coverage
.coverage
htmlcov/
# IDE
.vscode/

16
.yamllint.yml Normal file
View File

@@ -0,0 +1,16 @@
extends: default
ignore: |
.git/
.venv/
.mypy_cache/
.pytest_cache/
.ruff_cache/
rules:
document-start: disable
line-length:
max: 120
truthy:
allowed-values: ["true", "false"]
check-keys: false

46
CHANGELOG.md Normal file
View File

@@ -0,0 +1,46 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
______________________________________________________________________
## [Unreleased]
### Added
- `README.md` — project overview, description, example workflow, supported distributions, and suggested tooling
- `ROADMAP.md` — phased development plan covering decisions, data collection, AI integration, CLI design, and hardening
- `CHANGELOG.md` — this file; established changelog tracking for the project
- `.gitea/workflows/ci.yml` — Gitea Actions CI workflow for push and pull request events
- Gitea CI now uses native `git` checkout and system Python setup to avoid host-executor JavaScript action path issues
- Gitea native checkout now uses `CI_GIT_TOKEN` repository secret for authenticated fetch from private repos
- Gitea CI now installs dependencies in a local `.venv` to avoid Debian/PEP 668 externally-managed pip errors
- Python package scaffold with `src` layout and project metadata in `pyproject.toml`
- Initial CLI entrypoint with agreed SSH flags: `--identity-file`, `--jump-host`, and `--ignore-ssh-config`
- Input parsing/validation module and core request model
- SSH configuration scaffold module for upcoming connection/read-only execution work
- Implemented SSH module with real key-based command execution via system `ssh`
- Added explicit SSH port support across CLI, input parsing, request model, and SSH client (`--port`, e.g. 5566)
- Added live SSH connectivity probe (`uname -a`) enabled by default, with `--no-probe` opt-out and non-zero exit on failure
- Added baseline diagnostics collection via `--collect`, including service, journal, disk, and network checks
- Read-only command policy enforcement (allowlist + blocked shell operators)
- Added byte-limited SSH output capture with truncation markers for large command output
- Test scaffold (`pytest`) with initial parser and CLI coverage
- SSH test coverage for policy checks, SSH argument construction, and config summary behavior
- CI workflow for lint (`ruff`), type-check (`mypy`), and tests (`pytest`)
- CI coverage expanded with Markdown formatting checks (`mdformat --check`) and YAML linting (`yamllint`)
### Removed
- `.github/workflows/ci.yml` — GitHub Actions workflow removed; CI is now Gitea-only
### Decided
- Implementation language: **Python**
- Distribution strategy: single distributable binary via **Nuitka** (PyInstaller as fallback)
- SSH authentication: **keypair only** (ed25519/RSA); auto-accept new hosts; hard reject on host key change with MITM warning
- SSH bastion support: `--jump-host` flag using SSH native ProxyJump
- SSH config behavior: use `~/.ssh/config` by default; allow override via `--ignore-ssh-config`
- Interface: **interactive REPL** for v0.1; `textual`-based TUI (split-pane) for v0.2+

View File

@@ -1,3 +1,93 @@
# tai
# tai — Linux AI Troubleshooting Agent
Linux AI driven troubleshooting agent.
`tai` is an agentic AI-driven troubleshooting tool for Linux systems. It autonomously investigates issues on remote hosts via SSH, analyzes relevant logs and configuration files, and provides a clear diagnosis along with suggested remediation steps — all without making any changes to the target system.
## Overview
Given a problem description and a target hostname, `tai` connects to the remote system over SSH, gathers relevant data (logs, configuration files, service status, etc.), and uses a locally-hosted AI model to reason about the root cause and recommend solutions.
The agent operates in **read-only mode at all times**. It will never modify the target system under any circumstances — all suggestions are presented to the human troubleshooter for review and action.
## Supported Distributions
- Ubuntu
- Debian
- RHEL
- Rocky Linux
## Example Workflow
A troubleshooter receives a ticket reporting that the Apache service on a remote server has failed to start. They provide `tai` with:
1. The ticket description or error message
1. The hostname of the affected system
1. Any relevant directories to focus on
`tai` then connects to the host, reads through system logs, service configurations, and any other related files, and returns a structured analysis of the likely cause along with recommended next steps.
## Suggested Tooling
| Component | Tool |
|-----------|------|
| AI inference backend | [Ollama](https://ollama.com) |
| Model | `gemma3:4b`, `llama3.1:8b`, or `qwen2.5:7b` |
| Language | Python 3.11+ |
______________________________________________________________________
## How-To: Setting Up the AI Backend (Arch Linux + RTX 3080)
`tai` uses [Ollama](https://ollama.com) as its local AI backend. It exposes an OpenAI-compatible HTTP API that `tai` talks to — no cloud services, no data leaving your machine.
An RTX 3080 (10 GB VRAM) comfortably runs 78B parameter models at 4-bit quantisation.
### 1. Install CUDA and Ollama
```bash
# CUDA runtime (skip if already installed)
sudo pacman -S cuda
# Ollama with CUDA support from the AUR
yay -S ollama-cuda
# or: paru -S ollama-cuda
# Enable and start the service
sudo systemctl enable --now ollama
```
### 2. Pull a model
```bash
ollama pull gemma3:4b # ~3 GB — fast, good for sysadmin tasks
ollama pull llama3.1:8b # ~5 GB — stronger reasoning
ollama pull qwen2.5:7b # ~4.5 GB — strong structured output
```
### 3. Verify the model works
```bash
ollama run gemma3:4b "what causes a systemd service to enter failed state?"
```
### 4. Verify the HTTP API is running
`tai` communicates with Ollama over its OpenAI-compatible REST API:
```bash
curl http://localhost:11434/api/generate \
-d '{"model":"gemma3:4b","prompt":"hello","stream":false}'
```
A JSON response with a `response` field confirms everything is working.
### 5. Point tai at your Ollama instance
Once `tai` AI integration is complete, use these flags:
```bash
tai "nginx failing to start" --host web01 \
--ai-host http://localhost:11434 \
--model gemma3:4b
```
The default values for `--ai-host` and `--model` will be `http://localhost:11434` and `gemma3:4b` respectively, so for local use you won't need to specify them explicitly.

130
ROADMAP.md Normal file
View File

@@ -0,0 +1,130 @@
# Roadmap
This document outlines the major decisions, milestones, and development phases required to bring `tai` from concept to a working tool.
______________________________________________________________________
## Phase 0 — Decisions & Prerequisites
These must be resolved before meaningful development can begin.
### Language Selection
- [x] **Decision: Python**
- Key factors: native vLLM integration, mature SSH libraries (`paramiko` / `asyncssh`), strong text/log parsing, rapid development
- Single binary distribution will be achieved via **Nuitka** (preferred for true compilation) or **PyInstaller** as a fallback
- [ ] Evaluate Nuitka vs PyInstaller for binary output quality and CI reproducibility
- [ ] Add binary build step to CI pipeline
### AI Backend & Model
- [ ] Confirm use of [vLLM](https://github.com/vllm-project/vllm) as the inference backend
- [ ] Confirm `gemma4:a4b` as the default model (or select an alternative)
- [ ] Define minimum hardware requirements for running the model locally
- [ ] Decide whether the AI backend is bundled, self-hosted externally, or user-supplied
### SSH Strategy
- [x] **Decision: keypair authentication only** — no password auth; eliminates credential storage risk
- Default key resolution: `~/.ssh/id_ed25519`, `~/.ssh/id_rsa` (in order of preference)
- CLI override via `--identity-file <path>`
- No SSH agent forwarding needed — a shared key is distributed to all managed hosts via Puppet
- [x] **Known hosts: auto-accept new hosts; reject on key mismatch** — a changed host key triggers a hard stop with a MITM warning; unknown/new hosts are accepted silently on first connect
- [x] **Bastion/jump host: `--jump-host <host>` flag** — delegates to SSH's native ProxyJump functionality
- [x] **SSH config behavior: respect existing `~/.ssh/config` by default; allow CLI override**
- Default: follow host settings from `~/.ssh/config` (for `User`, `Port`, `ProxyJump`, etc.)
- Override switch: `--ignore-ssh-config` to bypass local SSH config when required
### Scope & Constraints
- [ ] Define the supported scope of issues (services, network, disk, kernel, etc.)
- [ ] Confirm read-only guarantee — document exactly what "read-only" means in practice
- [x] **Decision: interactive REPL mode for v0.1, full TUI for v0.2+**
- v0.1: chat-loop REPL launched from CLI; human can follow up, correct, and redirect the agent
- v0.2+: `textual`-based TUI with split panes (collected data | AI output | input bar)
- Built-in slash commands: `/collect`, `/show logs`, `/clear`, `/host <hostname>`, `/help`, `/quit`
______________________________________________________________________
## Phase 1 — Project Foundation
Basic project scaffolding and connectivity.
- [x] Finalise repository structure and language toolchain
- [x] Set up CI pipeline (linting, tests)
- [ ] Implement SSH connection module
- [x] Define SSH config model and probe interface scaffold
- [x] Connect to remote host
- [x] Execute read-only commands (e.g. `journalctl`, `systemctl status`, `cat`)
- [x] Stream or collect command output safely (byte-limited output with truncation marker)
- [x] Implement basic input parsing (ticket text, hostname, target directories)
- [x] Write unit tests for SSH and input modules
- [x] Input parser and CLI tests added
- [x] SSH module tests added for command policy and SSH argv behavior
______________________________________________________________________
## Phase 2 — Data Collection Layer
Define what information the agent gathers and how.
- [ ] Identify the canonical set of data sources per issue type:
- Service failures: `journalctl`, `systemctl`, service config files
- Network issues: `ip`, `ss`, `netstat`, firewall rules
- Disk issues: `df`, `du`, `dmesg`, `smartctl`
- General: `/var/log/syslog`, `/var/log/messages`, `dmesg`
- [ ] Implement pluggable "collector" modules per data source
- [ ] Implement directory traversal for user-specified paths (read-only)
- [ ] Add support for per-distro variations (Ubuntu vs RHEL path differences, etc.)
- [ ] Write tests with mocked SSH output
______________________________________________________________________
## Phase 3 — AI Integration
Wire collected data into the local AI model.
- [ ] Implement vLLM client module
- [ ] Design prompt template: system context, collected data, issue description → diagnosis
- [ ] Implement response parsing and structured output (root cause + suggested steps)
- [ ] Tune context window usage — handle truncation for large log outputs
- [ ] Add streaming support for long AI responses
- [ ] Evaluate and test model output quality on common issue types
______________________________________________________________________
## Phase 4 — CLI & User Experience
Polish the interface for real-world use.
- [ ] Design CLI interface (flags, subcommands, interactive prompts)
- [ ] Implement structured output: diagnosis, confidence, recommended actions
- [ ] Add `--verbose` / `--debug` mode showing raw collected data
- [ ] Support output to file or clipboard
- [ ] Write man page / `--help` documentation
______________________________________________________________________
## Phase 5 — Hardening & Distribution
Prepare for broader use.
- [ ] Security review of SSH handling and credential storage
- [ ] Ensure no data is written to the remote system under any path
- [ ] Package for distribution (binary release, container image, or distro packages)
- [ ] Write installation and quickstart documentation
- [ ] End-to-end integration tests against a test VM
______________________________________________________________________
## Decisions Log
| Date | Decision | Outcome |
|------|----------|---------|
| 2026-05-04 | Implementation language | Python — with single distributable binary via Nuitka |
| — | AI inference backend | vLLM (provisional) |
| — | Default model | `gemma4:a4b` (provisional) |
| 2026-05-04 | SSH auth methods | Keypair only (ed25519/RSA); auto-accept new hosts; reject on key change (MITM) |
| 2026-05-04 | Bastion host support | `--jump-host` flag via SSH native ProxyJump |
| 2026-05-04 | SSH config behavior | Use `~/.ssh/config` by default; allow override via `--ignore-ssh-config` |
| 2026-05-04 | CLI vs interactive mode | Interactive: REPL for v0.1, `textual` TUI for v0.2+ |

53
pyproject.toml Normal file
View File

@@ -0,0 +1,53 @@
[build-system]
requires = ["hatchling>=1.25"]
build-backend = "hatchling.build"
[project]
name = "tai"
version = "0.1.0"
description = "Linux AI-driven troubleshooting agent"
readme = "README.md"
requires-python = ">=3.11"
authors = [
{ name = "tai contributors" }
]
dependencies = [
"typer>=0.12,<1.0",
"rich>=13.7,<14.0",
"asyncssh>=2.14,<3.0",
"openai>=1.30,<2.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.2,<9.0",
"ruff>=0.5,<1.0",
"mypy>=1.10,<2.0",
"mdformat>=0.7,<1.0",
"yamllint>=1.35,<2.0",
]
build = [
"nuitka>=2.4,<3.0",
]
[project.scripts]
tai = "tai.cli:main"
[tool.hatch.build.targets.wheel]
packages = ["src/tai"]
[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "-q"
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.ruff.lint]
select = ["E", "F", "I", "UP", "B"]
[tool.mypy]
python_version = "3.11"
strict = true
warn_unused_configs = true

5
src/tai/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""tai package."""
__all__ = ["__version__"]
__version__ = "0.1.0"

93
src/tai/ai_client.py Normal file
View File

@@ -0,0 +1,93 @@
"""AI backend client — OpenAI-compatible, works with Ollama, OpenAI, or any compatible endpoint."""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass, field
from openai import OpenAI
DEFAULT_AI_HOST = "http://localhost:11434/v1"
DEFAULT_MODEL = "gemma3:4b"
@dataclass(slots=True)
class AIConfig:
"""Connection parameters for an OpenAI-compatible AI backend."""
host: str = DEFAULT_AI_HOST
model: str = DEFAULT_MODEL
api_key: str = "ollama" # Ollama ignores this; required by the openai client
timeout_seconds: float = 120.0
max_tokens: int = 4096
extra_headers: dict[str, str] = field(default_factory=dict)
@dataclass(slots=True)
class AIResponse:
"""Structured response from an AI completion."""
model: str
content: str
prompt_tokens: int
completion_tokens: int
@property
def total_tokens(self) -> int:
return self.prompt_tokens + self.completion_tokens
class AIClient:
"""Thin wrapper around the openai client targeting a configurable endpoint."""
def __init__(self, config: AIConfig) -> None:
self._config = config
self._client = OpenAI(
base_url=config.host,
api_key=config.api_key,
timeout=config.timeout_seconds,
default_headers=config.extra_headers,
)
def complete(self, system_prompt: str, user_message: str) -> AIResponse:
"""Send a completion request and return the full response."""
response = self._client.chat.completions.create(
model=self._config.model,
max_tokens=self._config.max_tokens,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
)
choice = response.choices[0]
content = choice.message.content or ""
usage = response.usage
return AIResponse(
model=response.model,
content=content,
prompt_tokens=usage.prompt_tokens if usage else 0,
completion_tokens=usage.completion_tokens if usage else 0,
)
def stream(self, system_prompt: str, user_message: str) -> Iterator[str]:
"""Stream a completion, yielding text chunks as they arrive."""
stream = self._client.chat.completions.create(
model=self._config.model,
max_tokens=self._config.max_tokens,
stream=True,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
)
for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
yield delta
def summary(self) -> str:
"""Human-readable description of the AI config."""
return f"host={self._config.host} model={self._config.model}"

205
src/tai/cli.py Normal file
View File

@@ -0,0 +1,205 @@
"""CLI entrypoint for tai."""
from __future__ import annotations
import asyncio
from typing import Annotated
import typer
from rich.console import Console
from rich.markdown import Markdown
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.collectors import CollectionReport, collect_from_plan
from tai.input_parser import InputValidationError, build_request
from tai.models import TroubleshootRequest
from tai.plan import plan_from_request
from tai.prompt_builder import build_system_prompt, build_user_message
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig
app = typer.Typer(no_args_is_help=True, add_completion=False)
console = Console()
@app.command()
def run(
issue: Annotated[str, typer.Argument(help="Ticket text or issue summary.")],
host: Annotated[str, typer.Option("--host", help="Target host to troubleshoot.")],
port: Annotated[int, typer.Option("--port", help="SSH port for the target host.")] = 22,
path: Annotated[
list[str] | None,
typer.Option("--path", help="Path to inspect. Repeatable."),
] = None,
identity_file: Annotated[
str | None,
typer.Option("--identity-file", help="SSH private key path."),
] = None,
jump_host: Annotated[
str | None,
typer.Option("--jump-host", help="SSH bastion/jump host."),
] = None,
ignore_ssh_config: Annotated[
bool,
typer.Option(
"--ignore-ssh-config",
help="Ignore ~/.ssh/config and rely only on CLI options.",
),
] = False,
probe: Annotated[
bool,
typer.Option(
"--probe/--no-probe",
help="Enable or disable live SSH connectivity probe (uname -a).",
),
] = True,
collect: Annotated[
bool,
typer.Option(
"--collect/--no-collect",
help="Collect baseline diagnostics after probe.",
),
] = False,
analyze: Annotated[
bool,
typer.Option(
"--analyze/--no-analyze",
help="Send collected diagnostics to AI for analysis.",
),
] = False,
ai_host: Annotated[
str,
typer.Option("--ai-host", help="OpenAI-compatible AI backend URL."),
] = DEFAULT_AI_HOST,
model: Annotated[
str,
typer.Option("--model", help="Model name to use for AI analysis."),
] = DEFAULT_MODEL,
ai_key: Annotated[
str,
typer.Option("--ai-key", help="API key for the AI backend (not needed for Ollama)."),
] = "ollama",
) -> None:
"""Start an interactive troubleshooting session scaffold."""
try:
req = build_request(
issue=issue,
host=host,
port=port,
target_paths=path or [],
identity_file=identity_file,
jump_host=jump_host,
ignore_ssh_config=ignore_ssh_config,
)
except InputValidationError as exc:
console.print(f"[red]Input error:[/red] {exc}")
raise typer.Exit(code=2) from exc
config = SSHConnectionConfig(
host=req.host,
port=req.port,
identity_file=req.identity_file,
jump_host=req.jump_host,
ignore_ssh_config=req.ignore_ssh_config,
)
summary = SSHClient(config).summary()
console.print("[bold green]tai[/bold green]")
console.print(f"Issue: {req.issue}")
console.print(f"SSH: {summary}")
if req.target_paths:
console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}")
if not (probe or collect or analyze):
return # nothing SSH-related requested
ai_config = AIConfig(host=ai_host, model=model, api_key=ai_key)
if analyze:
console.print(f"[cyan]AI:[/cyan] {AIClient(ai_config).summary()}")
try:
asyncio.run(_async_main(config, req, probe=probe, collect=collect, analyze=analyze,
ai_config=ai_config))
except typer.Exit:
raise
except TimeoutError as exc:
console.print(f"[red]SSH timeout:[/red] {exc}")
raise typer.Exit(code=1) from exc
except OSError as exc:
console.print(f"[red]SSH error:[/red] unable to execute ssh: {exc}")
raise typer.Exit(code=1) from exc
async def _async_main(
config: SSHConnectionConfig,
req: TroubleshootRequest,
*,
probe: bool,
collect: bool,
analyze: bool,
ai_config: AIConfig,
) -> None:
"""Open a single SSH session and run probe / collection / analysis through it."""
client = SSHClient(config)
async with client.connect() as session:
if probe:
result = await session.probe()
_handle_probe_result(result)
report: CollectionReport | None = None
if collect or analyze:
plan = plan_from_request(req)
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
if analyze and report is not None:
_run_analysis(ai_config, req.issue, report)
def _handle_probe_result(result: SSHCommandResult) -> None:
"""Handle and render probe output for success or failure."""
console.print("[cyan]Running SSH probe:[/cyan] uname -a")
if result.exit_code != 0:
details = result.stderr or result.stdout or "no error output from ssh"
console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}")
raise typer.Exit(code=1)
output = result.stdout or "(no output)"
console.print("[bold green]Probe succeeded.[/bold green]")
console.print(f"Remote: {output}")
def _handle_collection_report(report: CollectionReport) -> None:
"""Render collected command status and truncation hints."""
console.print(
f"[bold]Collection complete:[/bold] {report.total} commands, {report.failed} failed"
)
for item in report.items:
status = "ok" if item.result.exit_code == 0 else f"exit {item.result.exit_code}"
truncated = item.result.stdout_truncated or item.result.stderr_truncated
trunc = " (truncated)" if truncated else ""
console.print(f"- {item.name}: {status}{trunc}")
def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) -> None:
"""Send collected data to the AI and stream the analysis to stdout."""
console.print("[cyan]Analyzing...[/cyan]\n")
ai = AIClient(ai_config)
system_prompt = build_system_prompt()
user_message = build_user_message(issue, report)
try:
chunks: list[str] = []
for chunk in ai.stream(system_prompt, user_message):
chunks.append(chunk)
console.print(Markdown("".join(chunks)))
except Exception as exc: # noqa: BLE001
console.print(f"[red]AI analysis failed:[/red] {exc}")
raise typer.Exit(code=1) from exc
def main() -> None:
"""Console script entrypoint."""
app()
if __name__ == "__main__":
main()

50
src/tai/collectors.py Normal file
View File

@@ -0,0 +1,50 @@
"""Data collection routines built on top of the SSH client."""
from dataclasses import dataclass
from tai.plan import CollectionPlan
from tai.ssh_client import SSHCommandResult, SSHSession
@dataclass(slots=True)
class CollectedItem:
"""Single collected diagnostic command result."""
name: str
result: SSHCommandResult
@dataclass(slots=True)
class CollectionReport:
"""Collection summary for a batch of diagnostics."""
host: str
items: list[CollectedItem]
@property
def total(self) -> int:
return len(self.items)
@property
def failed(self) -> int:
return sum(1 for item in self.items if item.result.exit_code != 0)
async def collect_from_plan(
session: SSHSession,
plan: CollectionPlan,
*,
max_output_bytes: int = 32768,
) -> CollectionReport:
"""Execute all commands in *plan* over a shared SSH session."""
items: list[CollectedItem] = []
for name, command in plan.commands:
result = await session.run_read_only_command(
command,
timeout_seconds=30.0,
max_output_bytes=max_output_bytes,
)
items.append(CollectedItem(name=name, result=result))
return CollectionReport(host=session._client.summary(), items=items)

46
src/tai/input_parser.py Normal file
View File

@@ -0,0 +1,46 @@
"""Helpers to normalize and validate CLI input."""
from pathlib import Path
from tai.models import TroubleshootRequest
class InputValidationError(ValueError):
"""Raised when required user input is missing or invalid."""
def build_request(
*,
issue: str,
host: str,
port: int,
target_paths: list[str],
identity_file: str | None,
jump_host: str | None,
ignore_ssh_config: bool,
) -> TroubleshootRequest:
"""Create a normalized request object from raw CLI values."""
normalized_issue = issue.strip()
normalized_host = host.strip()
if not normalized_issue:
raise InputValidationError("Issue description cannot be empty.")
if not normalized_host:
raise InputValidationError("Host cannot be empty.")
if port < 1 or port > 65535:
raise InputValidationError("Port must be between 1 and 65535.")
paths = [Path(p).expanduser() for p in target_paths]
identity = Path(identity_file).expanduser() if identity_file else None
return TroubleshootRequest(
issue=normalized_issue,
host=normalized_host,
port=port,
target_paths=paths,
identity_file=identity,
jump_host=jump_host.strip() if jump_host else None,
ignore_ssh_config=ignore_ssh_config,
)

17
src/tai/models.py Normal file
View File

@@ -0,0 +1,17 @@
"""Core domain models for tai."""
from dataclasses import dataclass, field
from pathlib import Path
@dataclass(slots=True)
class TroubleshootRequest:
"""User-provided troubleshooting input for a single run."""
issue: str
host: str
port: int = 22
target_paths: list[Path] = field(default_factory=list)
identity_file: Path | None = None
jump_host: str | None = None
ignore_ssh_config: bool = False

244
src/tai/plan.py Normal file
View File

@@ -0,0 +1,244 @@
"""Collection plan builder — decides what to collect based on the issue."""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from tai.models import TroubleshootRequest
# ---------------------------------------------------------------------------
# Keyword sets for issue classification
# ---------------------------------------------------------------------------
_SERVICE_KEYWORDS: frozenset[str] = frozenset(
{
"service",
"unit",
"daemon",
"failed",
"dead",
"inactive",
"crash",
"crashed",
"start",
"stop",
"restart",
"status",
"systemd",
"systemctl",
}
)
_NETWORK_KEYWORDS: frozenset[str] = frozenset(
{
"network",
"port",
"connect",
"connection",
"listen",
"firewall",
"route",
"routing",
"interface",
"dns",
"http",
"https",
"tcp",
"udp",
"socket",
"unreachable",
"refused",
"timeout",
"latency",
"bandwidth",
"packet",
}
)
_DISK_KEYWORDS: frozenset[str] = frozenset(
{
"disk",
"space",
"storage",
"inode",
"full",
"mount",
"filesystem",
"partition",
"quota",
"usage",
"capacity",
}
)
# ---------------------------------------------------------------------------
# Known service names and their candidate config paths
# ---------------------------------------------------------------------------
_KNOWN_SERVICES: list[str] = [
"apache2",
"httpd",
"nginx",
"mysql",
"mysqld",
"mariadb",
"postgresql",
"redis",
"redis-server",
"mongodb",
"mongod",
"docker",
"containerd",
"kubelet",
"sshd",
"postfix",
"dovecot",
"sendmail",
"php-fpm",
"elasticsearch",
"rabbitmq",
"rabbitmq-server",
"celery",
"gunicorn",
"ufw",
"fail2ban",
"cron",
"crond",
"rsyslog",
"auditd",
"firewalld",
"haproxy",
"varnish",
"memcached",
]
_SERVICE_CONFIGS: dict[str, list[str]] = {
"apache2": ["/etc/apache2/apache2.conf"],
"httpd": ["/etc/httpd/conf/httpd.conf"],
"nginx": ["/etc/nginx/nginx.conf"],
"mysql": ["/etc/mysql/mysql.conf.d/mysqld.cnf"],
"mysqld": ["/etc/my.cnf"],
"mariadb": ["/etc/mysql/mariadb.conf.d/50-server.cnf"],
"postgresql": ["/etc/postgresql"],
"sshd": ["/etc/ssh/sshd_config"],
"postfix": ["/etc/postfix/main.cf"],
"haproxy": ["/etc/haproxy/haproxy.cfg"],
"redis": ["/etc/redis/redis.conf"],
"redis-server": ["/etc/redis/redis.conf"],
"fail2ban": ["/etc/fail2ban/jail.conf"],
"ufw": ["/etc/ufw/ufw.conf"],
}
# ---------------------------------------------------------------------------
# Command sets
# ---------------------------------------------------------------------------
_ALWAYS: list[tuple[str, str]] = [
("kernel", "uname -a"),
("uptime", "cat /proc/uptime"),
("disk-usage", "df -h"),
("memory", "cat /proc/meminfo"),
("running-services", "systemctl list-units --type=service --state=running --no-pager"),
]
_SERVICE_EXTRA: list[tuple[str, str]] = [
("failed-services", "systemctl list-units --type=service --state=failed --no-pager"),
("journal-errors", "journalctl -p err -n 100 --no-pager"),
]
_NETWORK_EXTRA: list[tuple[str, str]] = [
("listening-ports", "ss -lntp"),
("ip-addresses", "ip addr show"),
("ip-routes", "ip route show"),
("ip-stats", "ip -s link show"),
]
_DISK_EXTRA: list[tuple[str, str]] = [
("disk-inodes", "df -i"),
("dmesg-disk", "dmesg -T --level=err,warn"),
("large-dirs", "du -sh /var /tmp /home /opt"),
]
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
@dataclass(slots=True)
class CollectionPlan:
"""Ordered list of (name, command) pairs to execute on a remote host."""
commands: list[tuple[str, str]] = field(default_factory=list)
def add(self, name: str, command: str) -> None:
self.commands.append((name, command))
def __len__(self) -> int:
return len(self.commands)
def plan_from_request(request: TroubleshootRequest) -> CollectionPlan:
"""Build a :class:`CollectionPlan` tailored to *request*."""
plan = CollectionPlan(commands=list(_ALWAYS))
keywords = _issue_words(request.issue)
# --- category expansions -------------------------------------------
if keywords & _SERVICE_KEYWORDS:
plan.commands.extend(_SERVICE_EXTRA)
if keywords & _NETWORK_KEYWORDS:
plan.commands.extend(_NETWORK_EXTRA)
if keywords & _DISK_KEYWORDS:
plan.commands.extend(_DISK_EXTRA)
# --- named service detection ---------------------------------------
services = _extract_services(request.issue)
seen: set[str] = set()
for svc in services:
if svc in seen:
continue
seen.add(svc)
plan.add(f"service-{svc}", f"systemctl status {svc}")
plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager")
for cfg_path in _SERVICE_CONFIGS.get(svc, []):
plan.add(f"config-{svc}", f"cat {cfg_path}")
# --- user-specified paths -----------------------------------------
for path in request.target_paths:
plan.add(f"ls-{path.name}", f"ls -la {path}")
if "log" in str(path).lower():
plan.add(
f"find-logs-{path.name}",
f"find {path} -maxdepth 2 -type f -name '*.log'",
)
else:
plan.add(
f"find-files-{path.name}",
f"find {path} -maxdepth 2 -type f",
)
return plan
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _issue_words(issue: str) -> set[str]:
"""Return the set of lowercase words in *issue*."""
return set(re.findall(r"\b\w+\b", issue.lower()))
def _extract_services(issue: str) -> list[str]:
"""Return known service names mentioned in *issue*."""
words = _issue_words(issue)
found: list[str] = []
for svc in _KNOWN_SERVICES:
# Match the service name or its stem (strip trailing 'd', e.g. 'apache' → 'apache2')
svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")}
if words & svc_words:
found.append(svc)
return found

74
src/tai/prompt_builder.py Normal file
View File

@@ -0,0 +1,74 @@
"""Formats collected diagnostics into prompts for the AI backend."""
from __future__ import annotations
from tai.collectors import CollectionReport
_SYSTEM_PROMPT = """\
You are an expert Linux systems administrator and troubleshooting assistant.
You are given diagnostic data collected read-only from a remote Linux host via SSH.
Your job:
1. Identify the root cause of the reported issue based only on the data provided.
2. Cite the specific output that supports your conclusion.
3. Give concise, actionable remediation steps.
Important rules:
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
rejected that specific command — it is not evidence about the service or system state.
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
additional commands or log files would be needed.
- Keep the response short. Skip sections that have nothing useful to say.
- Never suggest commands that modify the system unless explicitly asked.
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
"""
def build_system_prompt() -> str:
"""Return the static system prompt for the troubleshooting agent."""
return _SYSTEM_PROMPT.strip()
def build_user_message(issue: str, report: CollectionReport) -> str:
"""Format *issue* and *report* into the user message sent to the AI."""
lines: list[str] = []
lines.append(f"## Issue reported\n\n{issue}\n")
lines.append(f"## Target host\n\n{report.host}\n")
lines.append("## Collected diagnostics\n")
skipped: list[str] = []
for item in report.items:
result = item.result
# Exit 255 with no output = SSH couldn't execute the command at all.
# Exclude these entirely to prevent the AI from speculating on them.
if result.exit_code == 255 and not result.stdout and not result.stderr:
skipped.append(item.name)
continue
lines.append(f"### {item.name}\n")
lines.append(f"**Command:** `{result.command}` ")
lines.append(f"**Exit code:** {result.exit_code}\n")
if result.stdout:
trunc = " *(output truncated)*" if result.stdout_truncated else ""
lines.append(f"**stdout:**{trunc}\n```\n{result.stdout.strip()}\n```\n")
if result.stderr:
trunc = " *(output truncated)*" if result.stderr_truncated else ""
lines.append(f"**stderr:**{trunc}\n```\n{result.stderr.strip()}\n```\n")
if not result.stdout and not result.stderr:
lines.append("*(no output)*\n")
if skipped:
lines.append(
f"**Note:** The following commands could not be executed on this host "
f"and produced no output: {', '.join(skipped)}. "
f"Do not draw any conclusions from their absence.\n"
)
return "\n".join(lines)

383
src/tai/ssh_client.py Normal file
View File

@@ -0,0 +1,383 @@
"""SSH configuration and read-only command execution."""
import asyncio
import os
import shlex
import tempfile
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType
@dataclass(slots=True)
class SSHConnectionConfig:
"""Connection parameters for the target host."""
host: str
port: int = 22
identity_file: Path | None = None
jump_host: str | None = None
ignore_ssh_config: bool = False
@dataclass(slots=True)
class SSHCommandResult:
"""Result of a remote SSH command execution."""
command: str
exit_code: int
stdout: str
stderr: str
stdout_truncated: bool = False
stderr_truncated: bool = False
class SSHCommandRejectedError(ValueError):
"""Raised when a command violates read-only policy."""
class SSHClient:
"""Wrapper around SSH operations with read-only safeguards."""
_BLOCKED_TOKENS = {
">",
">>",
"<",
"|",
"&&",
"||",
";",
}
_READ_ONLY_COMMANDS = {
"cat",
"dmesg",
"df",
"du",
"find",
"grep",
"head",
"hostnamectl",
"ip",
"journalctl",
"ls",
"netstat",
"sed",
"ss",
"stat",
"systemctl",
"tail",
"uname",
}
_READ_ONLY_SYSTEMCTL_SUBCOMMANDS = {
"cat",
"is-active",
"is-failed",
"list-unit-files",
"list-units",
"show",
"status",
}
def __init__(self, config: SSHConnectionConfig) -> None:
self._config = config
def summary(self) -> str:
"""Return a short summary of connection settings."""
mode = "ignore ssh config" if self._config.ignore_ssh_config else "use ssh config"
jump = self._config.jump_host or "none"
key = str(self._config.identity_file) if self._config.identity_file else "auto"
return (
f"host={self._config.host} port={self._config.port} "
f"key={key} jump={jump} mode={mode}"
)
def _build_base_argv(self) -> list[str]:
"""Build the common SSH argv flags (no host or command appended)."""
argv = [
"ssh",
"-p",
str(self._config.port),
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=15",
"-o",
"StrictHostKeyChecking=accept-new",
]
if self._config.ignore_ssh_config:
argv += ["-F", "/dev/null"]
if self._config.identity_file:
argv += ["-i", str(self._config.identity_file)]
if self._config.jump_host:
argv += ["-J", self._config.jump_host]
return argv
def build_ssh_argv(self, remote_command: str) -> list[str]:
"""Build argv for a secure non-interactive SSH invocation."""
return self._build_base_argv() + [self._config.host, remote_command]
def connect(self, *, connect_timeout: float = 15.0) -> "SSHSession":
"""Return an :class:`SSHSession` async context manager for this host."""
return SSHSession(self, connect_timeout=connect_timeout)
def validate_read_only_command(self, command: str) -> None:
"""Validate that a command appears read-only and non-destructive."""
normalized = command.strip()
if not normalized:
raise SSHCommandRejectedError("Command cannot be empty.")
for token in self._BLOCKED_TOKENS:
if token in normalized:
raise SSHCommandRejectedError(
f"Command contains blocked shell operator: {token}"
)
parts = shlex.split(normalized)
if not parts:
raise SSHCommandRejectedError("Command cannot be empty.")
base = parts[0]
if base not in self._READ_ONLY_COMMANDS:
raise SSHCommandRejectedError(
f"Command '{base}' is not allowed by read-only policy."
)
if base == "systemctl":
if len(parts) < 2:
raise SSHCommandRejectedError("systemctl requires a subcommand.")
subcommand = parts[1]
if subcommand not in self._READ_ONLY_SYSTEMCTL_SUBCOMMANDS:
raise SSHCommandRejectedError(
f"systemctl subcommand '{subcommand}' is not read-only."
)
async def run_read_only_command(
self,
command: str,
*,
timeout_seconds: float = 30.0,
max_output_bytes: int = 32768,
) -> SSHCommandResult:
"""Run a validated read-only command over SSH."""
self.validate_read_only_command(command)
return await self._run_ssh(
command,
timeout_seconds=timeout_seconds,
max_output_bytes=max_output_bytes,
)
async def probe(self) -> SSHCommandResult:
"""Probe connectivity using a harmless remote command."""
return await self._run_ssh(
"uname -a",
timeout_seconds=15.0,
max_output_bytes=4096,
)
async def _run_ssh( self,
command: str,
*,
timeout_seconds: float,
max_output_bytes: int,
) -> SSHCommandResult:
argv = self.build_ssh_argv(command)
proc = await asyncio.create_subprocess_exec(
*argv,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
proc.communicate(),
timeout=timeout_seconds,
)
except TimeoutError as exc:
proc.kill()
await proc.wait()
raise TimeoutError(
f"SSH command timed out after {timeout_seconds} seconds: {command}"
) from exc
if proc.returncode is None:
raise RuntimeError("SSH process did not provide an exit code.")
stdout_text, stdout_truncated = self._truncate_output(
stdout_bytes.decode("utf-8", errors="replace"),
max_output_bytes=max_output_bytes,
)
stderr_text, stderr_truncated = self._truncate_output(
stderr_bytes.decode("utf-8", errors="replace"),
max_output_bytes=max_output_bytes,
)
return SSHCommandResult(
command=command,
exit_code=proc.returncode,
stdout=stdout_text,
stderr=stderr_text,
stdout_truncated=stdout_truncated,
stderr_truncated=stderr_truncated,
)
@staticmethod
def _truncate_output(text: str, *, max_output_bytes: int) -> tuple[str, bool]:
"""Trim output to a maximum byte length while preserving UTF-8 validity."""
if max_output_bytes < 256:
raise ValueError("max_output_bytes must be at least 256.")
encoded = text.encode("utf-8", errors="replace")
if len(encoded) <= max_output_bytes:
return text.strip(), False
marker = "\n...[truncated]"
marker_bytes = marker.encode("utf-8")
keep = max_output_bytes - len(marker_bytes)
trimmed_bytes = encoded[:keep]
trimmed_text = trimmed_bytes.decode("utf-8", errors="ignore").rstrip()
return f"{trimmed_text}{marker}", True
class SSHSession:
"""A persistent SSH connection using ControlMaster multiplexing.
All commands run over the same underlying TCP connection — no per-command
SSH handshake. Use as an async context manager::
async with client.connect() as session:
result = await session.run_read_only_command("df -h")
"""
def __init__(self, client: SSHClient, *, connect_timeout: float = 15.0) -> None:
self._client = client
self._connect_timeout = connect_timeout
self._socket_path: Path | None = None
self._master_proc: asyncio.subprocess.Process | None = None
async def __aenter__(self) -> "SSHSession":
fd, path = tempfile.mkstemp(prefix="tai-ssh-", suffix=".sock")
os.close(fd)
os.unlink(path) # SSH needs to create this itself as a socket
self._socket_path = Path(path)
master_argv = self._client._build_base_argv() + [
"-o", "ControlMaster=yes",
"-o", f"ControlPath={self._socket_path}",
"-o", "ControlPersist=no",
"-N",
self._client._config.host,
]
self._master_proc = await asyncio.create_subprocess_exec(
*master_argv,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
# Wait for the control socket to appear (master is ready)
loop = asyncio.get_event_loop()
deadline = loop.time() + self._connect_timeout
while not self._socket_path.exists():
if loop.time() > deadline:
await self._teardown()
raise TimeoutError(
f"SSH ControlMaster did not connect within "
f"{self._connect_timeout}s to {self._client._config.host}"
)
await asyncio.sleep(0.05)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self._teardown()
async def _teardown(self) -> None:
if self._master_proc and self._master_proc.returncode is None:
self._master_proc.terminate()
try:
await asyncio.wait_for(self._master_proc.wait(), timeout=5.0)
except TimeoutError:
self._master_proc.kill()
if self._socket_path and self._socket_path.exists():
self._socket_path.unlink(missing_ok=True)
def _command_argv(self, remote_command: str) -> list[str]:
return self._client._build_base_argv() + [
"-o", f"ControlPath={self._socket_path}",
"-o", "ControlMaster=no",
self._client._config.host,
remote_command,
]
async def probe(self) -> SSHCommandResult:
"""Run uname -a to confirm connectivity."""
return await self._run("uname -a", timeout_seconds=15.0, max_output_bytes=4096)
async def run_read_only_command(
self,
command: str,
*,
timeout_seconds: float = 30.0,
max_output_bytes: int = 32768,
) -> SSHCommandResult:
"""Validate and run a read-only command over the shared connection."""
self._client.validate_read_only_command(command)
return await self._run(
command, timeout_seconds=timeout_seconds, max_output_bytes=max_output_bytes
)
async def _run(
self,
command: str,
*,
timeout_seconds: float,
max_output_bytes: int,
) -> SSHCommandResult:
argv = self._command_argv(command)
proc = await asyncio.create_subprocess_exec(
*argv,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
proc.communicate(), timeout=timeout_seconds
)
except TimeoutError as exc:
proc.kill()
await proc.wait()
raise TimeoutError(
f"SSH command timed out after {timeout_seconds}s: {command}"
) from exc
if proc.returncode is None:
raise RuntimeError("SSH process did not provide an exit code.")
stdout_text, stdout_truncated = SSHClient._truncate_output(
stdout_bytes.decode("utf-8", errors="replace"),
max_output_bytes=max_output_bytes,
)
stderr_text, stderr_truncated = SSHClient._truncate_output(
stderr_bytes.decode("utf-8", errors="replace"),
max_output_bytes=max_output_bytes,
)
return SSHCommandResult(
command=command,
exit_code=proc.returncode,
stdout=stdout_text,
stderr=stderr_text,
stdout_truncated=stdout_truncated,
stderr_truncated=stderr_truncated,
)

192
tests/test_ai.py Normal file
View File

@@ -0,0 +1,192 @@
"""Tests for the AI client and prompt builder."""
from unittest.mock import MagicMock, patch
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.collectors import CollectedItem, CollectionReport
from tai.prompt_builder import build_system_prompt, build_user_message
from tai.ssh_client import SSHCommandResult
# ---------------------------------------------------------------------------
# AIConfig defaults
# ---------------------------------------------------------------------------
def test_ai_config_defaults() -> None:
config = AIConfig()
assert config.host == DEFAULT_AI_HOST
assert config.model == DEFAULT_MODEL
assert config.api_key == "ollama"
def test_ai_config_custom_values() -> None:
config = AIConfig(host="https://api.openai.com/v1", model="gpt-4o", api_key="sk-test")
assert config.host == "https://api.openai.com/v1"
assert config.model == "gpt-4o"
assert config.api_key == "sk-test"
# ---------------------------------------------------------------------------
# AIClient.summary
# ---------------------------------------------------------------------------
def test_ai_client_summary_contains_host_and_model() -> None:
config = AIConfig(host="http://myserver:11434/v1", model="llama3.1:8b")
client = AIClient(config)
summary = client.summary()
assert "http://myserver:11434/v1" in summary
assert "llama3.1:8b" in summary
# ---------------------------------------------------------------------------
# AIClient.complete (mocked)
# ---------------------------------------------------------------------------
def _make_mock_response(content: str, model: str = "gemma3:4b") -> MagicMock:
usage = MagicMock()
usage.prompt_tokens = 10
usage.completion_tokens = 20
message = MagicMock()
message.content = content
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
response.model = model
response.usage = usage
return response
def test_complete_returns_ai_response() -> None:
config = AIConfig()
client = AIClient(config)
mock_response = _make_mock_response("The root cause is X.")
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
result = client.complete("system prompt", "user message")
assert result.content == "The root cause is X."
assert result.prompt_tokens == 10
assert result.completion_tokens == 20
assert result.total_tokens == 30
def test_complete_handles_empty_content() -> None:
config = AIConfig()
client = AIClient(config)
mock_response = _make_mock_response(None) # type: ignore[arg-type]
mock_response.choices[0].message.content = None
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
result = client.complete("system", "user")
assert result.content == ""
# ---------------------------------------------------------------------------
# AIClient.stream (mocked)
# ---------------------------------------------------------------------------
def test_stream_yields_chunks() -> None:
config = AIConfig()
client = AIClient(config)
def _make_chunk(text: str | None) -> MagicMock:
delta = MagicMock()
delta.content = text
choice = MagicMock()
choice.delta = delta
chunk = MagicMock()
chunk.choices = [choice]
return chunk
mock_chunks = [
_make_chunk("Root "), _make_chunk("cause "), _make_chunk(None), _make_chunk("found."),
]
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
result = list(client.stream("system", "user"))
assert result == ["Root ", "cause ", "found."]
# ---------------------------------------------------------------------------
# prompt_builder
# ---------------------------------------------------------------------------
def _make_report(items: list[tuple[str, str, int, str, str]]) -> CollectionReport:
"""Build a CollectionReport from (name, command, exit_code, stdout, stderr) tuples."""
return CollectionReport(
host="root@testhost",
items=[
CollectedItem(
name=name,
result=SSHCommandResult(
command=command,
exit_code=exit_code,
stdout=stdout,
stderr=stderr,
),
)
for name, command, exit_code, stdout, stderr in items
],
)
def test_build_system_prompt_contains_key_instructions() -> None:
prompt = build_system_prompt()
assert "Root Cause" in prompt
assert "Evidence" in prompt
assert "Recommended Actions" in prompt
assert "read-only" in prompt.lower()
def test_build_user_message_contains_issue_and_host() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
msg = build_user_message("nginx is failing", report)
assert "nginx is failing" in msg
assert "root@testhost" in msg
def test_build_user_message_includes_command_output() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01 6.1.0", "")])
msg = build_user_message("test issue", report)
assert "uname -a" in msg
assert "Linux web01 6.1.0" in msg
def test_build_user_message_shows_stderr() -> None:
report = _make_report(
[("svc", "systemctl status nginx", 3, "", "Unit nginx.service not found.")]
)
msg = build_user_message("nginx not found", report)
assert "Unit nginx.service not found." in msg
def test_build_user_message_notes_truncation() -> None:
result = SSHCommandResult(
command="journalctl -n 100 --no-pager",
exit_code=0,
stdout="...",
stderr="",
stdout_truncated=True,
)
report = CollectionReport(
host="root@testhost",
items=[CollectedItem(name="journal", result=result)],
)
msg = build_user_message("disk issue", report)
assert "truncated" in msg
def test_build_user_message_handles_no_output() -> None:
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
msg = build_user_message("test", report)
assert "no output" in msg

141
tests/test_cli.py Normal file
View File

@@ -0,0 +1,141 @@
from unittest.mock import AsyncMock, MagicMock
from typer.testing import CliRunner
from tai.cli import app
from tai.collectors import CollectedItem, CollectionReport
from tai.ssh_client import SSHCommandResult
def _mock_session(
monkeypatch, # type: ignore[no-untyped-def]
*,
probe_result: SSHCommandResult | None = None,
probe_raises: Exception | None = None,
) -> MagicMock:
"""Patch SSHClient.connect to return a mock session."""
session = MagicMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)
if probe_raises:
session.probe = AsyncMock(side_effect=probe_raises)
else:
session.probe = AsyncMock(return_value=probe_result)
monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session)
return session
def test_run_command_prints_scaffold_summary() -> None:
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"web01",
"--port",
"5566",
"--no-probe",
"--path",
"/etc/apache2",
"--jump-host",
"bastion01",
"--ignore-ssh-config",
],
)
assert result.exit_code == 0
assert "tai" in result.stdout
assert "host=web01" in result.stdout
assert "port=5566" in result.stdout
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(
monkeypatch,
probe_result=SSHCommandResult(
command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr=""
),
)
runner = CliRunner()
result = runner.invoke(
app,
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
)
assert result.exit_code == 0
assert "Probe succeeded" in result.stdout
assert "Linux ssh 6.12.0" in result.stdout
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(
monkeypatch,
probe_result=SSHCommandResult(
command="uname -a",
exit_code=255,
stdout="",
stderr="Permission denied (publickey,password).",
),
)
runner = CliRunner()
result = runner.invoke(
app,
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
)
assert result.exit_code == 1
assert "Probe failed" in result.stdout
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
CollectedItem(
name="journal",
result=SSHCommandResult(
command="journalctl -n 200",
exit_code=0,
stdout="...",
stderr="",
stdout_truncated=True,
),
),
],
)
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--collect",
],
)
assert result.exit_code == 0
assert "Collection complete" in result.stdout
assert "kernel: ok" in result.stdout
assert "journal: ok (truncated)" in result.stdout

View File

@@ -0,0 +1,65 @@
from pathlib import Path
import pytest
from tai.input_parser import InputValidationError, build_request
def test_build_request_normalizes_values() -> None:
req = build_request(
issue=" apache fails to start ",
host=" web01 ",
port=5566,
target_paths=["/etc/apache2", "~/logs"],
identity_file="~/.ssh/id_ed25519",
jump_host=" bastion01 ",
ignore_ssh_config=True,
)
assert req.issue == "apache fails to start"
assert req.host == "web01"
assert req.port == 5566
assert req.target_paths[0] == Path("/etc/apache2")
assert req.target_paths[1] == Path("~/logs").expanduser()
assert req.identity_file == Path("~/.ssh/id_ed25519").expanduser()
assert req.jump_host == "bastion01"
assert req.ignore_ssh_config is True
def test_build_request_rejects_empty_issue() -> None:
with pytest.raises(InputValidationError):
build_request(
issue=" ",
host="web01",
port=22,
target_paths=[],
identity_file=None,
jump_host=None,
ignore_ssh_config=False,
)
def test_build_request_rejects_empty_host() -> None:
with pytest.raises(InputValidationError):
build_request(
issue="apache down",
host=" ",
port=22,
target_paths=[],
identity_file=None,
jump_host=None,
ignore_ssh_config=False,
)
def test_build_request_rejects_invalid_port() -> None:
with pytest.raises(InputValidationError):
build_request(
issue="apache down",
host="web01",
port=70000,
target_paths=[],
identity_file=None,
jump_host=None,
ignore_ssh_config=False,
)

169
tests/test_plan.py Normal file
View File

@@ -0,0 +1,169 @@
"""Tests for the collection plan builder."""
from pathlib import Path
from tai.models import TroubleshootRequest
from tai.plan import CollectionPlan, _extract_services, _issue_words, plan_from_request
def _req(issue: str, paths: list[str] | None = None) -> TroubleshootRequest:
return TroubleshootRequest(
issue=issue,
host="root@testhost",
target_paths=[Path(p) for p in (paths or [])],
)
def _commands(plan: CollectionPlan) -> list[str]:
"""Return flat list of command strings from plan."""
return [cmd for _, cmd in plan.commands]
def _names(plan: CollectionPlan) -> list[str]:
return [name for name, _ in plan.commands]
# ---------------------------------------------------------------------------
# Always-present commands
# ---------------------------------------------------------------------------
def test_plan_always_has_baseline_commands() -> None:
plan = plan_from_request(_req("some generic issue"))
cmds = _commands(plan)
assert any("uname -a" in c for c in cmds)
assert any("df -h" in c for c in cmds)
assert any("proc/meminfo" in c for c in cmds)
assert any("systemctl list-units" in c for c in cmds)
# ---------------------------------------------------------------------------
# Keyword-based category expansion
# ---------------------------------------------------------------------------
def test_service_keywords_add_failed_services_check() -> None:
plan = plan_from_request(_req("service failed to start"))
cmds = _commands(plan)
assert any("--state=failed" in c for c in cmds)
assert any("journalctl -p err" in c for c in cmds)
def test_network_keywords_add_network_commands() -> None:
plan = plan_from_request(_req("connection refused on port 80"))
cmds = _commands(plan)
assert any("ss -lntp" in c for c in cmds)
assert any("ip addr show" in c for c in cmds)
assert any("ip route show" in c for c in cmds)
def test_disk_keywords_add_disk_commands() -> None:
plan = plan_from_request(_req("disk full filesystem usage critical"))
cmds = _commands(plan)
assert any("df -i" in c for c in cmds)
assert any("dmesg" in c for c in cmds)
assert any("du -sh" in c for c in cmds)
def test_unrelated_issue_does_not_add_network_commands() -> None:
plan = plan_from_request(_req("apache service crashed"))
cmds = _commands(plan)
assert not any("ip route show" in c for c in cmds)
# ---------------------------------------------------------------------------
# Named service detection
# ---------------------------------------------------------------------------
def test_nginx_in_issue_adds_nginx_service_commands() -> None:
plan = plan_from_request(_req("nginx is failing to start"))
names = _names(plan)
cmds = _commands(plan)
assert "service-nginx" in names
assert "journal-nginx" in names
assert any("systemctl status nginx" in c for c in cmds)
assert any("journalctl -u nginx" in c for c in cmds)
def test_apache2_adds_config_cat() -> None:
plan = plan_from_request(_req("apache2 service check"))
cmds = _commands(plan)
assert any("cat /etc/apache2/apache2.conf" in c for c in cmds)
def test_sshd_adds_config_cat() -> None:
plan = plan_from_request(_req("sshd connection problems"))
cmds = _commands(plan)
assert any("cat /etc/ssh/sshd_config" in c for c in cmds)
def test_unknown_service_name_no_config_cat() -> None:
plan = plan_from_request(_req("myweirdapp service crashed"))
cmds = _commands(plan)
assert not any("cat /etc" in c for c in cmds)
def test_duplicate_service_name_not_repeated() -> None:
plan = plan_from_request(_req("nginx nginx nginx"))
names = _names(plan)
assert names.count("service-nginx") == 1
# ---------------------------------------------------------------------------
# Target path handling
# ---------------------------------------------------------------------------
def test_target_path_adds_ls_and_find() -> None:
plan = plan_from_request(_req("app crash", paths=["/opt/myapp"]))
cmds = _commands(plan)
assert any("ls -la /opt/myapp" in c for c in cmds)
assert any("find /opt/myapp" in c for c in cmds)
def test_log_path_uses_log_find_pattern() -> None:
plan = plan_from_request(_req("app errors", paths=["/var/log/myapp"]))
cmds = _commands(plan)
assert any("*.log" in c for c in cmds)
def test_non_log_path_uses_generic_find() -> None:
plan = plan_from_request(_req("config issue", paths=["/etc/myapp"]))
cmds = _commands(plan)
assert any("find /etc/myapp" in c and "*.log" not in c for c in cmds)
# ---------------------------------------------------------------------------
# Helper unit tests
# ---------------------------------------------------------------------------
def test_issue_words_lowercases_and_splits() -> None:
words = _issue_words("Apache Service FAILED")
assert "apache" in words
assert "service" in words
assert "failed" in words
def test_extract_services_finds_nginx() -> None:
assert "nginx" in _extract_services("nginx is down")
def test_extract_services_finds_nothing_for_unknown() -> None:
assert _extract_services("the widget is broken") == []
def test_extract_services_case_insensitive() -> None:
assert "nginx" in _extract_services("NGINX failed")
# ---------------------------------------------------------------------------
# Plan length sanity
# ---------------------------------------------------------------------------
def test_plain_issue_has_only_always_commands() -> None:
plan = plan_from_request(_req("something went wrong"))
# Only _ALWAYS (5 commands), no category expansion, no service, no paths
assert len(plan) == 5

108
tests/test_ssh_client.py Normal file
View File

@@ -0,0 +1,108 @@
from pathlib import Path
import pytest
from tai.ssh_client import SSHClient, SSHCommandRejectedError, SSHConnectionConfig
def _client(**kwargs: object) -> SSHClient:
host = str(kwargs.get("host", "root@ssh.archflux.net"))
port_value = kwargs.get("port", 22)
if not isinstance(port_value, int):
raise TypeError("port must be an int")
port = port_value
identity_file = kwargs.get("identity_file")
jump_host = kwargs.get("jump_host")
ignore_ssh_config = bool(kwargs.get("ignore_ssh_config", False))
if identity_file is not None and not isinstance(identity_file, Path):
raise TypeError("identity_file must be a Path or None")
if jump_host is not None and not isinstance(jump_host, str):
raise TypeError("jump_host must be a string or None")
return SSHClient(
SSHConnectionConfig(
host=host,
port=port,
identity_file=identity_file,
jump_host=jump_host,
ignore_ssh_config=ignore_ssh_config,
)
)
def test_summary_includes_expected_defaults() -> None:
client = _client()
text = client.summary()
assert "host=root@ssh.archflux.net" in text
assert "port=22" in text
assert "key=auto" in text
assert "jump=none" in text
assert "mode=use ssh config" in text
def test_build_ssh_argv_respects_flags() -> None:
client = _client(
identity_file=Path("/root/.ssh/id_ed25519"),
jump_host="bastion.archflux.net",
ignore_ssh_config=True,
)
argv = client.build_ssh_argv("uname -a")
assert argv[0] == "ssh"
assert "-p" in argv
assert "22" in argv
assert "-F" in argv
assert "/dev/null" in argv
assert "-i" in argv
assert "/root/.ssh/id_ed25519" in argv
assert "-J" in argv
assert "bastion.archflux.net" in argv
assert argv[-2] == "root@ssh.archflux.net"
assert argv[-1] == "uname -a"
def test_rejects_destructive_or_shell_operator_commands() -> None:
client = _client()
for command in ["rm -rf /tmp/x", "cat /etc/hosts | grep localhost", "uname -a; id"]:
with pytest.raises(SSHCommandRejectedError):
client.validate_read_only_command(command)
def test_allows_expected_read_only_commands() -> None:
client = _client()
for command in [
"uname -a",
"journalctl -n 100",
"systemctl status apache2",
"cat /etc/hosts",
"ss -lntp",
]:
client.validate_read_only_command(command)
def test_rejects_non_read_only_systemctl_subcommand() -> None:
client = _client()
with pytest.raises(SSHCommandRejectedError):
client.validate_read_only_command("systemctl restart apache2")
def test_truncate_output_marks_and_limits_content() -> None:
text = "a" * 400
rendered, truncated = SSHClient._truncate_output(text, max_output_bytes=256)
assert truncated is True
assert rendered.endswith("...[truncated]")
def test_truncate_output_keeps_short_content() -> None:
rendered, truncated = SSHClient._truncate_output("short output", max_output_bytes=256)
assert truncated is False
assert rendered == "short output"