initialCommit #2
32
.gitea/workflows/ci.yml
Normal file
32
.gitea/workflows/ci.yml
Normal file
@@ -0,0 +1,32 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install package and dev dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e .[dev]
|
||||
|
||||
- name: Lint
|
||||
run: ruff check .
|
||||
|
||||
- name: Type-check
|
||||
run: mypy src
|
||||
|
||||
- name: Test
|
||||
run: pytest
|
||||
32
.github/workflows/ci.yml
vendored
Normal file
32
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install package and dev dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e .[dev]
|
||||
|
||||
- name: Lint
|
||||
run: ruff check .
|
||||
|
||||
- name: Type-check
|
||||
run: mypy src
|
||||
|
||||
- name: Test
|
||||
run: pytest
|
||||
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
# Python cache and bytecode
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# Tool caches
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
.mypy_cache/
|
||||
|
||||
# Build artifacts
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
*.spec
|
||||
|
||||
# Coverage
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
34
CHANGELOG.md
Normal file
34
CHANGELOG.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
|
||||
---
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- `README.md` — project overview, description, example workflow, supported distributions, and suggested tooling
|
||||
- `ROADMAP.md` — phased development plan covering decisions, data collection, AI integration, CLI design, and hardening
|
||||
- `CHANGELOG.md` — this file; established changelog tracking for the project
|
||||
- `.gitea/workflows/ci.yml` — Gitea Actions CI workflow for push and pull request events
|
||||
- Python package scaffold with `src` layout and project metadata in `pyproject.toml`
|
||||
- Initial CLI entrypoint with agreed SSH flags: `--identity-file`, `--jump-host`, and `--ignore-ssh-config`
|
||||
- Input parsing/validation module and core request model
|
||||
- SSH configuration scaffold module for upcoming connection/read-only execution work
|
||||
- Implemented SSH module with real key-based command execution via system `ssh`
|
||||
- Added explicit SSH port support across CLI, input parsing, request model, and SSH client (`--port`, e.g. 5566)
|
||||
- Added live SSH connectivity probe (`uname -a`) enabled by default, with `--no-probe` opt-out and non-zero exit on failure
|
||||
- Read-only command policy enforcement (allowlist + blocked shell operators)
|
||||
- Test scaffold (`pytest`) with initial parser and CLI coverage
|
||||
- SSH test coverage for policy checks, SSH argument construction, and config summary behavior
|
||||
- CI workflow for lint (`ruff`), type-check (`mypy`), and tests (`pytest`)
|
||||
|
||||
### Decided
|
||||
- Implementation language: **Python**
|
||||
- Distribution strategy: single distributable binary via **Nuitka** (PyInstaller as fallback)
|
||||
- SSH authentication: **keypair only** (ed25519/RSA); auto-accept new hosts; hard reject on host key change with MITM warning
|
||||
- SSH bastion support: `--jump-host` flag using SSH native ProxyJump
|
||||
- SSH config behavior: use `~/.ssh/config` by default; allow override via `--ignore-ssh-config`
|
||||
- Interface: **interactive REPL** for v0.1; `textual`-based TUI (split-pane) for v0.2+
|
||||
36
README.md
36
README.md
@@ -1,3 +1,35 @@
|
||||
# tai
|
||||
# tai — Linux AI Troubleshooting Agent
|
||||
|
||||
Linux AI driven troubleshooting agent.
|
||||
`tai` is an agentic AI-driven troubleshooting tool for Linux systems. It autonomously investigates issues on remote hosts via SSH, analyzes relevant logs and configuration files, and provides a clear diagnosis along with suggested remediation steps — all without making any changes to the target system.
|
||||
|
||||
## Overview
|
||||
|
||||
Given a problem description and a target hostname, `tai` connects to the remote system over SSH, gathers relevant data (logs, configuration files, service status, etc.), and uses a locally-hosted AI model to reason about the root cause and recommend solutions.
|
||||
|
||||
The agent operates in **read-only mode at all times**. It will never modify the target system under any circumstances — all suggestions are presented to the human troubleshooter for review and action.
|
||||
|
||||
## Supported Distributions
|
||||
|
||||
- Ubuntu
|
||||
- Debian
|
||||
- RHEL
|
||||
- Rocky Linux
|
||||
|
||||
## Example Workflow
|
||||
|
||||
A troubleshooter receives a ticket reporting that the Apache service on a remote server has failed to start. They provide `tai` with:
|
||||
|
||||
1. The ticket description or error message
|
||||
2. The hostname of the affected system
|
||||
3. Any relevant directories to focus on
|
||||
|
||||
`tai` then connects to the host, reads through system logs, service configurations, and any other related files, and returns a structured analysis of the likely cause along with recommended next steps.
|
||||
|
||||
## Suggested Tooling
|
||||
|
||||
| Component | Tool |
|
||||
|-----------|------|
|
||||
| AI inference backend | [vLLM](https://github.com/vllm-project/vllm) |
|
||||
| Model | `gemma4:a4b` |
|
||||
|
||||
> **Note:** A suitable implementation language for this project is yet to be determined.
|
||||
126
ROADMAP.md
Normal file
126
ROADMAP.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Roadmap
|
||||
|
||||
This document outlines the major decisions, milestones, and development phases required to bring `tai` from concept to a working tool.
|
||||
|
||||
---
|
||||
|
||||
## Phase 0 — Decisions & Prerequisites
|
||||
|
||||
These must be resolved before meaningful development can begin.
|
||||
|
||||
### Language Selection
|
||||
- [x] **Decision: Python**
|
||||
- Key factors: native vLLM integration, mature SSH libraries (`paramiko` / `asyncssh`), strong text/log parsing, rapid development
|
||||
- Single binary distribution will be achieved via **Nuitka** (preferred for true compilation) or **PyInstaller** as a fallback
|
||||
- [ ] Evaluate Nuitka vs PyInstaller for binary output quality and CI reproducibility
|
||||
- [ ] Add binary build step to CI pipeline
|
||||
|
||||
### AI Backend & Model
|
||||
- [ ] Confirm use of [vLLM](https://github.com/vllm-project/vllm) as the inference backend
|
||||
- [ ] Confirm `gemma4:a4b` as the default model (or select an alternative)
|
||||
- [ ] Define minimum hardware requirements for running the model locally
|
||||
- [ ] Decide whether the AI backend is bundled, self-hosted externally, or user-supplied
|
||||
|
||||
### SSH Strategy
|
||||
- [x] **Decision: keypair authentication only** — no password auth; eliminates credential storage risk
|
||||
- Default key resolution: `~/.ssh/id_ed25519`, `~/.ssh/id_rsa` (in order of preference)
|
||||
- CLI override via `--identity-file <path>`
|
||||
- No SSH agent forwarding needed — a shared key is distributed to all managed hosts via Puppet
|
||||
- [x] **Known hosts: auto-accept new hosts; reject on key mismatch** — a changed host key triggers a hard stop with a MITM warning; unknown/new hosts are accepted silently on first connect
|
||||
- [x] **Bastion/jump host: `--jump-host <host>` flag** — delegates to SSH's native ProxyJump functionality
|
||||
- [x] **SSH config behavior: respect existing `~/.ssh/config` by default; allow CLI override**
|
||||
- Default: follow host settings from `~/.ssh/config` (for `User`, `Port`, `ProxyJump`, etc.)
|
||||
- Override switch: `--ignore-ssh-config` to bypass local SSH config when required
|
||||
|
||||
### Scope & Constraints
|
||||
- [ ] Define the supported scope of issues (services, network, disk, kernel, etc.)
|
||||
- [ ] Confirm read-only guarantee — document exactly what "read-only" means in practice
|
||||
- [x] **Decision: interactive REPL mode for v0.1, full TUI for v0.2+**
|
||||
- v0.1: chat-loop REPL launched from CLI; human can follow up, correct, and redirect the agent
|
||||
- v0.2+: `textual`-based TUI with split panes (collected data | AI output | input bar)
|
||||
- Built-in slash commands: `/collect`, `/show logs`, `/clear`, `/host <hostname>`, `/help`, `/quit`
|
||||
|
||||
---
|
||||
|
||||
## Phase 1 — Project Foundation
|
||||
|
||||
Basic project scaffolding and connectivity.
|
||||
|
||||
- [x] Finalise repository structure and language toolchain
|
||||
- [x] Set up CI pipeline (linting, tests)
|
||||
- [ ] Implement SSH connection module
|
||||
- [x] Define SSH config model and probe interface scaffold
|
||||
- [x] Connect to remote host
|
||||
- [x] Execute read-only commands (e.g. `journalctl`, `systemctl status`, `cat`)
|
||||
- [ ] Stream or collect command output safely
|
||||
- [x] Implement basic input parsing (ticket text, hostname, target directories)
|
||||
- [x] Write unit tests for SSH and input modules
|
||||
- [x] Input parser and CLI tests added
|
||||
- [x] SSH module tests added for command policy and SSH argv behavior
|
||||
|
||||
---
|
||||
|
||||
## Phase 2 — Data Collection Layer
|
||||
|
||||
Define what information the agent gathers and how.
|
||||
|
||||
- [ ] Identify the canonical set of data sources per issue type:
|
||||
- Service failures: `journalctl`, `systemctl`, service config files
|
||||
- Network issues: `ip`, `ss`, `netstat`, firewall rules
|
||||
- Disk issues: `df`, `du`, `dmesg`, `smartctl`
|
||||
- General: `/var/log/syslog`, `/var/log/messages`, `dmesg`
|
||||
- [ ] Implement pluggable "collector" modules per data source
|
||||
- [ ] Implement directory traversal for user-specified paths (read-only)
|
||||
- [ ] Add support for per-distro variations (Ubuntu vs RHEL path differences, etc.)
|
||||
- [ ] Write tests with mocked SSH output
|
||||
|
||||
---
|
||||
|
||||
## Phase 3 — AI Integration
|
||||
|
||||
Wire collected data into the local AI model.
|
||||
|
||||
- [ ] Implement vLLM client module
|
||||
- [ ] Design prompt template: system context, collected data, issue description → diagnosis
|
||||
- [ ] Implement response parsing and structured output (root cause + suggested steps)
|
||||
- [ ] Tune context window usage — handle truncation for large log outputs
|
||||
- [ ] Add streaming support for long AI responses
|
||||
- [ ] Evaluate and test model output quality on common issue types
|
||||
|
||||
---
|
||||
|
||||
## Phase 4 — CLI & User Experience
|
||||
|
||||
Polish the interface for real-world use.
|
||||
|
||||
- [ ] Design CLI interface (flags, subcommands, interactive prompts)
|
||||
- [ ] Implement structured output: diagnosis, confidence, recommended actions
|
||||
- [ ] Add `--verbose` / `--debug` mode showing raw collected data
|
||||
- [ ] Support output to file or clipboard
|
||||
- [ ] Write man page / `--help` documentation
|
||||
|
||||
---
|
||||
|
||||
## Phase 5 — Hardening & Distribution
|
||||
|
||||
Prepare for broader use.
|
||||
|
||||
- [ ] Security review of SSH handling and credential storage
|
||||
- [ ] Ensure no data is written to the remote system under any path
|
||||
- [ ] Package for distribution (binary release, container image, or distro packages)
|
||||
- [ ] Write installation and quickstart documentation
|
||||
- [ ] End-to-end integration tests against a test VM
|
||||
|
||||
---
|
||||
|
||||
## Decisions Log
|
||||
|
||||
| Date | Decision | Outcome |
|
||||
|------|----------|---------|
|
||||
| 2026-05-04 | Implementation language | Python — with single distributable binary via Nuitka |
|
||||
| — | AI inference backend | vLLM (provisional) |
|
||||
| — | Default model | `gemma4:a4b` (provisional) |
|
||||
| 2026-05-04 | SSH auth methods | Keypair only (ed25519/RSA); auto-accept new hosts; reject on key change (MITM) |
|
||||
| 2026-05-04 | Bastion host support | `--jump-host` flag via SSH native ProxyJump |
|
||||
| 2026-05-04 | SSH config behavior | Use `~/.ssh/config` by default; allow override via `--ignore-ssh-config` |
|
||||
| 2026-05-04 | CLI vs interactive mode | Interactive: REPL for v0.1, `textual` TUI for v0.2+ |
|
||||
50
pyproject.toml
Normal file
50
pyproject.toml
Normal file
@@ -0,0 +1,50 @@
|
||||
[build-system]
|
||||
requires = ["hatchling>=1.25"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "tai"
|
||||
version = "0.1.0"
|
||||
description = "Linux AI-driven troubleshooting agent"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
authors = [
|
||||
{ name = "tai contributors" }
|
||||
]
|
||||
dependencies = [
|
||||
"typer>=0.12,<1.0",
|
||||
"rich>=13.7,<14.0",
|
||||
"asyncssh>=2.14,<3.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.2,<9.0",
|
||||
"ruff>=0.5,<1.0",
|
||||
"mypy>=1.10,<2.0",
|
||||
]
|
||||
build = [
|
||||
"nuitka>=2.4,<3.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
tai = "tai.cli:main"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/tai"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
addopts = "-q"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py311"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "UP", "B"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
strict = true
|
||||
warn_unused_configs = true
|
||||
5
src/tai/__init__.py
Normal file
5
src/tai/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""tai package."""
|
||||
|
||||
__all__ = ["__version__"]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
117
src/tai/cli.py
Normal file
117
src/tai/cli.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""CLI entrypoint for tai."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
|
||||
from tai.input_parser import InputValidationError, build_request
|
||||
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig
|
||||
|
||||
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||
console = Console()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(
|
||||
issue: Annotated[str, typer.Argument(help="Ticket text or issue summary.")],
|
||||
host: Annotated[str, typer.Option("--host", help="Target host to troubleshoot.")],
|
||||
port: Annotated[int, typer.Option("--port", help="SSH port for the target host.")] = 22,
|
||||
path: Annotated[
|
||||
list[str] | None,
|
||||
typer.Option("--path", help="Path to inspect. Repeatable."),
|
||||
] = None,
|
||||
identity_file: Annotated[
|
||||
str | None,
|
||||
typer.Option("--identity-file", help="SSH private key path."),
|
||||
] = None,
|
||||
jump_host: Annotated[
|
||||
str | None,
|
||||
typer.Option("--jump-host", help="SSH bastion/jump host."),
|
||||
] = None,
|
||||
ignore_ssh_config: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--ignore-ssh-config",
|
||||
help="Ignore ~/.ssh/config and rely only on CLI options.",
|
||||
),
|
||||
] = False,
|
||||
probe: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--probe/--no-probe",
|
||||
help="Enable or disable live SSH connectivity probe (uname -a).",
|
||||
),
|
||||
] = True,
|
||||
) -> None:
|
||||
"""Start an interactive troubleshooting session scaffold."""
|
||||
try:
|
||||
req = build_request(
|
||||
issue=issue,
|
||||
host=host,
|
||||
port=port,
|
||||
target_paths=path or [],
|
||||
identity_file=identity_file,
|
||||
jump_host=jump_host,
|
||||
ignore_ssh_config=ignore_ssh_config,
|
||||
)
|
||||
except InputValidationError as exc:
|
||||
console.print(f"[red]Input error:[/red] {exc}")
|
||||
raise typer.Exit(code=2) from exc
|
||||
|
||||
config = SSHConnectionConfig(
|
||||
host=req.host,
|
||||
port=req.port,
|
||||
identity_file=req.identity_file,
|
||||
jump_host=req.jump_host,
|
||||
ignore_ssh_config=req.ignore_ssh_config,
|
||||
)
|
||||
|
||||
summary = SSHClient(config).summary()
|
||||
console.print("[bold green]tai scaffold ready[/bold green]")
|
||||
console.print(f"Issue: {req.issue}")
|
||||
console.print(f"SSH: {summary}")
|
||||
if req.target_paths:
|
||||
console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}")
|
||||
|
||||
if probe:
|
||||
_run_probe(SSHClient(config))
|
||||
|
||||
|
||||
def _run_probe(client: SSHClient) -> None:
|
||||
"""Run a live SSH probe and exit non-zero on failure."""
|
||||
console.print("[cyan]Running SSH probe:[/cyan] uname -a")
|
||||
try:
|
||||
result = asyncio.run(client.probe())
|
||||
except TimeoutError as exc:
|
||||
console.print(f"[red]Probe failed:[/red] {exc}")
|
||||
raise typer.Exit(code=1) from exc
|
||||
except OSError as exc:
|
||||
console.print(f"[red]Probe failed:[/red] unable to execute ssh: {exc}")
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
_handle_probe_result(result)
|
||||
|
||||
|
||||
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||
"""Handle and render probe output for success or failure."""
|
||||
if result.exit_code != 0:
|
||||
details = result.stderr or result.stdout or "no error output from ssh"
|
||||
console.print(f"[red]Probe failed (exit {result.exit_code}):[/red] {details}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
output = result.stdout or "(no output)"
|
||||
console.print("[bold green]Probe succeeded.[/bold green]")
|
||||
console.print(f"Remote: {output}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Console script entrypoint."""
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
46
src/tai/input_parser.py
Normal file
46
src/tai/input_parser.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Helpers to normalize and validate CLI input."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from tai.models import TroubleshootRequest
|
||||
|
||||
|
||||
class InputValidationError(ValueError):
|
||||
"""Raised when required user input is missing or invalid."""
|
||||
|
||||
|
||||
def build_request(
|
||||
*,
|
||||
issue: str,
|
||||
host: str,
|
||||
port: int,
|
||||
target_paths: list[str],
|
||||
identity_file: str | None,
|
||||
jump_host: str | None,
|
||||
ignore_ssh_config: bool,
|
||||
) -> TroubleshootRequest:
|
||||
"""Create a normalized request object from raw CLI values."""
|
||||
normalized_issue = issue.strip()
|
||||
normalized_host = host.strip()
|
||||
|
||||
if not normalized_issue:
|
||||
raise InputValidationError("Issue description cannot be empty.")
|
||||
|
||||
if not normalized_host:
|
||||
raise InputValidationError("Host cannot be empty.")
|
||||
|
||||
if port < 1 or port > 65535:
|
||||
raise InputValidationError("Port must be between 1 and 65535.")
|
||||
|
||||
paths = [Path(p).expanduser() for p in target_paths]
|
||||
identity = Path(identity_file).expanduser() if identity_file else None
|
||||
|
||||
return TroubleshootRequest(
|
||||
issue=normalized_issue,
|
||||
host=normalized_host,
|
||||
port=port,
|
||||
target_paths=paths,
|
||||
identity_file=identity,
|
||||
jump_host=jump_host.strip() if jump_host else None,
|
||||
ignore_ssh_config=ignore_ssh_config,
|
||||
)
|
||||
17
src/tai/models.py
Normal file
17
src/tai/models.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Core domain models for tai."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TroubleshootRequest:
|
||||
"""User-provided troubleshooting input for a single run."""
|
||||
|
||||
issue: str
|
||||
host: str
|
||||
port: int = 22
|
||||
target_paths: list[Path] = field(default_factory=list)
|
||||
identity_file: Path | None = None
|
||||
jump_host: str | None = None
|
||||
ignore_ssh_config: bool = False
|
||||
193
src/tai/ssh_client.py
Normal file
193
src/tai/ssh_client.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""SSH configuration and read-only command execution."""
|
||||
|
||||
import asyncio
|
||||
import shlex
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SSHConnectionConfig:
|
||||
"""Connection parameters for the target host."""
|
||||
|
||||
host: str
|
||||
port: int = 22
|
||||
identity_file: Path | None = None
|
||||
jump_host: str | None = None
|
||||
ignore_ssh_config: bool = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SSHCommandResult:
|
||||
"""Result of a remote SSH command execution."""
|
||||
|
||||
command: str
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
|
||||
|
||||
class SSHCommandRejectedError(ValueError):
|
||||
"""Raised when a command violates read-only policy."""
|
||||
|
||||
|
||||
class SSHClient:
|
||||
"""Wrapper around SSH operations with read-only safeguards."""
|
||||
|
||||
_BLOCKED_TOKENS = {
|
||||
">",
|
||||
">>",
|
||||
"<",
|
||||
"|",
|
||||
"&&",
|
||||
"||",
|
||||
";",
|
||||
}
|
||||
_READ_ONLY_COMMANDS = {
|
||||
"cat",
|
||||
"dmesg",
|
||||
"df",
|
||||
"du",
|
||||
"find",
|
||||
"grep",
|
||||
"head",
|
||||
"hostnamectl",
|
||||
"ip",
|
||||
"journalctl",
|
||||
"ls",
|
||||
"netstat",
|
||||
"sed",
|
||||
"ss",
|
||||
"stat",
|
||||
"systemctl",
|
||||
"tail",
|
||||
"uname",
|
||||
}
|
||||
_READ_ONLY_SYSTEMCTL_SUBCOMMANDS = {
|
||||
"cat",
|
||||
"is-active",
|
||||
"is-failed",
|
||||
"list-unit-files",
|
||||
"list-units",
|
||||
"show",
|
||||
"status",
|
||||
}
|
||||
|
||||
def __init__(self, config: SSHConnectionConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Return a short summary of connection settings."""
|
||||
mode = "ignore ssh config" if self._config.ignore_ssh_config else "use ssh config"
|
||||
jump = self._config.jump_host or "none"
|
||||
key = str(self._config.identity_file) if self._config.identity_file else "auto"
|
||||
return (
|
||||
f"host={self._config.host} port={self._config.port} "
|
||||
f"key={key} jump={jump} mode={mode}"
|
||||
)
|
||||
|
||||
def build_ssh_argv(self, remote_command: str) -> list[str]:
|
||||
"""Build argv for a secure non-interactive SSH invocation."""
|
||||
argv = [
|
||||
"ssh",
|
||||
"-p",
|
||||
str(self._config.port),
|
||||
"-o",
|
||||
"BatchMode=yes",
|
||||
"-o",
|
||||
"ConnectTimeout=15",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=accept-new",
|
||||
]
|
||||
|
||||
if self._config.ignore_ssh_config:
|
||||
argv += ["-F", "/dev/null"]
|
||||
|
||||
if self._config.identity_file:
|
||||
argv += ["-i", str(self._config.identity_file)]
|
||||
|
||||
if self._config.jump_host:
|
||||
argv += ["-J", self._config.jump_host]
|
||||
|
||||
argv += [self._config.host, remote_command]
|
||||
return argv
|
||||
|
||||
def validate_read_only_command(self, command: str) -> None:
|
||||
"""Validate that a command appears read-only and non-destructive."""
|
||||
normalized = command.strip()
|
||||
if not normalized:
|
||||
raise SSHCommandRejectedError("Command cannot be empty.")
|
||||
|
||||
for token in self._BLOCKED_TOKENS:
|
||||
if token in normalized:
|
||||
raise SSHCommandRejectedError(
|
||||
f"Command contains blocked shell operator: {token}"
|
||||
)
|
||||
|
||||
parts = shlex.split(normalized)
|
||||
if not parts:
|
||||
raise SSHCommandRejectedError("Command cannot be empty.")
|
||||
|
||||
base = parts[0]
|
||||
if base not in self._READ_ONLY_COMMANDS:
|
||||
raise SSHCommandRejectedError(
|
||||
f"Command '{base}' is not allowed by read-only policy."
|
||||
)
|
||||
|
||||
if base == "systemctl":
|
||||
if len(parts) < 2:
|
||||
raise SSHCommandRejectedError("systemctl requires a subcommand.")
|
||||
subcommand = parts[1]
|
||||
if subcommand not in self._READ_ONLY_SYSTEMCTL_SUBCOMMANDS:
|
||||
raise SSHCommandRejectedError(
|
||||
f"systemctl subcommand '{subcommand}' is not read-only."
|
||||
)
|
||||
|
||||
async def run_read_only_command(
|
||||
self,
|
||||
command: str,
|
||||
*,
|
||||
timeout_seconds: float = 30.0,
|
||||
) -> SSHCommandResult:
|
||||
"""Run a validated read-only command over SSH."""
|
||||
self.validate_read_only_command(command)
|
||||
return await self._run_ssh(command, timeout_seconds=timeout_seconds)
|
||||
|
||||
async def probe(self) -> SSHCommandResult:
|
||||
"""Probe connectivity using a harmless remote command."""
|
||||
return await self._run_ssh("uname -a", timeout_seconds=15.0)
|
||||
|
||||
async def _run_ssh(
|
||||
self,
|
||||
command: str,
|
||||
*,
|
||||
timeout_seconds: float,
|
||||
) -> SSHCommandResult:
|
||||
argv = self.build_ssh_argv(command)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*argv,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||
proc.communicate(),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
raise TimeoutError(
|
||||
f"SSH command timed out after {timeout_seconds} seconds: {command}"
|
||||
) from exc
|
||||
|
||||
if proc.returncode is None:
|
||||
raise RuntimeError("SSH process did not provide an exit code.")
|
||||
|
||||
return SSHCommandResult(
|
||||
command=command,
|
||||
exit_code=proc.returncode,
|
||||
stdout=stdout_bytes.decode("utf-8", errors="replace").strip(),
|
||||
stderr=stderr_bytes.decode("utf-8", errors="replace").strip(),
|
||||
)
|
||||
86
tests/test_cli.py
Normal file
86
tests/test_cli.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from tai.cli import app
|
||||
from tai.ssh_client import SSHCommandResult
|
||||
|
||||
|
||||
def test_run_command_prints_scaffold_summary() -> None:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"apache failed",
|
||||
"--host",
|
||||
"web01",
|
||||
"--port",
|
||||
"5566",
|
||||
"--no-probe",
|
||||
"--path",
|
||||
"/etc/apache2",
|
||||
"--jump-host",
|
||||
"bastion01",
|
||||
"--ignore-ssh-config",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "tai scaffold ready" in result.stdout
|
||||
assert "host=web01" in result.stdout
|
||||
assert "port=5566" in result.stdout
|
||||
|
||||
|
||||
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def]
|
||||
return SSHCommandResult(
|
||||
command="uname -a",
|
||||
exit_code=0,
|
||||
stdout="Linux ssh 6.12.0",
|
||||
stderr="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"apache failed",
|
||||
"--host",
|
||||
"ssh.archflux.net",
|
||||
"--port",
|
||||
"5566",
|
||||
"--probe",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Probe succeeded" in result.stdout
|
||||
assert "Linux ssh 6.12.0" in result.stdout
|
||||
|
||||
|
||||
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def]
|
||||
return SSHCommandResult(
|
||||
command="uname -a",
|
||||
exit_code=255,
|
||||
stdout="",
|
||||
stderr="Permission denied (publickey,password).",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"apache failed",
|
||||
"--host",
|
||||
"ssh.archflux.net",
|
||||
"--port",
|
||||
"5566",
|
||||
"--probe",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Probe failed" in result.stdout
|
||||
65
tests/test_input_parser.py
Normal file
65
tests/test_input_parser.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tai.input_parser import InputValidationError, build_request
|
||||
|
||||
|
||||
def test_build_request_normalizes_values() -> None:
|
||||
req = build_request(
|
||||
issue=" apache fails to start ",
|
||||
host=" web01 ",
|
||||
port=5566,
|
||||
target_paths=["/etc/apache2", "~/logs"],
|
||||
identity_file="~/.ssh/id_ed25519",
|
||||
jump_host=" bastion01 ",
|
||||
ignore_ssh_config=True,
|
||||
)
|
||||
|
||||
assert req.issue == "apache fails to start"
|
||||
assert req.host == "web01"
|
||||
assert req.port == 5566
|
||||
assert req.target_paths[0] == Path("/etc/apache2")
|
||||
assert req.target_paths[1] == Path("~/logs").expanduser()
|
||||
assert req.identity_file == Path("~/.ssh/id_ed25519").expanduser()
|
||||
assert req.jump_host == "bastion01"
|
||||
assert req.ignore_ssh_config is True
|
||||
|
||||
|
||||
def test_build_request_rejects_empty_issue() -> None:
|
||||
with pytest.raises(InputValidationError):
|
||||
build_request(
|
||||
issue=" ",
|
||||
host="web01",
|
||||
port=22,
|
||||
target_paths=[],
|
||||
identity_file=None,
|
||||
jump_host=None,
|
||||
ignore_ssh_config=False,
|
||||
)
|
||||
|
||||
|
||||
def test_build_request_rejects_empty_host() -> None:
|
||||
with pytest.raises(InputValidationError):
|
||||
build_request(
|
||||
issue="apache down",
|
||||
host=" ",
|
||||
port=22,
|
||||
target_paths=[],
|
||||
identity_file=None,
|
||||
jump_host=None,
|
||||
ignore_ssh_config=False,
|
||||
)
|
||||
|
||||
|
||||
def test_build_request_rejects_invalid_port() -> None:
|
||||
with pytest.raises(InputValidationError):
|
||||
build_request(
|
||||
issue="apache down",
|
||||
host="web01",
|
||||
port=70000,
|
||||
target_paths=[],
|
||||
identity_file=None,
|
||||
jump_host=None,
|
||||
ignore_ssh_config=False,
|
||||
)
|
||||
93
tests/test_ssh_client.py
Normal file
93
tests/test_ssh_client.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tai.ssh_client import SSHClient, SSHCommandRejectedError, SSHConnectionConfig
|
||||
|
||||
|
||||
def _client(**kwargs: object) -> SSHClient:
|
||||
host = str(kwargs.get("host", "root@ssh.archflux.net"))
|
||||
port_value = kwargs.get("port", 22)
|
||||
if not isinstance(port_value, int):
|
||||
raise TypeError("port must be an int")
|
||||
port = port_value
|
||||
identity_file = kwargs.get("identity_file")
|
||||
jump_host = kwargs.get("jump_host")
|
||||
ignore_ssh_config = bool(kwargs.get("ignore_ssh_config", False))
|
||||
|
||||
if identity_file is not None and not isinstance(identity_file, Path):
|
||||
raise TypeError("identity_file must be a Path or None")
|
||||
|
||||
if jump_host is not None and not isinstance(jump_host, str):
|
||||
raise TypeError("jump_host must be a string or None")
|
||||
|
||||
return SSHClient(
|
||||
SSHConnectionConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
identity_file=identity_file,
|
||||
jump_host=jump_host,
|
||||
ignore_ssh_config=ignore_ssh_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_summary_includes_expected_defaults() -> None:
|
||||
client = _client()
|
||||
text = client.summary()
|
||||
|
||||
assert "host=root@ssh.archflux.net" in text
|
||||
assert "port=22" in text
|
||||
assert "key=auto" in text
|
||||
assert "jump=none" in text
|
||||
assert "mode=use ssh config" in text
|
||||
|
||||
|
||||
def test_build_ssh_argv_respects_flags() -> None:
|
||||
client = _client(
|
||||
identity_file=Path("/root/.ssh/id_ed25519"),
|
||||
jump_host="bastion.archflux.net",
|
||||
ignore_ssh_config=True,
|
||||
)
|
||||
|
||||
argv = client.build_ssh_argv("uname -a")
|
||||
|
||||
assert argv[0] == "ssh"
|
||||
assert "-p" in argv
|
||||
assert "22" in argv
|
||||
assert "-F" in argv
|
||||
assert "/dev/null" in argv
|
||||
assert "-i" in argv
|
||||
assert "/root/.ssh/id_ed25519" in argv
|
||||
assert "-J" in argv
|
||||
assert "bastion.archflux.net" in argv
|
||||
assert argv[-2] == "root@ssh.archflux.net"
|
||||
assert argv[-1] == "uname -a"
|
||||
|
||||
|
||||
def test_rejects_destructive_or_shell_operator_commands() -> None:
|
||||
client = _client()
|
||||
|
||||
for command in ["rm -rf /tmp/x", "cat /etc/hosts | grep localhost", "uname -a; id"]:
|
||||
with pytest.raises(SSHCommandRejectedError):
|
||||
client.validate_read_only_command(command)
|
||||
|
||||
|
||||
def test_allows_expected_read_only_commands() -> None:
|
||||
client = _client()
|
||||
|
||||
for command in [
|
||||
"uname -a",
|
||||
"journalctl -n 100",
|
||||
"systemctl status apache2",
|
||||
"cat /etc/hosts",
|
||||
"ss -lntp",
|
||||
]:
|
||||
client.validate_read_only_command(command)
|
||||
|
||||
|
||||
def test_rejects_non_read_only_systemctl_subcommand() -> None:
|
||||
client = _client()
|
||||
|
||||
with pytest.raises(SSHCommandRejectedError):
|
||||
client.validate_read_only_command("systemctl restart apache2")
|
||||
Reference in New Issue
Block a user