Skip to content

Commit 5fd4f94

Browse files
zigaLuksicMatic Lubej
and
Matic Lubej
authored
Adapt Config to be a dataclass (#437)
* switch to dataclass implementation * adapt rest of code * fix tests * fix init parameters and add tests for hierarchy * remove redundant pytint comment * Update sentinelhub/config.py Co-authored-by: Matic Lubej <[email protected]> * adjust precedence --------- Co-authored-by: Matic Lubej <[email protected]>
1 parent 6bba467 commit 5fd4f94

File tree

5 files changed

+80
-125
lines changed

5 files changed

+80
-125
lines changed

Diff for: sentinelhub/commands.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def main_help() -> None:
3131

3232
def _config_options(func: FC) -> FC:
3333
"""A helper function which joins click.option functions of each parameter from config.json"""
34-
for param in SHConfig().get_params()[-1::-1]:
34+
for param in list(SHConfig().to_dict())[-1::-1]:
3535
func = click.option(f"--{param}", param, help=f"Set new values to configuration parameter `{param}`")(func)
3636
return func
3737

@@ -66,12 +66,9 @@ def config(show: bool, profile: str, **params: Any) -> None:
6666

6767
sh_config.save(profile=profile)
6868

69-
for param in sh_config.get_params():
70-
value = getattr(sh_config, param)
69+
for param, value in sh_config.to_dict(mask_credentials=False).items():
7170
if value != getattr(old_config, param):
72-
if isinstance(value, str):
73-
value = f"'{value}'"
74-
click.echo(f"The value of parameter `{param}` was updated to {value}")
71+
click.echo(f"The value of parameter `{param}` was updated to {repr(value)}")
7572

7673
if show:
7774
unmasked_str_repr = json.dumps(sh_config.to_dict(mask_credentials=False), indent=2)

Diff for: sentinelhub/config.py

+56-106
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
import copy
77
import json
8-
import numbers
98
import os
9+
from dataclasses import asdict, dataclass
1010
from pathlib import Path
11-
from typing import Dict, Tuple, Union
11+
from typing import Any, Dict, Optional, Union
1212

1313
import tomli
1414
import tomli_w
@@ -19,7 +19,37 @@
1919
SH_CLIENT_SECRET_ENV_VAR = "SH_CLIENT_SECRET"
2020

2121

22-
class SHConfig: # pylint: disable=too-many-instance-attributes
22+
@dataclass(repr=False)
23+
class _SHConfig:
24+
instance_id: str = ""
25+
sh_client_id: str = ""
26+
sh_client_secret: str = ""
27+
sh_base_url: str = "https://services.sentinel-hub.com"
28+
sh_auth_base_url: str = "https://services.sentinel-hub.com"
29+
geopedia_wms_url: str = "https://service.geopedia.world"
30+
geopedia_rest_url: str = "https://www.geopedia.world/rest"
31+
aws_access_key_id: str = ""
32+
aws_secret_access_key: str = ""
33+
aws_session_token: str = ""
34+
aws_metadata_url: str = "https://roda.sentinel-hub.com"
35+
aws_s3_l1c_bucket: str = "sentinel-s2-l1c"
36+
aws_s3_l2a_bucket: str = "sentinel-s2-l2a"
37+
opensearch_url: str = "http://opensearch.sentinel-hub.com/resto/api/collections/Sentinel2"
38+
max_wfs_records_per_query: int = 100
39+
max_opensearch_records_per_query: int = 500 # pylint: disable=invalid-name
40+
max_download_attempts: int = 4
41+
download_sleep_time: float = 5.0
42+
download_timeout_seconds: float = 120.0
43+
number_of_download_processes: int = 1
44+
45+
def __post_init__(self) -> None:
46+
if self.max_wfs_records_per_query > 100:
47+
raise ValueError("Value of config parameter `max_wfs_records_per_query` must be at most 100")
48+
if self.max_opensearch_records_per_query > 500:
49+
raise ValueError("Value of config parameter `max_opensearch_records_per_query` must be at most 500")
50+
51+
52+
class SHConfig(_SHConfig):
2353
"""A sentinelhub-py package configuration class.
2454
2555
The class reads the configurable settings from ``config.toml`` file on initialization:
@@ -51,14 +81,7 @@ class SHConfig: # pylint: disable=too-many-instance-attributes
5181
- `download_timeout_seconds`: Maximum number of seconds before download attempt is canceled.
5282
- `number_of_download_processes`: Number of download processes, used to calculate rate-limit sleep time.
5383
54-
For manual modification of `config.toml` you can see the expected location of the file via
55-
`SHConfig.get_config_location()`.
56-
57-
Usage in the code:
58-
59-
* ``SHConfig().sh_base_url``
60-
* ``SHConfig().instance_id``
61-
84+
The location of `config.toml` for manual modification can be found with `SHConfig.get_config_location()`.
6285
"""
6386

6487
CREDENTIALS = (
@@ -69,132 +92,64 @@ class SHConfig: # pylint: disable=too-many-instance-attributes
6992
"aws_secret_access_key",
7093
"aws_session_token",
7194
)
72-
OTHER_PARAMS = (
73-
"sh_base_url",
74-
"sh_auth_base_url",
75-
"geopedia_wms_url",
76-
"geopedia_rest_url",
77-
"aws_metadata_url",
78-
"aws_s3_l1c_bucket",
79-
"aws_s3_l2a_bucket",
80-
"opensearch_url",
81-
"max_wfs_records_per_query",
82-
"max_opensearch_records_per_query",
83-
"max_download_attempts",
84-
"download_sleep_time",
85-
"download_timeout_seconds",
86-
"number_of_download_processes",
87-
)
8895

89-
def __init__(self, profile: str = DEFAULT_PROFILE, *, use_defaults: bool = False):
96+
def __init__(self, profile: Optional[str] = None, *, use_defaults: bool = False, **kwargs: Any):
9097
"""
91-
:param profile: Specifies which profile to load form the configuration file. The environment variable
92-
SH_USER_PROFILE has precedence.
98+
:param profile: Specifies which profile to load from the configuration file. Has precedence over the environment
99+
variable `SH_USER_PROFILE`.
93100
:param use_defaults: Does not load the configuration file, returns config object with defaults only.
101+
:param kwargs: Any fields of `SHConfig` to be updated. Overrides settings from `config.toml` and environment.
94102
"""
95-
96-
self.instance_id: str = ""
97-
self.sh_client_id: str = ""
98-
self.sh_client_secret: str = ""
99-
self.sh_base_url: str = "https://services.sentinel-hub.com"
100-
self.sh_auth_base_url: str = "https://services.sentinel-hub.com"
101-
self.geopedia_wms_url: str = "https://service.geopedia.world"
102-
self.geopedia_rest_url: str = "https://www.geopedia.world/rest"
103-
self.aws_access_key_id: str = ""
104-
self.aws_secret_access_key: str = ""
105-
self.aws_session_token: str = ""
106-
self.aws_metadata_url: str = "https://roda.sentinel-hub.com"
107-
self.aws_s3_l1c_bucket: str = "sentinel-s2-l1c"
108-
self.aws_s3_l2a_bucket: str = "sentinel-s2-l2a"
109-
self.opensearch_url: str = "http://opensearch.sentinel-hub.com/resto/api/collections/Sentinel2"
110-
self.max_wfs_records_per_query: int = 100
111-
self.max_opensearch_records_per_query: int = 500 # pylint: disable=invalid-name
112-
self.max_download_attempts: int = 4
113-
self.download_sleep_time: float = 5.0
114-
self.download_timeout_seconds: float = 120.0
115-
self.number_of_download_processes: int = 1
116-
117-
profile = os.environ.get(SH_PROFILE_ENV_VAR, default=profile)
103+
if profile is None:
104+
profile = os.environ.get(SH_PROFILE_ENV_VAR, default=DEFAULT_PROFILE)
118105

119106
if not use_defaults:
107+
env_kwargs = {
108+
"sh_client_id": os.environ.get(SH_CLIENT_ID_ENV_VAR),
109+
"sh_client_secret": os.environ.get(SH_CLIENT_SECRET_ENV_VAR),
110+
}
111+
env_kwargs = {k: v for k, v in env_kwargs.items() if v is not None}
112+
120113
# load from config.toml
121-
loaded_instance = SHConfig.load(profile=profile) # user parameters validated in here already
122-
for param in SHConfig.get_params():
123-
setattr(self, param, getattr(loaded_instance, param))
124-
125-
# check env
126-
self.sh_client_id = os.environ.get(SH_CLIENT_ID_ENV_VAR, default=self.sh_client_id)
127-
self.sh_client_secret = os.environ.get(SH_CLIENT_SECRET_ENV_VAR, default=self.sh_client_secret)
128-
129-
def _validate_values(self) -> None:
130-
"""Ensures that the values are aligned with expectations."""
131-
default = SHConfig(use_defaults=True)
132-
133-
for param in self.get_params():
134-
value = getattr(self, param)
135-
default_value = getattr(default, param)
136-
param_type = type(default_value)
137-
138-
if isinstance(value, str) and value.startswith("http"):
139-
value = value.rstrip("/")
140-
if (param_type is float) and isinstance(value, numbers.Number):
141-
continue
142-
if not isinstance(value, param_type):
143-
raise ValueError(f"Value of parameter `{param}` must be of type {param_type.__name__}")
144-
if self.max_wfs_records_per_query > 100:
145-
raise ValueError("Value of config parameter `max_wfs_records_per_query` must be at most 100")
146-
if self.max_opensearch_records_per_query > 500:
147-
raise ValueError("Value of config parameter `max_opensearch_records_per_query` must be at most 500")
114+
loaded_kwargs = SHConfig.load(profile=profile).to_dict(mask_credentials=False)
115+
116+
kwargs = {**loaded_kwargs, **env_kwargs, **kwargs} # precedence: init params > env > loaded
117+
118+
super().__init__(**kwargs)
148119

149120
def __str__(self) -> str:
150-
"""Content of SHConfig in json schema. Credentials are masked for safety."""
121+
"""Content of `SHConfig` in json schema. Credentials are masked for safety."""
151122
return json.dumps(self.to_dict(mask_credentials=True), indent=2)
152123

153124
def __repr__(self) -> str:
154-
"""Representation of SHConfig parameters. Credentials are masked for safety."""
125+
"""Representation of `SHConfig`. Credentials are masked for safety."""
155126
config_dict = self.to_dict(mask_credentials=True)
156127
content = ",\n ".join(f"{key}={repr(value)}" for key, value in config_dict.items())
157128
return f"{self.__class__.__name__}(\n {content},\n)"
158129

159-
def __eq__(self, other: object) -> bool:
160-
"""Two instances of `SHConfig` are equal if all values of their parameters are equal."""
161-
if not isinstance(other, SHConfig):
162-
return False
163-
return all(getattr(self, param) == getattr(other, param) for param in self.get_params())
164-
165130
@classmethod
166131
def load(cls, profile: str = DEFAULT_PROFILE) -> SHConfig:
167132
"""Loads configuration parameters from the config file at `SHConfig.get_config_location()`.
168133
169134
:param profile: Which profile to load from the configuration file.
170135
"""
171-
config = cls(use_defaults=True)
172-
173136
filename = cls.get_config_location()
174137
if not os.path.exists(filename):
175-
config.save(profile) # store default configuration to standard location
138+
cls(use_defaults=True).save(profile) # store default configuration to standard location
176139

177140
with open(filename, "rb") as cfg_file:
178141
configurations_dict = tomli.load(cfg_file)
179142

180143
if profile not in configurations_dict:
181144
raise KeyError(f"Profile {profile} not found in configuration file.")
182145

183-
config_fields = cls.get_params()
184-
for param, value in configurations_dict[profile].items():
185-
if param in config_fields:
186-
setattr(config, param, value)
187-
188-
config._validate_values()
189-
return config
146+
return cls(use_defaults=True, **configurations_dict[profile])
190147

191148
def save(self, profile: str = DEFAULT_PROFILE) -> None:
192149
"""Saves configuration parameters to the config file at `SHConfig.get_config_location()`.
193150
194151
:param profile: Under which profile to save the configuration.
195152
"""
196-
self._validate_values()
197-
198153
file_path = Path(self.get_config_location())
199154
file_path.parent.mkdir(parents=True, exist_ok=True)
200155

@@ -218,18 +173,13 @@ def copy(self) -> SHConfig:
218173
"""Makes a copy of an instance of `SHConfig`"""
219174
return copy.copy(self)
220175

221-
@classmethod
222-
def get_params(cls) -> Tuple[str, ...]:
223-
"""Returns a list of parameter names."""
224-
return cls.CREDENTIALS + cls.OTHER_PARAMS
225-
226176
def to_dict(self, mask_credentials: bool = True) -> Dict[str, Union[str, float]]:
227177
"""Get a dictionary representation of the `SHConfig` class.
228178
229179
:param hide_credentials: Wether to mask fields containing credentials.
230180
:return: A dictionary with configuration parameters
231181
"""
232-
config_params = {param: getattr(self, param) for param in self.get_params()}
182+
config_params = asdict(self)
233183

234184
if mask_credentials:
235185
for param in self.CREDENTIALS:

Diff for: tests/api/test_byoc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
@pytest.fixture(name="config")
1818
def config_fixture() -> SHConfig:
1919
config = SHConfig()
20-
for param in config.get_params():
20+
for param in config.to_dict():
2121
env_variable = param.upper()
2222
if os.environ.get(env_variable):
2323
setattr(config, param, os.environ.get(env_variable))

Diff for: tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
def pytest_configure(config: Config) -> None:
2222
shconfig = SHConfig()
23-
for param in shconfig.get_params():
23+
for param in shconfig.to_dict():
2424
env_variable = param.upper()
2525
if os.environ.get(env_variable):
2626
setattr(shconfig, param, os.environ.get(env_variable))

Diff for: tests/test_config.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,6 @@ def test_save(restore_config_file: None) -> None:
7171
config = SHConfig()
7272
old_value = config.download_timeout_seconds
7373

74-
config.download_timeout_seconds = "abcd" # type: ignore[assignment]
75-
with pytest.raises(ValueError):
76-
config.save()
77-
7874
new_value = 150.5
7975
config.download_timeout_seconds = new_value
8076

@@ -102,10 +98,22 @@ def test_environment_variables(restore_config_file: None, monkeypatch) -> None:
10298
assert config.sh_client_secret == "bees-are-very-friendly"
10399

104100

101+
@pytest.mark.dependency(depends=["test_user_config_is_masked"])
102+
def test_initialization_with_params(monkeypatch) -> None:
103+
loaded_config = SHConfig()
104+
monkeypatch.setenv(SH_CLIENT_ID_ENV_VAR, "beekeeper")
105+
106+
config = SHConfig(sh_client_id="me", instance_id="beep", geopedia_wms_url="123")
107+
assert config.instance_id == "beep"
108+
assert config.geopedia_wms_url != loaded_config.geopedia_wms_url, "Should override settings from config.toml"
109+
assert config.sh_client_id == "me", "Should override environment variable."
110+
111+
105112
@pytest.mark.dependency(depends=["test_user_config_is_masked"])
106113
def test_profiles(restore_config_file: None) -> None:
107114
config = SHConfig()
108115
config.instance_id = "beepbeep"
116+
config.sh_client_id = "beepbeep" # also some tests with a credentials field
109117
config.save(profile="beep")
110118

111119
config.instance_id = "boopboop"
@@ -116,10 +124,10 @@ def test_profiles(restore_config_file: None) -> None:
116124
assert SHConfig.load(profile="boop").instance_id == "boopboop"
117125

118126
# save an existing profile
119-
beep_config.instance_id = "bap"
120-
assert SHConfig(profile="beep").instance_id == "beepbeep"
127+
beep_config.sh_client_id = "bap"
128+
assert SHConfig(profile="beep").sh_client_id == "beepbeep"
121129
beep_config.save(profile="beep")
122-
assert SHConfig(profile="beep").instance_id == "bap"
130+
assert SHConfig(profile="beep").sh_client_id == "bap"
123131

124132

125133
@pytest.mark.dependency(depends=["test_user_config_is_masked"])
@@ -134,7 +142,7 @@ def test_profiles_from_env(restore_config_file: None, monkeypatch) -> None:
134142

135143
monkeypatch.setenv(SH_PROFILE_ENV_VAR, "beekeeper")
136144
assert SHConfig().instance_id == "bee", "Environment profile is not used."
137-
assert SHConfig(profile=DEFAULT_PROFILE).instance_id == "bee", "Environment should override explicit profile."
145+
assert SHConfig(profile=DEFAULT_PROFILE).instance_id == "", "Explicit profile overrides environment."
138146

139147

140148
def test_loading_unknown_profile_fails() -> None:
@@ -175,8 +183,9 @@ def test_config_repr() -> None:
175183
assert config.instance_id not in config_repr, "Credentials are not masked properly."
176184
assert "*" * 16 + "a" * 4 in config_repr, "Credentials are not masked properly."
177185

178-
for param in SHConfig.OTHER_PARAMS:
179-
assert f"{param}={repr(getattr(config, param))}" in config_repr
186+
for param in config.to_dict():
187+
if param not in SHConfig.CREDENTIALS:
188+
assert f"{param}={repr(getattr(config, param))}" in config_repr
180189

181190

182191
@pytest.mark.dependency(depends=["test_user_config_is_masked"])
@@ -188,7 +197,6 @@ def test_transformation_to_dict(hide_credentials: bool) -> None:
188197

189198
config_dict = config.to_dict(hide_credentials)
190199
assert isinstance(config_dict, dict)
191-
assert tuple(config_dict) == config.get_params()
192200

193201
if hide_credentials:
194202
assert config_dict["sh_client_secret"] == "*" * 11 + "x" * 4

0 commit comments

Comments
 (0)