feat(cli): add response guardrails and grounded followup re-anchoring
This commit is contained in:
36
src/tai/ai_guardrails.py
Normal file
36
src/tai/ai_guardrails.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Heuristic checks for AI response quality and safety."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
_RISKY_ACTION_PATTERNS = [
|
||||||
|
r"\bsystemctl\s+(restart|stop|start)\b",
|
||||||
|
r"\b(edit|modify|change)\s+/etc/",
|
||||||
|
r"\bpasswd\b",
|
||||||
|
r"\bapt\s+install\b",
|
||||||
|
r"\bdnf\s+install\b",
|
||||||
|
r"\byum\s+install\b",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_ai_response(response: str) -> list[str]:
|
||||||
|
"""Return warning messages for potentially unsafe or weakly grounded output."""
|
||||||
|
warnings: list[str] = []
|
||||||
|
|
||||||
|
if "Evidence" not in response:
|
||||||
|
warnings.append("Response is missing an Evidence section.")
|
||||||
|
|
||||||
|
if "`" not in response:
|
||||||
|
warnings.append("Response does not include quoted evidence snippets.")
|
||||||
|
|
||||||
|
lower_response = response.lower()
|
||||||
|
for pattern in _RISKY_ACTION_PATTERNS:
|
||||||
|
if re.search(pattern, lower_response):
|
||||||
|
warnings.append(
|
||||||
|
"Response suggests potentially modifying actions; "
|
||||||
|
"prefer read-only verification unless remediation was explicitly requested."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
return warnings
|
||||||
@@ -10,11 +10,12 @@ from rich.console import Console
|
|||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||||
|
from tai.ai_guardrails import validate_ai_response
|
||||||
from tai.collectors import CollectionReport, collect_from_plan
|
from tai.collectors import CollectionReport, collect_from_plan
|
||||||
from tai.input_parser import InputValidationError, build_request
|
from tai.input_parser import InputValidationError, build_request
|
||||||
from tai.models import TroubleshootRequest
|
from tai.models import TroubleshootRequest
|
||||||
from tai.plan import plan_from_request
|
from tai.plan import plan_from_request
|
||||||
from tai.prompt_builder import build_system_prompt, build_user_message
|
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
|
||||||
from tai.session_log import SessionLogger
|
from tai.session_log import SessionLogger
|
||||||
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
from tai.ssh_client import SSHClient, SSHCommandResult, SSHConnectionConfig, SSHSession
|
||||||
|
|
||||||
@@ -230,22 +231,7 @@ async def _interactive_loop(
|
|||||||
"ask questions directly, or use /collect, /analyze, /help, /quit"
|
"ask questions directly, or use /collect, /analyze, /help, /quit"
|
||||||
)
|
)
|
||||||
|
|
||||||
ai = AIClient(ai_config)
|
prior_questions: list[str] = []
|
||||||
messages: list[dict[str, str]] | None = None
|
|
||||||
|
|
||||||
def _reset_messages(current_report: CollectionReport | None) -> list[dict[str, str]] | None:
|
|
||||||
if current_report is None:
|
|
||||||
return None
|
|
||||||
return [
|
|
||||||
{"role": "system", "content": build_system_prompt()},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": build_user_message(req.issue, current_report),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
if report is not None:
|
|
||||||
messages = _reset_messages(report)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -275,7 +261,6 @@ async def _interactive_loop(
|
|||||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
report = await collect_from_plan(session, plan)
|
report = await collect_from_plan(session, plan)
|
||||||
_handle_collection_report(report)
|
_handle_collection_report(report)
|
||||||
messages = _reset_messages(report)
|
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event(
|
logger.log_event(
|
||||||
"collection_summary",
|
"collection_summary",
|
||||||
@@ -292,19 +277,21 @@ async def _interactive_loop(
|
|||||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
report = await collect_from_plan(session, plan)
|
report = await collect_from_plan(session, plan)
|
||||||
_handle_collection_report(report)
|
_handle_collection_report(report)
|
||||||
messages = _reset_messages(report)
|
if report is None:
|
||||||
if messages is None:
|
|
||||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
messages.append(
|
_run_followup_analysis(
|
||||||
{
|
ai_config,
|
||||||
"role": "user",
|
req.issue,
|
||||||
"content": "Provide an updated diagnosis from the current diagnostics.",
|
report,
|
||||||
}
|
"Provide an updated diagnosis from the current diagnostics.",
|
||||||
|
prior_questions,
|
||||||
|
logger=logger,
|
||||||
)
|
)
|
||||||
response = _stream_conversation(ai, messages, logger=logger)
|
prior_questions.append("/analyze")
|
||||||
messages.append({"role": "assistant", "content": response})
|
if logger is not None:
|
||||||
|
logger.log_event("interactive_followup", {"question": "/analyze"})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if report is None:
|
if report is None:
|
||||||
@@ -312,15 +299,22 @@ async def _interactive_loop(
|
|||||||
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
console.print(f"[cyan]Collecting diagnostics:[/cyan] {len(plan)} commands")
|
||||||
report = await collect_from_plan(session, plan)
|
report = await collect_from_plan(session, plan)
|
||||||
_handle_collection_report(report)
|
_handle_collection_report(report)
|
||||||
messages = _reset_messages(report)
|
|
||||||
|
|
||||||
if messages is None:
|
if report is None:
|
||||||
console.print("[red]No diagnostics available to analyze.[/red]")
|
console.print("[red]No diagnostics available to analyze.[/red]")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
messages.append({"role": "user", "content": command})
|
_run_followup_analysis(
|
||||||
response = _stream_conversation(ai, messages, logger=logger)
|
ai_config,
|
||||||
messages.append({"role": "assistant", "content": response})
|
req.issue,
|
||||||
|
report,
|
||||||
|
command,
|
||||||
|
prior_questions,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
prior_questions.append(command)
|
||||||
|
if logger is not None:
|
||||||
|
logger.log_event("interactive_followup", {"question": command})
|
||||||
|
|
||||||
|
|
||||||
def _handle_probe_result(result: SSHCommandResult) -> None:
|
def _handle_probe_result(result: SSHCommandResult) -> None:
|
||||||
@@ -365,12 +359,18 @@ def _run_analysis(
|
|||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
response = "".join(chunks)
|
response = "".join(chunks)
|
||||||
console.print(Markdown(response))
|
console.print(Markdown(response))
|
||||||
|
|
||||||
|
warnings = validate_ai_response(response)
|
||||||
|
for item in warnings:
|
||||||
|
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
|
||||||
|
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event(
|
logger.log_event(
|
||||||
"analysis_response",
|
"analysis_response",
|
||||||
{
|
{
|
||||||
"issue": issue,
|
"issue": issue,
|
||||||
"response": response,
|
"response": response,
|
||||||
|
"guardrail_warnings": warnings,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
@@ -380,33 +380,46 @@ def _run_analysis(
|
|||||||
raise typer.Exit(code=1) from exc
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
||||||
def _stream_conversation(
|
def _run_followup_analysis(
|
||||||
ai: AIClient,
|
ai_config: AIConfig,
|
||||||
messages: list[dict[str, str]],
|
issue: str,
|
||||||
|
report: CollectionReport,
|
||||||
|
question: str,
|
||||||
|
prior_questions: list[str],
|
||||||
*,
|
*,
|
||||||
logger: SessionLogger | None,
|
logger: SessionLogger | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Stream a multi-turn AI response and return the final text."""
|
"""Run grounded follow-up analysis re-anchored to current diagnostics."""
|
||||||
console.print("[cyan]Analyzing...[/cyan]\n")
|
console.print("[cyan]Analyzing...[/cyan]\n")
|
||||||
|
ai = AIClient(ai_config)
|
||||||
|
system_prompt = build_system_prompt()
|
||||||
|
user_message = build_followup_message(issue, report, question, prior_questions)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunks: list[str] = []
|
chunks: list[str] = []
|
||||||
for chunk in ai.stream_messages(messages):
|
for chunk in ai.stream(system_prompt, user_message):
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
response = "".join(chunks)
|
response = "".join(chunks)
|
||||||
console.print(Markdown(response))
|
console.print(Markdown(response))
|
||||||
if logger is not None and messages:
|
|
||||||
|
warnings = validate_ai_response(response)
|
||||||
|
for item in warnings:
|
||||||
|
console.print(f"[yellow]Guardrail warning:[/yellow] {item}")
|
||||||
|
|
||||||
|
if logger is not None:
|
||||||
logger.log_event(
|
logger.log_event(
|
||||||
"analysis_response",
|
"analysis_response",
|
||||||
{
|
{
|
||||||
"last_user_message": messages[-1].get("content", ""),
|
"last_user_message": question,
|
||||||
"response": response,
|
"response": response,
|
||||||
|
"guardrail_warnings": warnings,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
console.print(f"[red]AI analysis failed:[/red] {exc}")
|
console.print(f"[red]AI analysis failed:[/red] {exc}")
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.log_event("analysis_error", {"error": str(exc)})
|
logger.log_event("analysis_error", {"error": str(exc), "question": question})
|
||||||
raise typer.Exit(code=1) from exc
|
raise typer.Exit(code=1) from exc
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,15 @@ Your job:
|
|||||||
|
|
||||||
Important rules:
|
Important rules:
|
||||||
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
|
- Only draw conclusions from data that is actually present. Do not speculate or invent evidence.
|
||||||
|
- For every root-cause claim, quote at least one exact snippet from collected output in backticks.
|
||||||
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
|
- If a command shows "could not be executed (SSH error)" it means the remote host blocked or
|
||||||
rejected that specific command — it is not evidence about the service or system state.
|
rejected that specific command — it is not evidence about the service or system state.
|
||||||
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
|
- If there is not enough data to diagnose the issue, say so plainly and list exactly what
|
||||||
additional commands or log files would be needed.
|
additional commands or log files would be needed.
|
||||||
- Keep the response short. Skip sections that have nothing useful to say.
|
- Keep the response short. Skip sections that have nothing useful to say.
|
||||||
- Never suggest commands that modify the system unless explicitly asked.
|
- Never suggest commands that modify the system unless explicitly asked.
|
||||||
|
- Default to read-only verification steps. Do not suggest restarting services or editing configs
|
||||||
|
unless the user explicitly asks for remediation actions.
|
||||||
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
|
- Format with clear sections: **Root Cause**, **Evidence**, **Recommended Actions**.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -72,3 +75,27 @@ def build_user_message(issue: str, report: CollectionReport) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def build_followup_message(
|
||||||
|
issue: str,
|
||||||
|
report: CollectionReport,
|
||||||
|
question: str,
|
||||||
|
prior_questions: list[str],
|
||||||
|
) -> str:
|
||||||
|
"""Build a grounded follow-up message that re-anchors to diagnostics each turn."""
|
||||||
|
base = build_user_message(issue, report)
|
||||||
|
lines: list[str] = [base, "## Follow-up"]
|
||||||
|
|
||||||
|
if prior_questions:
|
||||||
|
lines.append("\nRecent user follow-up questions:")
|
||||||
|
for idx, item in enumerate(prior_questions[-5:], start=1):
|
||||||
|
lines.append(f"{idx}. {item}")
|
||||||
|
|
||||||
|
lines.append("\nCurrent follow-up question:")
|
||||||
|
lines.append(question)
|
||||||
|
lines.append(
|
||||||
|
"\nAnswer strictly from the collected diagnostics above. "
|
||||||
|
"If evidence is insufficient, explicitly say so."
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||||
from tai.collectors import CollectedItem, CollectionReport
|
from tai.collectors import CollectedItem, CollectionReport
|
||||||
from tai.prompt_builder import build_system_prompt, build_user_message
|
from tai.prompt_builder import build_followup_message, build_system_prompt, build_user_message
|
||||||
from tai.ssh_client import SSHCommandResult
|
from tai.ssh_client import SSHCommandResult
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -218,3 +218,16 @@ def test_build_user_message_handles_no_output() -> None:
|
|||||||
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
|
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
|
||||||
msg = build_user_message("test", report)
|
msg = build_user_message("test", report)
|
||||||
assert "no output" in msg
|
assert "no output" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_followup_message_includes_question_context() -> None:
|
||||||
|
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
|
||||||
|
msg = build_followup_message(
|
||||||
|
"nginx is failing",
|
||||||
|
report,
|
||||||
|
"what should I check next?",
|
||||||
|
["is nginx running?", "show me logs"],
|
||||||
|
)
|
||||||
|
assert "Current follow-up question" in msg
|
||||||
|
assert "what should I check next?" in msg
|
||||||
|
assert "Recent user follow-up questions" in msg
|
||||||
|
|||||||
24
tests/test_ai_guardrails.py
Normal file
24
tests/test_ai_guardrails.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Tests for AI response guardrails."""
|
||||||
|
|
||||||
|
from tai.ai_guardrails import validate_ai_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_ai_response_flags_missing_evidence_and_quotes() -> None:
|
||||||
|
warnings = validate_ai_response("Root cause only, no structure.")
|
||||||
|
assert any("Evidence section" in item for item in warnings)
|
||||||
|
assert any("quoted evidence" in item for item in warnings)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_ai_response_flags_risky_actions() -> None:
|
||||||
|
text = "Evidence: `PasswordAuthentication no`\nRun systemctl restart sshd now."
|
||||||
|
warnings = validate_ai_response(text)
|
||||||
|
assert any("modifying actions" in item for item in warnings)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_ai_response_allows_grounded_read_only_answer() -> None:
|
||||||
|
text = (
|
||||||
|
"Evidence: `PasswordAuthentication no`\n"
|
||||||
|
"Recommended Actions: run `journalctl -u sshd -n 200 --no-pager`"
|
||||||
|
)
|
||||||
|
warnings = validate_ai_response(text)
|
||||||
|
assert not warnings
|
||||||
@@ -207,7 +207,7 @@ def test_interactive_unknown_command_prints_hint(monkeypatch) -> None: # type:
|
|||||||
commands = iter(["what should I check next?", "/quit"])
|
commands = iter(["what should I check next?", "/quit"])
|
||||||
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
monkeypatch.setattr("tai.cli.collect_from_plan", fake_collect_from_plan)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"tai.cli.AIClient.stream_messages",
|
"tai.cli.AIClient.stream",
|
||||||
lambda *_args, **_kwargs: iter(["Check logs."]),
|
lambda *_args, **_kwargs: iter(["Check logs."]),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
|
monkeypatch.setattr("builtins.input", lambda _prompt: next(commands))
|
||||||
|
|||||||
Reference in New Issue
Block a user