5
5
6
6
import copy
7
7
import json
8
- import numbers
9
8
import os
9
+ from dataclasses import asdict , dataclass
10
10
from pathlib import Path
11
- from typing import Dict , Tuple , Union
11
+ from typing import Any , Dict , Optional , Union
12
12
13
13
import tomli
14
14
import tomli_w
19
19
SH_CLIENT_SECRET_ENV_VAR = "SH_CLIENT_SECRET"
20
20
21
21
22
- class SHConfig : # pylint: disable=too-many-instance-attributes
22
+ @dataclass (repr = False )
23
+ class _SHConfig :
24
+ instance_id : str = ""
25
+ sh_client_id : str = ""
26
+ sh_client_secret : str = ""
27
+ sh_base_url : str = "https://services.sentinel-hub.com"
28
+ sh_auth_base_url : str = "https://services.sentinel-hub.com"
29
+ geopedia_wms_url : str = "https://service.geopedia.world"
30
+ geopedia_rest_url : str = "https://www.geopedia.world/rest"
31
+ aws_access_key_id : str = ""
32
+ aws_secret_access_key : str = ""
33
+ aws_session_token : str = ""
34
+ aws_metadata_url : str = "https://roda.sentinel-hub.com"
35
+ aws_s3_l1c_bucket : str = "sentinel-s2-l1c"
36
+ aws_s3_l2a_bucket : str = "sentinel-s2-l2a"
37
+ opensearch_url : str = "http://opensearch.sentinel-hub.com/resto/api/collections/Sentinel2"
38
+ max_wfs_records_per_query : int = 100
39
+ max_opensearch_records_per_query : int = 500 # pylint: disable=invalid-name
40
+ max_download_attempts : int = 4
41
+ download_sleep_time : float = 5.0
42
+ download_timeout_seconds : float = 120.0
43
+ number_of_download_processes : int = 1
44
+
45
+ def __post_init__ (self ) -> None :
46
+ if self .max_wfs_records_per_query > 100 :
47
+ raise ValueError ("Value of config parameter `max_wfs_records_per_query` must be at most 100" )
48
+ if self .max_opensearch_records_per_query > 500 :
49
+ raise ValueError ("Value of config parameter `max_opensearch_records_per_query` must be at most 500" )
50
+
51
+
52
+ class SHConfig (_SHConfig ):
23
53
"""A sentinelhub-py package configuration class.
24
54
25
55
The class reads the configurable settings from ``config.toml`` file on initialization:
@@ -51,14 +81,7 @@ class SHConfig: # pylint: disable=too-many-instance-attributes
51
81
- `download_timeout_seconds`: Maximum number of seconds before download attempt is canceled.
52
82
- `number_of_download_processes`: Number of download processes, used to calculate rate-limit sleep time.
53
83
54
- For manual modification of `config.toml` you can see the expected location of the file via
55
- `SHConfig.get_config_location()`.
56
-
57
- Usage in the code:
58
-
59
- * ``SHConfig().sh_base_url``
60
- * ``SHConfig().instance_id``
61
-
84
+ The location of `config.toml` for manual modification can be found with `SHConfig.get_config_location()`.
62
85
"""
63
86
64
87
CREDENTIALS = (
@@ -69,132 +92,64 @@ class SHConfig: # pylint: disable=too-many-instance-attributes
69
92
"aws_secret_access_key" ,
70
93
"aws_session_token" ,
71
94
)
72
- OTHER_PARAMS = (
73
- "sh_base_url" ,
74
- "sh_auth_base_url" ,
75
- "geopedia_wms_url" ,
76
- "geopedia_rest_url" ,
77
- "aws_metadata_url" ,
78
- "aws_s3_l1c_bucket" ,
79
- "aws_s3_l2a_bucket" ,
80
- "opensearch_url" ,
81
- "max_wfs_records_per_query" ,
82
- "max_opensearch_records_per_query" ,
83
- "max_download_attempts" ,
84
- "download_sleep_time" ,
85
- "download_timeout_seconds" ,
86
- "number_of_download_processes" ,
87
- )
88
95
89
- def __init__ (self , profile : str = DEFAULT_PROFILE , * , use_defaults : bool = False ):
96
+ def __init__ (self , profile : Optional [ str ] = None , * , use_defaults : bool = False , ** kwargs : Any ):
90
97
"""
91
- :param profile: Specifies which profile to load form the configuration file. The environment variable
92
- SH_USER_PROFILE has precedence .
98
+ :param profile: Specifies which profile to load from the configuration file. Has precedence over the environment
99
+ variable `SH_USER_PROFILE` .
93
100
:param use_defaults: Does not load the configuration file, returns config object with defaults only.
101
+ :param kwargs: Any fields of `SHConfig` to be updated. Overrides settings from `config.toml` and environment.
94
102
"""
95
-
96
- self .instance_id : str = ""
97
- self .sh_client_id : str = ""
98
- self .sh_client_secret : str = ""
99
- self .sh_base_url : str = "https://services.sentinel-hub.com"
100
- self .sh_auth_base_url : str = "https://services.sentinel-hub.com"
101
- self .geopedia_wms_url : str = "https://service.geopedia.world"
102
- self .geopedia_rest_url : str = "https://www.geopedia.world/rest"
103
- self .aws_access_key_id : str = ""
104
- self .aws_secret_access_key : str = ""
105
- self .aws_session_token : str = ""
106
- self .aws_metadata_url : str = "https://roda.sentinel-hub.com"
107
- self .aws_s3_l1c_bucket : str = "sentinel-s2-l1c"
108
- self .aws_s3_l2a_bucket : str = "sentinel-s2-l2a"
109
- self .opensearch_url : str = "http://opensearch.sentinel-hub.com/resto/api/collections/Sentinel2"
110
- self .max_wfs_records_per_query : int = 100
111
- self .max_opensearch_records_per_query : int = 500 # pylint: disable=invalid-name
112
- self .max_download_attempts : int = 4
113
- self .download_sleep_time : float = 5.0
114
- self .download_timeout_seconds : float = 120.0
115
- self .number_of_download_processes : int = 1
116
-
117
- profile = os .environ .get (SH_PROFILE_ENV_VAR , default = profile )
103
+ if profile is None :
104
+ profile = os .environ .get (SH_PROFILE_ENV_VAR , default = DEFAULT_PROFILE )
118
105
119
106
if not use_defaults :
107
+ env_kwargs = {
108
+ "sh_client_id" : os .environ .get (SH_CLIENT_ID_ENV_VAR ),
109
+ "sh_client_secret" : os .environ .get (SH_CLIENT_SECRET_ENV_VAR ),
110
+ }
111
+ env_kwargs = {k : v for k , v in env_kwargs .items () if v is not None }
112
+
120
113
# load from config.toml
121
- loaded_instance = SHConfig .load (profile = profile ) # user parameters validated in here already
122
- for param in SHConfig .get_params ():
123
- setattr (self , param , getattr (loaded_instance , param ))
124
-
125
- # check env
126
- self .sh_client_id = os .environ .get (SH_CLIENT_ID_ENV_VAR , default = self .sh_client_id )
127
- self .sh_client_secret = os .environ .get (SH_CLIENT_SECRET_ENV_VAR , default = self .sh_client_secret )
128
-
129
- def _validate_values (self ) -> None :
130
- """Ensures that the values are aligned with expectations."""
131
- default = SHConfig (use_defaults = True )
132
-
133
- for param in self .get_params ():
134
- value = getattr (self , param )
135
- default_value = getattr (default , param )
136
- param_type = type (default_value )
137
-
138
- if isinstance (value , str ) and value .startswith ("http" ):
139
- value = value .rstrip ("/" )
140
- if (param_type is float ) and isinstance (value , numbers .Number ):
141
- continue
142
- if not isinstance (value , param_type ):
143
- raise ValueError (f"Value of parameter `{ param } ` must be of type { param_type .__name__ } " )
144
- if self .max_wfs_records_per_query > 100 :
145
- raise ValueError ("Value of config parameter `max_wfs_records_per_query` must be at most 100" )
146
- if self .max_opensearch_records_per_query > 500 :
147
- raise ValueError ("Value of config parameter `max_opensearch_records_per_query` must be at most 500" )
114
+ loaded_kwargs = SHConfig .load (profile = profile ).to_dict (mask_credentials = False )
115
+
116
+ kwargs = {** loaded_kwargs , ** env_kwargs , ** kwargs } # precedence: init params > env > loaded
117
+
118
+ super ().__init__ (** kwargs )
148
119
149
120
def __str__ (self ) -> str :
150
- """Content of SHConfig in json schema. Credentials are masked for safety."""
121
+ """Content of ` SHConfig` in json schema. Credentials are masked for safety."""
151
122
return json .dumps (self .to_dict (mask_credentials = True ), indent = 2 )
152
123
153
124
def __repr__ (self ) -> str :
154
- """Representation of SHConfig parameters . Credentials are masked for safety."""
125
+ """Representation of ` SHConfig` . Credentials are masked for safety."""
155
126
config_dict = self .to_dict (mask_credentials = True )
156
127
content = ",\n " .join (f"{ key } ={ repr (value )} " for key , value in config_dict .items ())
157
128
return f"{ self .__class__ .__name__ } (\n { content } ,\n )"
158
129
159
- def __eq__ (self , other : object ) -> bool :
160
- """Two instances of `SHConfig` are equal if all values of their parameters are equal."""
161
- if not isinstance (other , SHConfig ):
162
- return False
163
- return all (getattr (self , param ) == getattr (other , param ) for param in self .get_params ())
164
-
165
130
@classmethod
166
131
def load (cls , profile : str = DEFAULT_PROFILE ) -> SHConfig :
167
132
"""Loads configuration parameters from the config file at `SHConfig.get_config_location()`.
168
133
169
134
:param profile: Which profile to load from the configuration file.
170
135
"""
171
- config = cls (use_defaults = True )
172
-
173
136
filename = cls .get_config_location ()
174
137
if not os .path .exists (filename ):
175
- config .save (profile ) # store default configuration to standard location
138
+ cls ( use_defaults = True ) .save (profile ) # store default configuration to standard location
176
139
177
140
with open (filename , "rb" ) as cfg_file :
178
141
configurations_dict = tomli .load (cfg_file )
179
142
180
143
if profile not in configurations_dict :
181
144
raise KeyError (f"Profile { profile } not found in configuration file." )
182
145
183
- config_fields = cls .get_params ()
184
- for param , value in configurations_dict [profile ].items ():
185
- if param in config_fields :
186
- setattr (config , param , value )
187
-
188
- config ._validate_values ()
189
- return config
146
+ return cls (use_defaults = True , ** configurations_dict [profile ])
190
147
191
148
def save (self , profile : str = DEFAULT_PROFILE ) -> None :
192
149
"""Saves configuration parameters to the config file at `SHConfig.get_config_location()`.
193
150
194
151
:param profile: Under which profile to save the configuration.
195
152
"""
196
- self ._validate_values ()
197
-
198
153
file_path = Path (self .get_config_location ())
199
154
file_path .parent .mkdir (parents = True , exist_ok = True )
200
155
@@ -218,18 +173,13 @@ def copy(self) -> SHConfig:
218
173
"""Makes a copy of an instance of `SHConfig`"""
219
174
return copy .copy (self )
220
175
221
- @classmethod
222
- def get_params (cls ) -> Tuple [str , ...]:
223
- """Returns a list of parameter names."""
224
- return cls .CREDENTIALS + cls .OTHER_PARAMS
225
-
226
176
def to_dict (self , mask_credentials : bool = True ) -> Dict [str , Union [str , float ]]:
227
177
"""Get a dictionary representation of the `SHConfig` class.
228
178
229
179
:param hide_credentials: Wether to mask fields containing credentials.
230
180
:return: A dictionary with configuration parameters
231
181
"""
232
- config_params = { param : getattr (self , param ) for param in self . get_params ()}
182
+ config_params = asdict (self )
233
183
234
184
if mask_credentials :
235
185
for param in self .CREDENTIALS :
0 commit comments