Files
tai/tests/test_ssh_client.py
zphinx e589240c67
All checks were successful
CI / test (push) Successful in 15s
update
Co-authored-by: Copilot <copilot@github.com>
2026-05-04 04:22:58 +02:00

109 lines
3.1 KiB
Python

from pathlib import Path
import pytest
from tai.ssh_client import SSHClient, SSHCommandRejectedError, SSHConnectionConfig
def _client(**kwargs: object) -> SSHClient:
host = str(kwargs.get("host", "root@ssh.archflux.net"))
port_value = kwargs.get("port", 22)
if not isinstance(port_value, int):
raise TypeError("port must be an int")
port = port_value
identity_file = kwargs.get("identity_file")
jump_host = kwargs.get("jump_host")
ignore_ssh_config = bool(kwargs.get("ignore_ssh_config", False))
if identity_file is not None and not isinstance(identity_file, Path):
raise TypeError("identity_file must be a Path or None")
if jump_host is not None and not isinstance(jump_host, str):
raise TypeError("jump_host must be a string or None")
return SSHClient(
SSHConnectionConfig(
host=host,
port=port,
identity_file=identity_file,
jump_host=jump_host,
ignore_ssh_config=ignore_ssh_config,
)
)
def test_summary_includes_expected_defaults() -> None:
client = _client()
text = client.summary()
assert "host=root@ssh.archflux.net" in text
assert "port=22" in text
assert "key=auto" in text
assert "jump=none" in text
assert "mode=use ssh config" in text
def test_build_ssh_argv_respects_flags() -> None:
client = _client(
identity_file=Path("/root/.ssh/id_ed25519"),
jump_host="bastion.archflux.net",
ignore_ssh_config=True,
)
argv = client.build_ssh_argv("uname -a")
assert argv[0] == "ssh"
assert "-p" in argv
assert "22" in argv
assert "-F" in argv
assert "/dev/null" in argv
assert "-i" in argv
assert "/root/.ssh/id_ed25519" in argv
assert "-J" in argv
assert "bastion.archflux.net" in argv
assert argv[-2] == "root@ssh.archflux.net"
assert argv[-1] == "uname -a"
def test_rejects_destructive_or_shell_operator_commands() -> None:
client = _client()
for command in ["rm -rf /tmp/x", "cat /etc/hosts | grep localhost", "uname -a; id"]:
with pytest.raises(SSHCommandRejectedError):
client.validate_read_only_command(command)
def test_allows_expected_read_only_commands() -> None:
client = _client()
for command in [
"uname -a",
"journalctl -n 100",
"systemctl status apache2",
"cat /etc/hosts",
"ss -lntp",
]:
client.validate_read_only_command(command)
def test_rejects_non_read_only_systemctl_subcommand() -> None:
client = _client()
with pytest.raises(SSHCommandRejectedError):
client.validate_read_only_command("systemctl restart apache2")
def test_truncate_output_marks_and_limits_content() -> None:
text = "a" * 400
rendered, truncated = SSHClient._truncate_output(text, max_output_bytes=256)
assert truncated is True
assert rendered.endswith("...[truncated]")
def test_truncate_output_keeps_short_content() -> None:
rendered, truncated = SSHClient._truncate_output("short output", max_output_bytes=256)
assert truncated is False
assert rendered == "short output"