From 530be62185ef41b6b941e78722451d9f49c1210e Mon Sep 17 00:00:00 2001 From: zphinx Date: Mon, 4 May 2026 06:11:55 +0200 Subject: [PATCH] feat(cli): add response guardrails and grounded followup re-anchoring --- src/tai/ai_guardrails.py | 36 ++++++++++++++ src/tai/cli.py | 93 +++++++++++++++++++++---------------- src/tai/prompt_builder.py | 27 +++++++++++ tests/test_ai.py | 15 +++++- tests/test_ai_guardrails.py | 24 ++++++++++ tests/test_cli.py | 2 +- 6 files changed, 155 insertions(+), 42 deletions(-) create mode 100644 src/tai/ai_guardrails.py create mode 100644 tests/test_ai_guardrails.py diff --git a/src/tai/ai_guardrails.py b/src/tai/ai_guardrails.py new file mode 100644 index 0000000..2789d0d --- /dev/null +++ b/src/tai/ai_guardrails.py @@ -0,0 +1,36 @@ +"""Heuristic checks for AI response quality and safety.""" + +from __future__ import annotations + +import re + +_RISKY_ACTION_PATTERNS = [ + r"\bsystemctl\s+(restart|stop|start)\b", + r"\b(edit|modify|change)\s+/etc/", + r"\bpasswd\b", + r"\bapt\s+install\b", + r"\bdnf\s+install\b", + r"\byum\s+install\b", +] + + +def validate_ai_response(response: str) -> list[str]: + """Return warning messages for potentially unsafe or weakly grounded output.""" + warnings: list[str] = [] + + if "Evidence" not in response: + warnings.append("Response is missing an Evidence section.") + + if "`" not in response: + warnings.append("Response does not include quoted evidence snippets.") + + lower_response = response.lower() + for pattern in _RISKY_ACTION_PATTERNS: + if re.search(pattern, lower_response): + warnings.append( + "Response suggests potentially modifying actions; " + "prefer read-only verification unless remediation was explicitly requested." + ) + break + + return warnings diff --git a/src/tai/cli.py b/src/tai/cli.py index 16952b3..09e121d 100644 --- a/src/tai/cli.py +++ b/src/tai/cli.py @@ -10,11 +10,12 @@ from rich.console import Console from rich.markdown import Markdown from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig +from tai.ai_guardrails import validate_ai_response from tai.collectors import CollectionReport, collect_from_plan from tai.input_parser import InputValidationError, build_request from tai.models import TroubleshootRequest from tai.plan import plan_from_request -from tai.prompt_builder import build_system_prompt, build_user_message +from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message from tai.session_log import SessionLogger from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession @@ -230,22 +231,7 @@ async def _interactive_loop( "ask questions directly, or use /collect, /analyze, /help, /quit" ) - ai = AIClient(ai_config) - messages: list[dict[str, str]] | None = None - - def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None: - if current_report is None: - return None - return [ - {"role": "system", "content": build_system_prompt()}, - { - "role": "user", - "content": build_user_message(req.issue, current_report), - }, - ] - - if report is not None: - messages = _reset_messages(report) + prior_questions: list[str] = [] while True: try: @@ -275,7 +261,6 @@ async def _interactive_loop( console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") report = await collect_from_plan(session, plan) _handle_collection_report(report) - messages = _reset_messages(report) if logger is not None: logger.log_event( "collection_summary", @@ -292,19 +277,21 @@ async def _interactive_loop( console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") report = await collect_from_plan(session, plan) _handle_collection_report(report) - messages = _reset_messages(report) - if messages is None: + if report is None: console.print("[red]No diagnostics available to analyze.[/red]") continue - messages.append( - { - "role": "user", - "content": "Provide an updated diagnosis from the current diagnostics.", - } + _run_followup_analysis( + ai_config, + req.issue, + report, + "Provide an updated diagnosis from the current diagnostics.", + prior_questions, + logger=logger, ) - response = _stream_conversation(ai, messages, logger=logger) - messages.append({"role": "assistant", "content": response}) + prior_questions.append("/analyze") + if logger is not None: + logger.log_event("interactive_followup", {"question": "/analyze"}) continue if report is None: @@ -312,15 +299,22 @@ async def _interactive_loop( console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") report = await collect_from_plan(session, plan) _handle_collection_report(report) - messages = _reset_messages(report) - if messages is None: + if report is None: console.print("[red]No diagnostics available to analyze.[/red]") continue - messages.append({"role": "user", "content": command}) - response = _stream_conversation(ai, messages, logger=logger) - messages.append({"role": "assistant", "content": response}) + _run_followup_analysis( + ai_config, + req.issue, + report, + command, + prior_questions, + logger=logger, + ) + prior_questions.append(command) + if logger is not None: + logger.log_event("interactive_followup", {"question": command}) def _handle_probe_result(result: SSHCommandResult) -> None: @@ -365,12 +359,18 @@ def _run_analysis( chunks.append(chunk) response = "".join(chunks) console.print(Markdown(response)) + + warnings = validate_ai_response(response) + for item in warnings: + console.print(f"[yellow]Guardrail warning:[/yellow] {item}") + if logger is not None: logger.log_event( "analysis_response", { "issue": issue, "response": response, + "guardrail_warnings": warnings, }, ) except Exception as exc: # noqa: BLE001 @@ -380,33 +380,46 @@ def _run_analysis( raise typer.Exit(code=1) from exc -def _stream_conversation( - ai: AIClient, - messages: list[dict[str, str]], +def _run_followup_analysis( + ai_config: AIConfig, + issue: str, + report: CollectionReport, + question: str, + prior_questions: list[str], *, logger: SessionLogger | None, ) -> str: - """Stream a multi-turn AI response and return the final text.""" + """Run grounded follow-up analysis re-anchored to current diagnostics.""" console.print("[cyan]Analyzing...[/cyan]\n") + ai = AIClient(ai_config) + system_prompt = build_system_prompt() + user_message = build_followup_message(issue, report, question, prior_questions) + try: chunks: list[str] = [] - for chunk in ai.stream_messages(messages): + for chunk in ai.stream(system_prompt, user_message): chunks.append(chunk) response = "".join(chunks) console.print(Markdown(response)) - if logger is not None and messages: + + warnings = validate_ai_response(response) + for item in warnings: + console.print(f"[yellow]Guardrail warning:[/yellow] {item}") + + if logger is not None: logger.log_event( "analysis_response", { - "last_user_message": messages[-1].get("content", ""), + "last_user_message": question, "response": response, + "guardrail_warnings": warnings, }, ) return response except Exception as exc: # noqa: BLE001 console.print(f"[red]AI analysis failed:[/red] {exc}") if logger is not None: - logger.log_event("analysis_error", {"error": str(exc)}) + logger.log_event("analysis_error", {"error": str(exc), "question": question}) raise typer.Exit(code=1) from exc diff --git a/src/tai/prompt_builder.py b/src/tai/prompt_builder.py index 360cfb7..e4a87f2 100644 --- a/src/tai/prompt_builder.py +++ b/src/tai/prompt_builder.py @@ -15,12 +15,15 @@ Your job: Important rules: - Only draw conclusions from data that is actually present. Do not speculate or invent evidence. +- For every root-cause claim, quote at least one exact snippet from collected output in backticks. - If a command shows "could not be executed (SSH error)" it means the remote host blocked or rejected that specific command — it is not evidence about the service or system state. - If there is not enough data to diagnose the issue, say so plainly and list exactly what additional commands or log files would be needed. - Keep the response short. Skip sections that have nothing useful to say. - Never suggest commands that modify the system unless explicitly asked. +- Default to read-only verification steps. Do not suggest restarting services or editing configs + unless the user explicitly asks for remediation actions. - Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**. """ @@ -72,3 +75,27 @@ def build_user_message(issue: str, report: CollectionReport) -> str: ) return "\n".join(lines) + + +def build_followup_message( + issue: str, + report: CollectionReport, + question: str, + prior_questions: list[str], +) -> str: + """Build a grounded follow-up message that re-anchors to diagnostics each turn.""" + base = build_user_message(issue, report) + lines: list[str] = [base, "## Follow-up"] + + if prior_questions: + lines.append("\nRecent user follow-up questions:") + for idx, item in enumerate(prior_questions[-5:], start=1): + lines.append(f"{idx}. {item}") + + lines.append("\nCurrent follow-up question:") + lines.append(question) + lines.append( + "\nAnswer strictly from the collected diagnostics above. " + "If evidence is insufficient, explicitly say so." + ) + return "\n".join(lines) diff --git a/tests/test_ai.py b/tests/test_ai.py index b45c5f5..9446823 100644 --- a/tests/test_ai.py +++ b/tests/test_ai.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.collectors import CollectedItem, CollectionReport -from tai.prompt_builder import build_system_prompt, build_user_message +from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message from tai.ssh_client import SSHCommandResult # --------------------------------------------------------------------------- @@ -218,3 +218,16 @@ def test_build_user_message_handles_no_output() -> None: report = _make_report([("empty", "cat /nonexistent", 1, "", "")]) msg = build_user_message("test", report) assert "no output" in msg + + +def test_build_followup_message_includes_question_context() -> None: + report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")]) + msg = build_followup_message( + "nginx is failing", + report, + "what should I check next?", + ["is nginx running?", "show me logs"], + ) + assert "Current follow-up question" in msg + assert "what should I check next?" in msg + assert "Recent user follow-up questions" in msg diff --git a/tests/test_ai_guardrails.py b/tests/test_ai_guardrails.py new file mode 100644 index 0000000..cc20e49 --- /dev/null +++ b/tests/test_ai_guardrails.py @@ -0,0 +1,24 @@ +"""Tests for AI response guardrails.""" + +from tai.ai_guardrails import validate_ai_response + + +def test_validate_ai_response_flags_missing_evidence_and_quotes() -> None: + warnings = validate_ai_response("Root cause only, no structure.") + assert any("Evidence section" in item for item in warnings) + assert any("quoted evidence" in item for item in warnings) + + +def test_validate_ai_response_flags_risky_actions() -> None: + text = "Evidence: `PasswordAuthentication no`\nRun systemctl restart sshd now." + warnings = validate_ai_response(text) + assert any("modifying actions" in item for item in warnings) + + +def test_validate_ai_response_allows_grounded_read_only_answer() -> None: + text = ( + "Evidence: `PasswordAuthentication no`\n" + "Recommended Actions: run `journalctl -u sshd -n 200 --no-pager`" + ) + warnings = validate_ai_response(text) + assert not warnings diff --git a/tests/test_cli.py b/tests/test_cli.py index a49b3f5..9046ede 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -207,7 +207,7 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type: commands = iter(["what should I check next?", "/quit"]) monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) monkeypatch.setattr( - "tai.cli.AIClient.stream_messages", + "tai.cli.AIClient.stream", lambda *_args, **_kwargs: iter(["Check logs."]), ) monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))