diff --git a/setup.py b/setup.py index 412e8823..4f8837c0 100755 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ readme = f.read() setup( - version='1.10.1', + version='1.10.2', name='testgres', packages=['testgres', 'testgres.operations', 'testgres.helpers'], description='Testing utility for PostgreSQL and its extensions', diff --git a/testgres/node.py b/testgres/node.py index e5e8fd5f..7726e731 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -529,7 +529,9 @@ def get_auth_method(t): u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host), u"host\treplication\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host), - u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host) + u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host), + u"host\tall\tall\tall\t{}\n".format(auth_host), + u"host\treplication\tall\tall\t{}\n".format(auth_host) ] # yapf: disable # write missing lines diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 697b4258..3483d256 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,8 +1,11 @@ -import logging import os +import socket import subprocess import tempfile import platform +import time + +from ..utils import reserve_port # we support both pg8000 and psycopg2 try: @@ -14,6 +17,7 @@ raise ImportError("You must have psycopg2 or pg8000 modules installed") from ..exceptions import ExecUtilException +from ..utils import reserve_port from .os_ops import OsOperations, ConnectionParams, get_default_encoding error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory'] @@ -46,6 +50,7 @@ def __init__(self, conn_params: ConnectionParams): self.host = conn_params.host self.port = conn_params.port self.ssh_key = conn_params.ssh_key + self.ssh_cmd = ["-o StrictHostKeyChecking=no"] self.ssh_args = [] if self.ssh_key: self.ssh_args += ["-i", self.ssh_key] @@ -53,8 +58,8 @@ def __init__(self, conn_params: ConnectionParams): self.ssh_args += ["-p", self.port] self.remote = True self.username = conn_params.username or self.get_user() - self.add_known_host(self.host) self.tunnel_process = None + self.tunnel_port = None def __enter__(self): return self @@ -62,31 +67,43 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close_ssh_tunnel() - def establish_ssh_tunnel(self, local_port, remote_port): + @staticmethod + def is_port_open(host, port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1) # Таймаут для попытки соединения + try: + sock.connect((host, port)) + return True + except socket.error: + return False + + def establish_ssh_tunnel(self, local_port, remote_port, host): """ Establish an SSH tunnel from a local port to a remote PostgreSQL port. """ - ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"] + if host != 'localhost': + ssh_cmd = ['-N', '-L', f"localhost:{local_port}:{host}:{remote_port}"] + else: + ssh_cmd = ['-N', '-L', f"{local_port}:{host}:{remote_port}"] self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300) + timeout = 10 + start_time = time.time() + while time.time() - start_time < timeout: + if self.is_port_open('localhost', local_port): + print("SSH tunnel established.") + return + time.sleep(0.5) + raise Exception("Failed to establish SSH tunnel within the timeout period.") def close_ssh_tunnel(self): - if hasattr(self, 'tunnel_process'): + if self.tunnel_process: self.tunnel_process.terminate() self.tunnel_process.wait() + print("SSH tunnel closed.") del self.tunnel_process else: print("No active tunnel to close.") - def add_known_host(self, host): - known_hosts_path = os.path.expanduser("~/.ssh/known_hosts") - cmd = 'ssh-keyscan -H %s >> %s' % (host, known_hosts_path) - - try: - subprocess.check_call(cmd, shell=True) - logging.info("Successfully added %s to known_hosts." % host) - except subprocess.CalledProcessError as e: - raise Exception("Failed to add %s to known_hosts. Error: %s" % (host, str(e))) - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, stderr=None, get_process=None, timeout=None): @@ -295,6 +312,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: # For scp the port is specified by a "-P" option scp_args = ['-P' if x == '-p' else x for x in self.ssh_args] + if not truncate: scp_cmd = ['scp'] + scp_args + [f"{self.username}@{self.host}:{filename}", tmp_file.name] subprocess.run(scp_cmd, check=False) # The file might not exist yet @@ -394,17 +412,25 @@ def get_process_children(self, pid): # Database control def db_connect(self, dbname, user, password=None, host="localhost", port=5432): """ - Established SSH tunnel and Connects to a PostgreSQL + Establish SSH tunnel and connect to a PostgreSQL database. """ - self.establish_ssh_tunnel(local_port=port, remote_port=5432) + local_port = reserve_port() + self.tunnel_port = local_port + self.establish_ssh_tunnel(local_port=local_port, remote_port=port, host=host) try: conn = pglib.connect( - host=host, - port=port, + host='localhost', + port=local_port, database=dbname, user=user, password=password, + timeout=10 ) + print("Database connection established successfully.") return conn except Exception as e: - raise Exception(f"Could not connect to the database. Error: {e}") + print(f"Error connecting to the database: {str(e)}") + if self.tunnel_process: + self.tunnel_process.terminate() + print("SSH tunnel closed due to connection failure.") + raise diff --git a/testgres/plugins/pg_probackup2/pg_probackup2/app.py b/testgres/plugins/pg_probackup2/pg_probackup2/app.py index 1a4ca9e7..000d4f9b 100644 --- a/testgres/plugins/pg_probackup2/pg_probackup2/app.py +++ b/testgres/plugins/pg_probackup2/pg_probackup2/app.py @@ -43,14 +43,14 @@ def __str__(self): class ProbackupApp: def __init__(self, test_class: unittest.TestCase, - pg_node, pb_log_path, test_env, auto_compress_alg, backup_dir): + pg_node, pb_log_path, test_env, auto_compress_alg, backup_dir, probackup_path=None): self.test_class = test_class self.pg_node = pg_node self.pb_log_path = pb_log_path self.test_env = test_env self.auto_compress_alg = auto_compress_alg self.backup_dir = backup_dir - self.probackup_path = init_params.probackup_path + self.probackup_path = probackup_path or init_params.probackup_path self.probackup_old_path = init_params.probackup_old_path self.remote = init_params.remote self.verbose = init_params.verbose