feat(cli): add response guardrails and grounded followup re-anchoring

This commit is contained in:
2026-05-04 06:11:55 +02:00
parent 2662d1b253
commit 530be62185
6 changed files with 155 additions and 42 deletions

36
src/tai/ai_guardrails.py Normal file
View File

@@ -0,0 +1,36 @@
"""Heuristic checks for AI response quality and safety."""
from __future__ import annotations
import re
_RISKY_ACTION_PATTERNS = [
r"\bsystemctl\s+(restart|stop|start)\b",
r"\b(edit|modify|change)\s+/etc/",
r"\bpasswd\b",
r"\bapt\s+install\b",
r"\bdnf\s+install\b",
r"\byum\s+install\b",
]
def validate_ai_response(response: str) -> list[str]:
"""Return warning messages for potentially unsafe or weakly grounded output."""
warnings: list[str] = []
if "Evidence" not in response:
warnings.append("Response is missing an Evidence section.")
if "`" not in response:
warnings.append("Response does not include quoted evidence snippets.")
lower_response = response.lower()
for pattern in _RISKY_ACTION_PATTERNS:
if re.search(pattern, lower_response):
warnings.append(
"Response suggests potentially modifying actions; "
"prefer read-only verification unless remediation was explicitly requested."
)
break
return warnings

View File

@@ -10,11 +10,12 @@ from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.ai_guardrails import validate_ai_response
from tai.collectors import CollectionReport, collect_from_plan from tai.collectors import CollectionReport, collect_from_plan
from tai.input_parser import InputValidationError, build_request from tai.input_parser import InputValidationError, build_request
from tai.models import TroubleshootRequest from tai.models import TroubleshootRequest
from tai.plan import plan_from_request from tai.plan import plan_from_request
from tai.prompt_builder import build_system_prompt, build_user_message from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
from tai.session_log import SessionLogger from tai.session_log import SessionLogger
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
@@ -230,22 +231,7 @@ async def _interactive_loop(
"ask questions directly, or use /collect, /analyze, /help, /quit" "ask questions directly, or use /collect, /analyze, /help, /quit"
) )
ai = AIClient(ai_config) prior_questions: list[str] = []
messages: list[dict[str, str]] | None = None
def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None:
if current_report is None:
return None
return [
{"role": "system", "content": build_system_prompt()},
{
"role": "user",
"content": build_user_message(req.issue, current_report),
},
]
if report is not None:
messages = _reset_messages(report)
while True: while True:
try: try:
@@ -275,7 +261,6 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
messages = _reset_messages(report)
if logger is not None: if logger is not None:
logger.log_event( logger.log_event(
"collection_summary", "collection_summary",
@@ -292,19 +277,21 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
messages = _reset_messages(report) if report is None:
if messages is None:
console.print("[red]No diagnostics available to analyze.[/red]") console.print("[red]No diagnostics available to analyze.[/red]")
continue continue
messages.append( _run_followup_analysis(
{ ai_config,
"role": "user", req.issue,
"content": "Provide an updated diagnosis from the current diagnostics.", report,
} "Provide an updated diagnosis from the current diagnostics.",
prior_questions,
logger=logger,
) )
response = _stream_conversation(ai, messages, logger=logger) prior_questions.append("/analyze")
messages.append({"role": "assistant", "content": response}) if logger is not None:
logger.log_event("interactive_followup", {"question": "/analyze"})
continue continue
if report is None: if report is None:
@@ -312,15 +299,22 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands") console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan) report = await collect_from_plan(session, plan)
_handle_collection_report(report) _handle_collection_report(report)
messages = _reset_messages(report)
if messages is None: if report is None:
console.print("[red]No diagnostics available to analyze.[/red]") console.print("[red]No diagnostics available to analyze.[/red]")
continue continue
messages.append({"role": "user", "content": command}) _run_followup_analysis(
response = _stream_conversation(ai, messages, logger=logger) ai_config,
messages.append({"role": "assistant", "content": response}) req.issue,
report,
command,
prior_questions,
logger=logger,
)
prior_questions.append(command)
if logger is not None:
logger.log_event("interactive_followup", {"question": command})
def _handle_probe_result(result: SSHCommandResult) -> None: def _handle_probe_result(result: SSHCommandResult) -> None:
@@ -365,12 +359,18 @@ def _run_analysis(
chunks.append(chunk) chunks.append(chunk)
response = "".join(chunks) response = "".join(chunks)
console.print(Markdown(response)) console.print(Markdown(response))
warnings = validate_ai_response(response)
for item in warnings:
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
if logger is not None: if logger is not None:
logger.log_event( logger.log_event(
"analysis_response", "analysis_response",
{ {
"issue": issue, "issue": issue,
"response": response, "response": response,
"guardrail_warnings": warnings,
}, },
) )
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
@@ -380,33 +380,46 @@ def _run_analysis(
raise typer.Exit(code=1) from exc raise typer.Exit(code=1) from exc
def _stream_conversation( def _run_followup_analysis(
ai: AIClient, ai_config: AIConfig,
messages: list[dict[str, str]], issue: str,
report: CollectionReport,
question: str,
prior_questions: list[str],
*, *,
logger: SessionLogger | None, logger: SessionLogger | None,
) -> str: ) -> str:
"""Stream a multi-turn AI response and return the final text.""" """Run grounded follow-up analysis re-anchored to current diagnostics."""
console.print("[cyan]Analyzing...[/cyan]\n") console.print("[cyan]Analyzing...[/cyan]\n")
ai = AIClient(ai_config)
system_prompt = build_system_prompt()
user_message = build_followup_message(issue, report, question, prior_questions)
try: try:
chunks: list[str] = [] chunks: list[str] = []
for chunk in ai.stream_messages(messages): for chunk in ai.stream(system_prompt, user_message):
chunks.append(chunk) chunks.append(chunk)
response = "".join(chunks) response = "".join(chunks)
console.print(Markdown(response)) console.print(Markdown(response))
if logger is not None and messages:
warnings = validate_ai_response(response)
for item in warnings:
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
if logger is not None:
logger.log_event( logger.log_event(
"analysis_response", "analysis_response",
{ {
"last_user_message": messages[-1].get("content", ""), "last_user_message": question,
"response": response, "response": response,
"guardrail_warnings": warnings,
}, },
) )
return response return response
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
console.print(f"[red]AI analysis failed:[/red] {exc}") console.print(f"[red]AI analysis failed:[/red] {exc}")
if logger is not None: if logger is not None:
logger.log_event("analysis_error", {"error": str(exc)}) logger.log_event("analysis_error", {"error": str(exc), "question": question})
raise typer.Exit(code=1) from exc raise typer.Exit(code=1) from exc

View File

@@ -15,12 +15,15 @@ Your job:
Important rules: Important rules:
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence. - Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
- For every root-cause claim, quote at least one exact snippet from collected output in backticks.
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or - If a command shows "could not be executed (SSH error)" it means the remote host blocked or
rejected that specific command — it is not evidence about the service or system state. rejected that specific command — it is not evidence about the service or system state.
- If there is not enough data to diagnose the issue, say so plainly and list exactly what - If there is not enough data to diagnose the issue, say so plainly and list exactly what
additional commands or log files would be needed. additional commands or log files would be needed.
- Keep the response short. Skip sections that have nothing useful to say. - Keep the response short. Skip sections that have nothing useful to say.
- Never suggest commands that modify the system unless explicitly asked. - Never suggest commands that modify the system unless explicitly asked.
- Default to read-only verification steps. Do not suggest restarting services or editing configs
unless the user explicitly asks for remediation actions.
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**. - Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
""" """
@@ -72,3 +75,27 @@ def build_user_message(issue: str, report: CollectionReport) -> str:
) )
return "\n".join(lines) return "\n".join(lines)
def build_followup_message(
issue: str,
report: CollectionReport,
question: str,
prior_questions: list[str],
) -> str:
"""Build a grounded follow-up message that re-anchors to diagnostics each turn."""
base = build_user_message(issue, report)
lines: list[str] = [base, "## Follow-up"]
if prior_questions:
lines.append("\nRecent user follow-up questions:")
for idx, item in enumerate(prior_questions[-5:], start=1):
lines.append(f"{idx}. {item}")
lines.append("\nCurrent follow-up question:")
lines.append(question)
lines.append(
"\nAnswer strictly from the collected diagnostics above. "
"If evidence is insufficient, explicitly say so."
)
return "\n".join(lines)

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.collectors import CollectedItem, CollectionReport from tai.collectors import CollectedItem, CollectionReport
from tai.prompt_builder import build_system_prompt, build_user_message from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
from tai.ssh_client import SSHCommandResult from tai.ssh_client import SSHCommandResult
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -218,3 +218,16 @@ def test_build_user_message_handles_no_output() -> None:
report = _make_report([("empty", "cat /nonexistent", 1, "", "")]) report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
msg = build_user_message("test", report) msg = build_user_message("test", report)
assert "no output" in msg assert "no output" in msg
def test_build_followup_message_includes_question_context() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
msg = build_followup_message(
"nginx is failing",
report,
"what should I check next?",
["is nginx running?", "show me logs"],
)
assert "Current follow-up question" in msg
assert "what should I check next?" in msg
assert "Recent user follow-up questions" in msg

View File

@@ -0,0 +1,24 @@
"""Tests for AI response guardrails."""
from tai.ai_guardrails import validate_ai_response
def test_validate_ai_response_flags_missing_evidence_and_quotes() -> None:
warnings = validate_ai_response("Root cause only, no structure.")
assert any("Evidence section" in item for item in warnings)
assert any("quoted evidence" in item for item in warnings)
def test_validate_ai_response_flags_risky_actions() -> None:
text = "Evidence: `PasswordAuthentication no`\nRun systemctl restart sshd now."
warnings = validate_ai_response(text)
assert any("modifying actions" in item for item in warnings)
def test_validate_ai_response_allows_grounded_read_only_answer() -> None:
text = (
"Evidence: `PasswordAuthentication no`\n"
"Recommended Actions: run `journalctl -u sshd -n 200 --no-pager`"
)
warnings = validate_ai_response(text)
assert not warnings

View File

@@ -207,7 +207,7 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
commands = iter(["what should I check next?", "/quit"]) commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan) monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr( monkeypatch.setattr(
"tai.cli.AIClient.stream_messages", "tai.cli.AIClient.stream",
lambda *_args, **_kwargs: iter(["Check logs."]), lambda *_args, **_kwargs: iter(["Check logs."]),
) )
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands)) monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))