feat(cli): support conversational AI follow-ups in interactive mode

This commit is contained in:
2026-05-04 05:58:26 +02:00
parent 67a0cb3e69
commit fdcde37e46
4 changed files with 133 additions and 6 deletions

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, cast
from openai import OpenAI from openai import OpenAI
@@ -88,6 +89,20 @@ class AIClient:
if delta: if delta:
yield delta yield delta
def stream_messages(self, messages: list[dict[str, str]]) -> Iterator[str]:
"""Stream a completion from an explicit chat history."""
stream = self._client.chat.completions.create(
model=self._config.model,
max_tokens=self._config.max_tokens,
stream=True,
messages=cast(Any, messages),
)
for chunk in cast(Iterator[Any], stream):
delta = chunk.choices[0].delta.content
if delta:
yield delta
def summary(self) -> str: def summary(self) -> str:
"""Human-readable description of the AI config.""" """Human-readable description of the AI config."""
return f"host={self._config.host} model={self._config.model}" return f"host={self._config.host} model={self._config.model}"

View File

@@ -182,8 +182,28 @@ async def _interactive_loop(
ai_config: AIConfig, ai_config: AIConfig,
report: CollectionReport | None, report: CollectionReport | None,
) -> None: ) -> None:
"""Run a tiny follow-up loop for collecting and analyzing on demand.""" """Run a follow-up loop for collecting and conversational analysis."""
console.print("[cyan]Interactive mode:[/cyan] /collect, /analyze, /help, /quit") console.print(
"[cyan]Interactive mode:[/cyan] "
"ask questions directly, or use /collect, /analyze, /help, /quit"
)
ai = AIClient(ai_config)
messages: list[dict[str, str]] | None = None
def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None:
if current_report is None:
return None
return [
{"role": "system", "content": build_system_prompt()},
{
"role": "user",
"content": build_user_message(req.issue, current_report),
},
]
if report is not None:
messages = _reset_messages(report)
while True: while True:
try: try:
@@ -201,6 +221,7 @@ async def _interactive_loop(
if command == "/help": if command == "/help":
console.print("Commands: /collect, /analyze, /help, /quit") console.print("Commands: /collect, /analyze, /help, /quit")
console.print("Tip: any non-slash text is treated as a follow-up AI question.")
continue continue
if command == "/collect": if command == "/collect":
@@ -208,6 +229,7 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
messages = _reset_messages(report)
continue continue
if command == "/analyze": if command == "/analyze":
@@ -216,10 +238,35 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
_run_analysis(ai_config, req.issue, report) messages = _reset_messages(report)
if messages is None:
console.print("[red]No diagnostics available to analyze.[/red]")
continue continue
console.print(f"[yellow]Unknown command:[/yellow] {command}. Try /help") messages.append(
{
"role": "user",
"content": "Provide an updated diagnosis from the current diagnostics.",
}
)
response = _stream_conversation(ai, messages)
messages.append({"role": "assistant", "content": response})
continue
if report is None:
plan = plan_from_request(req)
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
messages = _reset_messages(report)
if messages is None:
console.print("[red]No diagnostics available to analyze.[/red]")
continue
messages.append({"role": "user", "content": command})
response = _stream_conversation(ai, messages)
messages.append({"role": "assistant", "content": response})
def _handle_probe_result(result: SSHCommandResult) -> None: def _handle_probe_result(result: SSHCommandResult) -> None:
@@ -262,6 +309,21 @@ def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) ->
raise typer.Exit(code=1) from exc raise typer.Exit(code=1) from exc
def _stream_conversation(ai: AIClient, messages: list[dict[str, str]]) -> str:
"""Stream a multi-turn AI response and return the final text."""
console.print("[cyan]Analyzing...[/cyan]\n")
try:
chunks: list[str] = []
for chunk in ai.stream_messages(messages):
chunks.append(chunk)
response = "".join(chunks)
console.print(Markdown(response))
return response
except Exception as exc: # noqa: BLE001
console.print(f"[red]AI analysis failed:[/red] {exc}")
raise typer.Exit(code=1) from exc
def main() -> None: def main() -> None:
"""Console script entrypoint.""" """Console script entrypoint."""
app() app()

View File

@@ -116,6 +116,34 @@ def test_stream_yields_chunks() -> None:
assert result == ["Root ", "cause ", "found."] assert result == ["Root ", "cause ", "found."]
def test_stream_messages_yields_chunks() -> None:
config = AIConfig()
client = AIClient(config)
def _make_chunk(text: str | None) -> MagicMock:
delta = MagicMock()
delta.content = text
choice = MagicMock()
choice.delta = delta
chunk = MagicMock()
chunk.choices = [choice]
return chunk
mock_chunks = [_make_chunk("A"), _make_chunk(None), _make_chunk("B")]
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
result = list(
client.stream_messages(
[
{"role": "system", "content": "sys"},
{"role": "user", "content": "question"},
]
)
)
assert result == ["A", "B"]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# prompt_builder # prompt_builder
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -188,7 +188,28 @@ def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-
def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def] def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch) _mock_session(monkeypatch)
commands = iter(["/wat", "/quit"]) async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli.AIClient.stream_messages",
lambda *_args, **_kwargs: iter(["Check logs."]),
)
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands)) monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
runner = CliRunner() runner = CliRunner()
@@ -206,4 +227,5 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
) )
assert result.exit_code == 0 assert result.exit_code == 0
assert "Unknown command" in result.stdout assert "Analyzing..." in result.stdout
assert "Check logs." in result.stdout