Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
192
tests/test_ai.py
Normal file
192
tests/test_ai.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for the AI client and prompt builder."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tai.ai_client import DEFAULT_AI_HOST, DEFAULT_MODEL, AIClient, AIConfig
|
||||
from tai.collectors import CollectedItem, CollectionReport
|
||||
from tai.prompt_builder import build_system_prompt, build_user_message
|
||||
from tai.ssh_client import SSHCommandResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AIConfig defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_ai_config_defaults() -> None:
|
||||
config = AIConfig()
|
||||
assert config.host == DEFAULT_AI_HOST
|
||||
assert config.model == DEFAULT_MODEL
|
||||
assert config.api_key == "ollama"
|
||||
|
||||
|
||||
def test_ai_config_custom_values() -> None:
|
||||
config = AIConfig(host="https://api.openai.com/v1", model="gpt-4o", api_key="sk-test")
|
||||
assert config.host == "https://api.openai.com/v1"
|
||||
assert config.model == "gpt-4o"
|
||||
assert config.api_key == "sk-test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AIClient.summary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_ai_client_summary_contains_host_and_model() -> None:
|
||||
config = AIConfig(host="http://myserver:11434/v1", model="llama3.1:8b")
|
||||
client = AIClient(config)
|
||||
summary = client.summary()
|
||||
assert "http://myserver:11434/v1" in summary
|
||||
assert "llama3.1:8b" in summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AIClient.complete (mocked)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_response(content: str, model: str = "gemma3:4b") -> MagicMock:
|
||||
usage = MagicMock()
|
||||
usage.prompt_tokens = 10
|
||||
usage.completion_tokens = 20
|
||||
|
||||
message = MagicMock()
|
||||
message.content = content
|
||||
|
||||
choice = MagicMock()
|
||||
choice.message = message
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [choice]
|
||||
response.model = model
|
||||
response.usage = usage
|
||||
return response
|
||||
|
||||
|
||||
def test_complete_returns_ai_response() -> None:
|
||||
config = AIConfig()
|
||||
client = AIClient(config)
|
||||
mock_response = _make_mock_response("The root cause is X.")
|
||||
|
||||
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
|
||||
result = client.complete("system prompt", "user message")
|
||||
|
||||
assert result.content == "The root cause is X."
|
||||
assert result.prompt_tokens == 10
|
||||
assert result.completion_tokens == 20
|
||||
assert result.total_tokens == 30
|
||||
|
||||
|
||||
def test_complete_handles_empty_content() -> None:
|
||||
config = AIConfig()
|
||||
client = AIClient(config)
|
||||
mock_response = _make_mock_response(None) # type: ignore[arg-type]
|
||||
mock_response.choices[0].message.content = None
|
||||
|
||||
with patch.object(client._client.chat.completions, "create", return_value=mock_response):
|
||||
result = client.complete("system", "user")
|
||||
|
||||
assert result.content == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AIClient.stream (mocked)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stream_yields_chunks() -> None:
|
||||
config = AIConfig()
|
||||
client = AIClient(config)
|
||||
|
||||
def _make_chunk(text: str | None) -> MagicMock:
|
||||
delta = MagicMock()
|
||||
delta.content = text
|
||||
choice = MagicMock()
|
||||
choice.delta = delta
|
||||
chunk = MagicMock()
|
||||
chunk.choices = [choice]
|
||||
return chunk
|
||||
|
||||
mock_chunks = [
|
||||
_make_chunk("Root "), _make_chunk("cause "), _make_chunk(None), _make_chunk("found."),
|
||||
]
|
||||
|
||||
with patch.object(client._client.chat.completions, "create", return_value=iter(mock_chunks)):
|
||||
result = list(client.stream("system", "user"))
|
||||
|
||||
assert result == ["Root ", "cause ", "found."]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt_builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_report(items: list[tuple[str, str, int, str, str]]) -> CollectionReport:
|
||||
"""Build a CollectionReport from (name, command, exit_code, stdout, stderr) tuples."""
|
||||
return CollectionReport(
|
||||
host="root@testhost",
|
||||
items=[
|
||||
CollectedItem(
|
||||
name=name,
|
||||
result=SSHCommandResult(
|
||||
command=command,
|
||||
exit_code=exit_code,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
),
|
||||
)
|
||||
for name, command, exit_code, stdout, stderr in items
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_build_system_prompt_contains_key_instructions() -> None:
|
||||
prompt = build_system_prompt()
|
||||
assert "Root Cause" in prompt
|
||||
assert "Evidence" in prompt
|
||||
assert "Recommended Actions" in prompt
|
||||
assert "read-only" in prompt.lower()
|
||||
|
||||
|
||||
def test_build_user_message_contains_issue_and_host() -> None:
|
||||
report = _make_report([("kernel", "uname -a", 0, "Linux web01", "")])
|
||||
msg = build_user_message("nginx is failing", report)
|
||||
assert "nginx is failing" in msg
|
||||
assert "root@testhost" in msg
|
||||
|
||||
|
||||
def test_build_user_message_includes_command_output() -> None:
|
||||
report = _make_report([("kernel", "uname -a", 0, "Linux web01 6.1.0", "")])
|
||||
msg = build_user_message("test issue", report)
|
||||
assert "uname -a" in msg
|
||||
assert "Linux web01 6.1.0" in msg
|
||||
|
||||
|
||||
def test_build_user_message_shows_stderr() -> None:
|
||||
report = _make_report(
|
||||
[("svc", "systemctl status nginx", 3, "", "Unit nginx.service not found.")]
|
||||
)
|
||||
msg = build_user_message("nginx not found", report)
|
||||
assert "Unit nginx.service not found." in msg
|
||||
|
||||
|
||||
def test_build_user_message_notes_truncation() -> None:
|
||||
result = SSHCommandResult(
|
||||
command="journalctl -n 100 --no-pager",
|
||||
exit_code=0,
|
||||
stdout="...",
|
||||
stderr="",
|
||||
stdout_truncated=True,
|
||||
)
|
||||
report = CollectionReport(
|
||||
host="root@testhost",
|
||||
items=[CollectedItem(name="journal", result=result)],
|
||||
)
|
||||
msg = build_user_message("disk issue", report)
|
||||
assert "truncated" in msg
|
||||
|
||||
|
||||
def test_build_user_message_handles_no_output() -> None:
|
||||
report = _make_report([("empty", "cat /nonexistent", 1, "", "")])
|
||||
msg = build_user_message("test", report)
|
||||
assert "no output" in msg
|
||||
@@ -1,3 +1,5 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from tai.cli import app
|
||||
@@ -5,6 +7,24 @@ from tai.collectors import CollectedItem, CollectionReport
|
||||
from tai.ssh_client import SSHCommandResult
|
||||
|
||||
|
||||
def _mock_session(
|
||||
monkeypatch, # type: ignore[no-untyped-def]
|
||||
*,
|
||||
probe_result: SSHCommandResult | None = None,
|
||||
probe_raises: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Patch SSHClient.connect to return a mock session."""
|
||||
session = MagicMock()
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=None)
|
||||
if probe_raises:
|
||||
session.probe = AsyncMock(side_effect=probe_raises)
|
||||
else:
|
||||
session.probe = AsyncMock(return_value=probe_result)
|
||||
monkeypatch.setattr("tai.cli.SSHClient.connect", lambda _self, **kw: session)
|
||||
return session
|
||||
|
||||
|
||||
def test_run_command_prints_scaffold_summary() -> None:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
@@ -25,33 +45,23 @@ def test_run_command_prints_scaffold_summary() -> None:
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "tai scaffold ready" in result.stdout
|
||||
assert "tai" in result.stdout
|
||||
assert "host=web01" in result.stdout
|
||||
assert "port=5566" in result.stdout
|
||||
|
||||
|
||||
def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def]
|
||||
return SSHCommandResult(
|
||||
command="uname -a",
|
||||
exit_code=0,
|
||||
stdout="Linux ssh 6.12.0",
|
||||
stderr="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe)
|
||||
_mock_session(
|
||||
monkeypatch,
|
||||
probe_result=SSHCommandResult(
|
||||
command="uname -a", exit_code=0, stdout="Linux ssh 6.12.0", stderr=""
|
||||
),
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"apache failed",
|
||||
"--host",
|
||||
"ssh.archflux.net",
|
||||
"--port",
|
||||
"5566",
|
||||
"--probe",
|
||||
],
|
||||
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
@@ -60,27 +70,20 @@ def test_probe_success_prints_remote_output_by_default(monkeypatch) -> None: #
|
||||
|
||||
|
||||
def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
async def fake_probe(self) -> SSHCommandResult: # type: ignore[no-untyped-def]
|
||||
return SSHCommandResult(
|
||||
_mock_session(
|
||||
monkeypatch,
|
||||
probe_result=SSHCommandResult(
|
||||
command="uname -a",
|
||||
exit_code=255,
|
||||
stdout="",
|
||||
stderr="Permission denied (publickey,password).",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("tai.cli.SSHClient.probe", fake_probe)
|
||||
),
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"apache failed",
|
||||
"--host",
|
||||
"ssh.archflux.net",
|
||||
"--port",
|
||||
"5566",
|
||||
"--probe",
|
||||
],
|
||||
["apache failed", "--host", "ssh.archflux.net", "--port", "5566", "--probe"],
|
||||
)
|
||||
|
||||
assert result.exit_code == 1
|
||||
@@ -88,7 +91,9 @@ def test_probe_failure_returns_non_zero(monkeypatch) -> None: # type: ignore[no
|
||||
|
||||
|
||||
def test_collect_success_prints_summary(monkeypatch) -> None: # type: ignore[no-untyped-def]
|
||||
async def fake_collect_from_plan(_client, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||
_mock_session(monkeypatch)
|
||||
|
||||
async def fake_collect_from_plan(_session, _plan) -> CollectionReport: # type: ignore[no-untyped-def]
|
||||
return CollectionReport(
|
||||
host="ssh.archflux.net",
|
||||
items=[
|
||||
|
||||
Reference in New Issue
Block a user