feat(cli): add response guardrails and grounded followup re-anchoring
This commit is contained in:
36
src/tai/ai_guardrails.py
Normal file
36
src/tai/ai_guardrails.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Heuristic checks for AI response quality and safety."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
_RISKY_ACTION_PATTERNS = [
|
||||
r"\bsystemctl\s+(restart|stop|start)\b",
|
||||
r"\b(edit|modify|change)\s+/etc/",
|
||||
r"\bpasswd\b",
|
||||
r"\bapt\s+install\b",
|
||||
r"\bdnf\s+install\b",
|
||||
r"\byum\s+install\b",
|
||||
]
|
||||
|
||||
|
||||
def validate_ai_response(response: str) -> list[str]:
|
||||
"""Return warning messages for potentially unsafe or weakly grounded output."""
|
||||
warnings: list[str] = []
|
||||
|
||||
if "Evidence" not in response:
|
||||
warnings.append("Response is missing an Evidence section.")
|
||||
|
||||
if "`" not in response:
|
||||
warnings.append("Response does not include quoted evidence snippets.")
|
||||
|
||||
lower_response = response.lower()
|
||||
for pattern in _RISKY_ACTION_PATTERNS:
|
||||
if re.search(pattern, lower_response):
|
||||
warnings.append(
|
||||
"Response suggests potentially modifying actions; "
|
||||
"prefer read-only verification unless remediation was explicitly requested."
|
||||
)
|
||||
break
|
||||
|
||||
return warnings
|
||||
@@ -10,11 +10,12 @@ from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||
from tai.ai_guardrails import validate_ai_response
|
||||
from tai.collectors import CollectionReport, collect_from_plan
|
||||
from tai.input_parser import InputValidationError, build_request
|
||||
from tai.models import TroubleshootRequest
|
||||
from tai.plan import plan_from_request
|
||||
from tai.prompt_builder import build_system_prompt, build_user_message
|
||||
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
|
||||
from tai.session_log import SessionLogger
|
||||
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
||||
|
||||
@@ -230,22 +231,7 @@ async def _interactive_loop(
|
||||
"ask questions directly, or use /collect, /analyze, /help, /quit"
|
||||
)
|
||||
|
||||
ai = AIClient(ai_config)
|
||||
messages: list[dict[str, str]] | None = None
|
||||
|
||||
def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None:
|
||||
if current_report is None:
|
||||
return None
|
||||
return [
|
||||
{"role": "system", "content": build_system_prompt()},
|
||||
{
|
||||
"role": "user",
|
||||
"content": build_user_message(req.issue, current_report),
|
||||
},
|
||||
]
|
||||
|
||||
if report is not None:
|
||||
messages = _reset_messages(report)
|
||||
prior_questions: list[str] = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -275,7 +261,6 @@ async def _interactive_loop(
|
||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||
report = await collect_from_plan(session, plan)
|
||||
_handle_collection_report(report)
|
||||
messages = _reset_messages(report)
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"collection_summary",
|
||||
@@ -292,19 +277,21 @@ async def _interactive_loop(
|
||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||
report = await collect_from_plan(session, plan)
|
||||
_handle_collection_report(report)
|
||||
messages = _reset_messages(report)
|
||||
if messages is None:
|
||||
if report is None:
|
||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||
continue
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Provide an updated diagnosis from the current diagnostics.",
|
||||
}
|
||||
_run_followup_analysis(
|
||||
ai_config,
|
||||
req.issue,
|
||||
report,
|
||||
"Provide an updated diagnosis from the current diagnostics.",
|
||||
prior_questions,
|
||||
logger=logger,
|
||||
)
|
||||
response = _stream_conversation(ai, messages, logger=logger)
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
prior_questions.append("/analyze")
|
||||
if logger is not None:
|
||||
logger.log_event("interactive_followup", {"question": "/analyze"})
|
||||
continue
|
||||
|
||||
if report is None:
|
||||
@@ -312,15 +299,22 @@ async def _interactive_loop(
|
||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||
report = await collect_from_plan(session, plan)
|
||||
_handle_collection_report(report)
|
||||
messages = _reset_messages(report)
|
||||
|
||||
if messages is None:
|
||||
if report is None:
|
||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": command})
|
||||
response = _stream_conversation(ai, messages, logger=logger)
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
_run_followup_analysis(
|
||||
ai_config,
|
||||
req.issue,
|
||||
report,
|
||||
command,
|
||||
prior_questions,
|
||||
logger=logger,
|
||||
)
|
||||
prior_questions.append(command)
|
||||
if logger is not None:
|
||||
logger.log_event("interactive_followup", {"question": command})
|
||||
|
||||
|
||||
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||
@@ -365,12 +359,18 @@ def _run_analysis(
|
||||
chunks.append(chunk)
|
||||
response = "".join(chunks)
|
||||
console.print(Markdown(response))
|
||||
|
||||
warnings = validate_ai_response(response)
|
||||
for item in warnings:
|
||||
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
|
||||
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"analysis_response",
|
||||
{
|
||||
"issue": issue,
|
||||
"response": response,
|
||||
"guardrail_warnings": warnings,
|
||||
},
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
@@ -380,33 +380,46 @@ def _run_analysis(
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
|
||||
def _stream_conversation(
|
||||
ai: AIClient,
|
||||
messages: list[dict[str, str]],
|
||||
def _run_followup_analysis(
|
||||
ai_config: AIConfig,
|
||||
issue: str,
|
||||
report: CollectionReport,
|
||||
question: str,
|
||||
prior_questions: list[str],
|
||||
*,
|
||||
logger: SessionLogger | None,
|
||||
) -> str:
|
||||
"""Stream a multi-turn AI response and return the final text."""
|
||||
"""Run grounded follow-up analysis re-anchored to current diagnostics."""
|
||||
console.print("[cyan]Analyzing...[/cyan]\n")
|
||||
ai = AIClient(ai_config)
|
||||
system_prompt = build_system_prompt()
|
||||
user_message = build_followup_message(issue, report, question, prior_questions)
|
||||
|
||||
try:
|
||||
chunks: list[str] = []
|
||||
for chunk in ai.stream_messages(messages):
|
||||
for chunk in ai.stream(system_prompt, user_message):
|
||||
chunks.append(chunk)
|
||||
response = "".join(chunks)
|
||||
console.print(Markdown(response))
|
||||
if logger is not None and messages:
|
||||
|
||||
warnings = validate_ai_response(response)
|
||||
for item in warnings:
|
||||
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
|
||||
|
||||
if logger is not None:
|
||||
logger.log_event(
|
||||
"analysis_response",
|
||||
{
|
||||
"last_user_message": messages[-1].get("content", ""),
|
||||
"last_user_message": question,
|
||||
"response": response,
|
||||
"guardrail_warnings": warnings,
|
||||
},
|
||||
)
|
||||
return response
|
||||
except Exception as exc: # noqa: BLE001
|
||||
console.print(f"[red]AI analysis failed:[/red] {exc}")
|
||||
if logger is not None:
|
||||
logger.log_event("analysis_error", {"error": str(exc)})
|
||||
logger.log_event("analysis_error", {"error": str(exc), "question": question})
|
||||
raise typer.Exit(code=1) from exc
|
||||
|
||||
|
||||
|
||||
@@ -15,12 +15,15 @@ Your job:
|
||||
|
||||
Important rules:
|
||||
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
|
||||
- For every root-cause claim, quote at least one exact snippet from collected output in backticks.
|
||||
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
|
||||
rejected that specific command — it is not evidence about the service or system state.
|
||||
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
|
||||
additional commands or log files would be needed.
|
||||
- Keep the response short. Skip sections that have nothing useful to say.
|
||||
- Never suggest commands that modify the system unless explicitly asked.
|
||||
- Default to read-only verification steps. Do not suggest restarting services or editing configs
|
||||
unless the user explicitly asks for remediation actions.
|
||||
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
|
||||
"""
|
||||
|
||||
@@ -72,3 +75,27 @@ def build_user_message(issue: str, report: CollectionReport) -> str:
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_followup_message(
|
||||
issue: str,
|
||||
report: CollectionReport,
|
||||
question: str,
|
||||
prior_questions: list[str],
|
||||
) -> str:
|
||||
"""Build a grounded follow-up message that re-anchors to diagnostics each turn."""
|
||||
base = build_user_message(issue, report)
|
||||
lines: list[str] = [base, "## Follow-up"]
|
||||
|
||||
if prior_questions:
|
||||
lines.append("\nRecent user follow-up questions:")
|
||||
for idx, item in enumerate(prior_questions[-5:], start=1):
|
||||
lines.append(f"{idx}. {item}")
|
||||
|
||||
lines.append("\nCurrent follow-up question:")
|
||||
lines.append(question)
|
||||
lines.append(
|
||||
"\nAnswer strictly from the collected diagnostics above. "
|
||||
"If evidence is insufficient, explicitly say so."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||
from tai.collectors import CollectedItem, CollectionReport
|
||||
from tai.prompt_builder import build_system_prompt, build_user_message
|
||||
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
|
||||
from tai.ssh_client import SSHCommandResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -218,3 +218,16 @@ def test_build_user_message_handles_no_output() -> None:
|
||||
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
|
||||
msg = build_user_message("test", report)
|
||||
assert "no output" in msg
|
||||
|
||||
|
||||
def test_build_followup_message_includes_question_context() -> None:
|
||||
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
|
||||
msg = build_followup_message(
|
||||
"nginx is failing",
|
||||
report,
|
||||
"what should I check next?",
|
||||
["is nginx running?", "show me logs"],
|
||||
)
|
||||
assert "Current follow-up question" in msg
|
||||
assert "what should I check next?" in msg
|
||||
assert "Recent user follow-up questions" in msg
|
||||
|
||||
24
tests/test_ai_guardrails.py
Normal file
24
tests/test_ai_guardrails.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Tests for AI response guardrails."""
|
||||
|
||||
from tai.ai_guardrails import validate_ai_response
|
||||
|
||||
|
||||
def test_validate_ai_response_flags_missing_evidence_and_quotes() -> None:
|
||||
warnings = validate_ai_response("Root cause only, no structure.")
|
||||
assert any("Evidence section" in item for item in warnings)
|
||||
assert any("quoted evidence" in item for item in warnings)
|
||||
|
||||
|
||||
def test_validate_ai_response_flags_risky_actions() -> None:
|
||||
text = "Evidence: `PasswordAuthentication no`\nRun systemctl restart sshd now."
|
||||
warnings = validate_ai_response(text)
|
||||
assert any("modifying actions" in item for item in warnings)
|
||||
|
||||
|
||||
def test_validate_ai_response_allows_grounded_read_only_answer() -> None:
|
||||
text = (
|
||||
"Evidence: `PasswordAuthentication no`\n"
|
||||
"Recommended Actions: run `journalctl -u sshd -n 200 --no-pager`"
|
||||
)
|
||||
warnings = validate_ai_response(text)
|
||||
assert not warnings
|
||||
@@ -207,7 +207,7 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
|
||||
commands = iter(["what should I check next?", "/quit"])
|
||||
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||
monkeypatch.setattr(
|
||||
"tai.cli.AIClient.stream_messages",
|
||||
"tai.cli.AIClient.stream",
|
||||
lambda *_args, **_kwargs: iter(["Check logs."]),
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
|
||||
|
||||
Reference in New Issue
Block a user