Skip to content

Commit 40d0880

Browse files
Align syntax with node-oracledb and other drivers.
1 parent 6e4214c commit 40d0880

File tree

3 files changed

+129
-73
lines changed

3 files changed

+129
-73
lines changed

Diff for: doc/src/user_guide/connection_handling.rst

+14-13
Original file line numberDiff line numberDiff line change
@@ -1227,8 +1227,8 @@ syntax is::
12271227
{
12281228
"user": "scott",
12291229
"password": {
1230-
"type": "oci-vault",
1231-
"uri": "oci.vaultsecret.my-secret-id"
1230+
"type": "ocivault",
1231+
"value": "oci.vaultsecret.my-secret-id"
12321232
"authentication": {
12331233
"method": "OCI_INSTANCE_PRINCIPAL"
12341234
}
@@ -1241,15 +1241,16 @@ syntax is::
12411241
}
12421242
}
12431243

1244-
Passwords can optionally be stored using the Azure Key Vault. To do this, you
1245-
must define the Azure Key Vault credentials in the ``password`` key. In
1246-
this, the ``azure_client_id`` and ``azure_tenant_id`` must be specified. Also,
1247-
either ``azure_client_secret`` or ``azure_client_certificate_path`` should be
1248-
specified. For example::
1244+
Passwords can optionally be stored using the Azure Key Vault. To do this,
1245+
you must ``import oracledb.plugins.azure_config_provider`` in your application and you must
1246+
define the Azure Key Vault credentials in the ``password`` key.
1247+
In this, the ``azure_client_id`` and ``azure_tenant_id`` must be specified.
1248+
Also, either ``azure_client_secret`` or ``azure_client_certificate_path``
1249+
should be specified. For example::
12491250

12501251
"password": {
1251-
"type": "azure-vault",
1252-
"uri": "<Azure Key Vault URI>",
1252+
"type": "azurevault",
1253+
"value": "<Azure Key Vault URI>",
12531254
"authentication": {
12541255
"azure_tenant_id": "<tenant_id>",
12551256
"azure_client_id": "<client_id>",
@@ -1260,8 +1261,8 @@ specified. For example::
12601261
Or::
12611262

12621263
"password": {
1263-
"type": "azure-vault",
1264-
"uri": "<Azure Key Vault URI>",
1264+
"type": "azurevault",
1265+
"value": "<Azure Key Vault URI>",
12651266
"authentication": {
12661267
"azure_tenant_id": "<tenant_id>",
12671268
"azure_client_id": "<client_id>",
@@ -1277,7 +1278,7 @@ configuration provider is:
12771278

12781279
.. code-block:: python
12791280
1280-
configociurl = "config-ociobject://abc.oraclecloud.com/n/abcnamespace/b/abcbucket/o/abcobject?oci_tenancy=abc123&oci_user=ociuser1&oci_fingerprint=ab:14:ba:13&oci_key_file=ociabc/ocikeyabc.pem"
1281+
configociurl = "config-ociobject://abc.oraclecloud.com/n/abcnamespace/b/abcbucket/o/abcobject?authentication=oci_default&oci_tenancy=abc123&oci_user=ociuser1&oci_fingerprint=ab:14:ba:13&oci_key_file=ociabc/ocikeyabc.pem"
12811282
12821283
To create a :ref:`standalone connection <standaloneconnection>` you could use
12831284
this like:
@@ -1286,7 +1287,7 @@ this like:
12861287
12871288
import oracledb.plugins.oci_config_provider
12881289
1289-
configociurl = "config-ociobject://abc.oraclecloud.com/n/abcnamespace/b/abcbucket/o/abcobject?oci_tenancy=abc123&oci_user=ociuser1&oci_fingerprint=ab:14:ba:13&oci_key_file=ociabc/ocikeyabc.pem"
1290+
configociurl = "config-ociobject://abc.oraclecloud.com/n/abcnamespace/b/abcbucket/o/abcobject?authentication=oci_default&oci_tenancy=abc123&oci_user=ociuser1&oci_fingerprint=ab:14:ba:13&oci_key_file=ociabc/ocikeyabc.pem"
12901291
12911292
connection = oracledb.connect(dsn=configociurl)
12921293

Diff for: src/oracledb/plugins/azure_config_provider.py

+75-25
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,25 @@
4747
)
4848

4949

50+
def _get_authentication_method(parameters):
51+
auth_method = parameters.get("authentication", parameters.get("method"))
52+
if auth_method is not None:
53+
auth_method = auth_method.upper()
54+
if auth_method == "AZURE_DEFAULT":
55+
auth_method = None
56+
return auth_method
57+
58+
5059
def _get_credential(parameters):
5160
"""
5261
Returns the appropriate credential given the input supplied by the original
5362
connect string.
5463
"""
5564

5665
tokens = []
57-
auth = parameters.get("authentication")
58-
if auth is not None:
59-
auth = auth.upper()
60-
if auth == "AZURE_DEFAULT":
61-
auth = None
66+
auth_method = _get_authentication_method(parameters)
6267

63-
if auth is None or auth == "AZURE_SERVICE_PRINCIPAL":
68+
if auth_method is None or auth_method == "AZURE_SERVICE_PRINCIPAL":
6469
if "azure_client_secret" in parameters:
6570
tokens.append(
6671
ClientSecretCredential(
@@ -69,7 +74,7 @@ def _get_credential(parameters):
6974
_get_required_parameter(parameters, "azure_client_secret"),
7075
)
7176
)
72-
if "azure_client_certificate_path" in parameters:
77+
elif "azure_client_certificate_path" in parameters:
7378
tokens.append(
7479
CertificateCredential(
7580
_get_required_parameter(parameters, "azure_tenant_id"),
@@ -79,25 +84,79 @@ def _get_credential(parameters):
7984
),
8085
)
8186
)
82-
if auth is None or auth == "AZURE_MANAGED_IDENTITY":
87+
if auth_method is None or auth_method == "AZURE_MANAGED_IDENTITY":
8388
client_id = parameters.get("azure_managed_identity_client_id")
8489
if client_id is not None:
8590
tokens.append(ManagedIdentityCredential(client_id=client_id))
8691

8792
if len(tokens) == 0:
88-
message = "Authentication options not available in Connection String"
93+
message = (
94+
"Authentication options were not available in Connection String"
95+
)
8996
raise Exception(message)
9097
elif len(tokens) == 1:
9198
return tokens[0]
9299
tokens.append(EnvironmentCredential())
93100
return ChainedTokenCredential(*tokens)
94101

95102

96-
def _get_required_parameter(parameters, name):
103+
def _get_password(pwd_string, parameters):
104+
try:
105+
pwd = json.loads(pwd_string)
106+
except json.JSONDecodeError:
107+
message = (
108+
"Password is expected to be JSON"
109+
" containing Azure Vault details."
110+
)
111+
raise Exception(message)
112+
113+
pwd["value"] = pwd.pop("uri")
114+
pwd["type"] = "azurevault"
115+
116+
# make authentication section
117+
pwd["authentication"] = authentication = {}
118+
119+
authentication["method"] = auth_method = _get_authentication_method(
120+
parameters
121+
)
122+
123+
if auth_method is None or auth_method == "AZURE_SERVICE_PRINCIPAL":
124+
if "azure_client_secret" in parameters:
125+
authentication["azure_tenant_id"] = _get_required_parameter(
126+
parameters, "azure_tenant_id"
127+
)
128+
authentication["azure_client_id"] = _get_required_parameter(
129+
parameters, "azure_client_id"
130+
)
131+
authentication["azure_client_secret"] = _get_required_parameter(
132+
parameters, "azure_client_secret"
133+
)
134+
135+
elif "azure_client_certificate_path" in parameters:
136+
authentication["azure_tenant_id"] = (
137+
_get_required_parameter(parameters, "azure_tenant_id"),
138+
)
139+
authentication["azure_client_id"] = (
140+
_get_required_parameter(parameters, "azure_client_id"),
141+
)
142+
authentication["azure_client_certificate_path"] = (
143+
_get_required_parameter(
144+
parameters, "azure_client_certificate_path"
145+
)
146+
)
147+
148+
if auth_method is None or auth_method == "AZURE_MANAGED_IDENTITY":
149+
authentication["azure_managed_identity_client_id"] = parameters.get(
150+
"azure_managed_identity_client_id"
151+
)
152+
return pwd
153+
154+
155+
def _get_required_parameter(parameters, name, location="connection string"):
97156
try:
98157
return parameters[name]
99158
except KeyError:
100-
message = f'Parameter named "{name}" missing from connection string'
159+
message = f'Parameter named "{name}" is missing from {location}'
101160
raise Exception(message) from None
102161

103162

@@ -134,7 +193,7 @@ def _parse_parameters(protocol_arg: str) -> dict:
134193

135194

136195
def password_type_azure_vault_hook(args):
137-
uri = _get_required_parameter(args, "uri")
196+
uri = _get_required_parameter(args, "value", '"password" key section')
138197
credential = args.get("credential")
139198

140199
if credential is None:
@@ -144,7 +203,7 @@ def password_type_azure_vault_hook(args):
144203
auth = args.get("authentication")
145204
if auth is None:
146205
raise Exception(
147-
"Azure Vault authentication details are not provided."
206+
"Azure Vault authentication details were not provided."
148207
)
149208
credential = _get_credential(auth)
150209

@@ -182,17 +241,8 @@ def _process_config(parameters, connect_params):
182241
config["user"] = _get_setting(client, key, "user", label, required=False)
183242
pwd = _get_setting(client, key, "password", label, required=False)
184243
if pwd is not None:
185-
try:
186-
pwd = json.loads(pwd)
187-
pwd["type"] = "azure-vault"
188-
pwd["credential"] = credential
189-
except json.JSONDecodeError:
190-
message = (
191-
"Password is expected to be JSON"
192-
" containing Azure Vault details."
193-
)
194-
raise Exception(message)
195-
config["password"] = pwd
244+
config["password"] = _get_password(pwd, parameters)
245+
196246
config["config_time_to_live"] = _get_setting(
197247
client, key, "config_time_to_live", label, required=False
198248
)
@@ -217,5 +267,5 @@ def config_azure_hook(protocol, protocol_arg, connect_params):
217267
_process_config(parameters, connect_params)
218268

219269

220-
oracledb.register_password_type("azure-vault", password_type_azure_vault_hook)
270+
oracledb.register_password_type("azurevault", password_type_azure_vault_hook)
221271
oracledb.register_protocol("config-azure", config_azure_hook)

Diff for: src/oracledb/plugins/oci_config_provider.py

+40-35
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,11 @@ def _get_config(parameters, connect_params):
9090
if connect_params.user is None:
9191
config["user"] = settings.get("user")
9292
if "password" in settings:
93-
pwd = settings["password"]
94-
if settings["password"]["type"] == "oci-vault":
95-
pwd["credential"] = credential
96-
pwd["auth"] = auth_method
97-
98-
# password should be stored in JSON and not plain text.
99-
config["password"] = pwd
93+
config["password"] = pwd = settings["password"]
94+
if pwd["type"] == "ocivault":
95+
authentication = pwd.setdefault("authentication", {})
96+
authentication.setdefault("method", auth_method)
97+
authentication["credential"] = credential
10098

10199
# config cache settings
102100
config["config_time_to_live"] = settings.get("config_time_to_live")
@@ -116,12 +114,18 @@ def _get_credential(parameters):
116114
Returns the appropriate credential given the input supplied by the original
117115
connect string.
118116
"""
119-
auth = parameters.get("authentication")
120-
if auth is not None:
121-
auth = auth.upper()
117+
auth_method = parameters.get("authentication", parameters.get("method"))
118+
119+
if auth_method is not None:
120+
auth_method = auth_method.upper()
121+
122+
# if region is not in connection string, retrieve from object server name.
123+
region = parameters.get(
124+
"oci_region", _retrieve_region(parameters.get("objservername"))
125+
)
122126

123127
try:
124-
if auth is None or auth == "OCI_DEFAULT":
128+
if auth_method is None or auth_method == "OCI_DEFAULT":
125129
# Default Authentication
126130
# default path ~/.oci/config
127131
return oci_from_file(), None
@@ -136,30 +140,30 @@ def _get_credential(parameters):
136140
fingerprint=parameters["oci_fingerprint"],
137141
key_file=parameters["oci_key_file"],
138142
private_key_content=public_key,
139-
region=_retrieve_region(parameters.get("objservername")),
143+
region=region,
140144
)
141145
return provider, None
142146

143-
if auth == "OCI_INSTANCE_PRINCIPAL":
147+
if auth_method == "OCI_INSTANCE_PRINCIPAL":
144148
signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
145149
return (
146-
dict(region=_retrieve_region(parameters.get("objservername"))),
150+
dict(region=region),
147151
signer,
148152
)
149153

150-
elif auth == "OCI_RESOURCE_PRINCIPAL":
151-
rps = oci.auth.signers.get_resource_principals_signer()
152-
return {}, rps
154+
elif auth_method == "OCI_RESOURCE_PRINCIPAL":
155+
signer = oci.auth.signers.get_resource_principals_signer()
156+
return {}, signer
153157
else:
154-
msg = "Authentication options not available in Connection String"
158+
msg = "Authentication options were not available in Connection String"
155159
raise Exception(msg)
156160

157161

158-
def _get_required_parameter(parameters, name):
162+
def _get_required_parameter(parameters, name, location="connection string"):
159163
try:
160164
return parameters[name]
161165
except KeyError:
162-
message = f'Parameter named "{name}" missing from connect string'
166+
message = f'Parameter named "{name}" is missing from {location}'
163167
raise Exception(message) from None
164168

165169

@@ -186,23 +190,24 @@ def _parse_parameters(protocol_arg: str) -> dict:
186190

187191

188192
def password_type_oci_vault_hook(args):
189-
secret_id = args.get("uri")
190-
credential = args.get("credential")
193+
secret_id = _get_required_parameter(
194+
args, "value", '"password" key section'
195+
)
196+
authentication = args.get("authentication")
197+
if authentication is None:
198+
raise Exception(
199+
"OCI Key Vault authentication details were not provided."
200+
)
191201

192202
# if credentials are not present, create credentials with given
193203
# authentication details.
204+
credential = authentication.get("credential")
194205
if credential is None:
195-
auth = _get_required_parameter(args, "auth")
196-
if auth is None:
197-
raise Exception(
198-
"OCI Key Vault authentication details are not provided."
199-
)
200-
credential, signer = _get_credential(auth)
201-
auth_method = args.get("auth")
206+
credential, signer = _get_credential(authentication)
202207

208+
auth_method = authentication.get("method")
203209
if auth_method is not None:
204210
auth_method = auth_method.upper()
205-
206211
if auth_method is None or auth_method == "OCI_DEFAULT":
207212
secret_client_oci = oci_secrets_client(credential)
208213
elif auth_method == "OCI_INSTANCE_PRINCIPAL":
@@ -216,18 +221,18 @@ def password_type_oci_vault_hook(args):
216221
config=credential, signer=signer
217222
)
218223

219-
get_secret_bundle_request = {"secret_id": secret_id}
220224
get_secret_bundle_response = secret_client_oci.get_secret_bundle(
221-
**get_secret_bundle_request
225+
secret_id=secret_id
222226
)
223227
# decoding the vault content
224228
b64content = get_secret_bundle_response.data.secret_bundle_content.content
225229
return base64.b64decode(b64content).decode()
226230

227231

228232
def _retrieve_region(objservername):
229-
arr = objservername.split(".")
230-
return arr[1].lower().replace("_", "-")
233+
if objservername is not None:
234+
arr = objservername.split(".")
235+
return arr[1].lower().replace("_", "-")
231236

232237

233238
def _stream_to_string(stream):
@@ -244,5 +249,5 @@ def config_oci_hook(
244249
_get_config(parameters, connect_params)
245250

246251

247-
oracledb.register_password_type("oci-vault", password_type_oci_vault_hook)
252+
oracledb.register_password_type("ocivault", password_type_oci_vault_hook)
248253
oracledb.register_protocol("config-ociobject", config_oci_hook)

0 commit comments

Comments
 (0)