From 67a0cb3e697633af8b5884e8e98733c34a17edf3 Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 05:54:15 +0200 Subject: [PATCH] feat(cli): add interactive follow-up loop with slash commands --- src/tai/cli.py | 76 +++++++++++++++++++++++++++++++++++++++++++---- tests/test_cli.py | 68 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 5 deletions(-) diff --git a/src/tai/cli.py b/src/tai/cli.py index 4481be6..4f543d7 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -15,7 +15,7 @@ from tai.input_parser import InputValidationError, build_request from tai.models import TroubleshootRequest from tai.plan import plan_from_request from tai.prompt_builder import build_system_prompt, build_user_message -from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig +from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession app = typer.Typer(no_args_is_help=True, add_completion=False) console = Console() @@ -66,6 +66,13 @@ def run( help="Send collected diagnostics to AI for analysis.", ), ] = False, + interactive: Annotated[ + bool, + typer.Option( + "--interactive/--no-interactive", + help="Start interactive follow-up mode (/collect, /analyze, /quit).", + ), + ] = False, ai_host: Annotated[ str, typer.Option("--ai-host", help="OpenAI-compatible AI backend URL."), @@ -109,16 +116,25 @@ def run( if req.target_paths: console.print(f"Paths: {', '.join(str(p) for p in req.target_paths)}") - if not (probe or collect or analyze): + if not (probe or collect or analyze or interactive): return # nothing SSH-related requested ai_config = AIConfig(host=ai_host, model=model, api_key=ai_key) - if analyze: + if analyze or interactive: console.print(f"[cyan]AI:[/cyan] {AIClient(ai_config).summary()}") try: - asyncio.run(_async_main(config, req, probe=probe, collect=collect, analyze=analyze, - ai_config=ai_config)) + asyncio.run( + _async_main( + config, + req, + probe=probe, + collect=collect, + analyze=analyze, + interactive=interactive, + ai_config=ai_config, + ) + ) except typer.Exit: raise except TimeoutError as exc: @@ -136,6 +152,7 @@ async def _async_main( probe: bool, collect: bool, analyze: bool, + interactive: bool, ai_config: AIConfig, ) -> None: """Open a single SSH session and run probe / collection / analysis through it.""" @@ -155,6 +172,55 @@ async def _async_main( if analyze and report is not None: _run_analysis(ai_config, req.issue, report) + if interactive: + await _interactive_loop(session, req, ai_config, report) + + +async def _interactive_loop( + session: SSHSession, + req: TroubleshootRequest, + ai_config: AIConfig, + report: CollectionReport | None, +) -> None: + """Run a tiny follow-up loop for collecting and analyzing on demand.""" + console.print("[cyan]Interactive mode:[/cyan] /collect, /analyze, /help, /quit") + + while True: + try: + command = input("tai> ").strip() + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Exiting interactive mode.[/yellow]") + return + + if not command: + continue + + if command in {"/quit", "/exit"}: + console.print("[green]Bye.[/green]") + return + + if command == "/help": + console.print("Commands: /collect, /analyze, /help, /quit") + continue + + if command == "/collect": + plan = plan_from_request(req) + console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") + report = await collect_from_plan(session, plan) + _handle_collection_report(report) + continue + + if command == "/analyze": + if report is None: + plan = plan_from_request(req) + console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") + report = await collect_from_plan(session, plan) + _handle_collection_report(report) + _run_analysis(ai_config, req.issue, report) + continue + + console.print(f"[yellow]Unknown command:[/yellow] {command}. Try /help") + def _handle_probe_result(result: SSHCommandResult) -> None: """Handle and render probe output for success or failure.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index c26716f..6b28d64 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -139,3 +139,71 @@ def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no assert "Collection complete" in result.stdout assert "kernel: ok" in result.stdout assert "journal: ok (truncated)" in result.stdout + + +def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-untyped-def] + _mock_session(monkeypatch) + + async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] + return CollectionReport( + host="ssh.archflux.net", + items=[ + CollectedItem( + name="kernel", + result=SSHCommandResult( + command="uname -a", + exit_code=0, + stdout="Linux test", + stderr="", + ), + ), + ], + ) + + commands = iter(["/collect", "/quit"]) + + monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) + monkeypatch.setattr("builtins.input", lambda _prompt: next(commands)) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "apache failed", + "--host", + "ssh.archflux.net", + "--port", + "5566", + "--no-probe", + "--interactive", + ], + ) + + assert result.exit_code == 0 + assert "Interactive mode" in result.stdout + assert "Collection complete" in result.stdout + assert "Bye." in result.stdout + + +def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def] + _mock_session(monkeypatch) + + commands = iter(["/wat", "/quit"]) + monkeypatch.setattr("builtins.input", lambda _prompt: next(commands)) + + runner = CliRunner() + result = runner.invoke( + app, + [ + "apache failed", + "--host", + "ssh.archflux.net", + "--port", + "5566", + "--no-probe", + "--interactive", + ], + ) + + assert result.exit_code == 0 + assert "Unknown command" in result.stdout