88import asyncssh
99import paramiko
1010import pytest
11+ from paramiko .server import InteractiveQuery
1112from pytest_mock import MockerFixture
1213from pytest_test_utils .waiters import wait_until
1314
@@ -52,13 +53,24 @@ class Server(paramiko.ServerInterface):
5253 """http://docs.paramiko.org/en/2.4/api/server.html."""
5354
5455 def __init__ (self , commands , * args , ** kwargs ) -> None :
55- super ().__init__ (* args , ** kwargs )
56+ super ().__init__ ()
5657 self .commands = commands
58+ self .allowed_auths = kwargs .get ("allowed_auths" , "publickey,password" )
5759
5860 def check_channel_exec_request (self , channel , command ):
5961 self .commands .append (command )
6062 return True
6163
64+ def check_auth_interactive (self , username : str , submethods : str ):
65+ return InteractiveQuery (
66+ "Password" , "Enter the password" , f"Password for user { USER } :"
67+ )
68+
69+ def check_auth_interactive_response (self , responses ):
70+ if responses [0 ] == PASSWORD :
71+ return paramiko .AUTH_SUCCESSFUL
72+ return paramiko .AUTH_FAILED
73+
6274 def check_auth_password (self , username , password ):
6375 if username == USER and password == PASSWORD :
6476 return paramiko .AUTH_SUCCESSFUL
@@ -76,12 +88,12 @@ def check_channel_request(self, kind, chanid):
7688 return paramiko .OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
7789
7890 def get_allowed_auths (self , username ):
79- return "password,publickey"
91+ return self . allowed_auths
8092
8193
8294@pytest .fixture
8395def ssh_conn (request : pytest .FixtureRequest ) -> dict [str , Any ]:
84- server = Server ([])
96+ server = Server ([], ** getattr ( request , "param" , {}) )
8597
8698 socket .setdefaulttimeout (10 )
8799 request .addfinalizer (lambda : socket .setdefaulttimeout (None ))
@@ -133,7 +145,8 @@ def test_run_command_password(server: Server, ssh_port: int):
133145 assert b"test_run_command_password" in server .commands
134146
135147
136- def test_run_command_no_password (server : Server , ssh_port : int ):
148+ @pytest .mark .parametrize ("ssh_conn" , [{"allowed_auths" : "publickey" }], indirect = True )
149+ def test_run_command_no_password (ssh_port : int ):
137150 vendor = AsyncSSHVendor ()
138151 with pytest .raises (AuthError ):
139152 vendor .run_command (
@@ -145,6 +158,28 @@ def test_run_command_no_password(server: Server, ssh_port: int):
145158 )
146159
147160
161+ @pytest .mark .parametrize (
162+ "ssh_conn" ,
163+ [{"allowed_auths" : "password" }, {"allowed_auths" : "keyboard-interactive" }],
164+ indirect = True ,
165+ ids = ["password" , "interactive" ],
166+ )
167+ def test_should_prompt_for_password_when_no_password_passed (
168+ mocker : MockerFixture , server : Server , ssh_port : int
169+ ):
170+ mocked_getpass = mocker .patch ("getpass.getpass" , return_value = PASSWORD )
171+ vendor = AsyncSSHVendor ()
172+ vendor .run_command (
173+ "127.0.0.1" ,
174+ "test_run_command_password" ,
175+ username = USER ,
176+ port = ssh_port ,
177+ password = None ,
178+ )
179+ assert server .commands == [b"test_run_command_password" ]
180+ mocked_getpass .asssert_called_once ()
181+
182+
148183def test_run_command_with_privkey (server : Server , ssh_port : int ):
149184 key = asyncssh .import_private_key (CLIENT_KEY )
150185
0 commit comments