Files
tai/tests/test_cli.py

550 lines
17 KiB
Python

import json
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from typer.testing import CliRunner
from tai.cli import app
from tai.collectors import CollectedItem, CollectionReport
from tai.rag_retriever import Chunk, EmbeddedChunk
from tai.ssh_client import SSHCommandResult
def _mock_session(
monkeypatch, # type: ignore[no-untyped-def]
*,
probe_result: SSHCommandResult | None = None,
probe_raises: Exception | None = None,
) -> MagicMock:
"""Patch SSHClient.connect to return a mock session."""
session = MagicMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)
if probe_raises:
session.probe = AsyncMock(side_effect=probe_raises)
else:
session.probe = AsyncMock(return_value=probe_result)
monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session)
return session
def test_run_command_prints_scaffold_summary() -> None:
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"web01",
"--port",
"5566",
"--no-probe",
"--path",
"/etc/apache2",
"--jump-host",
"bastion01",
"--ignore-ssh-config",
],
)
assert result.exit_code == 0
assert "tai" in result.stdout
assert "host=web01" in result.stdout
assert "port=5566" in result.stdout
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(
monkeypatch,
probe_result=SSHCommandResult(
command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr=""
),
)
runner = CliRunner()
result = runner.invoke(
app,
["run", "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
)
assert result.exit_code == 0
assert "Probe succeeded" in result.stdout
assert "Linux ssh 6.12.0" in result.stdout
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(
monkeypatch,
probe_result=SSHCommandResult(
command="uname -a",
exit_code=255,
stdout="",
stderr="Permission denied (publickey,password).",
),
)
runner = CliRunner()
result = runner.invoke(
app,
["run", "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
)
assert result.exit_code == 1
assert "Probe failed" in result.stdout
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
CollectedItem(
name="journal",
result=SSHCommandResult(
command="journalctl -n 200",
exit_code=0,
stdout="...",
stderr="",
stdout_truncated=True,
),
),
],
)
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--collect",
],
)
assert result.exit_code == 0
assert "Collection complete" in result.stdout
assert "kernel" in result.stdout
assert "journal" in result.stdout
assert "truncated" in result.stdout
def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["/collect", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "ask questions directly" in result.stdout.lower()
assert "collection complete" in result.stdout.lower()
assert "Bye." in result.stdout
def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: SimpleNamespace(content="Check logs."),
)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "AI Response" in result.stdout
assert "Check logs." in result.stdout
def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0))
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: SimpleNamespace(content="Check logs."),
)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "RAG unavailable (indexing failed)" in result.stdout
assert "AI Response" in result.stdout
def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli._try_embed_report",
lambda *_args: (
[EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])],
None,
1.0,
),
)
monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0])
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: SimpleNamespace(content="Check logs."),
)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
"--rag-debug",
],
)
assert result.exit_code == 0
assert "RAG retrieve:" in result.stdout
def test_history_command_lists_sessions(monkeypatch) -> None: # type: ignore[no-untyped-def]
class FakeStore:
def __init__(self, _path: str) -> None:
pass
def list_recent(self, *, host: str | None = None, limit: int = 20):
del limit
if host == "web01":
return [
SimpleNamespace(
session_id="20260507T120000Z",
host="web01",
issue="nginx down",
summary="Root cause: bad config",
)
]
return []
monkeypatch.setattr("tai.cli.SessionStore", FakeStore)
runner = CliRunner()
result = runner.invoke(
app,
["history", "--session-memory", "~/.tai/sessions", "--host", "web01"],
)
assert result.exit_code == 0
assert "session(s)" in result.stdout
assert "20260507T120000Z" in result.stdout
def test_history_command_exports_markdown(monkeypatch, tmp_path: Path) -> None: # type: ignore[no-untyped-def]
class FakeStore:
def __init__(self, _path: str) -> None:
pass
def list_recent(self, *, host: str | None = None, limit: int = 20):
del host, limit
return [
SimpleNamespace(
session_id="20260507T120000Z",
host="web01",
issue="nginx down",
summary="Root cause: bad config",
)
]
monkeypatch.setattr("tai.cli.SessionStore", FakeStore)
export_path = tmp_path / "history.md"
runner = CliRunner()
result = runner.invoke(
app,
["history", "--session-memory", "~/.tai/sessions", "--export", str(export_path)],
)
assert result.exit_code == 0
assert "Exported" in result.stdout
text = export_path.read_text(encoding="utf-8")
assert "# tai session history" in text
assert "nginx down" in text
def test_interactive_history_without_store_shows_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["/history", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--interactive",
],
)
assert result.exit_code == 0
assert "Session memory is disabled" in result.stdout
def test_run_analyze_writes_output_file(monkeypatch, tmp_path: Path) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: SimpleNamespace(content="Root Cause\n\nEvidence\n\nRecommended Actions"),
)
output_path = tmp_path / "analysis.md"
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--analyze",
"--output-file",
str(output_path),
],
)
assert result.exit_code == 0
assert "Wrote analysis output" in result.stdout
assert output_path.exists()
assert "Root Cause" in output_path.read_text(encoding="utf-8")
def test_run_analyze_writes_json_output_and_strips_ansi(monkeypatch, tmp_path: Path) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli.AIClient.complete",
lambda *_args, **_kwargs: SimpleNamespace(
content="\x1b[31mRoot Cause\x1b[0m\n\nEvidence\n\nRecommended Actions"
),
)
output_path = tmp_path / "analysis.json"
runner = CliRunner()
result = runner.invoke(
app,
[
"run", "apache failed",
"--host",
"ssh.archflux.net",
"--port",
"5566",
"--no-probe",
"--analyze",
"--output-file",
str(output_path),
"--output-format",
"json",
],
)
assert result.exit_code == 0
payload = json.loads(output_path.read_text(encoding="utf-8"))
assert payload["issue"] == "apache failed"
assert payload["host"] == "ssh.archflux.net"
assert "Root Cause" in payload["analysis"]
assert "\u001b" not in payload["analysis"]