"""Tests for the collection plan builder.""" from pathlib import Path from tai.models import TroubleshootRequest from tai.plan import CollectionPlan, _extract_services, _issue_words, plan_from_request def _req(issue: str, paths: list[str] | None = None) -> TroubleshootRequest: return TroubleshootRequest( issue=issue, host="root@testhost", target_paths=[Path(p) for p in (paths or [])], ) def _commands(plan: CollectionPlan) -> list[str]: """Return flat list of command strings from plan.""" return [cmd for _, cmd in plan.commands] def _names(plan: CollectionPlan) -> list[str]: return [name for name, _ in plan.commands] # --------------------------------------------------------------------------- # Always-present commands # --------------------------------------------------------------------------- def test_plan_always_has_baseline_commands() -> None: plan = plan_from_request(_req("some generic issue")) cmds = _commands(plan) assert any("uname -a" in c for c in cmds) assert any("df -h" in c for c in cmds) assert any("proc/meminfo" in c for c in cmds) assert any("systemctl list-units" in c for c in cmds) # --------------------------------------------------------------------------- # Keyword-based category expansion # --------------------------------------------------------------------------- def test_service_keywords_add_failed_services_check() -> None: plan = plan_from_request(_req("service failed to start")) cmds = _commands(plan) assert any("--state=failed" in c for c in cmds) assert any("journalctl -p err" in c for c in cmds) def test_network_keywords_add_network_commands() -> None: plan = plan_from_request(_req("connection refused on port 80")) cmds = _commands(plan) assert any("ss -lntp" in c for c in cmds) assert any("ip addr show" in c for c in cmds) assert any("ip route show" in c for c in cmds) def test_disk_keywords_add_disk_commands() -> None: plan = plan_from_request(_req("disk full filesystem usage critical")) cmds = _commands(plan) assert any("df -i" in c for c in cmds) assert any("dmesg" in c for c in cmds) assert any("du -sh" in c for c in cmds) def test_unrelated_issue_does_not_add_network_commands() -> None: plan = plan_from_request(_req("apache service crashed")) cmds = _commands(plan) assert not any("ip route show" in c for c in cmds) # --------------------------------------------------------------------------- # Named service detection # --------------------------------------------------------------------------- def test_nginx_in_issue_adds_nginx_service_commands() -> None: plan = plan_from_request(_req("nginx is failing to start")) names = _names(plan) cmds = _commands(plan) assert "service-nginx" in names assert "journal-nginx" in names assert any("systemctl status nginx" in c for c in cmds) assert any("journalctl -u nginx" in c for c in cmds) def test_apache2_adds_config_cat() -> None: plan = plan_from_request(_req("apache2 service check")) cmds = _commands(plan) assert any("cat /etc/apache2/apache2.conf" in c for c in cmds) def test_sshd_adds_config_cat() -> None: plan = plan_from_request(_req("sshd connection problems")) cmds = _commands(plan) assert any("cat /etc/ssh/sshd_config" in c for c in cmds) def test_unknown_service_name_no_config_cat() -> None: plan = plan_from_request(_req("myweirdapp service crashed")) cmds = _commands(plan) assert not any("cat /etc" in c for c in cmds) def test_duplicate_service_name_not_repeated() -> None: plan = plan_from_request(_req("nginx nginx nginx")) names = _names(plan) assert names.count("service-nginx") == 1 # --------------------------------------------------------------------------- # Target path handling # --------------------------------------------------------------------------- def test_target_path_adds_ls_and_find() -> None: plan = plan_from_request(_req("app crash", paths=["/opt/myapp"])) cmds = _commands(plan) assert any("ls -la /opt/myapp" in c for c in cmds) assert any("find /opt/myapp" in c for c in cmds) def test_log_path_uses_log_find_pattern() -> None: plan = plan_from_request(_req("app errors", paths=["/var/log/myapp"])) cmds = _commands(plan) assert any("*.log" in c for c in cmds) def test_non_log_path_uses_generic_find() -> None: plan = plan_from_request(_req("config issue", paths=["/etc/myapp"])) cmds = _commands(plan) assert any("find /etc/myapp" in c and "*.log" not in c for c in cmds) # --------------------------------------------------------------------------- # Helper unit tests # --------------------------------------------------------------------------- def test_issue_words_lowercases_and_splits() -> None: words = _issue_words("Apache Service FAILED") assert "apache" in words assert "service" in words assert "failed" in words def test_extract_services_finds_nginx() -> None: assert "nginx" in _extract_services("nginx is down") def test_extract_services_finds_nothing_for_unknown() -> None: assert _extract_services("the widget is broken") == [] def test_extract_services_case_insensitive() -> None: assert "nginx" in _extract_services("NGINX failed") # --------------------------------------------------------------------------- # Plan length sanity # --------------------------------------------------------------------------- def test_plain_issue_has_only_always_commands() -> None: plan = plan_from_request(_req("something went wrong")) # Only _ALWAYS (5 commands), no category expansion, no service, no paths assert len(plan) == 5