Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
93
tests/test_ssh_client.py
Normal file
93
tests/test_ssh_client.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tai.ssh_client import SSHClient, SSHCommandRejectedError, SSHConnectionConfig
|
||||
|
||||
|
||||
def _client(**kwargs: object) -> SSHClient:
|
||||
host = str(kwargs.get("host", "root@ssh.archflux.net"))
|
||||
port_value = kwargs.get("port", 22)
|
||||
if not isinstance(port_value, int):
|
||||
raise TypeError("port must be an int")
|
||||
port = port_value
|
||||
identity_file = kwargs.get("identity_file")
|
||||
jump_host = kwargs.get("jump_host")
|
||||
ignore_ssh_config = bool(kwargs.get("ignore_ssh_config", False))
|
||||
|
||||
if identity_file is not None and not isinstance(identity_file, Path):
|
||||
raise TypeError("identity_file must be a Path or None")
|
||||
|
||||
if jump_host is not None and not isinstance(jump_host, str):
|
||||
raise TypeError("jump_host must be a string or None")
|
||||
|
||||
return SSHClient(
|
||||
SSHConnectionConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
identity_file=identity_file,
|
||||
jump_host=jump_host,
|
||||
ignore_ssh_config=ignore_ssh_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_summary_includes_expected_defaults() -> None:
|
||||
client = _client()
|
||||
text = client.summary()
|
||||
|
||||
assert "host=root@ssh.archflux.net" in text
|
||||
assert "port=22" in text
|
||||
assert "key=auto" in text
|
||||
assert "jump=none" in text
|
||||
assert "mode=use ssh config" in text
|
||||
|
||||
|
||||
def test_build_ssh_argv_respects_flags() -> None:
|
||||
client = _client(
|
||||
identity_file=Path("/root/.ssh/id_ed25519"),
|
||||
jump_host="bastion.archflux.net",
|
||||
ignore_ssh_config=True,
|
||||
)
|
||||
|
||||
argv = client.build_ssh_argv("uname -a")
|
||||
|
||||
assert argv[0] == "ssh"
|
||||
assert "-p" in argv
|
||||
assert "22" in argv
|
||||
assert "-F" in argv
|
||||
assert "/dev/null" in argv
|
||||
assert "-i" in argv
|
||||
assert "/root/.ssh/id_ed25519" in argv
|
||||
assert "-J" in argv
|
||||
assert "bastion.archflux.net" in argv
|
||||
assert argv[-2] == "root@ssh.archflux.net"
|
||||
assert argv[-1] == "uname -a"
|
||||
|
||||
|
||||
def test_rejects_destructive_or_shell_operator_commands() -> None:
|
||||
client = _client()
|
||||
|
||||
for command in ["rm -rf /tmp/x", "cat /etc/hosts | grep localhost", "uname -a; id"]:
|
||||
with pytest.raises(SSHCommandRejectedError):
|
||||
client.validate_read_only_command(command)
|
||||
|
||||
|
||||
def test_allows_expected_read_only_commands() -> None:
|
||||
client = _client()
|
||||
|
||||
for command in [
|
||||
"uname -a",
|
||||
"journalctl -n 100",
|
||||
"systemctl status apache2",
|
||||
"cat /etc/hosts",
|
||||
"ss -lntp",
|
||||
]:
|
||||
client.validate_read_only_command(command)
|
||||
|
||||
|
||||
def test_rejects_non_read_only_systemctl_subcommand() -> None:
|
||||
client = _client()
|
||||
|
||||
with pytest.raises(SSHCommandRejectedError):
|
||||
client.validate_read_only_command("systemctl restart apache2")
|
||||
Reference in New Issue
Block a user