diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 08ca86d..0000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: CI - -on: - push: - pull_request: - -jobs: - test: - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - - name: Install package and dev dependencies - run: | - python -m pip install --upgrade pip - pip install -e .[dev] - - - name: Lint - run: ruff check . - - - name: Lint Markdown - run: mdformat --check README.md ROADMAP.md CHANGELOG.md - - - name: Lint YAML - run: yamllint . - - - name: Type-check - run: mypy src - - - name: Test - run: pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b57691..cfd146d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,12 +24,18 @@ ______________________________________________________________________ - Implemented SSH module with real key-based command execution via system `ssh` - Added explicit SSH port support across CLI, input parsing, request model, and SSH client (`--port`, e.g. 5566) - Added live SSH connectivity probe (`uname -a`) enabled by default, with `--no-probe` opt-out and non-zero exit on failure +- Added baseline diagnostics collection via `--collect`, including service, journal, disk, and network checks - Read-only command policy enforcement (allowlist + blocked shell operators) +- Added byte-limited SSH output capture with truncation markers for large command output - Test scaffold (`pytest`) with initial parser and CLI coverage - SSH test coverage for policy checks, SSH argument construction, and config summary behavior - CI workflow for lint (`ruff`), type-check (`mypy`), and tests (`pytest`) - CI coverage expanded with Markdown formatting checks (`mdformat --check`) and YAML linting (`yamllint`) +### Removed + +- `.github/workflows/ci.yml` — GitHub Actions workflow removed; CI is now Gitea-only + ### Decided - Implementation language: **Python** diff --git a/ROADMAP.md b/ROADMAP.md index 45dbc1d..c20e242 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -56,7 +56,7 @@ Basic project scaffolding and connectivity. - [x] Define SSH config model and probe interface scaffold - [x] Connect to remote host - [x] Execute read-only commands (e.g. `journalctl`, `systemctl status`, `cat`) - - [ ] Stream or collect command output safely + - [x] Stream or collect command output safely (byte-limited output with truncation marker) - [x] Implement basic input parsing (ticket text, hostname, target directories) - [x] Write unit tests for SSH and input modules - [x] Input parser and CLI tests added diff --git a/src/tai/cli.py b/src/tai/cli.py index 5a3e269..aa5baa3 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -8,7 +8,10 @@ from typing import Annotated import typer from rich.console import Console +from tai.collectors import CollectionReport, collect_from_plan from tai.input_parser import InputValidationError, build_request +from tai.models import TroubleshootRequest +from tai.plan import plan_from_request from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig app = typer.Typer(no_args_is_help=True, add_completion=False) @@ -46,6 +49,13 @@ def run( help="Enable or disable live SSH connectivity probe (uname -a).", ), ] = True, + collect: Annotated[ + bool, + typer.Option( + "--collect/--no-collect", + help="Collect baseline diagnostics after probe.", + ), + ] = False, ) -> None: """Start an interactive troubleshooting session scaffold.""" try: @@ -77,8 +87,13 @@ def run( if req.target_paths: console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}") + client = SSHClient(config) + if probe: - _run_probe(SSHClient(config)) + _run_probe(client) + + if collect: + _run_collection(client, req) def _run_probe(client: SSHClient) -> None: @@ -96,6 +111,35 @@ def _run_probe(client: SSHClient) -> None: _handle_probe_result(result) +def _run_collection(client: SSHClient, request: TroubleshootRequest) -> None: + """Run issue-aware collection and print a compact summary.""" + plan = plan_from_request(request) + console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") + try: + report = asyncio.run(collect_from_plan(client, plan)) + except TimeoutError as exc: + console.print(f"[red]Collection failed:[/red] {exc}") + raise typer.Exit(code=1) from exc + except OSError as exc: + console.print(f"[red]Collection failed:[/red] unable to execute ssh: {exc}") + raise typer.Exit(code=1) from exc + + _handle_collection_report(report) + + +def _handle_collection_report(report: CollectionReport) -> None: + """Render collected command status and truncation hints.""" + console.print( + f"[bold]Collection complete:[/bold] {report.total} commands, {report.failed} failed" + ) + for item in report.items: + status = "ok" if item.result.exit_code == 0 else f"exit {item.result.exit_code}" + trunc = "" + if item.result.stdout_truncated or item.result.stderr_truncated: + trunc = " (truncated)" + console.print(f"- {item.name}: {status}{trunc}") + + def _handle_probe_result(result: SSHCommandResult) -> None: """Handle and render probe output for success or failure.""" if result.exit_code != 0: diff --git a/src/tai/collectors.py b/src/tai/collectors.py new file mode 100644 index 0000000..9ad4754 --- /dev/null +++ b/src/tai/collectors.py @@ -0,0 +1,50 @@ +"""Data collection routines built on top of the SSH client.""" + +from dataclasses import dataclass + +from tai.plan import CollectionPlan +from tai.ssh_client import SSHClient, SSHCommandResult + + +@dataclass(slots=True) +class CollectedItem: + """Single collected diagnostic command result.""" + + name: str + result: SSHCommandResult + + +@dataclass(slots=True) +class CollectionReport: + """Collection summary for a batch of diagnostics.""" + + host: str + items: list[CollectedItem] + + @property + def total(self) -> int: + return len(self.items) + + @property + def failed(self) -> int: + return sum(1 for item in self.items if item.result.exit_code != 0) + + +async def collect_from_plan( + client: SSHClient, + plan: CollectionPlan, + *, + max_output_bytes: int = 32768, +) -> CollectionReport: + """Execute all commands in *plan* and return a :class:`CollectionReport`.""" + items: list[CollectedItem] = [] + + for name, command in plan.commands: + result = await client.run_read_only_command( + command, + timeout_seconds=30.0, + max_output_bytes=max_output_bytes, + ) + items.append(CollectedItem(name=name, result=result)) + + return CollectionReport(host=client.summary(), items=items) diff --git a/src/tai/plan.py b/src/tai/plan.py new file mode 100644 index 0000000..e3c76a8 --- /dev/null +++ b/src/tai/plan.py @@ -0,0 +1,244 @@ +"""Collection plan builder — decides what to collect based on the issue.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field + +from tai.models import TroubleshootRequest + +# --------------------------------------------------------------------------- +# Keyword sets for issue classification +# --------------------------------------------------------------------------- + +_SERVICE_KEYWORDS: frozenset[str] = frozenset( + { + "service", + "unit", + "daemon", + "failed", + "dead", + "inactive", + "crash", + "crashed", + "start", + "stop", + "restart", + "status", + "systemd", + "systemctl", + } +) + +_NETWORK_KEYWORDS: frozenset[str] = frozenset( + { + "network", + "port", + "connect", + "connection", + "listen", + "firewall", + "route", + "routing", + "interface", + "dns", + "http", + "https", + "tcp", + "udp", + "socket", + "unreachable", + "refused", + "timeout", + "latency", + "bandwidth", + "packet", + } +) + +_DISK_KEYWORDS: frozenset[str] = frozenset( + { + "disk", + "space", + "storage", + "inode", + "full", + "mount", + "filesystem", + "partition", + "quota", + "usage", + "capacity", + } +) + +# --------------------------------------------------------------------------- +# Known service names and their candidate config paths +# --------------------------------------------------------------------------- + +_KNOWN_SERVICES: list[str] = [ + "apache2", + "httpd", + "nginx", + "mysql", + "mysqld", + "mariadb", + "postgresql", + "redis", + "redis-server", + "mongodb", + "mongod", + "docker", + "containerd", + "kubelet", + "sshd", + "postfix", + "dovecot", + "sendmail", + "php-fpm", + "elasticsearch", + "rabbitmq", + "rabbitmq-server", + "celery", + "gunicorn", + "ufw", + "fail2ban", + "cron", + "crond", + "rsyslog", + "auditd", + "firewalld", + "haproxy", + "varnish", + "memcached", +] + +_SERVICE_CONFIGS: dict[str, list[str]] = { + "apache2": ["/etc/apache2/apache2.conf"], + "httpd": ["/etc/httpd/conf/httpd.conf"], + "nginx": ["/etc/nginx/nginx.conf"], + "mysql": ["/etc/mysql/mysql.conf.d/mysqld.cnf"], + "mysqld": ["/etc/my.cnf"], + "mariadb": ["/etc/mysql/mariadb.conf.d/50-server.cnf"], + "postgresql": ["/etc/postgresql"], + "sshd": ["/etc/ssh/sshd_config"], + "postfix": ["/etc/postfix/main.cf"], + "haproxy": ["/etc/haproxy/haproxy.cfg"], + "redis": ["/etc/redis/redis.conf"], + "redis-server": ["/etc/redis/redis.conf"], + "fail2ban": ["/etc/fail2ban/jail.conf"], + "ufw": ["/etc/ufw/ufw.conf"], +} + +# --------------------------------------------------------------------------- +# Command sets +# --------------------------------------------------------------------------- + +_ALWAYS: list[tuple[str, str]] = [ + ("kernel", "uname -a"), + ("uptime", "cat /proc/uptime"), + ("disk-usage", "df -h"), + ("memory", "cat /proc/meminfo"), + ("running-services", "systemctl list-units --type=service --state=running --no-pager"), +] + +_SERVICE_EXTRA: list[tuple[str, str]] = [ + ("failed-services", "systemctl list-units --type=service --state=failed --no-pager"), + ("journal-errors", "journalctl -p err -n 100 --no-pager"), +] + +_NETWORK_EXTRA: list[tuple[str, str]] = [ + ("listening-ports", "ss -lntp"), + ("ip-addresses", "ip addr show"), + ("ip-routes", "ip route show"), + ("ip-stats", "ip -s link show"), +] + +_DISK_EXTRA: list[tuple[str, str]] = [ + ("disk-inodes", "df -i"), + ("dmesg-disk", "dmesg -T --level=err,warn"), + ("large-dirs", "du -sh /var /tmp /home /opt"), +] + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +@dataclass(slots=True) +class CollectionPlan: + """Ordered list of (name, command) pairs to execute on a remote host.""" + + commands: list[tuple[str, str]] = field(default_factory=list) + + def add(self, name: str, command: str) -> None: + self.commands.append((name, command)) + + def __len__(self) -> int: + return len(self.commands) + + +def plan_from_request(request: TroubleshootRequest) -> CollectionPlan: + """Build a :class:`CollectionPlan` tailored to *request*.""" + plan = CollectionPlan(commands=list(_ALWAYS)) + keywords = _issue_words(request.issue) + + # --- category expansions ------------------------------------------- + if keywords & _SERVICE_KEYWORDS: + plan.commands.extend(_SERVICE_EXTRA) + + if keywords & _NETWORK_KEYWORDS: + plan.commands.extend(_NETWORK_EXTRA) + + if keywords & _DISK_KEYWORDS: + plan.commands.extend(_DISK_EXTRA) + + # --- named service detection --------------------------------------- + services = _extract_services(request.issue) + seen: set[str] = set() + for svc in services: + if svc in seen: + continue + seen.add(svc) + plan.add(f"service-{svc}", f"systemctl status {svc}") + plan.add(f"journal-{svc}", f"journalctl -u {svc} -n 100 --no-pager") + for cfg_path in _SERVICE_CONFIGS.get(svc, []): + plan.add(f"config-{svc}", f"cat {cfg_path}") + + # --- user-specified paths ----------------------------------------- + for path in request.target_paths: + plan.add(f"ls-{path.name}", f"ls -la {path}") + if "log" in str(path).lower(): + plan.add( + f"find-logs-{path.name}", + f"find {path} -maxdepth 2 -type f -name '*.log'", + ) + else: + plan.add( + f"find-files-{path.name}", + f"find {path} -maxdepth 2 -type f", + ) + + return plan + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _issue_words(issue: str) -> set[str]: + """Return the set of lowercase words in *issue*.""" + return set(re.findall(r"\b\w+\b", issue.lower())) + + +def _extract_services(issue: str) -> list[str]: + """Return known service names mentioned in *issue*.""" + words = _issue_words(issue) + found: list[str] = [] + for svc in _KNOWN_SERVICES: + # Match the service name or its stem (strip trailing 'd', e.g. 'apache' → 'apache2') + svc_words = {svc, svc.rstrip("d"), svc.replace("-", ""), svc.replace("-server", "")} + if words & svc_words: + found.append(svc) + return found diff --git a/src/tai/ssh_client.py b/src/tai/ssh_client.py index c690b6e..4f7164a 100644 --- a/src/tai/ssh_client.py +++ b/src/tai/ssh_client.py @@ -25,6 +25,8 @@ class SSHCommandResult: exit_code: int stdout: str stderr: str + stdout_truncated: bool = False + stderr_truncated: bool = False class SSHCommandRejectedError(ValueError): @@ -148,20 +150,30 @@ class SSHClient: command: str, *, timeout_seconds: float = 30.0, + max_output_bytes: int = 32768, ) -> SSHCommandResult: """Run a validated read-only command over SSH.""" self.validate_read_only_command(command) - return await self._run_ssh(command, timeout_seconds=timeout_seconds) + return await self._run_ssh( + command, + timeout_seconds=timeout_seconds, + max_output_bytes=max_output_bytes, + ) async def probe(self) -> SSHCommandResult: """Probe connectivity using a harmless remote command.""" - return await self._run_ssh("uname -a", timeout_seconds=15.0) + return await self._run_ssh( + "uname -a", + timeout_seconds=15.0, + max_output_bytes=4096, + ) async def _run_ssh( self, command: str, *, timeout_seconds: float, + max_output_bytes: int, ) -> SSHCommandResult: argv = self.build_ssh_argv(command) proc = await asyncio.create_subprocess_exec( @@ -185,9 +197,37 @@ class SSHClient: if proc.returncode is None: raise RuntimeError("SSH process did not provide an exit code.") + stdout_text, stdout_truncated = self._truncate_output( + stdout_bytes.decode("utf-8", errors="replace"), + max_output_bytes=max_output_bytes, + ) + stderr_text, stderr_truncated = self._truncate_output( + stderr_bytes.decode("utf-8", errors="replace"), + max_output_bytes=max_output_bytes, + ) + return SSHCommandResult( command=command, exit_code=proc.returncode, - stdout=stdout_bytes.decode("utf-8", errors="replace").strip(), - stderr=stderr_bytes.decode("utf-8", errors="replace").strip(), + stdout=stdout_text, + stderr=stderr_text, + stdout_truncated=stdout_truncated, + stderr_truncated=stderr_truncated, ) + + @staticmethod + def _truncate_output(text: str, *, max_output_bytes: int) -> tuple[str, bool]: + """Trim output to a maximum byte length while preserving UTF-8 validity.""" + if max_output_bytes < 256: + raise ValueError("max_output_bytes must be at least 256.") + + encoded = text.encode("utf-8", errors="replace") + if len(encoded) <= max_output_bytes: + return text.strip(), False + + marker = "\n...[truncated]" + marker_bytes = marker.encode("utf-8") + keep = max_output_bytes - len(marker_bytes) + trimmed_bytes = encoded[:keep] + trimmed_text = trimmed_bytes.decode("utf-8", errors="ignore").rstrip() + return f"{trimmed_text}{marker}", True diff --git a/tests/test_cli.py b/tests/test_cli.py index 68f013b..9bac274 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,7 @@ from typer.testing import CliRunner from tai.cli import app +from tai.collectors import CollectedItem, CollectionReport from tai.ssh_client import SSHCommandResult @@ -84,3 +85,52 @@ def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no assert result.exit_code == 1 assert "Probe failed" in result.stdout + + +def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def] + async def fake_collect_from_plan(_client, _plan) -> CollectionReport: # type: ignore[no-untyped-def] + return CollectionReport( + host="ssh.archflux.net", + items=[ + CollectedItem( + name="kernel", + result=SSHCommandResult( + command="uname -a", + exit_code=0, + stdout="Linux test", + stderr="", + ), + ), + CollectedItem( + name="journal", + result=SSHCommandResult( + command="journalctl -n 200", + exit_code=0, + stdout="...", + stderr="", + stdout_truncated=True, + ), + ), + ], + ) + + monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "apache failed", + "--host", + "ssh.archflux.net", + "--port", + "5566", + "--no-probe", + "--collect", + ], + ) + + assert result.exit_code == 0 + assert "Collection complete" in result.stdout + assert "kernel: ok" in result.stdout + assert "journal: ok (truncated)" in result.stdout diff --git a/tests/test_plan.py b/tests/test_plan.py new file mode 100644 index 0000000..94016e2 --- /dev/null +++ b/tests/test_plan.py @@ -0,0 +1,169 @@ +"""Tests for the collection plan builder.""" + +from pathlib import Path + +from tai.models import TroubleshootRequest +from tai.plan import CollectionPlan, _extract_services, _issue_words, plan_from_request + + +def _req(issue: str, paths: list[str] | None = None) -> TroubleshootRequest: + return TroubleshootRequest( + issue=issue, + host="root@testhost", + target_paths=[Path(p) for p in (paths or [])], + ) + + +def _commands(plan: CollectionPlan) -> list[str]: + """Return flat list of command strings from plan.""" + return [cmd for _, cmd in plan.commands] + + +def _names(plan: CollectionPlan) -> list[str]: + return [name for name, _ in plan.commands] + + +# --------------------------------------------------------------------------- +# Always-present commands +# --------------------------------------------------------------------------- + + +def test_plan_always_has_baseline_commands() -> None: + plan = plan_from_request(_req("some generic issue")) + cmds = _commands(plan) + assert any("uname -a" in c for c in cmds) + assert any("df -h" in c for c in cmds) + assert any("proc/meminfo" in c for c in cmds) + assert any("systemctl list-units" in c for c in cmds) + + +# --------------------------------------------------------------------------- +# Keyword-based category expansion +# --------------------------------------------------------------------------- + + +def test_service_keywords_add_failed_services_check() -> None: + plan = plan_from_request(_req("service failed to start")) + cmds = _commands(plan) + assert any("--state=failed" in c for c in cmds) + assert any("journalctl -p err" in c for c in cmds) + + +def test_network_keywords_add_network_commands() -> None: + plan = plan_from_request(_req("connection refused on port 80")) + cmds = _commands(plan) + assert any("ss -lntp" in c for c in cmds) + assert any("ip addr show" in c for c in cmds) + assert any("ip route show" in c for c in cmds) + + +def test_disk_keywords_add_disk_commands() -> None: + plan = plan_from_request(_req("disk full filesystem usage critical")) + cmds = _commands(plan) + assert any("df -i" in c for c in cmds) + assert any("dmesg" in c for c in cmds) + assert any("du -sh" in c for c in cmds) + + +def test_unrelated_issue_does_not_add_network_commands() -> None: + plan = plan_from_request(_req("apache service crashed")) + cmds = _commands(plan) + assert not any("ip route show" in c for c in cmds) + + +# --------------------------------------------------------------------------- +# Named service detection +# --------------------------------------------------------------------------- + + +def test_nginx_in_issue_adds_nginx_service_commands() -> None: + plan = plan_from_request(_req("nginx is failing to start")) + names = _names(plan) + cmds = _commands(plan) + assert "service-nginx" in names + assert "journal-nginx" in names + assert any("systemctl status nginx" in c for c in cmds) + assert any("journalctl -u nginx" in c for c in cmds) + + +def test_apache2_adds_config_cat() -> None: + plan = plan_from_request(_req("apache2 service check")) + cmds = _commands(plan) + assert any("cat /etc/apache2/apache2.conf" in c for c in cmds) + + +def test_sshd_adds_config_cat() -> None: + plan = plan_from_request(_req("sshd connection problems")) + cmds = _commands(plan) + assert any("cat /etc/ssh/sshd_config" in c for c in cmds) + + +def test_unknown_service_name_no_config_cat() -> None: + plan = plan_from_request(_req("myweirdapp service crashed")) + cmds = _commands(plan) + assert not any("cat /etc" in c for c in cmds) + + +def test_duplicate_service_name_not_repeated() -> None: + plan = plan_from_request(_req("nginx nginx nginx")) + names = _names(plan) + assert names.count("service-nginx") == 1 + + +# --------------------------------------------------------------------------- +# Target path handling +# --------------------------------------------------------------------------- + + +def test_target_path_adds_ls_and_find() -> None: + plan = plan_from_request(_req("app crash", paths=["/opt/myapp"])) + cmds = _commands(plan) + assert any("ls -la /opt/myapp" in c for c in cmds) + assert any("find /opt/myapp" in c for c in cmds) + + +def test_log_path_uses_log_find_pattern() -> None: + plan = plan_from_request(_req("app errors", paths=["/var/log/myapp"])) + cmds = _commands(plan) + assert any("*.log" in c for c in cmds) + + +def test_non_log_path_uses_generic_find() -> None: + plan = plan_from_request(_req("config issue", paths=["/etc/myapp"])) + cmds = _commands(plan) + assert any("find /etc/myapp" in c and "*.log" not in c for c in cmds) + + +# --------------------------------------------------------------------------- +# Helper unit tests +# --------------------------------------------------------------------------- + + +def test_issue_words_lowercases_and_splits() -> None: + words = _issue_words("Apache Service FAILED") + assert "apache" in words + assert "service" in words + assert "failed" in words + + +def test_extract_services_finds_nginx() -> None: + assert "nginx" in _extract_services("nginx is down") + + +def test_extract_services_finds_nothing_for_unknown() -> None: + assert _extract_services("the widget is broken") == [] + + +def test_extract_services_case_insensitive() -> None: + assert "nginx" in _extract_services("NGINX failed") + + +# --------------------------------------------------------------------------- +# Plan length sanity +# --------------------------------------------------------------------------- + + +def test_plain_issue_has_only_always_commands() -> None: + plan = plan_from_request(_req("something went wrong")) + # Only _ALWAYS (5 commands), no category expansion, no service, no paths + assert len(plan) == 5 diff --git a/tests/test_ssh_client.py b/tests/test_ssh_client.py index 37425ee..fcad417 100644 --- a/tests/test_ssh_client.py +++ b/tests/test_ssh_client.py @@ -91,3 +91,18 @@ def test_rejects_non_read_only_systemctl_subcommand() -> None: with pytest.raises(SSHCommandRejectedError): client.validate_read_only_command("systemctl restart apache2") + + +def test_truncate_output_marks_and_limits_content() -> None: + text = "a" * 400 + rendered, truncated = SSHClient._truncate_output(text, max_output_bytes=256) + + assert truncated is True + assert rendered.endswith("...[truncated]") + + +def test_truncate_output_keeps_short_content() -> None: + rendered, truncated = SSHClient._truncate_output("short output", max_output_bytes=256) + + assert truncated is False + assert rendered == "short output"