From fdcde37e46878be284d5cd86f0ce1a35110425f8 Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 05:58:26 +0200 Subject: [PATCH] feat(cli): support conversational AI follow-ups in interactive mode --- src/tai/ai_client.py | 15 ++++++++++ src/tai/cli.py | 70 +++++++++++++++++++++++++++++++++++++++++--- tests/test_ai.py | 28 ++++++++++++++++++ tests/test_cli.py | 26 ++++++++++++++-- 4 files changed, 133 insertions(+), 6 deletions(-) diff --git a/src/tai/ai_client.py b/src/tai/ai_client.py index c80103e..a50457f 100644 --- a/src/tai/ai_client.py +++ b/src/tai/ai_client.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections.abc import Iterator from dataclasses import dataclass, field +from typing import Any, cast from openai import OpenAI @@ -88,6 +89,20 @@ class AIClient: if delta: yield delta + def stream_messages(self, messages: list[dict[str, str]]) -> Iterator[str]: + """Stream a completion from an explicit chat history.""" + stream = self._client.chat.completions.create( + model=self._config.model, + max_tokens=self._config.max_tokens, + stream=True, + messages=cast(Any, messages), + ) + + for chunk in cast(Iterator[Any], stream): + delta = chunk.choices[0].delta.content + if delta: + yield delta + def summary(self) -> str: """Human-readable description of the AI config.""" return f"host={self._config.host} model={self._config.model}" diff --git a/src/tai/cli.py b/src/tai/cli.py index 4f543d7..1b3a0a7 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -182,8 +182,28 @@ async def _interactive_loop( ai_config: AIConfig, report: CollectionReport | None, ) -> None: - """Run a tiny follow-up loop for collecting and analyzing on demand.""" - console.print("[cyan]Interactive mode:[/cyan] /collect, /analyze, /help, /quit") + """Run a follow-up loop for collecting and conversational analysis.""" + console.print( + "[cyan]Interactive mode:[/cyan] " + "ask questions directly, or use /collect, /analyze, /help, /quit" + ) + + ai = AIClient(ai_config) + messages: list[dict[str, str]] | None = None + + def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None: + if current_report is None: + return None + return [ + {"role": "system", "content": build_system_prompt()}, + { + "role": "user", + "content": build_user_message(req.issue, current_report), + }, + ] + + if report is not None: + messages = _reset_messages(report) while True: try: @@ -201,6 +221,7 @@ async def _interactive_loop( if command == "/help": console.print("Commands: /collect, /analyze, /help, /quit") + console.print("Tip: any non-slash text is treated as a follow-up AI question.") continue if command == "/collect": @@ -208,6 +229,7 @@ async def _interactive_loop( console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") report = await collect_from_plan(session, plan) _handle_collection_report(report) + messages = _reset_messages(report) continue if command == "/analyze": @@ -216,10 +238,35 @@ async def _interactive_loop( console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") report = await collect_from_plan(session, plan) _handle_collection_report(report) - _run_analysis(ai_config, req.issue, report) + messages = _reset_messages(report) + if messages is None: + console.print("[red]No diagnostics available to analyze.[/red]") + continue + + messages.append( + { + "role": "user", + "content": "Provide an updated diagnosis from the current diagnostics.", + } + ) + response = _stream_conversation(ai, messages) + messages.append({"role": "assistant", "content": response}) continue - console.print(f"[yellow]Unknown command:[/yellow] {command}. Try /help") + if report is None: + plan = plan_from_request(req) + console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") + report = await collect_from_plan(session, plan) + _handle_collection_report(report) + messages = _reset_messages(report) + + if messages is None: + console.print("[red]No diagnostics available to analyze.[/red]") + continue + + messages.append({"role": "user", "content": command}) + response = _stream_conversation(ai, messages) + messages.append({"role": "assistant", "content": response}) def _handle_probe_result(result: SSHCommandResult) -> None: @@ -262,6 +309,21 @@ def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) -> raise typer.Exit(code=1) from exc +def _stream_conversation(ai: AIClient, messages: list[dict[str, str]]) -> str: + """Stream a multi-turn AI response and return the final text.""" + console.print("[cyan]Analyzing...[/cyan]\n") + try: + chunks: list[str] = [] + for chunk in ai.stream_messages(messages): + chunks.append(chunk) + response = "".join(chunks) + console.print(Markdown(response)) + return response + except Exception as exc: # noqa: BLE001 + console.print(f"[red]AI analysis failed:[/red] {exc}") + raise typer.Exit(code=1) from exc + + def main() -> None: """Console script entrypoint.""" app() diff --git a/tests/test_ai.py b/tests/test_ai.py index 08d1510..b45c5f5 100644 --- a/tests/test_ai.py +++ b/tests/test_ai.py @@ -116,6 +116,34 @@ def test_stream_yields_chunks() -> None: assert result == ["Root ", "cause ", "found."] +def test_stream_messages_yields_chunks() -> None: + config = AIConfig() + client = AIClient(config) + + def _make_chunk(text: str | None) -> MagicMock: + delta = MagicMock() + delta.content = text + choice = MagicMock() + choice.delta = delta + chunk = MagicMock() + chunk.choices = [choice] + return chunk + + mock_chunks = [_make_chunk("A"), _make_chunk(None), _make_chunk("B")] + + with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)): + result = list( + client.stream_messages( + [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "question"}, + ] + ) + ) + + assert result == ["A", "B"] + + # --------------------------------------------------------------------------- # prompt_builder # --------------------------------------------------------------------------- diff --git a/tests/test_cli.py b/tests/test_cli.py index 6b28d64..a49b3f5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -188,7 +188,28 @@ def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no- def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def] _mock_session(monkeypatch) - commands = iter(["/wat", "/quit"]) + async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def] + return CollectionReport( + host="ssh.archflux.net", + items=[ + CollectedItem( + name="kernel", + result=SSHCommandResult( + command="uname -a", + exit_code=0, + stdout="Linux test", + stderr="", + ), + ), + ], + ) + + commands = iter(["what should I check next?", "/quit"]) + monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) + monkeypatch.setattr( + "tai.cli.AIClient.stream_messages", + lambda *_args, **_kwargs: iter(["Check logs."]), + ) monkeypatch.setattr("builtins.input", lambda _prompt: next(commands)) runner = CliRunner() @@ -206,4 +227,5 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ) assert result.exit_code == 0 - assert "Unknown command" in result.stdout + assert "Analyzing..." in result.stdout + assert "Check logs." in result.stdout