Files
tai/tests/test_ai.py
zphinx 61d3e2c4e6
Some checks failed
CI / test (push) Failing after 15s
update
Co-authored-by: Copilot <copilot@github.com>
2026-05-04 04:51:48 +02:00

193 lines
6.2 KiB
Python

"""Tests for the AI client and prompt builder."""
from unittest.mock import MagicMock, patch
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
from tai.collectors import CollectedItem, CollectionReport
from tai.prompt_builder import build_system_prompt, build_user_message
from tai.ssh_client import SSHCommandResult
# ---------------------------------------------------------------------------
# AIConfig defaults
# ---------------------------------------------------------------------------
def test_ai_config_defaults() -> None:
config = AIConfig()
assert config.host == DEFAULT_AI_HOST
assert config.model == DEFAULT_MODEL
assert config.api_key == "ollama"
def test_ai_config_custom_values() -> None:
config = AIConfig(host="https://api.openai.com/v1", model="gpt-4o", api_key="sk-test")
assert config.host == "https://api.openai.com/v1"
assert config.model == "gpt-4o"
assert config.api_key == "sk-test"
# ---------------------------------------------------------------------------
# AIClient.summary
# ---------------------------------------------------------------------------
def test_ai_client_summary_contains_host_and_model() -> None:
config = AIConfig(host="http://myserver:11434/v1", model="llama3.1:8b")
client = AIClient(config)
summary = client.summary()
assert "http://myserver:11434/v1" in summary
assert "llama3.1:8b" in summary
# ---------------------------------------------------------------------------
# AIClient.complete (mocked)
# ---------------------------------------------------------------------------
def _make_mock_response(content: str, model: str = "gemma3:4b") -> MagicMock:
usage = MagicMock()
usage.prompt_tokens = 10
usage.completion_tokens = 20
message = MagicMock()
message.content = content
choice = MagicMock()
choice.message = message
response = MagicMock()
response.choices = [choice]
response.model = model
response.usage = usage
return response
def test_complete_returns_ai_response() -> None:
config = AIConfig()
client = AIClient(config)
mock_response = _make_mock_response("The root cause is X.")
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
result = client.complete("system prompt", "user message")
assert result.content == "The root cause is X."
assert result.prompt_tokens == 10
assert result.completion_tokens == 20
assert result.total_tokens == 30
def test_complete_handles_empty_content() -> None:
config = AIConfig()
client = AIClient(config)
mock_response = _make_mock_response(None) # type: ignore[arg-type]
mock_response.choices[0].message.content = None
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
result = client.complete("system", "user")
assert result.content == ""
# ---------------------------------------------------------------------------
# AIClient.stream (mocked)
# ---------------------------------------------------------------------------
def test_stream_yields_chunks() -> None:
config = AIConfig()
client = AIClient(config)
def _make_chunk(text: str | None) -> MagicMock:
delta = MagicMock()
delta.content = text
choice = MagicMock()
choice.delta = delta
chunk = MagicMock()
chunk.choices = [choice]
return chunk
mock_chunks = [
_make_chunk("Root "), _make_chunk("cause "), _make_chunk(None), _make_chunk("found."),
]
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
result = list(client.stream("system", "user"))
assert result == ["Root ", "cause ", "found."]
# ---------------------------------------------------------------------------
# prompt_builder
# ---------------------------------------------------------------------------
def _make_report(items: list[tuple[str, str, int, str, str]]) -> CollectionReport:
"""Build a CollectionReport from (name, command, exit_code, stdout, stderr) tuples."""
return CollectionReport(
host="root@testhost",
items=[
CollectedItem(
name=name,
result=SSHCommandResult(
command=command,
exit_code=exit_code,
stdout=stdout,
stderr=stderr,
),
)
for name, command, exit_code, stdout, stderr in items
],
)
def test_build_system_prompt_contains_key_instructions() -> None:
prompt = build_system_prompt()
assert "Root Cause" in prompt
assert "Evidence" in prompt
assert "Recommended Actions" in prompt
assert "read-only" in prompt.lower()
def test_build_user_message_contains_issue_and_host() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
msg = build_user_message("nginx is failing", report)
assert "nginx is failing" in msg
assert "root@testhost" in msg
def test_build_user_message_includes_command_output() -> None:
report = _make_report([("kernel", "uname -a", 0, "Linux web01 6.1.0", "")])
msg = build_user_message("test issue", report)
assert "uname -a" in msg
assert "Linux web01 6.1.0" in msg
def test_build_user_message_shows_stderr() -> None:
report = _make_report(
[("svc", "systemctl status nginx", 3, "", "Unit nginx.service not found.")]
)
msg = build_user_message("nginx not found", report)
assert "Unit nginx.service not found." in msg
def test_build_user_message_notes_truncation() -> None:
result = SSHCommandResult(
command="journalctl -n 100 --no-pager",
exit_code=0,
stdout="...",
stderr="",
stdout_truncated=True,
)
report = CollectionReport(
host="root@testhost",
items=[CollectedItem(name="journal", result=result)],
)
msg = build_user_message("disk issue", report)
assert "truncated" in msg
def test_build_user_message_handles_no_output() -> None:
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
msg = build_user_message("test", report)
assert "no output" in msg