Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions winrm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,24 @@ def __init__(self, target: str, auth: tuple[str, str], **kwargs: t.Any) -> None:
self.url = self._build_url(target, kwargs.get("transport", "plaintext"))
self.protocol = Protocol(self.url, username=username, password=password, **kwargs)

def run_cmd(self, command: str, args: collections.abc.Iterable[str | bytes] = ()) -> Response:
def run_cmd(self, command: str, args: collections.abc.Iterable[str | bytes] = (), stdin_input: str | bytes | None = None) -> Response:
# TODO optimize perf. Do not call open/close shell every time
shell_id = self.protocol.open_shell()
command_id = self.protocol.run_command(shell_id, command, args)
if stdin_input is not None:
self.protocol.send_command_input(shell_id, command_id, stdin_input, True)
rs = Response(self.protocol.get_command_output(shell_id, command_id))
self.protocol.cleanup_command(shell_id, command_id)
self.protocol.close_shell(shell_id)
return rs

def run_ps(self, script: str) -> Response:
def run_ps(self, script: str, stdin_input: str | bytes | None = None) -> Response:
"""base64 encodes a Powershell script and executes the powershell
encoded script command
"""
# must use utf16 little endian on windows
encoded_ps = b64encode(script.encode("utf_16_le")).decode("ascii")
rs = self.run_cmd("powershell -encodedcommand {0}".format(encoded_ps))
rs = self.run_cmd("powershell -encodedcommand {0}".format(encoded_ps), stdin_input=stdin_input)
if len(rs.std_err):
# if there was an error message, clean it it up and make it human
# readable
Expand Down
15 changes: 15 additions & 0 deletions winrm/tests/test_integration_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@ def test_run_command_taking_more_than_operation_timeout_sec(protocol_real):
protocol_real.close_shell(shell_id)


def test_run_command_with_stdin_input(protocol_real):
shell_id = protocol_real.open_shell()
command_id = protocol_real.run_command(shell_id, "more")
stdin_text = "Hello, stdin input"
protocol_real.send_command_input(shell_id, command_id, stdin_text, end=True)
std_out, std_err, status_code = protocol_real.get_command_output(shell_id, command_id)

assert status_code == 0
assert len(std_err) == 0
assert std_out.decode().strip() == stdin_text.strip()

protocol_real.cleanup_command(shell_id, command_id)
protocol_real.close_shell(shell_id)


@xfail()
def test_set_timeout(protocol_real):
raise NotImplementedError()
Expand Down