from unittest.mock import AsyncMock, MagicMock from typer.testing import CliRunner from tai.cli import app from tai.collectors import CollectedItem, CollectionReport from tai.rag_retriever import Chunk, EmbeddedChunk from tai.ssh_client import SSHCommandResult def _mock_session( monkeypatch, # type: ignore[no-untyped-def] *, probe_result: SSHCommandResult | None = None, probe_raises: Exception | None = None, ) -> MagicMock: """Patch SSHClient.connect to return a mock session.""" session = MagicMock() session.__aenter__ = AsyncMock(return_value=session) session.__aexit__ = AsyncMock(return_value=None) if probe_raises: session.probe = AsyncMock(side_effect=probe_raises) else: session.probe = AsyncMock(return_value=probe_result) monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session) return session def test_run_command_prints_scaffold_summary() -> None: runner = CliRunner() result = runner.invoke( app, [ "apache failed", "--host", "web01", "--port", "5566", "--no-probe", "--path", "/etc/apache2", "--jump-host", "bastion01", "--ignore-ssh-config", ], ) assert result.exit_code == 0 assert "tai" in result.stdout assert "host=web01" in result.stdout assert "port=5566" in result.stdout def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def] _mock_session( monkeypatch, probe_result=SSHCommandResult( command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr="" ), ) runner = CliRunner() result = runner.invoke( app, ["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"], ) assert result.exit_code == 0 assert "Probe succeeded" in result.stdout assert "Linux ssh 6.12.0" in result.stdout def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def] _mock_session( monkeypatch, probe_result=SSHCommandResult( command="uname -a", exit_code=255, stdout="", stderr="Permission denied (publickey,password).", ), ) runner = CliRunner() result = runner.invoke( app, ["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"], ) assert result.exit_code == 1 assert "Probe failed" in result.stdout def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def] _mock_session(monkeypatch) async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] return CollectionReport( host="ssh.archflux.net", items=[ CollectedItem( name="kernel", result=SSHCommandResult( command="uname -a", exit_code=0, stdout="Linux test", stderr="", ), ), CollectedItem( name="journal", result=SSHCommandResult( command="journalctl -n 200", exit_code=0, stdout="...", stderr="", stdout_truncated=True, ), ), ], ) monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) runner = CliRunner() result = runner.invoke( app, [ "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--no-probe", "--collect", ], ) assert result.exit_code == 0 assert "Collection complete" in result.stdout assert "kernel" in result.stdout assert "journal" in result.stdout assert "truncated" in result.stdout def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-untyped-def] _mock_session(monkeypatch) async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] return CollectionReport( host="ssh.archflux.net", items=[ CollectedItem( name="kernel", result=SSHCommandResult( command="uname -a", exit_code=0, stdout="Linux test", stderr="", ), ), ], ) commands = iter(["/collect", "/quit"]) monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands)) monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True) runner = CliRunner() result = runner.invoke( app, [ "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--no-probe", "--interactive", ], ) assert result.exit_code == 0 assert "ask questions directly" in result.stdout.lower() assert "collection complete" in result.stdout.lower() assert "Bye." in result.stdout def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def] _mock_session(monkeypatch) async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] return CollectionReport( host="ssh.archflux.net", items=[ CollectedItem( name="kernel", result=SSHCommandResult( command="uname -a", exit_code=0, stdout="Linux test", stderr="", ), ), ], ) commands = iter(["what should I check next?", "/quit"]) monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) monkeypatch.setattr( "tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]), ) monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands)) monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True) runner = CliRunner() result = runner.invoke( app, [ "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--no-probe", "--interactive", ], ) assert result.exit_code == 0 assert "AI Response" in result.stdout assert "Check logs." in result.stdout def test_interactive_prints_rag_fallback_notice_on_index_failure(monkeypatch) -> None: # type: ignore[no-untyped-def] _mock_session(monkeypatch) async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] return CollectionReport( host="ssh.archflux.net", items=[ CollectedItem( name="kernel", result=SSHCommandResult( command="uname -a", exit_code=0, stdout="Linux test", stderr="", ), ), ], ) commands = iter(["what should I check next?", "/quit"]) monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) monkeypatch.setattr("tai.cli._try_embed_report", lambda *_args: (None, "embed failed", 1.0)) monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."])) monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands)) monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True) runner = CliRunner() result = runner.invoke( app, [ "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--no-probe", "--interactive", ], ) assert result.exit_code == 0 assert "RAG unavailable (indexing failed)" in result.stdout assert "AI Response" in result.stdout def test_interactive_rag_debug_prints_retrieval_scores(monkeypatch) -> None: # type: ignore[no-untyped-def] _mock_session(monkeypatch) async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] return CollectionReport( host="ssh.archflux.net", items=[ CollectedItem( name="kernel", result=SSHCommandResult( command="uname -a", exit_code=0, stdout="Linux test", stderr="", ), ), ], ) commands = iter(["what should I check next?", "/quit"]) monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) monkeypatch.setattr( "tai.cli._try_embed_report", lambda *_args: ( [EmbeddedChunk(chunk=Chunk(name="kernel", content="content"), embedding=[1.0, 0.0])], None, 1.0, ), ) monkeypatch.setattr("tai.cli.AIClient.embed", lambda *_args, **_kwargs: [1.0, 0.0]) monkeypatch.setattr("tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."])) monkeypatch.setattr("tai.cli.console.input", lambda _prompt: next(commands)) monkeypatch.setattr("tai.cli._stdin_is_tty", lambda: True) runner = CliRunner() result = runner.invoke( app, [ "apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--no-probe", "--interactive", "--rag-debug", ], ) assert result.exit_code == 0 assert "RAG retrieve:" in result.stdout