from pathlib import Path import pytest from tai.ssh_client import SSHClient, SSHCommandRejectedError, SSHConnectionConfig def _client(**kwargs: object) -> SSHClient: host = str(kwargs.get("host", "root@ssh.archflux.net")) port_value = kwargs.get("port", 22) if not isinstance(port_value, int): raise TypeError("port must be an int") port = port_value identity_file = kwargs.get("identity_file") jump_host = kwargs.get("jump_host") ignore_ssh_config = bool(kwargs.get("ignore_ssh_config", False)) if identity_file is not None and not isinstance(identity_file, Path): raise TypeError("identity_file must be a Path or None") if jump_host is not None and not isinstance(jump_host, str): raise TypeError("jump_host must be a string or None") return SSHClient( SSHConnectionConfig( host=host, port=port, identity_file=identity_file, jump_host=jump_host, ignore_ssh_config=ignore_ssh_config, ) ) def test_summary_includes_expected_defaults() -> None: client = _client() text = client.summary() assert "host=root@ssh.archflux.net" in text assert "port=22" in text assert "key=auto" in text assert "jump=none" in text assert "mode=use ssh config" in text def test_build_ssh_argv_respects_flags() -> None: client = _client( identity_file=Path("/root/.ssh/id_ed25519"), jump_host="bastion.archflux.net", ignore_ssh_config=True, ) argv = client.build_ssh_argv("uname -a") assert argv[0] == "ssh" assert "-p" in argv assert "22" in argv assert "-F" in argv assert "/dev/null" in argv assert "-i" in argv assert "/root/.ssh/id_ed25519" in argv assert "-J" in argv assert "bastion.archflux.net" in argv assert argv[-2] == "root@ssh.archflux.net" assert argv[-1] == "uname -a" def test_rejects_destructive_or_shell_operator_commands() -> None: client = _client() for command in ["rm -rf /tmp/x", "cat /etc/hosts | grep localhost", "uname -a; id"]: with pytest.raises(SSHCommandRejectedError): client.validate_read_only_command(command) def test_allows_expected_read_only_commands() -> None: client = _client() for command in [ "uname -a", "journalctl -n 100", "systemctl status apache2", "cat /etc/hosts", "ss -lntp", ]: client.validate_read_only_command(command) def test_rejects_non_read_only_systemctl_subcommand() -> None: client = _client() with pytest.raises(SSHCommandRejectedError): client.validate_read_only_command("systemctl restart apache2") def test_truncate_output_marks_and_limits_content() -> None: text = "a" * 400 rendered, truncated = SSHClient._truncate_output(text, max_output_bytes=256) assert truncated is True assert rendered.endswith("...[truncated]") def test_truncate_output_keeps_short_content() -> None: rendered, truncated = SSHClient._truncate_output("short output", max_output_bytes=256) assert truncated is False assert rendered == "short output"