feat(cli): support conversational AI follow-ups in interactive mode

This commit is contained in:
2026-05-04 05:58:26 +02:00
parent 67a0cb3e69
commit fdcde37e46
4 changed files with 133 additions and 6 deletions

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any, cast
from openai import OpenAI
@@ -88,6 +89,20 @@ class AIClient:
if delta:
yield delta
def stream_messages(self, messages: list[dict[str, str]]) -> Iterator[str]:
"""Stream a completion from an explicit chat history."""
stream = self._client.chat.completions.create(
model=self._config.model,
max_tokens=self._config.max_tokens,
stream=True,
messages=cast(Any, messages),
)
for chunk in cast(Iterator[Any], stream):
delta = chunk.choices[0].delta.content
if delta:
yield delta
def summary(self) -> str:
"""Human-readable description of the AI config."""
return f"host={self._config.host} model={self._config.model}"

View File

@@ -182,8 +182,28 @@ async def _interactive_loop(
ai_config: AIConfig,
report: CollectionReport | None,
) -> None:
"""Run a tiny follow-up loop for collecting and analyzing on demand."""
console.print("[cyan]Interactive mode:[/cyan] /collect, /analyze, /help, /quit")
"""Run a follow-up loop for collecting and conversational analysis."""
console.print(
"[cyan]Interactive mode:[/cyan] "
"ask questions directly, or use /collect, /analyze, /help, /quit"
)
ai = AIClient(ai_config)
messages: list[dict[str, str]] | None = None
def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None:
if current_report is None:
return None
return [
{"role": "system", "content": build_system_prompt()},
{
"role": "user",
"content": build_user_message(req.issue, current_report),
},
]
if report is not None:
messages = _reset_messages(report)
while True:
try:
@@ -201,6 +221,7 @@ async def _interactive_loop(
if command == "/help":
console.print("Commands: /collect, /analyze, /help, /quit")
console.print("Tip: any non-slash text is treated as a follow-up AI question.")
continue
if command == "/collect":
@@ -208,6 +229,7 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
messages = _reset_messages(report)
continue
if command == "/analyze":
@@ -216,10 +238,35 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
_run_analysis(ai_config, req.issue, report)
messages = _reset_messages(report)
if messages is None:
console.print("[red]No diagnostics available to analyze.[/red]")
continue
messages.append(
{
"role": "user",
"content": "Provide an updated diagnosis from the current diagnostics.",
}
)
response = _stream_conversation(ai, messages)
messages.append({"role": "assistant", "content": response})
continue
console.print(f"[yellow]Unknown command:[/yellow] {command}. Try /help")
if report is None:
plan = plan_from_request(req)
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
messages = _reset_messages(report)
if messages is None:
console.print("[red]No diagnostics available to analyze.[/red]")
continue
messages.append({"role": "user", "content": command})
response = _stream_conversation(ai, messages)
messages.append({"role": "assistant", "content": response})
def _handle_probe_result(result: SSHCommandResult) -> None:
@@ -262,6 +309,21 @@ def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) ->
raise typer.Exit(code=1) from exc
def _stream_conversation(ai: AIClient, messages: list[dict[str, str]]) -> str:
"""Stream a multi-turn AI response and return the final text."""
console.print("[cyan]Analyzing...[/cyan]\n")
try:
chunks: list[str] = []
for chunk in ai.stream_messages(messages):
chunks.append(chunk)
response = "".join(chunks)
console.print(Markdown(response))
return response
except Exception as exc: # noqa: BLE001
console.print(f"[red]AI analysis failed:[/red] {exc}")
raise typer.Exit(code=1) from exc
def main() -> None:
"""Console script entrypoint."""
app()

View File

@@ -116,6 +116,34 @@ def test_stream_yields_chunks() -> None:
assert result == ["Root ", "cause ", "found."]
def test_stream_messages_yields_chunks() -> None:
config = AIConfig()
client = AIClient(config)
def _make_chunk(text: str | None) -> MagicMock:
delta = MagicMock()
delta.content = text
choice = MagicMock()
choice.delta = delta
chunk = MagicMock()
chunk.choices = [choice]
return chunk
mock_chunks = [_make_chunk("A"), _make_chunk(None), _make_chunk("B")]
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
result = list(
client.stream_messages(
[
{"role": "system", "content": "sys"},
{"role": "user", "content": "question"},
]
)
)
assert result == ["A", "B"]
# ---------------------------------------------------------------------------
# prompt_builder
# ---------------------------------------------------------------------------

View File

@@ -188,7 +188,28 @@ def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-
def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
_mock_session(monkeypatch)
commands = iter(["/wat", "/quit"])
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
return CollectionReport(
host="ssh.archflux.net",
items=[
CollectedItem(
name="kernel",
result=SSHCommandResult(
command="uname -a",
exit_code=0,
stdout="Linux test",
stderr="",
),
),
],
)
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli.AIClient.stream_messages",
lambda *_args, **_kwargs: iter(["Check logs."]),
)
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
runner = CliRunner()
@@ -206,4 +227,5 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
)
assert result.exit_code == 0
assert "Unknown command" in result.stdout
assert "Analyzing..." in result.stdout
assert "Check logs." in result.stdout