From e0ef65e75078ef7b79bfe326961fe567c6405688 Mon Sep 17 00:00:00 2001 From: tanaydin Date: Tue, 10 Dec 2024 04:51:59 +0100 Subject: [PATCH] Refactor SSL socket handling to use `wrap_socket_to_ssl` method Update `test_ssl.py`, `socket.py`, and `ssl.py` to replace direct calls to `ssl.wrap_socket` with a new `wrap_socket_to_ssl` function. Adjust test cases accordingly to ensure proper integration and functionality. This improves code organization and makes SSL socket wrapping more explicit.Refactor SSL socket wrapping logic Moved SSL socket wrapping to a dedicated function `wrap_socket_to_ssl` within the `ssl.py` module, improving code modularity and readability. Updated the imports and references across the codebase to reflect this change, enhancing maintainability and reducing redundancy. --- nettacker/core/lib/socket.py | 14 ++++---------- nettacker/core/lib/ssl.py | 25 +++++++------------------ tests/core/lib/test_socket.py | 2 +- tests/core/lib/test_ssl.py | 8 ++++---- 4 files changed, 16 insertions(+), 33 deletions(-) diff --git a/nettacker/core/lib/socket.py b/nettacker/core/lib/socket.py index fbf26b3e4..57066870a 100644 --- a/nettacker/core/lib/socket.py +++ b/nettacker/core/lib/socket.py @@ -6,11 +6,11 @@ import re import select import socket -import ssl import struct import time from nettacker.core.lib.base import BaseEngine, BaseLibrary +from nettacker.core.lib.ssl import wrap_socket_to_ssl from nettacker.core.utils.common import reverse_and_regex_condition log = logging.getLogger(__name__) @@ -21,21 +21,15 @@ def create_tcp_socket(host, port, timeout): socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_connection.settimeout(timeout) socket_connection.connect((host, port)) - ssl_flag = False except ConnectionRefusedError: return None try: - socket_connection = ssl.wrap_socket(socket_connection) - ssl_flag = True + return wrap_socket_to_ssl(socket_connection), True except Exception: - socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - socket_connection.settimeout(timeout) - socket_connection.connect((host, port)) - # finally: - # socket_connection.shutdown() + pass - return socket_connection, ssl_flag + return socket_connection, False class SocketLibrary(BaseLibrary): diff --git a/nettacker/core/lib/ssl.py b/nettacker/core/lib/ssl.py index 1988a411a..0df080e01 100644 --- a/nettacker/core/lib/ssl.py +++ b/nettacker/core/lib/ssl.py @@ -107,24 +107,9 @@ def test_single_cipher(host, port, cipher, timeout): return supported_ciphers, False -def create_tcp_socket(host, port, timeout): - try: - socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - socket_connection.settimeout(timeout) - socket_connection.connect((host, port)) - ssl_flag = False - except ConnectionRefusedError: - return None - - try: - socket_connection = ssl.wrap_socket(socket_connection) - ssl_flag = True - except Exception: - socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - socket_connection.settimeout(timeout) - socket_connection.connect((host, port)) - - return socket_connection, ssl_flag +def wrap_socket_to_ssl(socket_connection): + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) # noqa + return context.wrap_socket(socket_connection) def get_cert_info(cert): @@ -154,6 +139,8 @@ def get_cert_info(cert): class SslLibrary(BaseLibrary): def ssl_certificate_scan(self, host, port, timeout): + from nettacker.core.lib.socket import create_tcp_socket + tcp_socket = create_tcp_socket(host, port, timeout) if tcp_socket is None: return None @@ -175,6 +162,8 @@ def ssl_certificate_scan(self, host, port, timeout): return scan_info def ssl_version_and_cipher_scan(self, host, port, timeout): + from nettacker.core.lib.socket import create_tcp_socket + tcp_socket = create_tcp_socket(host, port, timeout) if tcp_socket is None: return None diff --git a/tests/core/lib/test_socket.py b/tests/core/lib/test_socket.py index f6ad745cc..32cc42ff3 100644 --- a/tests/core/lib/test_socket.py +++ b/tests/core/lib/test_socket.py @@ -114,7 +114,7 @@ class Substeps: class TestSocketMethod(TestCase): @patch("socket.socket") - @patch("ssl.wrap_socket") + @patch("nettacker.core.lib.socket.wrap_socket_to_ssl") def test_create_tcp_socket(self, mock_wrap, mock_socket): HOST = "example.com" PORT = 80 diff --git a/tests/core/lib/test_ssl.py b/tests/core/lib/test_ssl.py index 781a706f0..eeed4ca25 100644 --- a/tests/core/lib/test_ssl.py +++ b/tests/core/lib/test_ssl.py @@ -1,10 +1,10 @@ import ssl from unittest.mock import patch +from nettacker.core.lib.socket import create_tcp_socket from nettacker.core.lib.ssl import ( SslEngine, SslLibrary, - create_tcp_socket, is_weak_hash_algo, is_weak_ssl_version, is_weak_cipher_suite, @@ -153,7 +153,7 @@ class Substeps: class TestSocketMethod(TestCase): @patch("socket.socket") - @patch("ssl.wrap_socket") + @patch("nettacker.core.lib.socket.wrap_socket_to_ssl") def test_create_tcp_socket(self, mock_wrap, mock_socket): HOST = "example.com" PORT = 80 @@ -167,7 +167,7 @@ def test_create_tcp_socket(self, mock_wrap, mock_socket): @patch("nettacker.core.lib.ssl.is_weak_cipher_suite") @patch("nettacker.core.lib.ssl.is_weak_ssl_version") - @patch("nettacker.core.lib.ssl.create_tcp_socket") + @patch("nettacker.core.lib.socket.create_tcp_socket") def test_ssl_version_and_cipher_scan(self, mock_connection, mock_ssl_check, mock_cipher_check): library = SslLibrary() HOST = "example.com" @@ -222,7 +222,7 @@ def test_ssl_version_and_cipher_scan(self, mock_connection, mock_ssl_check, mock }, ) - @patch("nettacker.core.lib.ssl.create_tcp_socket") + @patch("nettacker.core.lib.socket.create_tcp_socket") @patch("nettacker.core.lib.ssl.is_weak_hash_algo") @patch("nettacker.core.lib.ssl.crypto.load_certificate") @patch("nettacker.core.lib.ssl.ssl.get_server_certificate")