Merge pull request 'initialCommit' (#2) from initialCommit into main
Reviewed-on: #2
This commit is contained in:
101
.gitea/workflows/ci.yml
Normal file
101
.gitea/workflows/ci.yml
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Ensure git is available
|
||||||
|
run: |
|
||||||
|
if command -v git >/dev/null 2>&1; then
|
||||||
|
git --version
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if command -v apt-get >/dev/null 2>&1; then
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y git
|
||||||
|
elif command -v dnf >/dev/null 2>&1; then
|
||||||
|
dnf install -y git
|
||||||
|
elif command -v yum >/dev/null 2>&1; then
|
||||||
|
yum install -y git
|
||||||
|
else
|
||||||
|
echo "No supported package manager found to install git."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
git --version
|
||||||
|
|
||||||
|
- name: Checkout source (native git)
|
||||||
|
env:
|
||||||
|
CI_GIT_TOKEN: ${{ secrets.CI_GIT_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "${CI_GIT_TOKEN:-}" ]; then
|
||||||
|
echo "Missing secret CI_GIT_TOKEN. Add it in repository Actions secrets."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
auth_server="${GITHUB_SERVER_URL#https://}"
|
||||||
|
auth_server="${auth_server#http://}"
|
||||||
|
remote_url="https://oauth2:${CI_GIT_TOKEN}@${auth_server}/${GITHUB_REPOSITORY}.git"
|
||||||
|
|
||||||
|
if [ -n "${GITHUB_WORKSPACE:-}" ]; then
|
||||||
|
cd "$GITHUB_WORKSPACE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d .git ]; then
|
||||||
|
git init
|
||||||
|
fi
|
||||||
|
|
||||||
|
git remote remove origin >/dev/null 2>&1 || true
|
||||||
|
git remote add origin "$remote_url"
|
||||||
|
|
||||||
|
git fetch --depth 1 origin "$GITHUB_SHA"
|
||||||
|
git checkout --force FETCH_HEAD
|
||||||
|
|
||||||
|
- name: Ensure Python and pip are available
|
||||||
|
run: |
|
||||||
|
if command -v python3 >/dev/null 2>&1 && python3 -m pip --version >/dev/null 2>&1; then
|
||||||
|
python3 --version
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if command -v apt-get >/dev/null 2>&1; then
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y python3 python3-pip python3-venv
|
||||||
|
elif command -v dnf >/dev/null 2>&1; then
|
||||||
|
dnf install -y python3 python3-pip
|
||||||
|
elif command -v yum >/dev/null 2>&1; then
|
||||||
|
yum install -y python3 python3-pip
|
||||||
|
else
|
||||||
|
echo "No supported package manager found to install Python."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
python3 --version
|
||||||
|
|
||||||
|
- name: Install package and dev dependencies
|
||||||
|
run: |
|
||||||
|
python3 -m venv .venv
|
||||||
|
. .venv/bin/activate
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install -e .[dev]
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: .venv/bin/python -m ruff check .
|
||||||
|
|
||||||
|
- name: Lint Markdown
|
||||||
|
run: .venv/bin/mdformat --check README.md ROADMAP.md CHANGELOG.md
|
||||||
|
|
||||||
|
- name: Lint YAML
|
||||||
|
run: .venv/bin/yamllint .
|
||||||
|
|
||||||
|
- name: Type-check
|
||||||
|
run: .venv/bin/python -m mypy src
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: .venv/bin/python -m pytest
|
||||||
110
.gitea/workflows/release.yml
Normal file
110
.gitea/workflows/release.yml
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Ensure git is available
|
||||||
|
run: |
|
||||||
|
if command -v git >/dev/null 2>&1; then
|
||||||
|
git --version
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if command -v apt-get >/dev/null 2>&1; then
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y git
|
||||||
|
elif command -v dnf >/dev/null 2>&1; then
|
||||||
|
dnf install -y git
|
||||||
|
elif command -v yum >/dev/null 2>&1; then
|
||||||
|
yum install -y git
|
||||||
|
else
|
||||||
|
echo "No supported package manager found to install git."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Checkout source (native git)
|
||||||
|
env:
|
||||||
|
CI_GIT_TOKEN: ${{ secrets.CI_GIT_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "${CI_GIT_TOKEN:-}" ]; then
|
||||||
|
echo "Missing secret CI_GIT_TOKEN. Add it in repository Actions secrets."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
auth_server="${GITHUB_SERVER_URL#https://}"
|
||||||
|
auth_server="${auth_server#http://}"
|
||||||
|
remote_url="https://oauth2:${CI_GIT_TOKEN}@${auth_server}/${GITHUB_REPOSITORY}.git"
|
||||||
|
|
||||||
|
if [ -n "${GITHUB_WORKSPACE:-}" ]; then
|
||||||
|
cd "$GITHUB_WORKSPACE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d .git ]; then
|
||||||
|
git init
|
||||||
|
fi
|
||||||
|
|
||||||
|
git remote remove origin >/dev/null 2>&1 || true
|
||||||
|
git remote add origin "$remote_url"
|
||||||
|
|
||||||
|
# Fetch the tag by SHA so we get the exact tagged commit
|
||||||
|
git fetch --depth 1 origin "$GITHUB_SHA"
|
||||||
|
git checkout --force FETCH_HEAD
|
||||||
|
|
||||||
|
- name: Ensure Python and build dependencies are available
|
||||||
|
run: |
|
||||||
|
if ! command -v python3 >/dev/null 2>&1; then
|
||||||
|
if command -v apt-get >/dev/null 2>&1; then
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y python3 python3-pip python3-venv patchelf ccache
|
||||||
|
elif command -v dnf >/dev/null 2>&1; then
|
||||||
|
dnf install -y python3 python3-pip patchelf ccache
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# patchelf is required by Nuitka for standalone Linux binaries
|
||||||
|
command -v patchelf >/dev/null 2>&1 || {
|
||||||
|
apt-get update && apt-get install -y patchelf
|
||||||
|
}
|
||||||
|
|
||||||
|
python3 --version
|
||||||
|
|
||||||
|
- name: Set up venv and install package + build deps
|
||||||
|
run: |
|
||||||
|
python3 -m venv .venv
|
||||||
|
. .venv/bin/activate
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install -e ".[build]"
|
||||||
|
|
||||||
|
- name: Derive version from tag
|
||||||
|
id: version
|
||||||
|
run: echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Build standalone binary with Nuitka
|
||||||
|
run: |
|
||||||
|
. .venv/bin/activate
|
||||||
|
python -m nuitka \
|
||||||
|
--standalone \
|
||||||
|
--onefile \
|
||||||
|
--output-filename=tai \
|
||||||
|
--output-dir=dist \
|
||||||
|
--assume-yes-for-downloads \
|
||||||
|
--include-package=tai \
|
||||||
|
src/tai/cli.py
|
||||||
|
|
||||||
|
- name: Smoke-test the binary
|
||||||
|
run: dist/tai --help
|
||||||
|
|
||||||
|
- name: Upload binary artifact
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: tai-linux-amd64-${{ steps.version.outputs.tag }}
|
||||||
|
path: dist/tai
|
||||||
|
if-no-files-found: error
|
||||||
|
retention-days: 90
|
||||||
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Python cache and bytecode
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.pyo
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
|
||||||
|
# Tool caches
|
||||||
|
.pytest_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
|
||||||
|
# Build artifacts
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
*.egg-info/
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Coverage
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
16
.yamllint.yml
Normal file
16
.yamllint.yml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
extends: default
|
||||||
|
|
||||||
|
ignore: |
|
||||||
|
.git/
|
||||||
|
.venv/
|
||||||
|
.mypy_cache/
|
||||||
|
.pytest_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
rules:
|
||||||
|
document-start: disable
|
||||||
|
line-length:
|
||||||
|
max: 120
|
||||||
|
truthy:
|
||||||
|
allowed-values: ["true", "false"]
|
||||||
|
check-keys: false
|
||||||
46
CHANGELOG.md
Normal file
46
CHANGELOG.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- `README.md` — project overview, description, example workflow, supported distributions, and suggested tooling
|
||||||
|
- `ROADMAP.md` — phased development plan covering decisions, data collection, AI integration, CLI design, and hardening
|
||||||
|
- `CHANGELOG.md` — this file; established changelog tracking for the project
|
||||||
|
- `.gitea/workflows/ci.yml` — Gitea Actions CI workflow for push and pull request events
|
||||||
|
- Gitea CI now uses native `git` checkout and system Python setup to avoid host-executor JavaScript action path issues
|
||||||
|
- Gitea native checkout now uses `CI_GIT_TOKEN` repository secret for authenticated fetch from private repos
|
||||||
|
- Gitea CI now installs dependencies in a local `.venv` to avoid Debian/PEP 668 externally-managed pip errors
|
||||||
|
- Python package scaffold with `src` layout and project metadata in `pyproject.toml`
|
||||||
|
- Initial CLI entrypoint with agreed SSH flags: `--identity-file`, `--jump-host`, and `--ignore-ssh-config`
|
||||||
|
- Input parsing/validation module and core request model
|
||||||
|
- SSH configuration scaffold module for upcoming connection/read-only execution work
|
||||||
|
- Implemented SSH module with real key-based command execution via system `ssh`
|
||||||
|
- Added explicit SSH port support across CLI, input parsing, request model, and SSH client (`--port`, e.g. 5566)
|
||||||
|
- Added live SSH connectivity probe (`uname -a`) enabled by default, with `--no-probe` opt-out and non-zero exit on failure
|
||||||
|
- Added baseline diagnostics collection via `--collect`, including service, journal, disk, and network checks
|
||||||
|
- Read-only command policy enforcement (allowlist + blocked shell operators)
|
||||||
|
- Added byte-limited SSH output capture with truncation markers for large command output
|
||||||
|
- Test scaffold (`pytest`) with initial parser and CLI coverage
|
||||||
|
- SSH test coverage for policy checks, SSH argument construction, and config summary behavior
|
||||||
|
- CI workflow for lint (`ruff`), type-check (`mypy`), and tests (`pytest`)
|
||||||
|
- CI coverage expanded with Markdown formatting checks (`mdformat --check`) and YAML linting (`yamllint`)
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- `.github/workflows/ci.yml` — GitHub Actions workflow removed; CI is now Gitea-only
|
||||||
|
|
||||||
|
### Decided
|
||||||
|
|
||||||
|
- Implementation language: **Python**
|
||||||
|
- Distribution strategy: single distributable binary via **Nuitka** (PyInstaller as fallback)
|
||||||
|
- SSH authentication: **keypair only** (ed25519/RSA); auto-accept new hosts; hard reject on host key change with MITM warning
|
||||||
|
- SSH bastion support: `--jump-host` flag using SSH native ProxyJump
|
||||||
|
- SSH config behavior: use `~/.ssh/config` by default; allow override via `--ignore-ssh-config`
|
||||||
|
- Interface: **interactive REPL** for v0.1; `textual`-based TUI (split-pane) for v0.2+
|
||||||
94
README.md
94
README.md
@@ -1,3 +1,93 @@
|
|||||||
# tai
|
# tai — Linux AI Troubleshooting Agent
|
||||||
|
|
||||||
Linux AI driven troubleshooting agent.
|
`tai` is an agentic AI-driven troubleshooting tool for Linux systems. It autonomously investigates issues on remote hosts via SSH, analyzes relevant logs and configuration files, and provides a clear diagnosis along with suggested remediation steps — all without making any changes to the target system.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Given a problem description and a target hostname, `tai` connects to the remote system over SSH, gathers relevant data (logs, configuration files, service status, etc.), and uses a locally-hosted AI model to reason about the root cause and recommend solutions.
|
||||||
|
|
||||||
|
The agent operates in **read-only mode at all times**. It will never modify the target system under any circumstances — all suggestions are presented to the human troubleshooter for review and action.
|
||||||
|
|
||||||
|
## Supported Distributions
|
||||||
|
|
||||||
|
- Ubuntu
|
||||||
|
- Debian
|
||||||
|
- RHEL
|
||||||
|
- Rocky Linux
|
||||||
|
|
||||||
|
## Example Workflow
|
||||||
|
|
||||||
|
A troubleshooter receives a ticket reporting that the Apache service on a remote server has failed to start. They provide `tai` with:
|
||||||
|
|
||||||
|
1. The ticket description or error message
|
||||||
|
1. The hostname of the affected system
|
||||||
|
1. Any relevant directories to focus on
|
||||||
|
|
||||||
|
`tai` then connects to the host, reads through system logs, service configurations, and any other related files, and returns a structured analysis of the likely cause along with recommended next steps.
|
||||||
|
|
||||||
|
## Suggested Tooling
|
||||||
|
|
||||||
|
| Component | Tool |
|
||||||
|
|-----------|------|
|
||||||
|
| AI inference backend | [Ollama](https://ollama.com) |
|
||||||
|
| Model | `gemma3:4b`, `llama3.1:8b`, or `qwen2.5:7b` |
|
||||||
|
| Language | Python 3.11+ |
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
## How-To: Setting Up the AI Backend (Arch Linux + RTX 3080)
|
||||||
|
|
||||||
|
`tai` uses [Ollama](https://ollama.com) as its local AI backend. It exposes an OpenAI-compatible HTTP API that `tai` talks to — no cloud services, no data leaving your machine.
|
||||||
|
|
||||||
|
An RTX 3080 (10 GB VRAM) comfortably runs 7–8B parameter models at 4-bit quantisation.
|
||||||
|
|
||||||
|
### 1. Install CUDA and Ollama
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# CUDA runtime (skip if already installed)
|
||||||
|
sudo pacman -S cuda
|
||||||
|
|
||||||
|
# Ollama with CUDA support from the AUR
|
||||||
|
yay -S ollama-cuda
|
||||||
|
# or: paru -S ollama-cuda
|
||||||
|
|
||||||
|
# Enable and start the service
|
||||||
|
sudo systemctl enable --now ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Pull a model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull gemma3:4b # ~3 GB — fast, good for sysadmin tasks
|
||||||
|
ollama pull llama3.1:8b # ~5 GB — stronger reasoning
|
||||||
|
ollama pull qwen2.5:7b # ~4.5 GB — strong structured output
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Verify the model works
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama run gemma3:4b "what causes a systemd service to enter failed state?"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Verify the HTTP API is running
|
||||||
|
|
||||||
|
`tai` communicates with Ollama over its OpenAI-compatible REST API:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:11434/api/generate \
|
||||||
|
-d '{"model":"gemma3:4b","prompt":"hello","stream":false}'
|
||||||
|
```
|
||||||
|
|
||||||
|
A JSON response with a `response` field confirms everything is working.
|
||||||
|
|
||||||
|
### 5. Point tai at your Ollama instance
|
||||||
|
|
||||||
|
Once `tai` AI integration is complete, use these flags:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tai "nginx failing to start" --host web01 \
|
||||||
|
--ai-host http://localhost:11434 \
|
||||||
|
--model gemma3:4b
|
||||||
|
```
|
||||||
|
|
||||||
|
The default values for `--ai-host` and `--model` will be `http://localhost:11434` and `gemma3:4b` respectively, so for local use you won't need to specify them explicitly.
|
||||||
|
|||||||
130
ROADMAP.md
Normal file
130
ROADMAP.md
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
# Roadmap
|
||||||
|
|
||||||
|
This document outlines the major decisions, milestones, and development phases required to bring `tai` from concept to a working tool.
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
## Phase 0 — Decisions & Prerequisites
|
||||||
|
|
||||||
|
These must be resolved before meaningful development can begin.
|
||||||
|
|
||||||
|
### Language Selection
|
||||||
|
|
||||||
|
- [x] **Decision: Python**
|
||||||
|
- Key factors: native vLLM integration, mature SSH libraries (`paramiko` / `asyncssh`), strong text/log parsing, rapid development
|
||||||
|
- Single binary distribution will be achieved via **Nuitka** (preferred for true compilation) or **PyInstaller** as a fallback
|
||||||
|
- [ ] Evaluate Nuitka vs PyInstaller for binary output quality and CI reproducibility
|
||||||
|
- [ ] Add binary build step to CI pipeline
|
||||||
|
|
||||||
|
### AI Backend & Model
|
||||||
|
|
||||||
|
- [ ] Confirm use of [vLLM](https://github.com/vllm-project/vllm) as the inference backend
|
||||||
|
- [ ] Confirm `gemma4:a4b` as the default model (or select an alternative)
|
||||||
|
- [ ] Define minimum hardware requirements for running the model locally
|
||||||
|
- [ ] Decide whether the AI backend is bundled, self-hosted externally, or user-supplied
|
||||||
|
|
||||||
|
### SSH Strategy
|
||||||
|
|
||||||
|
- [x] **Decision: keypair authentication only** — no password auth; eliminates credential storage risk
|
||||||
|
- Default key resolution: `~/.ssh/id_ed25519`, `~/.ssh/id_rsa` (in order of preference)
|
||||||
|
- CLI override via `--identity-file <path>`
|
||||||
|
- No SSH agent forwarding needed — a shared key is distributed to all managed hosts via Puppet
|
||||||
|
- [x] **Known hosts: auto-accept new hosts; reject on key mismatch** — a changed host key triggers a hard stop with a MITM warning; unknown/new hosts are accepted silently on first connect
|
||||||
|
- [x] **Bastion/jump host: `--jump-host <host>` flag** — delegates to SSH's native ProxyJump functionality
|
||||||
|
- [x] **SSH config behavior: respect existing `~/.ssh/config` by default; allow CLI override**
|
||||||
|
- Default: follow host settings from `~/.ssh/config` (for `User`, `Port`, `ProxyJump`, etc.)
|
||||||
|
- Override switch: `--ignore-ssh-config` to bypass local SSH config when required
|
||||||
|
|
||||||
|
### Scope & Constraints
|
||||||
|
|
||||||
|
- [ ] Define the supported scope of issues (services, network, disk, kernel, etc.)
|
||||||
|
- [ ] Confirm read-only guarantee — document exactly what "read-only" means in practice
|
||||||
|
- [x] **Decision: interactive REPL mode for v0.1, full TUI for v0.2+**
|
||||||
|
- v0.1: chat-loop REPL launched from CLI; human can follow up, correct, and redirect the agent
|
||||||
|
- v0.2+: `textual`-based TUI with split panes (collected data | AI output | input bar)
|
||||||
|
- Built-in slash commands: `/collect`, `/show logs`, `/clear`, `/host <hostname>`, `/help`, `/quit`
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
## Phase 1 — Project Foundation
|
||||||
|
|
||||||
|
Basic project scaffolding and connectivity.
|
||||||
|
|
||||||
|
- [x] Finalise repository structure and language toolchain
|
||||||
|
- [x] Set up CI pipeline (linting, tests)
|
||||||
|
- [ ] Implement SSH connection module
|
||||||
|
- [x] Define SSH config model and probe interface scaffold
|
||||||
|
- [x] Connect to remote host
|
||||||
|
- [x] Execute read-only commands (e.g. `journalctl`, `systemctl status`, `cat`)
|
||||||
|
- [x] Stream or collect command output safely (byte-limited output with truncation marker)
|
||||||
|
- [x] Implement basic input parsing (ticket text, hostname, target directories)
|
||||||
|
- [x] Write unit tests for SSH and input modules
|
||||||
|
- [x] Input parser and CLI tests added
|
||||||
|
- [x] SSH module tests added for command policy and SSH argv behavior
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
## Phase 2 — Data Collection Layer
|
||||||
|
|
||||||
|
Define what information the agent gathers and how.
|
||||||
|
|
||||||
|
- [ ] Identify the canonical set of data sources per issue type:
|
||||||
|
- Service failures: `journalctl`, `systemctl`, service config files
|
||||||
|
- Network issues: `ip`, `ss`, `netstat`, firewall rules
|
||||||
|
- Disk issues: `df`, `du`, `dmesg`, `smartctl`
|
||||||
|
- General: `/var/log/syslog`, `/var/log/messages`, `dmesg`
|
||||||
|
- [ ] Implement pluggable "collector" modules per data source
|
||||||
|
- [ ] Implement directory traversal for user-specified paths (read-only)
|
||||||
|
- [ ] Add support for per-distro variations (Ubuntu vs RHEL path differences, etc.)
|
||||||
|
- [ ] Write tests with mocked SSH output
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
## Phase 3 — AI Integration
|
||||||
|
|
||||||
|
Wire collected data into the local AI model.
|
||||||
|
|
||||||
|
- [ ] Implement vLLM client module
|
||||||
|
- [ ] Design prompt template: system context, collected data, issue description → diagnosis
|
||||||
|
- [ ] Implement response parsing and structured output (root cause + suggested steps)
|
||||||
|
- [ ] Tune context window usage — handle truncation for large log outputs
|
||||||
|
- [ ] Add streaming support for long AI responses
|
||||||
|
- [ ] Evaluate and test model output quality on common issue types
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
## Phase 4 — CLI & User Experience
|
||||||
|
|
||||||
|
Polish the interface for real-world use.
|
||||||
|
|
||||||
|
- [ ] Design CLI interface (flags, subcommands, interactive prompts)
|
||||||
|
- [ ] Implement structured output: diagnosis, confidence, recommended actions
|
||||||
|
- [ ] Add `--verbose` / `--debug` mode showing raw collected data
|
||||||
|
- [ ] Support output to file or clipboard
|
||||||
|
- [ ] Write man page / `--help` documentation
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
## Phase 5 — Hardening & Distribution
|
||||||
|
|
||||||
|
Prepare for broader use.
|
||||||
|
|
||||||
|
- [ ] Security review of SSH handling and credential storage
|
||||||
|
- [ ] Ensure no data is written to the remote system under any path
|
||||||
|
- [ ] Package for distribution (binary release, container image, or distro packages)
|
||||||
|
- [ ] Write installation and quickstart documentation
|
||||||
|
- [ ] End-to-end integration tests against a test VM
|
||||||
|
|
||||||
|
______________________________________________________________________
|
||||||
|
|
||||||
|
## Decisions Log
|
||||||
|
|
||||||
|
| Date | Decision | Outcome |
|
||||||
|
|------|----------|---------|
|
||||||
|
| 2026-05-04 | Implementation language | Python — with single distributable binary via Nuitka |
|
||||||
|
| — | AI inference backend | vLLM (provisional) |
|
||||||
|
| — | Default model | `gemma4:a4b` (provisional) |
|
||||||
|
| 2026-05-04 | SSH auth methods | Keypair only (ed25519/RSA); auto-accept new hosts; reject on key change (MITM) |
|
||||||
|
| 2026-05-04 | Bastion host support | `--jump-host` flag via SSH native ProxyJump |
|
||||||
|
| 2026-05-04 | SSH config behavior | Use `~/.ssh/config` by default; allow override via `--ignore-ssh-config` |
|
||||||
|
| 2026-05-04 | CLI vs interactive mode | Interactive: REPL for v0.1, `textual` TUI for v0.2+ |
|
||||||
53
pyproject.toml
Normal file
53
pyproject.toml
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["hatchling>=1.25"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "tai"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Linux AI-driven troubleshooting agent"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
authors = [
|
||||||
|
{ name = "tai contributors" }
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"typer>=0.12,<1.0",
|
||||||
|
"rich>=13.7,<14.0",
|
||||||
|
"asyncssh>=2.14,<3.0",
|
||||||
|
"openai>=1.30,<2.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.2,<9.0",
|
||||||
|
"ruff>=0.5,<1.0",
|
||||||
|
"mypy>=1.10,<2.0",
|
||||||
|
"mdformat>=0.7,<1.0",
|
||||||
|
"yamllint>=1.35,<2.0",
|
||||||
|
]
|
||||||
|
build = [
|
||||||
|
"nuitka>=2.4,<3.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
tai = "tai.cli:main"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/tai"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
addopts = "-q"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py311"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I", "UP", "B"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
strict = true
|
||||||
|
warn_unused_configs = true
|
||||||
5
src/tai/__init__.py
Normal file
5
src/tai/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""tai package."""
|
||||||
|
|
||||||
|
__all__ = ["__version__"]
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
93
src/tai/ai_client.py
Normal file
93
src/tai/ai_client.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""AI backend client — OpenAI-compatible, works with Ollama, OpenAI, or any compatible endpoint."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
DEFAULT_AI_HOST = "http://localhost:11434/v1"
|
||||||
|
DEFAULT_MODEL = "gemma3:4b"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AIConfig:
|
||||||
|
"""Connection parameters for an OpenAI-compatible AI backend."""
|
||||||
|
|
||||||
|
host: str = DEFAULT_AI_HOST
|
||||||
|
model: str = DEFAULT_MODEL
|
||||||
|
api_key: str = "ollama" # Ollama ignores this; required by the openai client
|
||||||
|
timeout_seconds: float = 120.0
|
||||||
|
max_tokens: int = 4096
|
||||||
|
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AIResponse:
|
||||||
|
"""Structured response from an AI completion."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
content: str
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
return self.prompt_tokens + self.completion_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class AIClient:
|
||||||
|
"""Thin wrapper around the openai client targeting a configurable endpoint."""
|
||||||
|
|
||||||
|
def __init__(self, config: AIConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._client = OpenAI(
|
||||||
|
base_url=config.host,
|
||||||
|
api_key=config.api_key,
|
||||||
|
timeout=config.timeout_seconds,
|
||||||
|
default_headers=config.extra_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def complete(self, system_prompt: str, user_message: str) -> AIResponse:
|
||||||
|
"""Send a completion request and return the full response."""
|
||||||
|
response = self._client.chat.completions.create(
|
||||||
|
model=self._config.model,
|
||||||
|
max_tokens=self._config.max_tokens,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_message},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
choice = response.choices[0]
|
||||||
|
content = choice.message.content or ""
|
||||||
|
usage = response.usage
|
||||||
|
|
||||||
|
return AIResponse(
|
||||||
|
model=response.model,
|
||||||
|
content=content,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else 0,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(self, system_prompt: str, user_message: str) -> Iterator[str]:
|
||||||
|
"""Stream a completion, yielding text chunks as they arrive."""
|
||||||
|
stream = self._client.chat.completions.create(
|
||||||
|
model=self._config.model,
|
||||||
|
max_tokens=self._config.max_tokens,
|
||||||
|
stream=True,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_message},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta.content
|
||||||
|
if delta:
|
||||||
|
yield delta
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
"""Human-readable description of the AI config."""
|
||||||
|
return f"host={self._config.host} model={self._config.model}"
|
||||||
205
src/tai/cli.py
Normal file
205
src/tai/cli.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""CLI entrypoint for tai."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
|
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||||
|
from tai.collectors import CollectionReport, collect_from_plan
|
||||||
|
from tai.input_parser import InputValidationError, build_request
|
||||||
|
from tai.models import TroubleshootRequest
|
||||||
|
from tai.plan import plan_from_request
|
||||||
|
from tai.prompt_builder import build_system_prompt, build_user_message
|
||||||
|
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig
|
||||||
|
|
||||||
|
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def run(
|
||||||
|
issue: Annotated[str, typer.Argument(help="Ticket text or issue summary.")],
|
||||||
|
host: Annotated[str, typer.Option("--host", help="Target host to troubleshoot.")],
|
||||||
|
port: Annotated[int, typer.Option("--port", help="SSH port for the target host.")] = 22,
|
||||||
|
path: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
typer.Option("--path", help="Path to inspect. Repeatable."),
|
||||||
|
] = None,
|
||||||
|
identity_file: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Option("--identity-file", help="SSH private key path."),
|
||||||
|
] = None,
|
||||||
|
jump_host: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Option("--jump-host", help="SSH bastion/jump host."),
|
||||||
|
] = None,
|
||||||
|
ignore_ssh_config: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--ignore-ssh-config",
|
||||||
|
help="Ignore ~/.ssh/config and rely only on CLI options.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
probe: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--probe/--no-probe",
|
||||||
|
help="Enable or disable live SSH connectivity probe (uname -a).",
|
||||||
|
),
|
||||||
|
] = True,
|
||||||
|
collect: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--collect/--no-collect",
|
||||||
|
help="Collect baseline diagnostics after probe.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
analyze: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
"--analyze/--no-analyze",
|
||||||
|
help="Send collected diagnostics to AI for analysis.",
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
ai_host: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--ai-host", help="OpenAI-compatible AI backend URL."),
|
||||||
|
] = DEFAULT_AI_HOST,
|
||||||
|
model: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--model", help="Model name to use for AI analysis."),
|
||||||
|
] = DEFAULT_MODEL,
|
||||||
|
ai_key: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--ai-key", help="API key for the AI backend (not needed for Ollama)."),
|
||||||
|
] = "ollama",
|
||||||
|
) -> None:
|
||||||
|
"""Start an interactive troubleshooting session scaffold."""
|
||||||
|
try:
|
||||||
|
req = build_request(
|
||||||
|
issue=issue,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
target_paths=path or [],
|
||||||
|
identity_file=identity_file,
|
||||||
|
jump_host=jump_host,
|
||||||
|
ignore_ssh_config=ignore_ssh_config,
|
||||||
|
)
|
||||||
|
except InputValidationError as exc:
|
||||||
|
console.print(f"[red]Input error:[/red] {exc}")
|
||||||
|
raise typer.Exit(code=2) from exc
|
||||||
|
|
||||||
|
config = SSHConnectionConfig(
|
||||||
|
host=req.host,
|
||||||
|
port=req.port,
|
||||||
|
identity_file=req.identity_file,
|
||||||
|
jump_host=req.jump_host,
|
||||||
|
ignore_ssh_config=req.ignore_ssh_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
summary = SSHClient(config).summary()
|
||||||
|
console.print("[bold green]tai[/bold green]")
|
||||||
|
console.print(f"Issue: {req.issue}")
|
||||||
|
console.print(f"SSH: {summary}")
|
||||||
|
if req.target_paths:
|
||||||
|
console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}")
|
||||||
|
|
||||||
|
if not (probe or collect or analyze):
|
||||||
|
return # nothing SSH-related requested
|
||||||
|
|
||||||
|
ai_config = AIConfig(host=ai_host, model=model, api_key=ai_key)
|
||||||
|
if analyze:
|
||||||
|
console.print(f"[cyan]AI:[/cyan] {AIClient(ai_config).summary()}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(_async_main(config, req, probe=probe, collect=collect, analyze=analyze,
|
||||||
|
ai_config=ai_config))
|
||||||
|
except typer.Exit:
|
||||||
|
raise
|
||||||
|
except TimeoutError as exc:
|
||||||
|
console.print(f"[red]SSH timeout:[/red] {exc}")
|
||||||
|
raise typer.Exit(code=1) from exc
|
||||||
|
except OSError as exc:
|
||||||
|
console.print(f"[red]SSH error:[/red] unable to execute ssh: {exc}")
|
||||||
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_main(
|
||||||
|
config: SSHConnectionConfig,
|
||||||
|
req: TroubleshootRequest,
|
||||||
|
*,
|
||||||
|
probe: bool,
|
||||||
|
collect: bool,
|
||||||
|
analyze: bool,
|
||||||
|
ai_config: AIConfig,
|
||||||
|
) -> None:
|
||||||
|
"""Open a single SSH session and run probe / collection / analysis through it."""
|
||||||
|
client = SSHClient(config)
|
||||||
|
async with client.connect() as session:
|
||||||
|
if probe:
|
||||||
|
result = await session.probe()
|
||||||
|
_handle_probe_result(result)
|
||||||
|
|
||||||
|
report: CollectionReport | None = None
|
||||||
|
if collect or analyze:
|
||||||
|
plan = plan_from_request(req)
|
||||||
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
|
report = await collect_from_plan(session, plan)
|
||||||
|
_handle_collection_report(report)
|
||||||
|
|
||||||
|
if analyze and report is not None:
|
||||||
|
_run_analysis(ai_config, req.issue, report)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||||
|
"""Handle and render probe output for success or failure."""
|
||||||
|
console.print("[cyan]Running SSH probe:[/cyan] uname -a")
|
||||||
|
if result.exit_code != 0:
|
||||||
|
details = result.stderr or result.stdout or "no error output from ssh"
|
||||||
|
console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
output = result.stdout or "(no output)"
|
||||||
|
console.print("[bold green]Probe succeeded.[/bold green]")
|
||||||
|
console.print(f"Remote: {output}")
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_collection_report(report: CollectionReport) -> None:
|
||||||
|
"""Render collected command status and truncation hints."""
|
||||||
|
console.print(
|
||||||
|
f"[bold]Collection complete:[/bold] {report.total} commands, {report.failed} failed"
|
||||||
|
)
|
||||||
|
for item in report.items:
|
||||||
|
status = "ok" if item.result.exit_code == 0 else f"exit {item.result.exit_code}"
|
||||||
|
truncated = item.result.stdout_truncated or item.result.stderr_truncated
|
||||||
|
trunc = " (truncated)" if truncated else ""
|
||||||
|
console.print(f"- {item.name}: {status}{trunc}")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) -> None:
|
||||||
|
"""Send collected data to the AI and stream the analysis to stdout."""
|
||||||
|
console.print("[cyan]Analyzing...[/cyan]\n")
|
||||||
|
ai = AIClient(ai_config)
|
||||||
|
system_prompt = build_system_prompt()
|
||||||
|
user_message = build_user_message(issue, report)
|
||||||
|
try:
|
||||||
|
chunks: list[str] = []
|
||||||
|
for chunk in ai.stream(system_prompt, user_message):
|
||||||
|
chunks.append(chunk)
|
||||||
|
console.print(Markdown("".join(chunks)))
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
console.print(f"[red]AI analysis failed:[/red] {exc}")
|
||||||
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Console script entrypoint."""
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
50
src/tai/collectors.py
Normal file
50
src/tai/collectors.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Data collection routines built on top of the SSH client."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from tai.plan import CollectionPlan
|
||||||
|
from tai.ssh_client import SSHCommandResult, SSHSession
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class CollectedItem:
|
||||||
|
"""Single collected diagnostic command result."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
result: SSHCommandResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class CollectionReport:
|
||||||
|
"""Collection summary for a batch of diagnostics."""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
items: list[CollectedItem]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
return len(self.items)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failed(self) -> int:
|
||||||
|
return sum(1 for item in self.items if item.result.exit_code != 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def collect_from_plan(
|
||||||
|
session: SSHSession,
|
||||||
|
plan: CollectionPlan,
|
||||||
|
*,
|
||||||
|
max_output_bytes: int = 32768,
|
||||||
|
) -> CollectionReport:
|
||||||
|
"""Execute all commands in *plan* over a shared SSH session."""
|
||||||
|
items: list[CollectedItem] = []
|
||||||
|
|
||||||
|
for name, command in plan.commands:
|
||||||
|
result = await session.run_read_only_command(
|
||||||
|
command,
|
||||||
|
timeout_seconds=30.0,
|
||||||
|
max_output_bytes=max_output_bytes,
|
||||||
|
)
|
||||||
|
items.append(CollectedItem(name=name, result=result))
|
||||||
|
|
||||||
|
return CollectionReport(host=session._client.summary(), items=items)
|
||||||
46
src/tai/input_parser.py
Normal file
46
src/tai/input_parser.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Helpers to normalize and validate CLI input."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tai.models import TroubleshootRequest
|
||||||
|
|
||||||
|
|
||||||
|
class InputValidationError(ValueError):
|
||||||
|
"""Raised when required user input is missing or invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def build_request(
|
||||||
|
*,
|
||||||
|
issue: str,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
target_paths: list[str],
|
||||||
|
identity_file: str | None,
|
||||||
|
jump_host: str | None,
|
||||||
|
ignore_ssh_config: bool,
|
||||||
|
) -> TroubleshootRequest:
|
||||||
|
"""Create a normalized request object from raw CLI values."""
|
||||||
|
normalized_issue = issue.strip()
|
||||||
|
normalized_host = host.strip()
|
||||||
|
|
||||||
|
if not normalized_issue:
|
||||||
|
raise InputValidationError("Issue description cannot be empty.")
|
||||||
|
|
||||||
|
if not normalized_host:
|
||||||
|
raise InputValidationError("Host cannot be empty.")
|
||||||
|
|
||||||
|
if port < 1 or port > 65535:
|
||||||
|
raise InputValidationError("Port must be between 1 and 65535.")
|
||||||
|
|
||||||
|
paths = [Path(p).expanduser() for p in target_paths]
|
||||||
|
identity = Path(identity_file).expanduser() if identity_file else None
|
||||||
|
|
||||||
|
return TroubleshootRequest(
|
||||||
|
issue=normalized_issue,
|
||||||
|
host=normalized_host,
|
||||||
|
port=port,
|
||||||
|
target_paths=paths,
|
||||||
|
identity_file=identity,
|
||||||
|
jump_host=jump_host.strip() if jump_host else None,
|
||||||
|
ignore_ssh_config=ignore_ssh_config,
|
||||||
|
)
|
||||||
17
src/tai/models.py
Normal file
17
src/tai/models.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Core domain models for tai."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class TroubleshootRequest:
|
||||||
|
"""User-provided troubleshooting input for a single run."""
|
||||||
|
|
||||||
|
issue: str
|
||||||
|
host: str
|
||||||
|
port: int = 22
|
||||||
|
target_paths: list[Path] = field(default_factory=list)
|
||||||
|
identity_file: Path | None = None
|
||||||
|
jump_host: str | None = None
|
||||||
|
ignore_ssh_config: bool = False
|
||||||
244
src/tai/plan.py
Normal file
244
src/tai/plan.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""Collection plan builder — decides what to collect based on the issue."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from tai.models import TroubleshootRequest
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Keyword sets for issue classification
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_SERVICE_KEYWORDS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"service",
|
||||||
|
"unit",
|
||||||
|
"daemon",
|
||||||
|
"failed",
|
||||||
|
"dead",
|
||||||
|
"inactive",
|
||||||
|
"crash",
|
||||||
|
"crashed",
|
||||||
|
"start",
|
||||||
|
"stop",
|
||||||
|
"restart",
|
||||||
|
"status",
|
||||||
|
"systemd",
|
||||||
|
"systemctl",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_NETWORK_KEYWORDS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"network",
|
||||||
|
"port",
|
||||||
|
"connect",
|
||||||
|
"connection",
|
||||||
|
"listen",
|
||||||
|
"firewall",
|
||||||
|
"route",
|
||||||
|
"routing",
|
||||||
|
"interface",
|
||||||
|
"dns",
|
||||||
|
"http",
|
||||||
|
"https",
|
||||||
|
"tcp",
|
||||||
|
"udp",
|
||||||
|
"socket",
|
||||||
|
"unreachable",
|
||||||
|
"refused",
|
||||||
|
"timeout",
|
||||||
|
"latency",
|
||||||
|
"bandwidth",
|
||||||
|
"packet",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_DISK_KEYWORDS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"disk",
|
||||||
|
"space",
|
||||||
|
"storage",
|
||||||
|
"inode",
|
||||||
|
"full",
|
||||||
|
"mount",
|
||||||
|
"filesystem",
|
||||||
|
"partition",
|
||||||
|
"quota",
|
||||||
|
"usage",
|
||||||
|
"capacity",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Known service names and their candidate config paths
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_KNOWN_SERVICES: list[str] = [
|
||||||
|
"apache2",
|
||||||
|
"httpd",
|
||||||
|
"nginx",
|
||||||
|
"mysql",
|
||||||
|
"mysqld",
|
||||||
|
"mariadb",
|
||||||
|
"postgresql",
|
||||||
|
"redis",
|
||||||
|
"redis-server",
|
||||||
|
"mongodb",
|
||||||
|
"mongod",
|
||||||
|
"docker",
|
||||||
|
"containerd",
|
||||||
|
"kubelet",
|
||||||
|
"sshd",
|
||||||
|
"postfix",
|
||||||
|
"dovecot",
|
||||||
|
"sendmail",
|
||||||
|
"php-fpm",
|
||||||
|
"elasticsearch",
|
||||||
|
"rabbitmq",
|
||||||
|
"rabbitmq-server",
|
||||||
|
"celery",
|
||||||
|
"gunicorn",
|
||||||
|
"ufw",
|
||||||
|
"fail2ban",
|
||||||
|
"cron",
|
||||||
|
"crond",
|
||||||
|
"rsyslog",
|
||||||
|
"auditd",
|
||||||
|
"firewalld",
|
||||||
|
"haproxy",
|
||||||
|
"varnish",
|
||||||
|
"memcached",
|
||||||
|
]
|
||||||
|
|
||||||
|
_SERVICE_CONFIGS: dict[str, list[str]] = {
|
||||||
|
"apache2": ["/etc/apache2/apache2.conf"],
|
||||||
|
"httpd": ["/etc/httpd/conf/httpd.conf"],
|
||||||
|
"nginx": ["/etc/nginx/nginx.conf"],
|
||||||
|
"mysql": ["/etc/mysql/mysql.conf.d/mysqld.cnf"],
|
||||||
|
"mysqld": ["/etc/my.cnf"],
|
||||||
|
"mariadb": ["/etc/mysql/mariadb.conf.d/50-server.cnf"],
|
||||||
|
"postgresql": ["/etc/postgresql"],
|
||||||
|
"sshd": ["/etc/ssh/sshd_config"],
|
||||||
|
"postfix": ["/etc/postfix/main.cf"],
|
||||||
|
"haproxy": ["/etc/haproxy/haproxy.cfg"],
|
||||||
|
"redis": ["/etc/redis/redis.conf"],
|
||||||
|
"redis-server": ["/etc/redis/redis.conf"],
|
||||||
|
"fail2ban": ["/etc/fail2ban/jail.conf"],
|
||||||
|
"ufw": ["/etc/ufw/ufw.conf"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Command sets
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_ALWAYS: list[tuple[str, str]] = [
|
||||||
|
("kernel", "uname -a"),
|
||||||
|
("uptime", "cat /proc/uptime"),
|
||||||
|
("disk-usage", "df -h"),
|
||||||
|
("memory", "cat /proc/meminfo"),
|
||||||
|
("running-services", "systemctl list-units --type=service --state=running --no-pager"),
|
||||||
|
]
|
||||||
|
|
||||||
|
_SERVICE_EXTRA: list[tuple[str, str]] = [
|
||||||
|
("failed-services", "systemctl list-units --type=service --state=failed --no-pager"),
|
||||||
|
("journal-errors", "journalctl -p err -n 100 --no-pager"),
|
||||||
|
]
|
||||||
|
|
||||||
|
_NETWORK_EXTRA: list[tuple[str, str]] = [
|
||||||
|
("listening-ports", "ss -lntp"),
|
||||||
|
("ip-addresses", "ip addr show"),
|
||||||
|
("ip-routes", "ip route show"),
|
||||||
|
("ip-stats", "ip -s link show"),
|
||||||
|
]
|
||||||
|
|
||||||
|
_DISK_EXTRA: list[tuple[str, str]] = [
|
||||||
|
("disk-inodes", "df -i"),
|
||||||
|
("dmesg-disk", "dmesg -T --level=err,warn"),
|
||||||
|
("large-dirs", "du -sh /var /tmp /home /opt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class CollectionPlan:
|
||||||
|
"""Ordered list of (name, command) pairs to execute on a remote host."""
|
||||||
|
|
||||||
|
commands: list[tuple[str, str]] = field(default_factory=list)
|
||||||
|
|
||||||
|
def add(self, name: str, command: str) -> None:
|
||||||
|
self.commands.append((name, command))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.commands)
|
||||||
|
|
||||||
|
|
||||||
|
def plan_from_request(request: TroubleshootRequest) -> CollectionPlan:
|
||||||
|
"""Build a :class:`CollectionPlan` tailored to *request*."""
|
||||||
|
plan = CollectionPlan(commands=list(_ALWAYS))
|
||||||
|
keywords = _issue_words(request.issue)
|
||||||
|
|
||||||
|
# --- category expansions -------------------------------------------
|
||||||
|
if keywords & _SERVICE_KEYWORDS:
|
||||||
|
plan.commands.extend(_SERVICE_EXTRA)
|
||||||
|
|
||||||
|
if keywords & _NETWORK_KEYWORDS:
|
||||||
|
plan.commands.extend(_NETWORK_EXTRA)
|
||||||
|
|
||||||
|
if keywords & _DISK_KEYWORDS:
|
||||||
|
plan.commands.extend(_DISK_EXTRA)
|
||||||
|
|
||||||
|
# --- named service detection ---------------------------------------
|
||||||
|
services = _extract_services(request.issue)
|
||||||
|
seen: set[str] = set()
|
||||||
|
for svc in services:
|
||||||
|
if svc in seen:
|
||||||
|
continue
|
||||||
|
seen.add(svc)
|
||||||
|
plan.add(f"service-{svc}", f"systemctl status {svc}")
|
||||||
|
plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager")
|
||||||
|
for cfg_path in _SERVICE_CONFIGS.get(svc, []):
|
||||||
|
plan.add(f"config-{svc}", f"cat {cfg_path}")
|
||||||
|
|
||||||
|
# --- user-specified paths -----------------------------------------
|
||||||
|
for path in request.target_paths:
|
||||||
|
plan.add(f"ls-{path.name}", f"ls -la {path}")
|
||||||
|
if "log" in str(path).lower():
|
||||||
|
plan.add(
|
||||||
|
f"find-logs-{path.name}",
|
||||||
|
f"find {path} -maxdepth 2 -type f -name '*.log'",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
plan.add(
|
||||||
|
f"find-files-{path.name}",
|
||||||
|
f"find {path} -maxdepth 2 -type f",
|
||||||
|
)
|
||||||
|
|
||||||
|
return plan
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _issue_words(issue: str) -> set[str]:
|
||||||
|
"""Return the set of lowercase words in *issue*."""
|
||||||
|
return set(re.findall(r"\b\w+\b", issue.lower()))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_services(issue: str) -> list[str]:
|
||||||
|
"""Return known service names mentioned in *issue*."""
|
||||||
|
words = _issue_words(issue)
|
||||||
|
found: list[str] = []
|
||||||
|
for svc in _KNOWN_SERVICES:
|
||||||
|
# Match the service name or its stem (strip trailing 'd', e.g. 'apache' → 'apache2')
|
||||||
|
svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")}
|
||||||
|
if words & svc_words:
|
||||||
|
found.append(svc)
|
||||||
|
return found
|
||||||
74
src/tai/prompt_builder.py
Normal file
74
src/tai/prompt_builder.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Formats collected diagnostics into prompts for the AI backend."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from tai.collectors import CollectionReport
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = """\
|
||||||
|
You are an expert Linux systems administrator and troubleshooting assistant.
|
||||||
|
You are given diagnostic data collected read-only from a remote Linux host via SSH.
|
||||||
|
|
||||||
|
Your job:
|
||||||
|
1. Identify the root cause of the reported issue based only on the data provided.
|
||||||
|
2. Cite the specific output that supports your conclusion.
|
||||||
|
3. Give concise, actionable remediation steps.
|
||||||
|
|
||||||
|
Important rules:
|
||||||
|
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
|
||||||
|
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
|
||||||
|
rejected that specific command — it is not evidence about the service or system state.
|
||||||
|
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
|
||||||
|
additional commands or log files would be needed.
|
||||||
|
- Keep the response short. Skip sections that have nothing useful to say.
|
||||||
|
- Never suggest commands that modify the system unless explicitly asked.
|
||||||
|
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_system_prompt() -> str:
|
||||||
|
"""Return the static system prompt for the troubleshooting agent."""
|
||||||
|
return _SYSTEM_PROMPT.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def build_user_message(issue: str, report: CollectionReport) -> str:
|
||||||
|
"""Format *issue* and *report* into the user message sent to the AI."""
|
||||||
|
lines: list[str] = []
|
||||||
|
|
||||||
|
lines.append(f"## Issue reported\n\n{issue}\n")
|
||||||
|
lines.append(f"## Target host\n\n{report.host}\n")
|
||||||
|
lines.append("## Collected diagnostics\n")
|
||||||
|
|
||||||
|
skipped: list[str] = []
|
||||||
|
|
||||||
|
for item in report.items:
|
||||||
|
result = item.result
|
||||||
|
|
||||||
|
# Exit 255 with no output = SSH couldn't execute the command at all.
|
||||||
|
# Exclude these entirely to prevent the AI from speculating on them.
|
||||||
|
if result.exit_code == 255 and not result.stdout and not result.stderr:
|
||||||
|
skipped.append(item.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines.append(f"### {item.name}\n")
|
||||||
|
lines.append(f"**Command:** `{result.command}` ")
|
||||||
|
lines.append(f"**Exit code:** {result.exit_code}\n")
|
||||||
|
|
||||||
|
if result.stdout:
|
||||||
|
trunc = " *(output truncated)*" if result.stdout_truncated else ""
|
||||||
|
lines.append(f"**stdout:**{trunc}\n```\n{result.stdout.strip()}\n```\n")
|
||||||
|
|
||||||
|
if result.stderr:
|
||||||
|
trunc = " *(output truncated)*" if result.stderr_truncated else ""
|
||||||
|
lines.append(f"**stderr:**{trunc}\n```\n{result.stderr.strip()}\n```\n")
|
||||||
|
|
||||||
|
if not result.stdout and not result.stderr:
|
||||||
|
lines.append("*(no output)*\n")
|
||||||
|
|
||||||
|
if skipped:
|
||||||
|
lines.append(
|
||||||
|
f"**Note:** The following commands could not be executed on this host "
|
||||||
|
f"and produced no output: {', '.join(skipped)}. "
|
||||||
|
f"Do not draw any conclusions from their absence.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
383
src/tai/ssh_client.py
Normal file
383
src/tai/ssh_client.py
Normal file
@@ -0,0 +1,383 @@
|
|||||||
|
"""SSH configuration and read-only command execution."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from types import TracebackType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class SSHConnectionConfig:
|
||||||
|
"""Connection parameters for the target host."""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
port: int = 22
|
||||||
|
identity_file: Path | None = None
|
||||||
|
jump_host: str | None = None
|
||||||
|
ignore_ssh_config: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class SSHCommandResult:
|
||||||
|
"""Result of a remote SSH command execution."""
|
||||||
|
|
||||||
|
command: str
|
||||||
|
exit_code: int
|
||||||
|
stdout: str
|
||||||
|
stderr: str
|
||||||
|
stdout_truncated: bool = False
|
||||||
|
stderr_truncated: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class SSHCommandRejectedError(ValueError):
|
||||||
|
"""Raised when a command violates read-only policy."""
|
||||||
|
|
||||||
|
|
||||||
|
class SSHClient:
|
||||||
|
"""Wrapper around SSH operations with read-only safeguards."""
|
||||||
|
|
||||||
|
_BLOCKED_TOKENS = {
|
||||||
|
">",
|
||||||
|
">>",
|
||||||
|
"<",
|
||||||
|
"|",
|
||||||
|
"&&",
|
||||||
|
"||",
|
||||||
|
";",
|
||||||
|
}
|
||||||
|
_READ_ONLY_COMMANDS = {
|
||||||
|
"cat",
|
||||||
|
"dmesg",
|
||||||
|
"df",
|
||||||
|
"du",
|
||||||
|
"find",
|
||||||
|
"grep",
|
||||||
|
"head",
|
||||||
|
"hostnamectl",
|
||||||
|
"ip",
|
||||||
|
"journalctl",
|
||||||
|
"ls",
|
||||||
|
"netstat",
|
||||||
|
"sed",
|
||||||
|
"ss",
|
||||||
|
"stat",
|
||||||
|
"systemctl",
|
||||||
|
"tail",
|
||||||
|
"uname",
|
||||||
|
}
|
||||||
|
_READ_ONLY_SYSTEMCTL_SUBCOMMANDS = {
|
||||||
|
"cat",
|
||||||
|
"is-active",
|
||||||
|
"is-failed",
|
||||||
|
"list-unit-files",
|
||||||
|
"list-units",
|
||||||
|
"show",
|
||||||
|
"status",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: SSHConnectionConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
"""Return a short summary of connection settings."""
|
||||||
|
mode = "ignore ssh config" if self._config.ignore_ssh_config else "use ssh config"
|
||||||
|
jump = self._config.jump_host or "none"
|
||||||
|
key = str(self._config.identity_file) if self._config.identity_file else "auto"
|
||||||
|
return (
|
||||||
|
f"host={self._config.host} port={self._config.port} "
|
||||||
|
f"key={key} jump={jump} mode={mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_base_argv(self) -> list[str]:
|
||||||
|
"""Build the common SSH argv flags (no host or command appended)."""
|
||||||
|
argv = [
|
||||||
|
"ssh",
|
||||||
|
"-p",
|
||||||
|
str(self._config.port),
|
||||||
|
"-o",
|
||||||
|
"BatchMode=yes",
|
||||||
|
"-o",
|
||||||
|
"ConnectTimeout=15",
|
||||||
|
"-o",
|
||||||
|
"StrictHostKeyChecking=accept-new",
|
||||||
|
]
|
||||||
|
|
||||||
|
if self._config.ignore_ssh_config:
|
||||||
|
argv += ["-F", "/dev/null"]
|
||||||
|
|
||||||
|
if self._config.identity_file:
|
||||||
|
argv += ["-i", str(self._config.identity_file)]
|
||||||
|
|
||||||
|
if self._config.jump_host:
|
||||||
|
argv += ["-J", self._config.jump_host]
|
||||||
|
|
||||||
|
return argv
|
||||||
|
|
||||||
|
def build_ssh_argv(self, remote_command: str) -> list[str]:
|
||||||
|
"""Build argv for a secure non-interactive SSH invocation."""
|
||||||
|
return self._build_base_argv() + [self._config.host, remote_command]
|
||||||
|
|
||||||
|
def connect(self, *, connect_timeout: float = 15.0) -> "SSHSession":
|
||||||
|
"""Return an :class:`SSHSession` async context manager for this host."""
|
||||||
|
return SSHSession(self, connect_timeout=connect_timeout)
|
||||||
|
|
||||||
|
def validate_read_only_command(self, command: str) -> None:
|
||||||
|
"""Validate that a command appears read-only and non-destructive."""
|
||||||
|
normalized = command.strip()
|
||||||
|
if not normalized:
|
||||||
|
raise SSHCommandRejectedError("Command cannot be empty.")
|
||||||
|
|
||||||
|
for token in self._BLOCKED_TOKENS:
|
||||||
|
if token in normalized:
|
||||||
|
raise SSHCommandRejectedError(
|
||||||
|
f"Command contains blocked shell operator: {token}"
|
||||||
|
)
|
||||||
|
|
||||||
|
parts = shlex.split(normalized)
|
||||||
|
if not parts:
|
||||||
|
raise SSHCommandRejectedError("Command cannot be empty.")
|
||||||
|
|
||||||
|
base = parts[0]
|
||||||
|
if base not in self._READ_ONLY_COMMANDS:
|
||||||
|
raise SSHCommandRejectedError(
|
||||||
|
f"Command '{base}' is not allowed by read-only policy."
|
||||||
|
)
|
||||||
|
|
||||||
|
if base == "systemctl":
|
||||||
|
if len(parts) < 2:
|
||||||
|
raise SSHCommandRejectedError("systemctl requires a subcommand.")
|
||||||
|
subcommand = parts[1]
|
||||||
|
if subcommand not in self._READ_ONLY_SYSTEMCTL_SUBCOMMANDS:
|
||||||
|
raise SSHCommandRejectedError(
|
||||||
|
f"systemctl subcommand '{subcommand}' is not read-only."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_read_only_command(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
timeout_seconds: float = 30.0,
|
||||||
|
max_output_bytes: int = 32768,
|
||||||
|
) -> SSHCommandResult:
|
||||||
|
"""Run a validated read-only command over SSH."""
|
||||||
|
self.validate_read_only_command(command)
|
||||||
|
return await self._run_ssh(
|
||||||
|
command,
|
||||||
|
timeout_seconds=timeout_seconds,
|
||||||
|
max_output_bytes=max_output_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def probe(self) -> SSHCommandResult:
|
||||||
|
"""Probe connectivity using a harmless remote command."""
|
||||||
|
return await self._run_ssh(
|
||||||
|
"uname -a",
|
||||||
|
timeout_seconds=15.0,
|
||||||
|
max_output_bytes=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run_ssh( self,
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
timeout_seconds: float,
|
||||||
|
max_output_bytes: int,
|
||||||
|
) -> SSHCommandResult:
|
||||||
|
argv = self.build_ssh_argv(command)
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*argv,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||||
|
proc.communicate(),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
raise TimeoutError(
|
||||||
|
f"SSH command timed out after {timeout_seconds} seconds: {command}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if proc.returncode is None:
|
||||||
|
raise RuntimeError("SSH process did not provide an exit code.")
|
||||||
|
|
||||||
|
stdout_text, stdout_truncated = self._truncate_output(
|
||||||
|
stdout_bytes.decode("utf-8", errors="replace"),
|
||||||
|
max_output_bytes=max_output_bytes,
|
||||||
|
)
|
||||||
|
stderr_text, stderr_truncated = self._truncate_output(
|
||||||
|
stderr_bytes.decode("utf-8", errors="replace"),
|
||||||
|
max_output_bytes=max_output_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SSHCommandResult(
|
||||||
|
command=command,
|
||||||
|
exit_code=proc.returncode,
|
||||||
|
stdout=stdout_text,
|
||||||
|
stderr=stderr_text,
|
||||||
|
stdout_truncated=stdout_truncated,
|
||||||
|
stderr_truncated=stderr_truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _truncate_output(text: str, *, max_output_bytes: int) -> tuple[str, bool]:
|
||||||
|
"""Trim output to a maximum byte length while preserving UTF-8 validity."""
|
||||||
|
if max_output_bytes < 256:
|
||||||
|
raise ValueError("max_output_bytes must be at least 256.")
|
||||||
|
|
||||||
|
encoded = text.encode("utf-8", errors="replace")
|
||||||
|
if len(encoded) <= max_output_bytes:
|
||||||
|
return text.strip(), False
|
||||||
|
|
||||||
|
marker = "\n...[truncated]"
|
||||||
|
marker_bytes = marker.encode("utf-8")
|
||||||
|
keep = max_output_bytes - len(marker_bytes)
|
||||||
|
trimmed_bytes = encoded[:keep]
|
||||||
|
trimmed_text = trimmed_bytes.decode("utf-8", errors="ignore").rstrip()
|
||||||
|
return f"{trimmed_text}{marker}", True
|
||||||
|
|
||||||
|
|
||||||
|
class SSHSession:
|
||||||
|
"""A persistent SSH connection using ControlMaster multiplexing.
|
||||||
|
|
||||||
|
All commands run over the same underlying TCP connection — no per-command
|
||||||
|
SSH handshake. Use as an async context manager::
|
||||||
|
|
||||||
|
async with client.connect() as session:
|
||||||
|
result = await session.run_read_only_command("df -h")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client: SSHClient, *, connect_timeout: float = 15.0) -> None:
|
||||||
|
self._client = client
|
||||||
|
self._connect_timeout = connect_timeout
|
||||||
|
self._socket_path: Path | None = None
|
||||||
|
self._master_proc: asyncio.subprocess.Process | None = None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "SSHSession":
|
||||||
|
fd, path = tempfile.mkstemp(prefix="tai-ssh-", suffix=".sock")
|
||||||
|
os.close(fd)
|
||||||
|
os.unlink(path) # SSH needs to create this itself as a socket
|
||||||
|
self._socket_path = Path(path)
|
||||||
|
|
||||||
|
master_argv = self._client._build_base_argv() + [
|
||||||
|
"-o", "ControlMaster=yes",
|
||||||
|
"-o", f"ControlPath={self._socket_path}",
|
||||||
|
"-o", "ControlPersist=no",
|
||||||
|
"-N",
|
||||||
|
self._client._config.host,
|
||||||
|
]
|
||||||
|
|
||||||
|
self._master_proc = await asyncio.create_subprocess_exec(
|
||||||
|
*master_argv,
|
||||||
|
stdout=asyncio.subprocess.DEVNULL,
|
||||||
|
stderr=asyncio.subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the control socket to appear (master is ready)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
deadline = loop.time() + self._connect_timeout
|
||||||
|
while not self._socket_path.exists():
|
||||||
|
if loop.time() > deadline:
|
||||||
|
await self._teardown()
|
||||||
|
raise TimeoutError(
|
||||||
|
f"SSH ControlMaster did not connect within "
|
||||||
|
f"{self._connect_timeout}s to {self._client._config.host}"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
await self._teardown()
|
||||||
|
|
||||||
|
async def _teardown(self) -> None:
|
||||||
|
if self._master_proc and self._master_proc.returncode is None:
|
||||||
|
self._master_proc.terminate()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._master_proc.wait(), timeout=5.0)
|
||||||
|
except TimeoutError:
|
||||||
|
self._master_proc.kill()
|
||||||
|
if self._socket_path and self._socket_path.exists():
|
||||||
|
self._socket_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def _command_argv(self, remote_command: str) -> list[str]:
|
||||||
|
return self._client._build_base_argv() + [
|
||||||
|
"-o", f"ControlPath={self._socket_path}",
|
||||||
|
"-o", "ControlMaster=no",
|
||||||
|
self._client._config.host,
|
||||||
|
remote_command,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def probe(self) -> SSHCommandResult:
|
||||||
|
"""Run uname -a to confirm connectivity."""
|
||||||
|
return await self._run("uname -a", timeout_seconds=15.0, max_output_bytes=4096)
|
||||||
|
|
||||||
|
async def run_read_only_command(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
timeout_seconds: float = 30.0,
|
||||||
|
max_output_bytes: int = 32768,
|
||||||
|
) -> SSHCommandResult:
|
||||||
|
"""Validate and run a read-only command over the shared connection."""
|
||||||
|
self._client.validate_read_only_command(command)
|
||||||
|
return await self._run(
|
||||||
|
command, timeout_seconds=timeout_seconds, max_output_bytes=max_output_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
timeout_seconds: float,
|
||||||
|
max_output_bytes: int,
|
||||||
|
) -> SSHCommandResult:
|
||||||
|
argv = self._command_argv(command)
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*argv,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||||
|
proc.communicate(), timeout=timeout_seconds
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
raise TimeoutError(
|
||||||
|
f"SSH command timed out after {timeout_seconds}s: {command}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if proc.returncode is None:
|
||||||
|
raise RuntimeError("SSH process did not provide an exit code.")
|
||||||
|
|
||||||
|
stdout_text, stdout_truncated = SSHClient._truncate_output(
|
||||||
|
stdout_bytes.decode("utf-8", errors="replace"),
|
||||||
|
max_output_bytes=max_output_bytes,
|
||||||
|
)
|
||||||
|
stderr_text, stderr_truncated = SSHClient._truncate_output(
|
||||||
|
stderr_bytes.decode("utf-8", errors="replace"),
|
||||||
|
max_output_bytes=max_output_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SSHCommandResult(
|
||||||
|
command=command,
|
||||||
|
exit_code=proc.returncode,
|
||||||
|
stdout=stdout_text,
|
||||||
|
stderr=stderr_text,
|
||||||
|
stdout_truncated=stdout_truncated,
|
||||||
|
stderr_truncated=stderr_truncated,
|
||||||
|
)
|
||||||
|
|
||||||
192
tests/test_ai.py
Normal file
192
tests/test_ai.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""Tests for the AI client and prompt builder."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||||
|
from tai.collectors import CollectedItem, CollectionReport
|
||||||
|
from tai.prompt_builder import build_system_prompt, build_user_message
|
||||||
|
from tai.ssh_client import SSHCommandResult
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AIConfig defaults
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_ai_config_defaults() -> None:
|
||||||
|
config = AIConfig()
|
||||||
|
assert config.host == DEFAULT_AI_HOST
|
||||||
|
assert config.model == DEFAULT_MODEL
|
||||||
|
assert config.api_key == "ollama"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ai_config_custom_values() -> None:
|
||||||
|
config = AIConfig(host="https://api.openai.com/v1", model="gpt-4o", api_key="sk-test")
|
||||||
|
assert config.host == "https://api.openai.com/v1"
|
||||||
|
assert config.model == "gpt-4o"
|
||||||
|
assert config.api_key == "sk-test"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AIClient.summary
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_ai_client_summary_contains_host_and_model() -> None:
|
||||||
|
config = AIConfig(host="http://myserver:11434/v1", model="llama3.1:8b")
|
||||||
|
client = AIClient(config)
|
||||||
|
summary = client.summary()
|
||||||
|
assert "http://myserver:11434/v1" in summary
|
||||||
|
assert "llama3.1:8b" in summary
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AIClient.complete (mocked)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_response(content: str, model: str = "gemma3:4b") -> MagicMock:
|
||||||
|
usage = MagicMock()
|
||||||
|
usage.prompt_tokens = 10
|
||||||
|
usage.completion_tokens = 20
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.content = content
|
||||||
|
|
||||||
|
choice = MagicMock()
|
||||||
|
choice.message = message
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.choices = [choice]
|
||||||
|
response.model = model
|
||||||
|
response.usage = usage
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_returns_ai_response() -> None:
|
||||||
|
config = AIConfig()
|
||||||
|
client = AIClient(config)
|
||||||
|
mock_response = _make_mock_response("The root cause is X.")
|
||||||
|
|
||||||
|
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
|
||||||
|
result = client.complete("system prompt", "user message")
|
||||||
|
|
||||||
|
assert result.content == "The root cause is X."
|
||||||
|
assert result.prompt_tokens == 10
|
||||||
|
assert result.completion_tokens == 20
|
||||||
|
assert result.total_tokens == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_handles_empty_content() -> None:
|
||||||
|
config = AIConfig()
|
||||||
|
client = AIClient(config)
|
||||||
|
mock_response = _make_mock_response(None) # type: ignore[arg-type]
|
||||||
|
mock_response.choices[0].message.content = None
|
||||||
|
|
||||||
|
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
|
||||||
|
result = client.complete("system", "user")
|
||||||
|
|
||||||
|
assert result.content == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AIClient.stream (mocked)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_yields_chunks() -> None:
|
||||||
|
config = AIConfig()
|
||||||
|
client = AIClient(config)
|
||||||
|
|
||||||
|
def _make_chunk(text: str | None) -> MagicMock:
|
||||||
|
delta = MagicMock()
|
||||||
|
delta.content = text
|
||||||
|
choice = MagicMock()
|
||||||
|
choice.delta = delta
|
||||||
|
chunk = MagicMock()
|
||||||
|
chunk.choices = [choice]
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
mock_chunks = [
|
||||||
|
_make_chunk("Root "), _make_chunk("cause "), _make_chunk(None), _make_chunk("found."),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
|
||||||
|
result = list(client.stream("system", "user"))
|
||||||
|
|
||||||
|
assert result == ["Root ", "cause ", "found."]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# prompt_builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_report(items: list[tuple[str, str, int, str, str]]) -> CollectionReport:
|
||||||
|
"""Build a CollectionReport from (name, command, exit_code, stdout, stderr) tuples."""
|
||||||
|
return CollectionReport(
|
||||||
|
host="root@testhost",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name=name,
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command=command,
|
||||||
|
exit_code=exit_code,
|
||||||
|
stdout=stdout,
|
||||||
|
stderr=stderr,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for name, command, exit_code, stdout, stderr in items
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_prompt_contains_key_instructions() -> None:
|
||||||
|
prompt = build_system_prompt()
|
||||||
|
assert "Root Cause" in prompt
|
||||||
|
assert "Evidence" in prompt
|
||||||
|
assert "Recommended Actions" in prompt
|
||||||
|
assert "read-only" in prompt.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_contains_issue_and_host() -> None:
|
||||||
|
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
|
||||||
|
msg = build_user_message("nginx is failing", report)
|
||||||
|
assert "nginx is failing" in msg
|
||||||
|
assert "root@testhost" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_includes_command_output() -> None:
|
||||||
|
report = _make_report([("kernel", "uname -a", 0, "Linux web01 6.1.0", "")])
|
||||||
|
msg = build_user_message("test issue", report)
|
||||||
|
assert "uname -a" in msg
|
||||||
|
assert "Linux web01 6.1.0" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_shows_stderr() -> None:
|
||||||
|
report = _make_report(
|
||||||
|
[("svc", "systemctl status nginx", 3, "", "Unit nginx.service not found.")]
|
||||||
|
)
|
||||||
|
msg = build_user_message("nginx not found", report)
|
||||||
|
assert "Unit nginx.service not found." in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_notes_truncation() -> None:
|
||||||
|
result = SSHCommandResult(
|
||||||
|
command="journalctl -n 100 --no-pager",
|
||||||
|
exit_code=0,
|
||||||
|
stdout="...",
|
||||||
|
stderr="",
|
||||||
|
stdout_truncated=True,
|
||||||
|
)
|
||||||
|
report = CollectionReport(
|
||||||
|
host="root@testhost",
|
||||||
|
items=[CollectedItem(name="journal", result=result)],
|
||||||
|
)
|
||||||
|
msg = build_user_message("disk issue", report)
|
||||||
|
assert "truncated" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_message_handles_no_output() -> None:
|
||||||
|
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
|
||||||
|
msg = build_user_message("test", report)
|
||||||
|
assert "no output" in msg
|
||||||
141
tests/test_cli.py
Normal file
141
tests/test_cli.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from tai.cli import app
|
||||||
|
from tai.collectors import CollectedItem, CollectionReport
|
||||||
|
from tai.ssh_client import SSHCommandResult
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_session(
|
||||||
|
monkeypatch, # type: ignore[no-untyped-def]
|
||||||
|
*,
|
||||||
|
probe_result: SSHCommandResult | None = None,
|
||||||
|
probe_raises: Exception | None = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Patch SSHClient.connect to return a mock session."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.__aenter__ = AsyncMock(return_value=session)
|
||||||
|
session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
if probe_raises:
|
||||||
|
session.probe = AsyncMock(side_effect=probe_raises)
|
||||||
|
else:
|
||||||
|
session.probe = AsyncMock(return_value=probe_result)
|
||||||
|
monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_command_prints_scaffold_summary() -> None:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"apache failed",
|
||||||
|
"--host",
|
||||||
|
"web01",
|
||||||
|
"--port",
|
||||||
|
"5566",
|
||||||
|
"--no-probe",
|
||||||
|
"--path",
|
||||||
|
"/etc/apache2",
|
||||||
|
"--jump-host",
|
||||||
|
"bastion01",
|
||||||
|
"--ignore-ssh-config",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "tai" in result.stdout
|
||||||
|
assert "host=web01" in result.stdout
|
||||||
|
assert "port=5566" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||||
|
_mock_session(
|
||||||
|
monkeypatch,
|
||||||
|
probe_result=SSHCommandResult(
|
||||||
|
command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr=""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Probe succeeded" in result.stdout
|
||||||
|
assert "Linux ssh 6.12.0" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||||
|
_mock_session(
|
||||||
|
monkeypatch,
|
||||||
|
probe_result=SSHCommandResult(
|
||||||
|
command="uname -a",
|
||||||
|
exit_code=255,
|
||||||
|
stdout="",
|
||||||
|
stderr="Permission denied (publickey,password).",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Probe failed" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||||
|
_mock_session(monkeypatch)
|
||||||
|
|
||||||
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||||
|
return CollectionReport(
|
||||||
|
host="ssh.archflux.net",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name="kernel",
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command="uname -a",
|
||||||
|
exit_code=0,
|
||||||
|
stdout="Linux test",
|
||||||
|
stderr="",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
CollectedItem(
|
||||||
|
name="journal",
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command="journalctl -n 200",
|
||||||
|
exit_code=0,
|
||||||
|
stdout="...",
|
||||||
|
stderr="",
|
||||||
|
stdout_truncated=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"apache failed",
|
||||||
|
"--host",
|
||||||
|
"ssh.archflux.net",
|
||||||
|
"--port",
|
||||||
|
"5566",
|
||||||
|
"--no-probe",
|
||||||
|
"--collect",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Collection complete" in result.stdout
|
||||||
|
assert "kernel: ok" in result.stdout
|
||||||
|
assert "journal: ok (truncated)" in result.stdout
|
||||||
65
tests/test_input_parser.py
Normal file
65
tests/test_input_parser.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tai.input_parser import InputValidationError, build_request
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_normalizes_values() -> None:
|
||||||
|
req = build_request(
|
||||||
|
issue=" apache fails to start ",
|
||||||
|
host=" web01 ",
|
||||||
|
port=5566,
|
||||||
|
target_paths=["/etc/apache2", "~/logs"],
|
||||||
|
identity_file="~/.ssh/id_ed25519",
|
||||||
|
jump_host=" bastion01 ",
|
||||||
|
ignore_ssh_config=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert req.issue == "apache fails to start"
|
||||||
|
assert req.host == "web01"
|
||||||
|
assert req.port == 5566
|
||||||
|
assert req.target_paths[0] == Path("/etc/apache2")
|
||||||
|
assert req.target_paths[1] == Path("~/logs").expanduser()
|
||||||
|
assert req.identity_file == Path("~/.ssh/id_ed25519").expanduser()
|
||||||
|
assert req.jump_host == "bastion01"
|
||||||
|
assert req.ignore_ssh_config is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_rejects_empty_issue() -> None:
|
||||||
|
with pytest.raises(InputValidationError):
|
||||||
|
build_request(
|
||||||
|
issue=" ",
|
||||||
|
host="web01",
|
||||||
|
port=22,
|
||||||
|
target_paths=[],
|
||||||
|
identity_file=None,
|
||||||
|
jump_host=None,
|
||||||
|
ignore_ssh_config=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_rejects_empty_host() -> None:
|
||||||
|
with pytest.raises(InputValidationError):
|
||||||
|
build_request(
|
||||||
|
issue="apache down",
|
||||||
|
host=" ",
|
||||||
|
port=22,
|
||||||
|
target_paths=[],
|
||||||
|
identity_file=None,
|
||||||
|
jump_host=None,
|
||||||
|
ignore_ssh_config=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_request_rejects_invalid_port() -> None:
|
||||||
|
with pytest.raises(InputValidationError):
|
||||||
|
build_request(
|
||||||
|
issue="apache down",
|
||||||
|
host="web01",
|
||||||
|
port=70000,
|
||||||
|
target_paths=[],
|
||||||
|
identity_file=None,
|
||||||
|
jump_host=None,
|
||||||
|
ignore_ssh_config=False,
|
||||||
|
)
|
||||||
169
tests/test_plan.py
Normal file
169
tests/test_plan.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Tests for the collection plan builder."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tai.models import TroubleshootRequest
|
||||||
|
from tai.plan import CollectionPlan, _extract_services, _issue_words, plan_from_request
|
||||||
|
|
||||||
|
|
||||||
|
def _req(issue: str, paths: list[str] | None = None) -> TroubleshootRequest:
|
||||||
|
return TroubleshootRequest(
|
||||||
|
issue=issue,
|
||||||
|
host="root@testhost",
|
||||||
|
target_paths=[Path(p) for p in (paths or [])],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _commands(plan: CollectionPlan) -> list[str]:
|
||||||
|
"""Return flat list of command strings from plan."""
|
||||||
|
return [cmd for _, cmd in plan.commands]
|
||||||
|
|
||||||
|
|
||||||
|
def _names(plan: CollectionPlan) -> list[str]:
|
||||||
|
return [name for name, _ in plan.commands]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Always-present commands
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_plan_always_has_baseline_commands() -> None:
|
||||||
|
plan = plan_from_request(_req("some generic issue"))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert any("uname -a" in c for c in cmds)
|
||||||
|
assert any("df -h" in c for c in cmds)
|
||||||
|
assert any("proc/meminfo" in c for c in cmds)
|
||||||
|
assert any("systemctl list-units" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Keyword-based category expansion
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_service_keywords_add_failed_services_check() -> None:
|
||||||
|
plan = plan_from_request(_req("service failed to start"))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert any("--state=failed" in c for c in cmds)
|
||||||
|
assert any("journalctl -p err" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_network_keywords_add_network_commands() -> None:
|
||||||
|
plan = plan_from_request(_req("connection refused on port 80"))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert any("ss -lntp" in c for c in cmds)
|
||||||
|
assert any("ip addr show" in c for c in cmds)
|
||||||
|
assert any("ip route show" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disk_keywords_add_disk_commands() -> None:
|
||||||
|
plan = plan_from_request(_req("disk full filesystem usage critical"))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert any("df -i" in c for c in cmds)
|
||||||
|
assert any("dmesg" in c for c in cmds)
|
||||||
|
assert any("du -sh" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unrelated_issue_does_not_add_network_commands() -> None:
|
||||||
|
plan = plan_from_request(_req("apache service crashed"))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert not any("ip route show" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Named service detection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_nginx_in_issue_adds_nginx_service_commands() -> None:
|
||||||
|
plan = plan_from_request(_req("nginx is failing to start"))
|
||||||
|
names = _names(plan)
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert "service-nginx" in names
|
||||||
|
assert "journal-nginx" in names
|
||||||
|
assert any("systemctl status nginx" in c for c in cmds)
|
||||||
|
assert any("journalctl -u nginx" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apache2_adds_config_cat() -> None:
|
||||||
|
plan = plan_from_request(_req("apache2 service check"))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert any("cat /etc/apache2/apache2.conf" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sshd_adds_config_cat() -> None:
|
||||||
|
plan = plan_from_request(_req("sshd connection problems"))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert any("cat /etc/ssh/sshd_config" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_service_name_no_config_cat() -> None:
|
||||||
|
plan = plan_from_request(_req("myweirdapp service crashed"))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert not any("cat /etc" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_service_name_not_repeated() -> None:
|
||||||
|
plan = plan_from_request(_req("nginx nginx nginx"))
|
||||||
|
names = _names(plan)
|
||||||
|
assert names.count("service-nginx") == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Target path handling
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_target_path_adds_ls_and_find() -> None:
|
||||||
|
plan = plan_from_request(_req("app crash", paths=["/opt/myapp"]))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert any("ls -la /opt/myapp" in c for c in cmds)
|
||||||
|
assert any("find /opt/myapp" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_path_uses_log_find_pattern() -> None:
|
||||||
|
plan = plan_from_request(_req("app errors", paths=["/var/log/myapp"]))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert any("*.log" in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_log_path_uses_generic_find() -> None:
|
||||||
|
plan = plan_from_request(_req("config issue", paths=["/etc/myapp"]))
|
||||||
|
cmds = _commands(plan)
|
||||||
|
assert any("find /etc/myapp" in c and "*.log" not in c for c in cmds)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper unit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue_words_lowercases_and_splits() -> None:
|
||||||
|
words = _issue_words("Apache Service FAILED")
|
||||||
|
assert "apache" in words
|
||||||
|
assert "service" in words
|
||||||
|
assert "failed" in words
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_services_finds_nginx() -> None:
|
||||||
|
assert "nginx" in _extract_services("nginx is down")
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_services_finds_nothing_for_unknown() -> None:
|
||||||
|
assert _extract_services("the widget is broken") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_services_case_insensitive() -> None:
|
||||||
|
assert "nginx" in _extract_services("NGINX failed")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Plan length sanity
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_plain_issue_has_only_always_commands() -> None:
|
||||||
|
plan = plan_from_request(_req("something went wrong"))
|
||||||
|
# Only _ALWAYS (5 commands), no category expansion, no service, no paths
|
||||||
|
assert len(plan) == 5
|
||||||
108
tests/test_ssh_client.py
Normal file
108
tests/test_ssh_client.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tai.ssh_client import SSHClient, SSHCommandRejectedError, SSHConnectionConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _client(**kwargs: object) -> SSHClient:
|
||||||
|
host = str(kwargs.get("host", "root@ssh.archflux.net"))
|
||||||
|
port_value = kwargs.get("port", 22)
|
||||||
|
if not isinstance(port_value, int):
|
||||||
|
raise TypeError("port must be an int")
|
||||||
|
port = port_value
|
||||||
|
identity_file = kwargs.get("identity_file")
|
||||||
|
jump_host = kwargs.get("jump_host")
|
||||||
|
ignore_ssh_config = bool(kwargs.get("ignore_ssh_config", False))
|
||||||
|
|
||||||
|
if identity_file is not None and not isinstance(identity_file, Path):
|
||||||
|
raise TypeError("identity_file must be a Path or None")
|
||||||
|
|
||||||
|
if jump_host is not None and not isinstance(jump_host, str):
|
||||||
|
raise TypeError("jump_host must be a string or None")
|
||||||
|
|
||||||
|
return SSHClient(
|
||||||
|
SSHConnectionConfig(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
identity_file=identity_file,
|
||||||
|
jump_host=jump_host,
|
||||||
|
ignore_ssh_config=ignore_ssh_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary_includes_expected_defaults() -> None:
|
||||||
|
client = _client()
|
||||||
|
text = client.summary()
|
||||||
|
|
||||||
|
assert "host=root@ssh.archflux.net" in text
|
||||||
|
assert "port=22" in text
|
||||||
|
assert "key=auto" in text
|
||||||
|
assert "jump=none" in text
|
||||||
|
assert "mode=use ssh config" in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_ssh_argv_respects_flags() -> None:
|
||||||
|
client = _client(
|
||||||
|
identity_file=Path("/root/.ssh/id_ed25519"),
|
||||||
|
jump_host="bastion.archflux.net",
|
||||||
|
ignore_ssh_config=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
argv = client.build_ssh_argv("uname -a")
|
||||||
|
|
||||||
|
assert argv[0] == "ssh"
|
||||||
|
assert "-p" in argv
|
||||||
|
assert "22" in argv
|
||||||
|
assert "-F" in argv
|
||||||
|
assert "/dev/null" in argv
|
||||||
|
assert "-i" in argv
|
||||||
|
assert "/root/.ssh/id_ed25519" in argv
|
||||||
|
assert "-J" in argv
|
||||||
|
assert "bastion.archflux.net" in argv
|
||||||
|
assert argv[-2] == "root@ssh.archflux.net"
|
||||||
|
assert argv[-1] == "uname -a"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_destructive_or_shell_operator_commands() -> None:
|
||||||
|
client = _client()
|
||||||
|
|
||||||
|
for command in ["rm -rf /tmp/x", "cat /etc/hosts | grep localhost", "uname -a; id"]:
|
||||||
|
with pytest.raises(SSHCommandRejectedError):
|
||||||
|
client.validate_read_only_command(command)
|
||||||
|
|
||||||
|
|
||||||
|
def test_allows_expected_read_only_commands() -> None:
|
||||||
|
client = _client()
|
||||||
|
|
||||||
|
for command in [
|
||||||
|
"uname -a",
|
||||||
|
"journalctl -n 100",
|
||||||
|
"systemctl status apache2",
|
||||||
|
"cat /etc/hosts",
|
||||||
|
"ss -lntp",
|
||||||
|
]:
|
||||||
|
client.validate_read_only_command(command)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_non_read_only_systemctl_subcommand() -> None:
|
||||||
|
client = _client()
|
||||||
|
|
||||||
|
with pytest.raises(SSHCommandRejectedError):
|
||||||
|
client.validate_read_only_command("systemctl restart apache2")
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_output_marks_and_limits_content() -> None:
|
||||||
|
text = "a" * 400
|
||||||
|
rendered, truncated = SSHClient._truncate_output(text, max_output_bytes=256)
|
||||||
|
|
||||||
|
assert truncated is True
|
||||||
|
assert rendered.endswith("...[truncated]")
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_output_keeps_short_content() -> None:
|
||||||
|
rendered, truncated = SSHClient._truncate_output("short output", max_output_bytes=256)
|
||||||
|
|
||||||
|
assert truncated is False
|
||||||
|
assert rendered == "short output"
|
||||||
Reference in New Issue
Block a user