update
All checks were successful
CI / test (push) Successful in 15s

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-05-04 04:22:58 +02:00
parent 65c74dde5a
commit e589240c67
10 changed files with 624 additions and 44 deletions

View File

@@ -1,6 +1,7 @@
from typer.testing import CliRunner
from tai.cli import app
from tai.collectors import CollectedItem, CollectionReport
from tai.ssh_client import SSHCommandResult
@@ -84,3 +85,52 @@ def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no
assert result.exit_code == 1
assert "Probe failed" in result.stdout
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
async def fake_collect_from_plan(_client, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
CollectedItem(
name="journal",
result=SSHCommandResult(
command="journalctl -n 200",
exit_code=0,
stdout="...",
stderr="",
stdout_truncated=True,
),
),
],
)
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
runner = CliRunner()
result = runner.invoke(
app,
[
"apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--collect",
],
)
assert result.exit_code == 0
assert "Collection complete" in result.stdout
assert "kernel: ok" in result.stdout
assert "journal: ok (truncated)" in result.stdout

169
tests/test_plan.py Normal file
View File

@@ -0,0 +1,169 @@
"""Tests for the collection plan builder."""
from pathlib import Path
from tai.models import TroubleshootRequest
from tai.plan import CollectionPlan, _extract_services, _issue_words, plan_from_request
def _req(issue: str, paths: list[str] | None = None) -> TroubleshootRequest:
return TroubleshootRequest(
issue=issue,
host="root@testhost",
target_paths=[Path(p) for p in (paths or [])],
)
def _commands(plan: CollectionPlan) -> list[str]:
"""Return flat list of command strings from plan."""
return [cmd for _, cmd in plan.commands]
def _names(plan: CollectionPlan) -> list[str]:
return [name for name, _ in plan.commands]
# ---------------------------------------------------------------------------
# Always-present commands
# ---------------------------------------------------------------------------
def test_plan_always_has_baseline_commands() -> None:
plan = plan_from_request(_req("some generic issue"))
cmds = _commands(plan)
assert any("uname -a" in c for c in cmds)
assert any("df -h" in c for c in cmds)
assert any("proc/meminfo" in c for c in cmds)
assert any("systemctl list-units" in c for c in cmds)
# ---------------------------------------------------------------------------
# Keyword-based category expansion
# ---------------------------------------------------------------------------
def test_service_keywords_add_failed_services_check() -> None:
plan = plan_from_request(_req("service failed to start"))
cmds = _commands(plan)
assert any("--state=failed" in c for c in cmds)
assert any("journalctl -p err" in c for c in cmds)
def test_network_keywords_add_network_commands() -> None:
plan = plan_from_request(_req("connection refused on port 80"))
cmds = _commands(plan)
assert any("ss -lntp" in c for c in cmds)
assert any("ip addr show" in c for c in cmds)
assert any("ip route show" in c for c in cmds)
def test_disk_keywords_add_disk_commands() -> None:
plan = plan_from_request(_req("disk full filesystem usage critical"))
cmds = _commands(plan)
assert any("df -i" in c for c in cmds)
assert any("dmesg" in c for c in cmds)
assert any("du -sh" in c for c in cmds)
def test_unrelated_issue_does_not_add_network_commands() -> None:
plan = plan_from_request(_req("apache service crashed"))
cmds = _commands(plan)
assert not any("ip route show" in c for c in cmds)
# ---------------------------------------------------------------------------
# Named service detection
# ---------------------------------------------------------------------------
def test_nginx_in_issue_adds_nginx_service_commands() -> None:
plan = plan_from_request(_req("nginx is failing to start"))
names = _names(plan)
cmds = _commands(plan)
assert "service-nginx" in names
assert "journal-nginx" in names
assert any("systemctl status nginx" in c for c in cmds)
assert any("journalctl -u nginx" in c for c in cmds)
def test_apache2_adds_config_cat() -> None:
plan = plan_from_request(_req("apache2 service check"))
cmds = _commands(plan)
assert any("cat /etc/apache2/apache2.conf" in c for c in cmds)
def test_sshd_adds_config_cat() -> None:
plan = plan_from_request(_req("sshd connection problems"))
cmds = _commands(plan)
assert any("cat /etc/ssh/sshd_config" in c for c in cmds)
def test_unknown_service_name_no_config_cat() -> None:
plan = plan_from_request(_req("myweirdapp service crashed"))
cmds = _commands(plan)
assert not any("cat /etc" in c for c in cmds)
def test_duplicate_service_name_not_repeated() -> None:
plan = plan_from_request(_req("nginx nginx nginx"))
names = _names(plan)
assert names.count("service-nginx") == 1
# ---------------------------------------------------------------------------
# Target path handling
# ---------------------------------------------------------------------------
def test_target_path_adds_ls_and_find() -> None:
plan = plan_from_request(_req("app crash", paths=["/opt/myapp"]))
cmds = _commands(plan)
assert any("ls -la /opt/myapp" in c for c in cmds)
assert any("find /opt/myapp" in c for c in cmds)
def test_log_path_uses_log_find_pattern() -> None:
plan = plan_from_request(_req("app errors", paths=["/var/log/myapp"]))
cmds = _commands(plan)
assert any("*.log" in c for c in cmds)
def test_non_log_path_uses_generic_find() -> None:
plan = plan_from_request(_req("config issue", paths=["/etc/myapp"]))
cmds = _commands(plan)
assert any("find /etc/myapp" in c and "*.log" not in c for c in cmds)
# ---------------------------------------------------------------------------
# Helper unit tests
# ---------------------------------------------------------------------------
def test_issue_words_lowercases_and_splits() -> None:
words = _issue_words("Apache Service FAILED")
assert "apache" in words
assert "service" in words
assert "failed" in words
def test_extract_services_finds_nginx() -> None:
assert "nginx" in _extract_services("nginx is down")
def test_extract_services_finds_nothing_for_unknown() -> None:
assert _extract_services("the widget is broken") == []
def test_extract_services_case_insensitive() -> None:
assert "nginx" in _extract_services("NGINX failed")
# ---------------------------------------------------------------------------
# Plan length sanity
# ---------------------------------------------------------------------------
def test_plain_issue_has_only_always_commands() -> None:
plan = plan_from_request(_req("something went wrong"))
# Only _ALWAYS (5 commands), no category expansion, no service, no paths
assert len(plan) == 5

View File

@@ -91,3 +91,18 @@ def test_rejects_non_read_only_systemctl_subcommand() -> None:
with pytest.raises(SSHCommandRejectedError):
client.validate_read_only_command("systemctl restart apache2")
def test_truncate_output_marks_and_limits_content() -> None:
text = "a" * 400
rendered, truncated = SSHClient._truncate_output(text, max_output_bytes=256)
assert truncated is True
assert rendered.endswith("...[truncated]")
def test_truncate_output_keeps_short_content() -> None:
rendered, truncated = SSHClient._truncate_output("short output", max_output_bytes=256)
assert truncated is False
assert rendered == "short output"