Files
tai/tests/test_cli.py
zphinx 3be14f8f6f
All checks were successful
CI / test (push) Successful in 27s
commit all of this
2026-05-14 20:00:38 +02:00

860 lines
27 KiB
Python

import json
import os
import re
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from typer.testing import CliRunner
from tai.cli import (
_download_markdown_url,
_inject_url_credentials,
_load_env_file,
_materialize_runbook_add_path,
_materialize_runbooks_sync_path,
_resolve_secret,
app,
)
from tai.collectors import CollectedItem, CollectionReport
from tai.rag_retriever import Chunk, EmbeddedChunk
from tai.ssh_client import SSHCommandResult
def _mock_session(
monkeypatch, # type: ignore[no-untyped-def]
*,
probe_result: SSHCommandResult | None = None,
probe_raises: Exception | None = None,
) -> MagicMock:
"""Patch SSHClient.connect to return a mock session."""
session = MagicMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)
if probe_raises:
session.probe = AsyncMock(side_effect=probe_raises)
else:
session.probe = AsyncMock(return_value=probe_result)
monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session)
return session
def test_run_command_prints_scaffold_summary() -> None:
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"web01",
"--port",
"5566",
"--no-probe",
"--path",
"/etc/apache2",
"--jump-host",
"bastion01",
"--ignore-ssh-config",
],
)
assert result.exit_code == 0
assert "tai" in result.stdout
assert "host=web01" in result.stdout
assert "port=5566" in result.stdout
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(
monkeypatch,
probe_result=SSHCommandResult(
command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr=""
),
)
runner = CliRunner()
result = runner.invoke(
app,
["run", "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
)
assert result.exit_code == 0
assert "Probe succeeded" in result.stdout
assert "Linux ssh 6.12.0" in result.stdout
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(
monkeypatch,
probe_result=SSHCommandResult(
command="uname -a",
exit_code=255,
stdout="",
stderr="Permission denied (publickey,password).",
),
)
runner = CliRunner()
result = runner.invoke(
app,
["run", "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
)
assert result.exit_code == 1
assert "Probe failed" in result.stdout
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
CollectedItem(
name="journal",
result=SSHCommandResult(
command="journalctl -n 200",
exit_code=0,
stdout="...",
stderr="",
stdout_truncated=True,
),
),
],
)
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--collect",
],
)
assert result.exit_code == 0
assert "Collection complete" in result.stdout
assert "kernel" in result.stdout
assert "journal" in result.stdout
assert "truncated" in result.stdout
def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["/collect", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "ask questions directly" in result.stdout.lower()
assert "collection complete" in result.stdout.lower()
assert "Bye." in result.stdout
def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: SimpleNamespace(content="Check logs."),
)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "AI Response" in result.stdout
assert "Check logs." in result.stdout
def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0))
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: SimpleNamespace(content="Check logs."),
)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "RAG unavailable (indexing failed)" in result.stdout
assert "AI Response" in result.stdout
def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli._try_embed_report",
lambda *_args: (
[EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])],
None,
1.0,
),
)
monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0])
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: SimpleNamespace(content="Check logs."),
)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
"--rag-debug",
],
)
assert result.exit_code == 0
assert "RAG retrieve:" in result.stdout
def test_history_command_lists_sessions(monkeypatch) -> None: # type: ignore[no-untyped-def]
class FakeStore:
def __init__(self, _path: str, **_kwargs) -> None:
pass
def list_recent(self, *, host: str | None = None, limit: int = 20):
del limit
if host == "web01":
return [
SimpleNamespace(
session_id="20260507T120000Z",
host="web01",
issue="nginx down",
summary="Root cause: bad config",
)
]
return []
monkeypatch.setattr("tai.cli.RunHistoryStore", FakeStore)
runner = CliRunner()
result = runner.invoke(
app,
["history", "--history-db", "~/.tai/history.db", "--host", "web01"],
)
assert result.exit_code == 0
assert "session(s)" in result.stdout
assert "20260507T120000Z" in result.stdout
def test_history_command_exports_markdown(monkeypatch, tmp_path: Path) -> None: # type: ignore[no-untyped-def]
class FakeStore:
def __init__(self, _path: str, **_kwargs) -> None:
pass
def list_recent(self, *, host: str | None = None, limit: int = 20):
del host, limit
return [
SimpleNamespace(
session_id="20260507T120000Z",
host="web01",
issue="nginx down",
summary="Root cause: bad config",
)
]
monkeypatch.setattr("tai.cli.RunHistoryStore", FakeStore)
export_path = tmp_path / "history.md"
runner = CliRunner()
result = runner.invoke(
app,
["history", "--history-db", "~/.tai/history.db", "--export", str(export_path)],
)
assert result.exit_code == 0
assert "Exported" in result.stdout
text = export_path.read_text(encoding="utf-8")
assert "# tai session history" in text
assert "nginx down" in text
def test_interactive_history_without_store_shows_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["/history", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
"--no-history",
],
)
assert result.exit_code == 0
assert "History DB is disabled" in result.stdout
def test_run_analyze_writes_output_file(monkeypatch, tmp_path: Path) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
response = SimpleNamespace(content="Root Cause\n\nEvidence\n\nRecommended Actions")
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: response,
)
output_path = tmp_path / "analysis.md"
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--analyze",
"--output-file",
str(output_path),
],
)
assert result.exit_code == 0
assert "Wrote analysis output" in result.stdout
assert output_path.exists()
assert "Root Cause" in output_path.read_text(encoding="utf-8")
def test_run_analyze_writes_json_output_and_strips_ansi(monkeypatch, tmp_path: Path) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: SimpleNamespace(
content="\x1b[31mRoot Cause\x1b[0m\n\nEvidence\n\nRecommended Actions"
),
)
output_path = tmp_path / "analysis.json"
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--analyze",
"--output-file",
str(output_path),
"--output-format",
"json",
],
)
assert result.exit_code == 0
payload = json.loads(output_path.read_text(encoding="utf-8"))
assert payload["schema"] == "tai.analysis.v1"
assert "generated_at" in payload
assert payload["issue"] == "apache failed"
assert payload["host"] == "ssh.archflux.net"
assert payload["collection"] == {"total": 1, "failed": 0, "succeeded": 1}
assert payload["token_usage"] == {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
}
assert "Root Cause" in payload["analysis"]
assert "\u001b" not in payload["analysis"]
def test_run_analyze_writes_history_db_record(monkeypatch, tmp_path: Path) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
response = SimpleNamespace(content="Root Cause\n\nEvidence\n\nRecommended Actions")
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: response,
)
history_db = tmp_path / "history.db"
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--analyze",
"--history-db",
str(history_db),
],
)
assert result.exit_code == 0
import sqlite3
with sqlite3.connect(str(history_db)) as conn:
row = conn.execute(
"SELECT host, issue, payload_json FROM run_history ORDER BY id DESC LIMIT 1"
).fetchone()
assert row is not None
assert row[0] == "ssh.archflux.net"
assert row[1] == "apache failed"
payload = json.loads(row[2])
assert payload["schema"] == "tai.analysis.v1"
assert payload["host"] == "ssh.archflux.net"
def test_materialize_runbooks_sync_path_http_webroot(monkeypatch, tmp_path: Path) -> None: # type: ignore[no-untyped-def]
html = '<html><body><a href="nginx.md">nginx</a><a href="ssh.md">ssh</a></body></html>'
def fake_download(url: str) -> str:
if url == "https://kb.example/runbooks/":
return html
if url.endswith("nginx.md"):
return "---\nservice: nginx\n---\nbody"
if url.endswith("ssh.md"):
return "---\nservice: ssh\n---\nbody"
raise AssertionError(url)
monkeypatch.setattr("tai.cli._download_text_url", fake_download)
source_dir, label, temp_dir = _materialize_runbooks_sync_path(
"https://kb.example/runbooks/",
identity_file=None,
jump_host=None,
ignore_ssh_config=False,
)
assert label == "https://kb.example/runbooks/"
assert temp_dir is not None
assert (source_dir / "nginx.md").is_file()
assert (source_dir / "ssh.md").is_file()
def test_materialize_runbook_add_path_http_url(monkeypatch) -> None: # type: ignore[no-untyped-def]
monkeypatch.setattr(
"tai.cli._download_markdown_url",
lambda _url: "---\nservice: nginx\n---\nbody",
)
source_file, label, temp_dir = _materialize_runbook_add_path(
"https://kb.example/runbooks/nginx.md",
identity_file=None,
jump_host=None,
ignore_ssh_config=False,
)
assert label == "https://kb.example/runbooks/nginx.md"
assert temp_dir is not None
assert source_file.name == "nginx.md"
assert source_file.read_text(encoding="utf-8").startswith("---")
def test_download_markdown_url_rejects_html(monkeypatch) -> None: # type: ignore[no-untyped-def]
monkeypatch.setattr(
"tai.cli._download_text_url",
lambda _url: "<!DOCTYPE html><html><body>not markdown</body></html>",
)
with pytest.raises(ValueError, match="does not appear to be a Markdown payload"):
_download_markdown_url("https://kb.example/runbooks/nginx.md")
def test_materialize_runbooks_sync_path_http_skips_html_wrappers(monkeypatch) -> None: # type: ignore[no-untyped-def]
html = '<html><body><a href="nginx.md">nginx</a><a href="ssh.md">ssh</a></body></html>'
def fake_download(url: str) -> str:
if url == "https://kb.example/runbooks/":
return html
if url.endswith("nginx.md"):
return "---\nservice: nginx\n---\nbody"
if url.endswith("ssh.md"):
return "<!DOCTYPE html><html><body>wrapper</body></html>"
raise AssertionError(url)
monkeypatch.setattr("tai.cli._download_text_url", fake_download)
source_dir, _label, temp_dir = _materialize_runbooks_sync_path(
"https://kb.example/runbooks/",
identity_file=None,
jump_host=None,
ignore_ssh_config=False,
)
assert temp_dir is not None
assert (source_dir / "nginx.md").is_file()
assert not (source_dir / "ssh.md").exists()
def test_materialize_runbook_add_path_http_requires_md_suffix() -> None:
with pytest.raises(ValueError, match="must point to a .md file"):
_materialize_runbook_add_path(
"https://kb.example/runbooks/",
identity_file=None,
jump_host=None,
ignore_ssh_config=False,
)
def test_runbooks_sync_accepts_ssh_source(monkeypatch, tmp_path: Path) -> None: # type: ignore[no-untyped-def]
runbooks_dir = tmp_path / "remote-runbooks"
runbooks_dir.mkdir(parents=True)
(runbooks_dir / "nginx.md").write_text("---\nservice: nginx\n---\nbody", encoding="utf-8")
monkeypatch.setattr(
"tai.cli._materialize_runbooks_sync_path",
lambda *_args, **_kwargs: (runbooks_dir, "ssh://ops@host/runbooks", None),
)
class FakeStore:
def __init__(self, _path: str, **_kwargs) -> None:
pass
def sync(self, _dir: Path, _ai):
return 1
monkeypatch.setattr("tai.cli.RunbookStore", FakeStore)
monkeypatch.setattr("tai.cli.AIClient", lambda *_a, **_k: object())
runner = CliRunner()
result = runner.invoke(
app,
[
"runbooks",
"sync",
"--path",
"ssh://ops@host/runbooks",
"--store",
"~/.tai/runbooks",
],
)
assert result.exit_code == 0
assert "Synced 1 runbook(s)" in result.stdout
assert "ssh://ops@host/runbooks" in result.stdout
def test_runbooks_add_accepts_https_source(monkeypatch) -> None: # type: ignore[no-untyped-def]
import tempfile
fd, temp_name = tempfile.mkstemp(prefix="tai-runbook-test-", suffix=".md")
os.close(fd)
Path(temp_name).write_text("---\nservice: nginx\n---\nbody", encoding="utf-8")
monkeypatch.setattr(
"tai.cli._materialize_runbook_add_path",
lambda *_args, **_kwargs: (Path(temp_name), "https://kb.example/nginx.md", None),
)
class FakeStore:
def __init__(self, _path: str, **_kwargs) -> None:
pass
def sync_single(self, _path: Path, _ai):
return None
monkeypatch.setattr("tai.cli.RunbookStore", FakeStore)
monkeypatch.setattr("tai.cli.AIClient", lambda *_a, **_k: object())
runner = CliRunner()
result = runner.invoke(
app,
[
"runbooks",
"add",
"https://kb.example/nginx.md",
"--store",
"~/.tai/runbooks",
],
)
assert result.exit_code == 0
assert "Indexed" in result.stdout
assert "https://kb.example/nginx.md" in result.stdout
Path(temp_name).unlink(missing_ok=True)
def test_inject_url_credentials_postgres() -> None:
target = "postgresql://db.example.com:5432/tai"
rendered = _inject_url_credentials(
target,
user="tai_user",
password="secret",
schemes={"postgresql", "postgres"},
)
assert rendered.startswith("postgresql://tai_user:secret@db.example.com:5432/tai")
def test_inject_url_credentials_ignores_non_matching_scheme() -> None:
target = "~/.tai/history.db"
rendered = _inject_url_credentials(
target,
user="tai_user",
password="secret",
schemes={"postgresql", "postgres"},
)
assert rendered == target
def test_load_env_file_and_resolve_secret(tmp_path: Path, monkeypatch) -> None: # type: ignore[no-untyped-def]
env_file = tmp_path / ".env"
env_file.write_text(
"TAI_HISTORY_DB_USER=from_file\n"
"TAI_HISTORY_DB_PASSWORD=from_file_pw\n",
encoding="utf-8",
)
values = _load_env_file(str(env_file))
assert values["TAI_HISTORY_DB_USER"] == "from_file"
assert values["TAI_HISTORY_DB_PASSWORD"] == "from_file_pw"
monkeypatch.setenv("TAI_HISTORY_DB_USER", "from_env")
assert _resolve_secret(None, "TAI_HISTORY_DB_USER", values) == "from_file"
assert _resolve_secret("from_cli", "TAI_HISTORY_DB_USER", values) == "from_cli"
def test_man_page_covers_cli_long_options() -> None:
runner = CliRunner()
help_invocations = [
["run", "--help"],
["history", "--help"],
["runbooks", "sync", "--help"],
["runbooks", "list", "--help"],
["runbooks", "add", "--help"],
]
documented = Path("docs/tai.1").read_text(encoding="utf-8")
discovered: set[str] = set()
for args in help_invocations:
result = runner.invoke(app, args)
assert result.exit_code == 0, f"help command failed for: {' '.join(args)}"
discovered.update(re.findall(r"--[a-z0-9][a-z0-9-]*", result.stdout))
discovered.discard("--help")
missing = sorted(option for option in discovered if option not in documented)
assert missing == [], f"Missing options in docs/tai.1: {', '.join(missing)}"