341 lines
10 KiB
Python
341 lines
10 KiB
Python
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from typer.testing import CliRunner
|
|
|
|
from tai.cli import app
|
|
from tai.collectors import CollectedItem, CollectionReport
|
|
from tai.rag_retriever import Chunk, EmbeddedChunk
|
|
from tai.ssh_client import SSHCommandResult
|
|
|
|
|
|
def _mock_session(
|
|
monkeypatch, # type: ignore[no-untyped-def]
|
|
*,
|
|
probe_result: SSHCommandResult | None = None,
|
|
probe_raises: Exception | None = None,
|
|
) -> MagicMock:
|
|
"""Patch SSHClient.connect to return a mock session."""
|
|
session = MagicMock()
|
|
session.__aenter__ = AsyncMock(return_value=session)
|
|
session.__aexit__ = AsyncMock(return_value=None)
|
|
if probe_raises:
|
|
session.probe = AsyncMock(side_effect=probe_raises)
|
|
else:
|
|
session.probe = AsyncMock(return_value=probe_result)
|
|
monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session)
|
|
return session
|
|
|
|
|
|
def test_run_command_prints_scaffold_summary() -> None:
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
app,
|
|
[
|
|
"run", "apache failed",
|
|
"--host",
|
|
"web01",
|
|
"--port",
|
|
"5566",
|
|
"--no-probe",
|
|
"--path",
|
|
"/etc/apache2",
|
|
"--jump-host",
|
|
"bastion01",
|
|
"--ignore-ssh-config",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert "tai" in result.stdout
|
|
assert "host=web01" in result.stdout
|
|
assert "port=5566" in result.stdout
|
|
|
|
|
|
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
|
_mock_session(
|
|
monkeypatch,
|
|
probe_result=SSHCommandResult(
|
|
command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr=""
|
|
),
|
|
)
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
app,
|
|
["run", "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert "Probe succeeded" in result.stdout
|
|
assert "Linux ssh 6.12.0" in result.stdout
|
|
|
|
|
|
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
|
_mock_session(
|
|
monkeypatch,
|
|
probe_result=SSHCommandResult(
|
|
command="uname -a",
|
|
exit_code=255,
|
|
stdout="",
|
|
stderr="Permission denied (publickey,password).",
|
|
),
|
|
)
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
app,
|
|
["run", "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
|
|
)
|
|
|
|
assert result.exit_code == 1
|
|
assert "Probe failed" in result.stdout
|
|
|
|
|
|
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
|
_mock_session(monkeypatch)
|
|
|
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
|
return CollectionReport(
|
|
host="ssh.archflux.net",
|
|
items=[
|
|
CollectedItem(
|
|
name="kernel",
|
|
result=SSHCommandResult(
|
|
command="uname -a",
|
|
exit_code=0,
|
|
stdout="Linux test",
|
|
stderr="",
|
|
),
|
|
),
|
|
CollectedItem(
|
|
name="journal",
|
|
result=SSHCommandResult(
|
|
command="journalctl -n 200",
|
|
exit_code=0,
|
|
stdout="...",
|
|
stderr="",
|
|
stdout_truncated=True,
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
app,
|
|
[
|
|
"run", "apache failed",
|
|
"--host",
|
|
"ssh.archflux.net",
|
|
"--port",
|
|
"5566",
|
|
"--no-probe",
|
|
"--collect",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert "Collection complete" in result.stdout
|
|
assert "kernel" in result.stdout
|
|
assert "journal" in result.stdout
|
|
assert "truncated" in result.stdout
|
|
|
|
|
|
def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
|
_mock_session(monkeypatch)
|
|
|
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
|
return CollectionReport(
|
|
host="ssh.archflux.net",
|
|
items=[
|
|
CollectedItem(
|
|
name="kernel",
|
|
result=SSHCommandResult(
|
|
command="uname -a",
|
|
exit_code=0,
|
|
stdout="Linux test",
|
|
stderr="",
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
commands = iter(["/collect", "/quit"])
|
|
|
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
|
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
|
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
app,
|
|
[
|
|
"run", "apache failed",
|
|
"--host",
|
|
"ssh.archflux.net",
|
|
"--port",
|
|
"5566",
|
|
"--no-probe",
|
|
"--interactive",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert "ask questions directly" in result.stdout.lower()
|
|
assert "collection complete" in result.stdout.lower()
|
|
assert "Bye." in result.stdout
|
|
|
|
|
|
def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
|
_mock_session(monkeypatch)
|
|
|
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
|
return CollectionReport(
|
|
host="ssh.archflux.net",
|
|
items=[
|
|
CollectedItem(
|
|
name="kernel",
|
|
result=SSHCommandResult(
|
|
command="uname -a",
|
|
exit_code=0,
|
|
stdout="Linux test",
|
|
stderr="",
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
commands = iter(["what should I check next?", "/quit"])
|
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
|
monkeypatch.setattr(
|
|
"tai.cli.AIClient.complete",
|
|
lambda *_args, **_kwargs: SimpleNamespace(content="Check logs."),
|
|
)
|
|
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
|
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
app,
|
|
[
|
|
"run", "apache failed",
|
|
"--host",
|
|
"ssh.archflux.net",
|
|
"--port",
|
|
"5566",
|
|
"--no-probe",
|
|
"--interactive",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert "AI Response" in result.stdout
|
|
assert "Check logs." in result.stdout
|
|
|
|
|
|
def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
|
_mock_session(monkeypatch)
|
|
|
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
|
return CollectionReport(
|
|
host="ssh.archflux.net",
|
|
items=[
|
|
CollectedItem(
|
|
name="kernel",
|
|
result=SSHCommandResult(
|
|
command="uname -a",
|
|
exit_code=0,
|
|
stdout="Linux test",
|
|
stderr="",
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
commands = iter(["what should I check next?", "/quit"])
|
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
|
monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0))
|
|
monkeypatch.setattr(
|
|
"tai.cli.AIClient.complete",
|
|
lambda *_args, **_kwargs: SimpleNamespace(content="Check logs."),
|
|
)
|
|
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
|
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
app,
|
|
[
|
|
"run", "apache failed",
|
|
"--host",
|
|
"ssh.archflux.net",
|
|
"--port",
|
|
"5566",
|
|
"--no-probe",
|
|
"--interactive",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert "RAG unavailable (indexing failed)" in result.stdout
|
|
assert "AI Response" in result.stdout
|
|
|
|
|
|
def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
|
_mock_session(monkeypatch)
|
|
|
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
|
return CollectionReport(
|
|
host="ssh.archflux.net",
|
|
items=[
|
|
CollectedItem(
|
|
name="kernel",
|
|
result=SSHCommandResult(
|
|
command="uname -a",
|
|
exit_code=0,
|
|
stdout="Linux test",
|
|
stderr="",
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
commands = iter(["what should I check next?", "/quit"])
|
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
|
monkeypatch.setattr(
|
|
"tai.cli._try_embed_report",
|
|
lambda *_args: (
|
|
[EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])],
|
|
None,
|
|
1.0,
|
|
),
|
|
)
|
|
monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0])
|
|
monkeypatch.setattr(
|
|
"tai.cli.AIClient.complete",
|
|
lambda *_args, **_kwargs: SimpleNamespace(content="Check logs."),
|
|
)
|
|
monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands))
|
|
monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True)
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
app,
|
|
[
|
|
"run", "apache failed",
|
|
"--host",
|
|
"ssh.archflux.net",
|
|
"--port",
|
|
"5566",
|
|
"--no-probe",
|
|
"--interactive",
|
|
"--rag-debug",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert "RAG retrieve:" in result.stdout
|