"""Tests for the AI client and prompt builder.""" from unittest.mock import MagicMock, patch from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig from tai.collectors import CollectedItem, CollectionReport from tai.prompt_builder import build_system_prompt, build_user_message from tai.ssh_client import SSHCommandResult # --------------------------------------------------------------------------- # AIConfig defaults # --------------------------------------------------------------------------- def test_ai_config_defaults() -> None: config = AIConfig() assert config.host == DEFAULT_AI_HOST assert config.model == DEFAULT_MODEL assert config.api_key == "ollama" def test_ai_config_custom_values() -> None: config = AIConfig(host="https://api.openai.com/v1", model="gpt-4o", api_key="sk-test") assert config.host == "https://api.openai.com/v1" assert config.model == "gpt-4o" assert config.api_key == "sk-test" # --------------------------------------------------------------------------- # AIClient.summary # --------------------------------------------------------------------------- def test_ai_client_summary_contains_host_and_model() -> None: config = AIConfig(host="http://myserver:11434/v1", model="llama3.1:8b") client = AIClient(config) summary = client.summary() assert "http://myserver:11434/v1" in summary assert "llama3.1:8b" in summary # --------------------------------------------------------------------------- # AIClient.complete (mocked) # --------------------------------------------------------------------------- def _make_mock_response(content: str, model: str = "gemma3:4b") -> MagicMock: usage = MagicMock() usage.prompt_tokens = 10 usage.completion_tokens = 20 message = MagicMock() message.content = content choice = MagicMock() choice.message = message response = MagicMock() response.choices = [choice] response.model = model response.usage = usage return response def test_complete_returns_ai_response() -> None: config = AIConfig() client = AIClient(config) mock_response = _make_mock_response("The root cause is X.") with patch.object(client._client.chat.completions, "create", return_value=mock_response): result = client.complete("system prompt", "user message") assert result.content == "The root cause is X." assert result.prompt_tokens == 10 assert result.completion_tokens == 20 assert result.total_tokens == 30 def test_complete_handles_empty_content() -> None: config = AIConfig() client = AIClient(config) mock_response = _make_mock_response(None) # type: ignore[arg-type] mock_response.choices[0].message.content = None with patch.object(client._client.chat.completions, "create", return_value=mock_response): result = client.complete("system", "user") assert result.content == "" # --------------------------------------------------------------------------- # AIClient.stream (mocked) # --------------------------------------------------------------------------- def test_stream_yields_chunks() -> None: config = AIConfig() client = AIClient(config) def _make_chunk(text: str | None) -> MagicMock: delta = MagicMock() delta.content = text choice = MagicMock() choice.delta = delta chunk = MagicMock() chunk.choices = [choice] return chunk mock_chunks = [ _make_chunk("Root "), _make_chunk("cause "), _make_chunk(None), _make_chunk("found."), ] with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)): result = list(client.stream("system", "user")) assert result == ["Root ", "cause ", "found."] # --------------------------------------------------------------------------- # prompt_builder # --------------------------------------------------------------------------- def _make_report(items: list[tuple[str, str, int, str, str]]) -> CollectionReport: """Build a CollectionReport from (name, command, exit_code, stdout, stderr) tuples.""" return CollectionReport( host="root@testhost", items=[ CollectedItem( name=name, result=SSHCommandResult( command=command, exit_code=exit_code, stdout=stdout, stderr=stderr, ), ) for name, command, exit_code, stdout, stderr in items ], ) def test_build_system_prompt_contains_key_instructions() -> None: prompt = build_system_prompt() assert "Root Cause" in prompt assert "Evidence" in prompt assert "Recommended Actions" in prompt assert "read-only" in prompt.lower() def test_build_user_message_contains_issue_and_host() -> None: report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")]) msg = build_user_message("nginx is failing", report) assert "nginx is failing" in msg assert "root@testhost" in msg def test_build_user_message_includes_command_output() -> None: report = _make_report([("kernel", "uname -a", 0, "Linux web01 6.1.0", "")]) msg = build_user_message("test issue", report) assert "uname -a" in msg assert "Linux web01 6.1.0" in msg def test_build_user_message_shows_stderr() -> None: report = _make_report( [("svc", "systemctl status nginx", 3, "", "Unit nginx.service not found.")] ) msg = build_user_message("nginx not found", report) assert "Unit nginx.service not found." in msg def test_build_user_message_notes_truncation() -> None: result = SSHCommandResult( command="journalctl -n 100 --no-pager", exit_code=0, stdout="...", stderr="", stdout_truncated=True, ) report = CollectionReport( host="root@testhost", items=[CollectedItem(name="journal", result=result)], ) msg = build_user_message("disk issue", report) assert "truncated" in msg def test_build_user_message_handles_no_output() -> None: report = _make_report([("empty", "cat /nonexistent", 1, "", "")]) msg = build_user_message("test", report) assert "no output" in msg