diff --git a/nettacker/core/lib/socket.py b/nettacker/core/lib/socket.py index fbf26b3e4..57066870a 100644 --- a/nettacker/core/lib/socket.py +++ b/nettacker/core/lib/socket.py @@ -6,11 +6,11 @@ import re import select import socket -import ssl import struct import time from nettacker.core.lib.base import BaseEngine, BaseLibrary +from nettacker.core.lib.ssl import wrap_socket_to_ssl from nettacker.core.utils.common import reverse_and_regex_condition log = logging.getLogger(__name__) @@ -21,21 +21,15 @@ def create_tcp_socket(host, port, timeout): socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_connection.settimeout(timeout) socket_connection.connect((host, port)) - ssl_flag = False except ConnectionRefusedError: return None try: - socket_connection = ssl.wrap_socket(socket_connection) - ssl_flag = True + return wrap_socket_to_ssl(socket_connection), True except Exception: - socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - socket_connection.settimeout(timeout) - socket_connection.connect((host, port)) - # finally: - # socket_connection.shutdown() + pass - return socket_connection, ssl_flag + return socket_connection, False class SocketLibrary(BaseLibrary): diff --git a/nettacker/core/lib/ssl.py b/nettacker/core/lib/ssl.py index 1988a411a..0df080e01 100644 --- a/nettacker/core/lib/ssl.py +++ b/nettacker/core/lib/ssl.py @@ -107,24 +107,9 @@ def test_single_cipher(host, port, cipher, timeout): return supported_ciphers, False -def create_tcp_socket(host, port, timeout): - try: - socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - socket_connection.settimeout(timeout) - socket_connection.connect((host, port)) - ssl_flag = False - except ConnectionRefusedError: - return None - - try: - socket_connection = ssl.wrap_socket(socket_connection) - ssl_flag = True - except Exception: - socket_connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - socket_connection.settimeout(timeout) - socket_connection.connect((host, port)) - - return socket_connection, ssl_flag +def wrap_socket_to_ssl(socket_connection): + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) # noqa + return context.wrap_socket(socket_connection) def get_cert_info(cert): @@ -154,6 +139,8 @@ def get_cert_info(cert): class SslLibrary(BaseLibrary): def ssl_certificate_scan(self, host, port, timeout): + from nettacker.core.lib.socket import create_tcp_socket + tcp_socket = create_tcp_socket(host, port, timeout) if tcp_socket is None: return None @@ -175,6 +162,8 @@ def ssl_certificate_scan(self, host, port, timeout): return scan_info def ssl_version_and_cipher_scan(self, host, port, timeout): + from nettacker.core.lib.socket import create_tcp_socket + tcp_socket = create_tcp_socket(host, port, timeout) if tcp_socket is None: return None diff --git a/tests/core/lib/test_socket.py b/tests/core/lib/test_socket.py index f6ad745cc..32cc42ff3 100644 --- a/tests/core/lib/test_socket.py +++ b/tests/core/lib/test_socket.py @@ -114,7 +114,7 @@ class Substeps: class TestSocketMethod(TestCase): @patch("socket.socket") - @patch("ssl.wrap_socket") + @patch("nettacker.core.lib.socket.wrap_socket_to_ssl") def test_create_tcp_socket(self, mock_wrap, mock_socket): HOST = "example.com" PORT = 80 diff --git a/tests/core/lib/test_ssl.py b/tests/core/lib/test_ssl.py index 781a706f0..eeed4ca25 100644 --- a/tests/core/lib/test_ssl.py +++ b/tests/core/lib/test_ssl.py @@ -1,10 +1,10 @@ import ssl from unittest.mock import patch +from nettacker.core.lib.socket import create_tcp_socket from nettacker.core.lib.ssl import ( SslEngine, SslLibrary, - create_tcp_socket, is_weak_hash_algo, is_weak_ssl_version, is_weak_cipher_suite, @@ -153,7 +153,7 @@ class Substeps: class TestSocketMethod(TestCase): @patch("socket.socket") - @patch("ssl.wrap_socket") + @patch("nettacker.core.lib.socket.wrap_socket_to_ssl") def test_create_tcp_socket(self, mock_wrap, mock_socket): HOST = "example.com" PORT = 80 @@ -167,7 +167,7 @@ def test_create_tcp_socket(self, mock_wrap, mock_socket): @patch("nettacker.core.lib.ssl.is_weak_cipher_suite") @patch("nettacker.core.lib.ssl.is_weak_ssl_version") - @patch("nettacker.core.lib.ssl.create_tcp_socket") + @patch("nettacker.core.lib.socket.create_tcp_socket") def test_ssl_version_and_cipher_scan(self, mock_connection, mock_ssl_check, mock_cipher_check): library = SslLibrary() HOST = "example.com" @@ -222,7 +222,7 @@ def test_ssl_version_and_cipher_scan(self, mock_connection, mock_ssl_check, mock }, ) - @patch("nettacker.core.lib.ssl.create_tcp_socket") + @patch("nettacker.core.lib.socket.create_tcp_socket") @patch("nettacker.core.lib.ssl.is_weak_hash_algo") @patch("nettacker.core.lib.ssl.crypto.load_certificate") @patch("nettacker.core.lib.ssl.ssl.get_server_certificate")