feat(cli): support conversational AI follow-ups in interactive mode
This commit is contained in:
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -88,6 +89,20 @@ class AIClient:
|
|||||||
if delta:
|
if delta:
|
||||||
yield delta
|
yield delta
|
||||||
|
|
||||||
|
def stream_messages(self, messages: list[dict[str, str]]) -> Iterator[str]:
|
||||||
|
"""Stream a completion from an explicit chat history."""
|
||||||
|
stream = self._client.chat.completions.create(
|
||||||
|
model=self._config.model,
|
||||||
|
max_tokens=self._config.max_tokens,
|
||||||
|
stream=True,
|
||||||
|
messages=cast(Any, messages),
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in cast(Iterator[Any], stream):
|
||||||
|
delta = chunk.choices[0].delta.content
|
||||||
|
if delta:
|
||||||
|
yield delta
|
||||||
|
|
||||||
def summary(self) -> str:
|
def summary(self) -> str:
|
||||||
"""Human-readable description of the AI config."""
|
"""Human-readable description of the AI config."""
|
||||||
return f"host={self._config.host} model={self._config.model}"
|
return f"host={self._config.host} model={self._config.model}"
|
||||||
|
|||||||
@@ -182,8 +182,28 @@ async def _interactive_loop(
|
|||||||
ai_config: AIConfig,
|
ai_config: AIConfig,
|
||||||
report: CollectionReport | None,
|
report: CollectionReport | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run a tiny follow-up loop for collecting and analyzing on demand."""
|
"""Run a follow-up loop for collecting and conversational analysis."""
|
||||||
console.print("[cyan]Interactive mode:[/cyan] /collect, /analyze, /help, /quit")
|
console.print(
|
||||||
|
"[cyan]Interactive mode:[/cyan] "
|
||||||
|
"ask questions directly, or use /collect, /analyze, /help, /quit"
|
||||||
|
)
|
||||||
|
|
||||||
|
ai = AIClient(ai_config)
|
||||||
|
messages: list[dict[str, str]] | None = None
|
||||||
|
|
||||||
|
def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None:
|
||||||
|
if current_report is None:
|
||||||
|
return None
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": build_system_prompt()},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": build_user_message(req.issue, current_report),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
if report is not None:
|
||||||
|
messages = _reset_messages(report)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -201,6 +221,7 @@ async def _interactive_loop(
|
|||||||
|
|
||||||
if command == "/help":
|
if command == "/help":
|
||||||
console.print("Commands: /collect, /analyze, /help, /quit")
|
console.print("Commands: /collect, /analyze, /help, /quit")
|
||||||
|
console.print("Tip: any non-slash text is treated as a follow-up AI question.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if command == "/collect":
|
if command == "/collect":
|
||||||
@@ -208,6 +229,7 @@ async def _interactive_loop(
|
|||||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
report = await collect_from_plan(session, plan)
|
report = await collect_from_plan(session, plan)
|
||||||
_handle_collection_report(report)
|
_handle_collection_report(report)
|
||||||
|
messages = _reset_messages(report)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if command == "/analyze":
|
if command == "/analyze":
|
||||||
@@ -216,10 +238,35 @@ async def _interactive_loop(
|
|||||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
report = await collect_from_plan(session, plan)
|
report = await collect_from_plan(session, plan)
|
||||||
_handle_collection_report(report)
|
_handle_collection_report(report)
|
||||||
_run_analysis(ai_config, req.issue, report)
|
messages = _reset_messages(report)
|
||||||
|
if messages is None:
|
||||||
|
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
console.print(f"[yellow]Unknown command:[/yellow] {command}. Try /help")
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Provide an updated diagnosis from the current diagnostics.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = _stream_conversation(ai, messages)
|
||||||
|
messages.append({"role": "assistant", "content": response})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if report is None:
|
||||||
|
plan = plan_from_request(req)
|
||||||
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
|
report = await collect_from_plan(session, plan)
|
||||||
|
_handle_collection_report(report)
|
||||||
|
messages = _reset_messages(report)
|
||||||
|
|
||||||
|
if messages is None:
|
||||||
|
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": command})
|
||||||
|
response = _stream_conversation(ai, messages)
|
||||||
|
messages.append({"role": "assistant", "content": response})
|
||||||
|
|
||||||
|
|
||||||
def _handle_probe_result(result: SSHCommandResult) -> None:
|
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||||
@@ -262,6 +309,21 @@ def _run_analysis(ai_config: AIConfig, issue: str, report: CollectionReport) ->
|
|||||||
raise typer.Exit(code=1) from exc
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_conversation(ai: AIClient, messages: list[dict[str, str]]) -> str:
|
||||||
|
"""Stream a multi-turn AI response and return the final text."""
|
||||||
|
console.print("[cyan]Analyzing...[/cyan]\n")
|
||||||
|
try:
|
||||||
|
chunks: list[str] = []
|
||||||
|
for chunk in ai.stream_messages(messages):
|
||||||
|
chunks.append(chunk)
|
||||||
|
response = "".join(chunks)
|
||||||
|
console.print(Markdown(response))
|
||||||
|
return response
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
console.print(f"[red]AI analysis failed:[/red] {exc}")
|
||||||
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""Console script entrypoint."""
|
"""Console script entrypoint."""
|
||||||
app()
|
app()
|
||||||
|
|||||||
@@ -116,6 +116,34 @@ def test_stream_yields_chunks() -> None:
|
|||||||
assert result == ["Root ", "cause ", "found."]
|
assert result == ["Root ", "cause ", "found."]
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_messages_yields_chunks() -> None:
|
||||||
|
config = AIConfig()
|
||||||
|
client = AIClient(config)
|
||||||
|
|
||||||
|
def _make_chunk(text: str | None) -> MagicMock:
|
||||||
|
delta = MagicMock()
|
||||||
|
delta.content = text
|
||||||
|
choice = MagicMock()
|
||||||
|
choice.delta = delta
|
||||||
|
chunk = MagicMock()
|
||||||
|
chunk.choices = [choice]
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
mock_chunks = [_make_chunk("A"), _make_chunk(None), _make_chunk("B")]
|
||||||
|
|
||||||
|
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
|
||||||
|
result = list(
|
||||||
|
client.stream_messages(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "question"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == ["A", "B"]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# prompt_builder
|
# prompt_builder
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -188,7 +188,28 @@ def test_interactive_collect_then_quit(monkeypatch) -> None: # type: ignore[no-
|
|||||||
def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||||
_mock_session(monkeypatch)
|
_mock_session(monkeypatch)
|
||||||
|
|
||||||
commands = iter(["/wat", "/quit"])
|
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||||
|
return CollectionReport(
|
||||||
|
host="ssh.archflux.net",
|
||||||
|
items=[
|
||||||
|
CollectedItem(
|
||||||
|
name="kernel",
|
||||||
|
result=SSHCommandResult(
|
||||||
|
command="uname -a",
|
||||||
|
exit_code=0,
|
||||||
|
stdout="Linux test",
|
||||||
|
stderr="",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
commands = iter(["what should I check next?", "/quit"])
|
||||||
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"tai.cli.AIClient.stream_messages",
|
||||||
|
lambda *_args, **_kwargs: iter(["Check logs."]),
|
||||||
|
)
|
||||||
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
|
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
@@ -206,4 +227,5 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "Unknown command" in result.stdout
|
assert "Analyzing..." in result.stdout
|
||||||
|
assert "Check logs." in result.stdout
|
||||||
|
|||||||
Reference in New Issue
Block a user