feat(cli): add response guardrails and grounded followup re-anchoring

This commit is contained in:
2026-05-04 06:11:55 +02:00
parent 2662d1b253
commit 530be62185
6 changed files with 155 additions and 42 deletions

36
src/tai/ai_guardrails.py Normal file
View File

@@ -0,0 +1,36 @@
"""Heuristic checks for AI response quality and safety."""
from __future__ import annotations
import re
_RISKY_ACTION_PATTERNS = [
r"\bsystemctl\s+(restart|stop|start)\b",
r"\b(edit|modify|change)\s+/etc/",
r"\bpasswd\b",
r"\bapt\s+install\b",
r"\bdnf\s+install\b",
r"\byum\s+install\b",
]
def validate_ai_response(response: str) -> list[str]:
"""Return warning messages for potentially unsafe or weakly grounded output."""
warnings: list[str] = []
if "Evidence" not in response:
warnings.append("Response is missing an Evidence section.")
if "`" not in response:
warnings.append("Response does not include quoted evidence snippets.")
lower_response = response.lower()
for pattern in _RISKY_ACTION_PATTERNS:
if re.search(pattern, lower_response):
warnings.append(
"Response suggests potentially modifying actions; "
"prefer read-only verification unless remediation was explicitly requested."
)
break
return warnings

View File

@@ -10,11 +10,12 @@ from rich.console import Console
from rich.markdown import Markdown
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.ai_guardrails import validate_ai_response
from tai.collectors import CollectionReport, collect_from_plan
from tai.input_parser import InputValidationError, build_request
from tai.models import TroubleshootRequest
from tai.plan import plan_from_request
from tai.prompt_builder import build_system_prompt, build_user_message
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
from tai.session_log import SessionLogger
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
@@ -230,22 +231,7 @@ async def _interactive_loop(
"ask questions directly, or use /collect, /analyze, /help, /quit"
)
ai = AIClient(ai_config)
messages: list[dict[str, str]] | None = None
def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None:
if current_report is None:
return None
return [
{"role": "system", "content": build_system_prompt()},
{
"role": "user",
"content": build_user_message(req.issue, current_report),
},
]
if report is not None:
messages = _reset_messages(report)
prior_questions: list[str] = []
while True:
try:
@@ -275,7 +261,6 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
messages = _reset_messages(report)
if logger is not None:
logger.log_event(
"collection_summary",
@@ -292,19 +277,21 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
messages = _reset_messages(report)
if messages is None:
if report is None:
console.print("[red]No diagnostics available to analyze.[/red]")
continue
messages.append(
{
"role": "user",
"content": "Provide an updated diagnosis from the current diagnostics.",
}
_run_followup_analysis(
ai_config,
req.issue,
report,
"Provide an updated diagnosis from the current diagnostics.",
prior_questions,
logger=logger,
)
response = _stream_conversation(ai, messages, logger=logger)
messages.append({"role": "assistant", "content": response})
prior_questions.append("/analyze")
if logger is not None:
logger.log_event("interactive_followup", {"question": "/analyze"})
continue
if report is None:
@@ -312,15 +299,22 @@ async def _interactive_loop(
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
report = await collect_from_plan(session, plan)
_handle_collection_report(report)
messages = _reset_messages(report)
if messages is None:
if report is None:
console.print("[red]No diagnostics available to analyze.[/red]")
continue
messages.append({"role": "user", "content": command})
response = _stream_conversation(ai, messages, logger=logger)
messages.append({"role": "assistant", "content": response})
_run_followup_analysis(
ai_config,
req.issue,
report,
command,
prior_questions,
logger=logger,
)
prior_questions.append(command)
if logger is not None:
logger.log_event("interactive_followup", {"question": command})
def _handle_probe_result(result: SSHCommandResult) -> None:
@@ -365,12 +359,18 @@ def _run_analysis(
chunks.append(chunk)
response = "".join(chunks)
console.print(Markdown(response))
warnings = validate_ai_response(response)
for item in warnings:
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
if logger is not None:
logger.log_event(
"analysis_response",
{
"issue": issue,
"response": response,
"guardrail_warnings": warnings,
},
)
except Exception as exc: # noqa: BLE001
@@ -380,33 +380,46 @@ def _run_analysis(
raise typer.Exit(code=1) from exc
def _stream_conversation(
ai: AIClient,
messages: list[dict[str, str]],
def _run_followup_analysis(
ai_config: AIConfig,
issue: str,
report: CollectionReport,
question: str,
prior_questions: list[str],
*,
logger: SessionLogger | None,
) -> str:
"""Stream a multi-turn AI response and return the final text."""
"""Run grounded follow-up analysis re-anchored to current diagnostics."""
console.print("[cyan]Analyzing...[/cyan]\n")
ai = AIClient(ai_config)
system_prompt = build_system_prompt()
user_message = build_followup_message(issue, report, question, prior_questions)
try:
chunks: list[str] = []
for chunk in ai.stream_messages(messages):
for chunk in ai.stream(system_prompt, user_message):
chunks.append(chunk)
response = "".join(chunks)
console.print(Markdown(response))
if logger is not None and messages:
warnings = validate_ai_response(response)
for item in warnings:
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
if logger is not None:
logger.log_event(
"analysis_response",
{
"last_user_message": messages[-1].get("content", ""),
"last_user_message": question,
"response": response,
"guardrail_warnings": warnings,
},
)
return response
except Exception as exc: # noqa: BLE001
console.print(f"[red]AI analysis failed:[/red] {exc}")
if logger is not None:
logger.log_event("analysis_error", {"error": str(exc)})
logger.log_event("analysis_error", {"error": str(exc), "question": question})
raise typer.Exit(code=1) from exc

View File

@@ -15,12 +15,15 @@ Your job:
Important rules:
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
- For every root-cause claim, quote at least one exact snippet from collected output in backticks.
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
rejected that specific command — it is not evidence about the service or system state.
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
additional commands or log files would be needed.
- Keep the response short. Skip sections that have nothing useful to say.
- Never suggest commands that modify the system unless explicitly asked.
- Default to read-only verification steps. Do not suggest restarting services or editing configs
unless the user explicitly asks for remediation actions.
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
"""
@@ -72,3 +75,27 @@ def build_user_message(issue: str, report: CollectionReport) -> str:
)
return "\n".join(lines)
def build_followup_message(
issue: str,
report: CollectionReport,
question: str,
prior_questions: list[str],
) -> str:
"""Build a grounded follow-up message that re-anchors to diagnostics each turn."""
base = build_user_message(issue, report)
lines: list[str] = [base, "## Follow-up"]
if prior_questions:
lines.append("\nRecent user follow-up questions:")
for idx, item in enumerate(prior_questions[-5:], start=1):
lines.append(f"{idx}. {item}")
lines.append("\nCurrent follow-up question:")
lines.append(question)
lines.append(
"\nAnswer strictly from the collected diagnostics above. "
"If evidence is insufficient, explicitly say so."
)
return "\n".join(lines)

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.collectors import CollectedItem, CollectionReport
from tai.prompt_builder import build_system_prompt, build_user_message
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
from tai.ssh_client import SSHCommandResult
# ---------------------------------------------------------------------------
@@ -218,3 +218,16 @@ def test_build_user_message_handles_no_output() -> None:
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
msg = build_user_message("test", report)
assert "no output" in msg
def test_build_followup_message_includes_question_context() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
msg = build_followup_message(
"nginx is failing",
report,
"what should I check next?",
["is nginx running?", "show me logs"],
)
assert "Current follow-up question" in msg
assert "what should I check next?" in msg
assert "Recent user follow-up questions" in msg

View File

@@ -0,0 +1,24 @@
"""Tests for AI response guardrails."""
from tai.ai_guardrails import validate_ai_response
def test_validate_ai_response_flags_missing_evidence_and_quotes() -> None:
warnings = validate_ai_response("Root cause only, no structure.")
assert any("Evidence section" in item for item in warnings)
assert any("quoted evidence" in item for item in warnings)
def test_validate_ai_response_flags_risky_actions() -> None:
text = "Evidence: `PasswordAuthentication no`\nRun systemctl restart sshd now."
warnings = validate_ai_response(text)
assert any("modifying actions" in item for item in warnings)
def test_validate_ai_response_allows_grounded_read_only_answer() -> None:
text = (
"Evidence: `PasswordAuthentication no`\n"
"Recommended Actions: run `journalctl -u sshd -n 200 --no-pager`"
)
warnings = validate_ai_response(text)
assert not warnings

View File

@@ -207,7 +207,7 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
commands = iter(["what should I check next?", "/quit"])
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
monkeypatch.setattr(
"tai.cli.AIClient.stream_messages",
"tai.cli.AIClient.stream",
lambda *_args, **_kwargs: iter(["Check logs."]),
)
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))