Some checks failed
CI / test (push) Failing after 1s
Co-authored-by: Copilot <copilot@github.com>
94 lines
2.7 KiB
Python
94 lines
2.7 KiB
Python
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from tai.ssh_client import SSHClient, SSHCommandRejectedError, SSHConnectionConfig
|
|
|
|
|
|
def _client(**kwargs: object) -> SSHClient:
|
|
host = str(kwargs.get("host", "root@ssh.archflux.net"))
|
|
port_value = kwargs.get("port", 22)
|
|
if not isinstance(port_value, int):
|
|
raise TypeError("port must be an int")
|
|
port = port_value
|
|
identity_file = kwargs.get("identity_file")
|
|
jump_host = kwargs.get("jump_host")
|
|
ignore_ssh_config = bool(kwargs.get("ignore_ssh_config", False))
|
|
|
|
if identity_file is not None and not isinstance(identity_file, Path):
|
|
raise TypeError("identity_file must be a Path or None")
|
|
|
|
if jump_host is not None and not isinstance(jump_host, str):
|
|
raise TypeError("jump_host must be a string or None")
|
|
|
|
return SSHClient(
|
|
SSHConnectionConfig(
|
|
host=host,
|
|
port=port,
|
|
identity_file=identity_file,
|
|
jump_host=jump_host,
|
|
ignore_ssh_config=ignore_ssh_config,
|
|
)
|
|
)
|
|
|
|
|
|
def test_summary_includes_expected_defaults() -> None:
|
|
client = _client()
|
|
text = client.summary()
|
|
|
|
assert "host=root@ssh.archflux.net" in text
|
|
assert "port=22" in text
|
|
assert "key=auto" in text
|
|
assert "jump=none" in text
|
|
assert "mode=use ssh config" in text
|
|
|
|
|
|
def test_build_ssh_argv_respects_flags() -> None:
|
|
client = _client(
|
|
identity_file=Path("/root/.ssh/id_ed25519"),
|
|
jump_host="bastion.archflux.net",
|
|
ignore_ssh_config=True,
|
|
)
|
|
|
|
argv = client.build_ssh_argv("uname -a")
|
|
|
|
assert argv[0] == "ssh"
|
|
assert "-p" in argv
|
|
assert "22" in argv
|
|
assert "-F" in argv
|
|
assert "/dev/null" in argv
|
|
assert "-i" in argv
|
|
assert "/root/.ssh/id_ed25519" in argv
|
|
assert "-J" in argv
|
|
assert "bastion.archflux.net" in argv
|
|
assert argv[-2] == "root@ssh.archflux.net"
|
|
assert argv[-1] == "uname -a"
|
|
|
|
|
|
def test_rejects_destructive_or_shell_operator_commands() -> None:
|
|
client = _client()
|
|
|
|
for command in ["rm -rf /tmp/x", "cat /etc/hosts | grep localhost", "uname -a; id"]:
|
|
with pytest.raises(SSHCommandRejectedError):
|
|
client.validate_read_only_command(command)
|
|
|
|
|
|
def test_allows_expected_read_only_commands() -> None:
|
|
client = _client()
|
|
|
|
for command in [
|
|
"uname -a",
|
|
"journalctl -n 100",
|
|
"systemctl status apache2",
|
|
"cat /etc/hosts",
|
|
"ss -lntp",
|
|
]:
|
|
client.validate_read_only_command(command)
|
|
|
|
|
|
def test_rejects_non_read_only_systemctl_subcommand() -> None:
|
|
client = _client()
|
|
|
|
with pytest.raises(SSHCommandRejectedError):
|
|
client.validate_read_only_command("systemctl restart apache2")
|