diff --git a/winrm/__init__.py b/winrm/__init__.py index ec31a32..532d30a 100644 --- a/winrm/__init__.py +++ b/winrm/__init__.py @@ -39,22 +39,24 @@ def __init__(self, target: str, auth: tuple[str, str], **kwargs: t.Any) -> None: self.url = self._build_url(target, kwargs.get("transport", "plaintext")) self.protocol = Protocol(self.url, username=username, password=password, **kwargs) - def run_cmd(self, command: str, args: collections.abc.Iterable[str | bytes] = ()) -> Response: + def run_cmd(self, command: str, args: collections.abc.Iterable[str | bytes] = (), stdin_input: str | bytes | None = None) -> Response: # TODO optimize perf. Do not call open/close shell every time shell_id = self.protocol.open_shell() command_id = self.protocol.run_command(shell_id, command, args) + if stdin_input is not None: + self.protocol.send_command_input(shell_id, command_id, stdin_input, True) rs = Response(self.protocol.get_command_output(shell_id, command_id)) self.protocol.cleanup_command(shell_id, command_id) self.protocol.close_shell(shell_id) return rs - def run_ps(self, script: str) -> Response: + def run_ps(self, script: str, stdin_input: str | bytes | None = None) -> Response: """base64 encodes a Powershell script and executes the powershell encoded script command """ # must use utf16 little endian on windows encoded_ps = b64encode(script.encode("utf_16_le")).decode("ascii") - rs = self.run_cmd("powershell -encodedcommand {0}".format(encoded_ps)) + rs = self.run_cmd("powershell -encodedcommand {0}".format(encoded_ps), stdin_input=stdin_input) if len(rs.std_err): # if there was an error message, clean it it up and make it human # readable diff --git a/winrm/tests/test_integration_protocol.py b/winrm/tests/test_integration_protocol.py index df0ac6a..26964c0 100644 --- a/winrm/tests/test_integration_protocol.py +++ b/winrm/tests/test_integration_protocol.py @@ -85,6 +85,21 @@ def test_run_command_taking_more_than_operation_timeout_sec(protocol_real): protocol_real.close_shell(shell_id) +def test_run_command_with_stdin_input(protocol_real): + shell_id = protocol_real.open_shell() + command_id = protocol_real.run_command(shell_id, "more") + stdin_text = "Hello, stdin input" + protocol_real.send_command_input(shell_id, command_id, stdin_text, end=True) + std_out, std_err, status_code = protocol_real.get_command_output(shell_id, command_id) + + assert status_code == 0 + assert len(std_err) == 0 + assert std_out.decode().strip() == stdin_text.strip() + + protocol_real.cleanup_command(shell_id, command_id) + protocol_real.close_shell(shell_id) + + @xfail() def test_set_timeout(protocol_real): raise NotImplementedError()