feat(cli): support conversational AI follow-ups in interactive mode
This commit is contained in:
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -88,6 +89,20 @@ class AIClient:
|
||||
if delta:
|
||||
yield delta
|
||||
|
||||
def stream_messages(self, messages: list[dict[str, str]]) -> Iterator[str]:
|
||||
"""Stream a completion from an explicit chat history."""
|
||||
stream = self._client.chat.completions.create(
|
||||
model=self._config.model,
|
||||
max_tokens=self._config.max_tokens,
|
||||
stream=True,
|
||||
messages=cast(Any, messages),
|
||||
)
|
||||
|
||||
for chunk in cast(Iterator[Any], stream):
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
yield delta
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Human-readable description of the AI config."""
|
||||
return f"host={self._config.host} model={self._config.model}"
|
||||
|
||||
@@ -182,8 +182,28 @@ async def _interactive_loop(
|
||||
ai_config: AIConfig,
|
||||
report: CollectionReport | None,
|
||||
) -> None:
|
||||
"""Run a tiny follow-up loop for collecting and analyzing on demand."""
|
||||
console.print("[cyan]Interactive mode:[/cyan] /collect, /analyze, /help, /quit")
|
||||
"""Run a follow-up loop for collecting and conversational analysis."""
|
||||
console.print(
|
||||
"[cyan]Interactive mode:[/cyan] "
|
||||
"ask questions directly, or use /collect, /analyze, /help, /quit"
|
||||
)
|
||||
|
||||
ai = AIClient(ai_config)
|
||||
messages: list[dict[str, str]] | None = None
|
||||
|
||||
def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None:
|
||||
if current_report is None:
|
||||
return None
|
||||
return [
|
||||
{"role": "system", "content": build_system_prompt()},
|
||||
{
|
||||
"role": "user",
|
||||
"content": build_user_message(req.issue, current_report),
|
||||
},
|
||||
]
|
||||
|
||||
if report is not None:
|
||||
messages = _reset_messages(report)
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -201,6 +221,7 @@ async def _interactive_loop(
|
||||
|
||||
if command == "/help":
|
||||
console.print("Commands: /collect, /analyze, /help, /quit")
|
||||
console.print("Tip: any non-slash text is treated as a follow-up AI question.")
|
||||
continue
|
||||
|
||||
if command == "/collect":
|
||||
@@ -208,6 +229,7 @@ async def _interactive_loop(
|
||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||
report = await collect_from_plan(session, plan)
|
||||
_handle_collection_report(report)
|
||||
messages = _reset_messages(report)
|
||||
continue
|
||||
|
||||
if command == "/analyze":
|
||||
@@ -216,10 +238,35 @@ async def _interactive_loop(
|
||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||
report = await collect_from_plan(session, plan)
|
||||
_handle_collection_report(report)
|
||||
_run_analysis(ai_config, req.issue, report)
|
||||
messages = _reset_messages(report)
|
||||
if messages is None:
|
||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||
continue
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Provide an updated diagnosis from the current diagnostics.",
|
||||
}
|
||||
)
|
||||
response = _stream_conversation(ai, messages)
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
continue
|
||||
|
||||
console.print(f"[yellow]Unknown command:[/yellow] {command}. Try /help")
|
||||
if report is None:
|
||||
plan = plan_from_request(req)
|
||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||
report = await collect_from_plan(session, plan)
|
||||
_handle_collection_report(report)
|
||||
messages = _reset_messages(report)
|
||||
|
||||
if messages is None:
|
||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": command})
|
||||
response = _stream_conversation(ai, messages)
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||
@@ -262,6 +309,21 @@ def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) ->
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
|
||||
def _stream_conversation(ai: AIClient, messages: list[dict[str, str]]) -> str:
|
||||
"""Stream a multi-turn AI response and return the final text."""
|
||||
console.print("[cyan]Analyzing...[/cyan]\n")
|
||||
try:
|
||||
chunks: list[str] = []
|
||||
for chunk in ai.stream_messages(messages):
|
||||
chunks.append(chunk)
|
||||
response = "".join(chunks)
|
||||
console.print(Markdown(response))
|
||||
return response
|
||||
except Exception as exc: # noqa: BLE001
|
||||
console.print(f"[red]AI analysis failed:[/red] {exc}")
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Console script entrypoint."""
|
||||
app()
|
||||
|
||||
@@ -116,6 +116,34 @@ def test_stream_yields_chunks() -> None:
|
||||
assert result == ["Root ", "cause ", "found."]
|
||||
|
||||
|
||||
def test_stream_messages_yields_chunks() -> None:
|
||||
config = AIConfig()
|
||||
client = AIClient(config)
|
||||
|
||||
def _make_chunk(text: str | None) -> MagicMock:
|
||||
delta = MagicMock()
|
||||
delta.content = text
|
||||
choice = MagicMock()
|
||||
choice.delta = delta
|
||||
chunk = MagicMock()
|
||||
chunk.choices = [choice]
|
||||
return chunk
|
||||
|
||||
mock_chunks = [_make_chunk("A"), _make_chunk(None), _make_chunk("B")]
|
||||
|
||||
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
|
||||
result = list(
|
||||
client.stream_messages(
|
||||
[
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "question"},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert result == ["A", "B"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt_builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -188,7 +188,28 @@ def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-
|
||||
def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
_mock_session(monkeypatch)
|
||||
|
||||
commands = iter(["/wat", "/quit"])
|
||||
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||
return CollectionReport(
|
||||
host="ssh.archflux.net",
|
||||
items=[
|
||||
CollectedItem(
|
||||
name="kernel",
|
||||
result=SSHCommandResult(
|
||||
command="uname -a",
|
||||
exit_code=0,
|
||||
stdout="Linux test",
|
||||
stderr="",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
commands = iter(["what should I check next?", "/quit"])
|
||||
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||
monkeypatch.setattr(
|
||||
"tai.cli.AIClient.stream_messages",
|
||||
lambda *_args, **_kwargs: iter(["Check logs."]),
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
|
||||
|
||||
runner = CliRunner()
|
||||
@@ -206,4 +227,5 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Unknown command" in result.stdout
|
||||
assert "Analyzing..." in result.stdout
|
||||
assert "Check logs." in result.stdout
|
||||
|
||||
Reference in New Issue
Block a user