Skip to content

Commit c9aaf8e

Browse files
authored
MLCOMPUTE-949 | Pick spark ui port from a preferred port range (#128)
* Pick spark ui port from a preferred port range * Load port range from srv configs * Bump version * Fix typo and add more comments
1 parent 653ea7e commit c9aaf8e

File tree

4 files changed

+73
-31
lines changed

4 files changed

+73
-31
lines changed

service_configuration_lib/spark_config.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020
from urllib.parse import urlparse
2121

2222
import boto3
23-
import ephemeral_port_reserve
2423
import requests
2524
import yaml
2625
from boto3 import Session
2726

27+
from service_configuration_lib import utils
2828
from service_configuration_lib.text_colors import TextColors
29-
from service_configuration_lib.utils import load_spark_srv_conf
3029

3130
AWS_CREDENTIALS_DIR = '/etc/boto_cfg/'
3231
AWS_ENV_CREDENTIALS_PROVIDER = 'com.amazonaws.auth.EnvironmentVariableCredentialsProvider'
@@ -78,7 +77,6 @@
7877

7978
SUPPORTED_CLUSTER_MANAGERS = ['kubernetes', 'local']
8079
DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'
81-
PREFERRED_SPARK_UI_PORT = 39091
8280

8381
log = logging.Logger(__name__)
8482
log.setLevel(logging.INFO)
@@ -184,11 +182,6 @@ def assume_aws_role(
184182
return resp['Credentials']
185183

186184

187-
def _pick_random_port(preferred_port: int = 0) -> int:
188-
"""Return a random port. """
189-
return ephemeral_port_reserve.reserve('0.0.0.0', preferred_port)
190-
191-
192185
def _get_k8s_docker_volumes_conf(
193186
volumes: Optional[List[Mapping[str, str]]] = None,
194187
):
@@ -418,7 +411,7 @@ def __init__(self):
418411
(
419412
self.spark_srv_conf, self.spark_constants, self.default_spark_srv_conf,
420413
self.mandatory_default_spark_srv_conf, self.spark_costs,
421-
) = load_spark_srv_conf()
414+
) = utils.load_spark_srv_conf()
422415
except Exception as e:
423416
log.error(f'Failed to load Spark srv configs: {e}')
424417

@@ -1075,9 +1068,15 @@ def get_spark_conf(
10751068
spark_app_base_name
10761069
)
10771070

1071+
# Pick a port from a pre-defined port range, which will then be used by our Jupyter
1072+
# server metric aggregator API. The aggregator API collects Prometheus metrics from multiple
1073+
# Spark sessions and exposes them through a single endpoint.
10781074
ui_port = int(
10791075
(spark_opts_from_env or {}).get('spark.ui.port') or
1080-
_pick_random_port(PREFERRED_SPARK_UI_PORT),
1076+
utils.ephemeral_port_reserve_range(
1077+
self.spark_constants.get('preferred_spark_ui_port_start'),
1078+
self.spark_constants.get('preferred_spark_ui_port_end'),
1079+
),
10811080
)
10821081

10831082
spark_conf = {**(spark_opts_from_env or {}), **_filter_user_spark_opts(user_spark_opts)}

service_configuration_lib/utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
import contextlib
2+
import errno
13
import logging
4+
from socket import error as SocketError
5+
from socket import SO_REUSEADDR
6+
from socket import socket
7+
from socket import SOL_SOCKET
28
from typing import Mapping
39
from typing import Tuple
410

511
import yaml
612

13+
714
DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'
815

916
log = logging.Logger(__name__)
@@ -28,3 +35,47 @@ def load_spark_srv_conf(preset_values=None) -> Tuple[Mapping, Mapping, Mapping,
2835
except Exception as e:
2936
log.warning(f'Failed to load {DEFAULT_SPARK_RUN_CONFIG}: {e}')
3037
raise e
38+
39+
40+
def ephemeral_port_reserve_range(preferred_port_start: int, preferred_port_end: int, ip='127.0.0.1') -> int:
41+
"""
42+
Pick an available from the preferred port range. If all ports from the port range are unavailable,
43+
pick a random available ephemeral port.
44+
45+
Implemetation referenced from upstream:
46+
https://github.com/Yelp/ephemeral-port-reserve/blob/master/ephemeral_port_reserve.py
47+
48+
This function is used to pick a Spark UI (API) port from a pre-defined port range which is used by
49+
our Jupyter server metric aggregator. The aggregator API collects Prometheus metrics from multiple
50+
Spark sessions and exposes them through a single endpoint.
51+
"""
52+
assert preferred_port_start <= preferred_port_end
53+
54+
with contextlib.closing(socket()) as s:
55+
binded = False
56+
for port in range(preferred_port_start, preferred_port_end + 1):
57+
s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
58+
try:
59+
s.bind((ip, port))
60+
binded = True
61+
break
62+
except SocketError as e:
63+
# socket.error: EADDRINUSE Address already in use
64+
if e.errno == errno.EADDRINUSE:
65+
continue
66+
else:
67+
raise
68+
if not binded:
69+
s.bind((ip, 0))
70+
71+
# the connect below deadlocks on kernel >= 4.4.0 unless this arg is greater than zero
72+
s.listen(1)
73+
74+
sockname = s.getsockname()
75+
76+
# these three are necessary just to get the port into a TIME_WAIT state
77+
with contextlib.closing(socket()) as s2:
78+
s2.connect(sockname)
79+
sock, _ = s.accept()
80+
with contextlib.closing(sock):
81+
return sockname[1]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
setup(
1919
name='service-configuration-lib',
20-
version='2.18.6',
20+
version='2.18.7',
2121
provides=['service_configuration_lib'],
2222
description='Start, stop, and inspect Yelp SOA services',
2323
url='https://github.com/Yelp/service_configuration_lib',

tests/spark_config_test.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,6 @@ def test_fail(self, tmpdir):
130130
spark_config.get_aws_credentials(aws_credentials_yaml=str(fp))
131131

132132

133-
def test_pick_random_port():
134-
with mock.patch('ephemeral_port_reserve.reserve') as mock_reserve:
135-
preferred_port = 33123 # Any ephemeral port for testing
136-
port = spark_config._pick_random_port(preferred_port)
137-
(host, prefer_port), _ = mock_reserve.call_args
138-
assert host == '0.0.0.0'
139-
assert prefer_port >= 33000
140-
assert port == mock_reserve.return_value
141-
142-
143133
class MockConfigFunction:
144134

145135
def __init__(self, mock_obj, mock_func, return_value):
@@ -1092,13 +1082,13 @@ def test_convert_user_spark_opts_value_str(self):
10921082
}
10931083

10941084
@pytest.fixture
1095-
def mock_pick_random_port(self):
1085+
def mock_ephemeral_port_reserve_range(self):
10961086
port = '12345'
1097-
with mock.patch.object(spark_config, '_pick_random_port', return_value=port):
1087+
with mock.patch.object(utils, 'ephemeral_port_reserve_range', return_value=port):
10981088
yield port
10991089

11001090
@pytest.fixture(params=[None, '23456'])
1101-
def ui_port(self, request, mock_pick_random_port):
1091+
def ui_port(self, request):
11021092
return request.param
11031093

11041094
@pytest.fixture(params=[None, 'test_app_name_from_env'])
@@ -1111,8 +1101,8 @@ def spark_opts_from_env(self, request, ui_port):
11111101
return spark_opts or None
11121102

11131103
@pytest.fixture
1114-
def assert_ui_port(self, spark_opts_from_env, ui_port, mock_pick_random_port):
1115-
expected_output = ui_port if ui_port else mock_pick_random_port
1104+
def assert_ui_port(self, ui_port, mock_ephemeral_port_reserve_range):
1105+
expected_output = ui_port or mock_ephemeral_port_reserve_range
11161106

11171107
def verify(output):
11181108
key = 'spark.ui.port'
@@ -1125,13 +1115,13 @@ def user_spark_opts(self, request):
11251115
return request.param
11261116

11271117
@pytest.fixture
1128-
def assert_app_name(self, spark_opts_from_env, user_spark_opts, ui_port, mock_pick_random_port):
1118+
def assert_app_name(self, spark_opts_from_env, user_spark_opts, ui_port, mock_ephemeral_port_reserve_range):
11291119
expected_output = (spark_opts_from_env or {}).get('spark.app.name')
11301120
if not expected_output:
11311121
expected_output = (
11321122
(user_spark_opts or {}).get('spark.app.name') or
11331123
self.spark_app_base_name
1134-
) + '_' + (ui_port or mock_pick_random_port) + '_123'
1124+
) + '_' + (ui_port or mock_ephemeral_port_reserve_range) + '_123'
11351125

11361126
def verify(output):
11371127
key = 'spark.app.name'
@@ -1189,8 +1179,8 @@ def _get_k8s_base_volumes(self):
11891179
]
11901180

11911181
@pytest.fixture
1192-
def assert_kubernetes_conf(self, base_volumes, ui_port, mock_pick_random_port):
1193-
expected_ui_port = ui_port if ui_port else mock_pick_random_port
1182+
def assert_kubernetes_conf(self, base_volumes, ui_port, mock_ephemeral_port_reserve_range):
1183+
expected_ui_port = ui_port if ui_port else mock_ephemeral_port_reserve_range
11941184

11951185
expected_output = {
11961186
'spark.master': f'k8s://https://k8s.{self.cluster}.paasta:6443',
@@ -1238,7 +1228,6 @@ def test_leaders_get_spark_conf_kubernetes(
12381228
self,
12391229
user_spark_opts,
12401230
spark_opts_from_env,
1241-
ui_port,
12421231
base_volumes,
12431232
mock_append_spark_prometheus_conf,
12441233
mock_append_event_log_conf,
@@ -1248,6 +1237,7 @@ def test_leaders_get_spark_conf_kubernetes(
12481237
mock_get_dra_configs,
12491238
mock_update_spark_srv_configs,
12501239
mock_spark_srv_conf_file,
1240+
mock_ephemeral_port_reserve_range,
12511241
mock_time,
12521242
assert_ui_port,
12531243
assert_app_name,
@@ -1341,6 +1331,7 @@ def test_show_console_progress_jupyter(
13411331
mock_adjust_spark_requested_resources_kubernetes,
13421332
mock_get_dra_configs,
13431333
mock_spark_srv_conf_file,
1334+
mock_ephemeral_port_reserve_range,
13441335
mock_time,
13451336
assert_ui_port,
13461337
assert_app_name,
@@ -1382,6 +1373,7 @@ def test_local_spark(
13821373
mock_get_dra_configs,
13831374
mock_update_spark_srv_configs,
13841375
mock_spark_srv_conf_file,
1376+
mock_ephemeral_port_reserve_range,
13851377
mock_time,
13861378
assert_ui_port,
13871379
assert_app_name,

0 commit comments

Comments
 (0)