Skip to content

Commit caff8f7

Browse files
authored
Merge pull request #57 from databrickslabs/databricks-m2m-oauth
Added support for Databricks M2M OAuth
2 parents c79b9c5 + d2311d0 commit caff8f7

19 files changed

Lines changed: 4112 additions & 95 deletions

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
run: pip install -r requirements-dev.txt
2323
- name: Run tests and generate coverage
2424
working-directory: ./
25-
run: pytest --cov-config=tests/.coveragerc --cov=app/bin --cov-report=xml:coverage.xml tests
25+
run: pytest --cov-config=tests/.coveragerc --cov=app/bin --cov-report=xml:coverage.xml --cov-fail-under=80 tests
2626
- name: Publish test coverage
2727
uses: codecov/codecov-action@v1
2828

app/README.md

Lines changed: 61 additions & 14 deletions
Large diffs are not rendered by default.

app/appserver/static/js/build/custom/auth_select_hook.js

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@ class AuthSelectHook {
99

1010
onChange(field, value, dataDict) {
1111
if (field == 'auth_type') {
12-
if (value == 'AAD') {
13-
this.toggleAADFields(true);
14-
} else {
15-
this.toggleAADFields(false);
16-
}
12+
this.toggleAuthFields(value);
1713
}
1814
if (field == 'config_for_dbquery') {
1915
if (value == 'interactive_cluster') {
@@ -26,11 +22,7 @@ class AuthSelectHook {
2622

2723
onRender() {
2824
var selected_auth = this.state.data.auth_type.value;
29-
if (selected_auth == 'AAD') {
30-
this.toggleAADFields(true);
31-
} else {
32-
this.toggleAADFields(false);
33-
}
25+
this.toggleAuthFields(selected_auth);
3426
}
3527

3628
hideWarehouseField(state) {
@@ -41,13 +33,25 @@ class AuthSelectHook {
4133
});
4234
}
4335

44-
toggleAADFields(state) {
36+
toggleAuthFields(authType) {
4537
this.util.setState((prevState) => {
4638
let data = {...prevState.data };
47-
data.aad_client_id.display = state;
48-
data.aad_tenant_id.display = state;
49-
data.aad_client_secret.display = state;
50-
data.databricks_pat.display = !state;
39+
40+
// OAuth M2M fields
41+
const showOAuth = (authType === 'OAUTH_M2M');
42+
data.oauth_client_id.display = showOAuth;
43+
data.oauth_client_secret.display = showOAuth;
44+
45+
// AAD fields
46+
const showAAD = (authType === 'AAD');
47+
data.aad_client_id.display = showAAD;
48+
data.aad_tenant_id.display = showAAD;
49+
data.aad_client_secret.display = showAAD;
50+
51+
// PAT field
52+
const showPAT = (authType === 'PAT');
53+
data.databricks_pat.display = showPAT;
54+
5155
return { data }
5256
});
5357
}

app/appserver/static/js/build/globalConfig.json

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@
146146
{
147147
"label": "Azure Active Directory",
148148
"value": "AAD"
149+
},
150+
{
151+
"label": "Databricks OAuth (M2M)",
152+
"value": "OAUTH_M2M"
149153
}
150154
]
151155
},
@@ -205,6 +209,42 @@
205209
"placeholder": "required"
206210
}
207211
},
212+
{
213+
"field": "oauth_client_id",
214+
"label": "OAuth Client ID",
215+
"type": "text",
216+
"help": "Enter the Client ID from your Databricks service principal.",
217+
"required": false,
218+
"defaultValue": "",
219+
"encrypted": false,
220+
"validators": [{
221+
"type": "string",
222+
"minLength": 0,
223+
"maxLength": 200,
224+
"errorMsg": "Max length of OAuth Client ID is 200"
225+
}],
226+
"options": {
227+
"placeholder": "required"
228+
}
229+
},
230+
{
231+
"field": "oauth_client_secret",
232+
"label": "OAuth Client Secret",
233+
"type": "text",
234+
"help": "Enter the OAuth secret from your Databricks service principal.",
235+
"required": false,
236+
"defaultValue": "",
237+
"encrypted": true,
238+
"validators": [{
239+
"type": "string",
240+
"minLength": 0,
241+
"maxLength": 500,
242+
"errorMsg": "Max length of OAuth Client Secret is 500"
243+
}],
244+
"options": {
245+
"placeholder": "required"
246+
}
247+
},
208248
{
209249
"field": "databricks_pat",
210250
"label": "Databricks Access Token",
@@ -318,7 +358,7 @@
318358
"field": "use_for_oauth",
319359
"label": "Use Proxy for OAuth",
320360
"defaultValue": 0,
321-
"tooltip": "Check this box if you want to use proxy just for AAD token generation (https://login.microsoftonline.com/). All other network calls will skip the proxy even if it's enabled.",
361+
"tooltip": "Check this box if you want to use proxy just for Azure AD token generation (https://login.microsoftonline.com/). All other network calls (including Databricks API and OAuth M2M) will skip the proxy even if it's enabled.",
322362
"type": "checkbox"
323363
}
324364
],

app/bin/TA_Databricks_rh_account.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,35 @@
6262
default='',
6363
validator=None
6464
),
65+
field.RestField(
66+
'oauth_client_id',
67+
required=False,
68+
encrypted=False,
69+
default='',
70+
validator=validator.String(
71+
min_len=0,
72+
max_len=200,
73+
)
74+
),
75+
field.RestField(
76+
'oauth_client_secret',
77+
required=False,
78+
encrypted=True,
79+
default='',
80+
validator=None
81+
),
82+
field.RestField(
83+
'oauth_access_token',
84+
required=False,
85+
encrypted=True
86+
),
87+
field.RestField(
88+
'oauth_token_expiration',
89+
required=False,
90+
encrypted=False,
91+
default='',
92+
validator=None
93+
),
6594
field.RestField(
6695
'databricks_pat',
6796
required=False,

app/bin/databricks_com.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,17 @@ def __init__(self, account_name, session_key):
4444
self.session.timeout = const.TIMEOUT
4545
if self.auth_type == "PAT":
4646
self.databricks_token = databricks_configs.get("databricks_pat")
47-
else:
47+
elif self.auth_type == "AAD":
4848
self.databricks_token = databricks_configs.get("aad_access_token")
4949
self.aad_client_id = databricks_configs.get("aad_client_id")
5050
self.aad_tenant_id = databricks_configs.get("aad_tenant_id")
5151
self.aad_client_secret = databricks_configs.get("aad_client_secret")
52+
elif self.auth_type == "OAUTH_M2M":
53+
self.databricks_token = databricks_configs.get("oauth_access_token")
54+
self.oauth_client_id = databricks_configs.get("oauth_client_id")
55+
self.oauth_client_secret = databricks_configs.get("oauth_client_secret")
56+
oauth_token_expiration_str = databricks_configs.get("oauth_token_expiration")
57+
self.oauth_token_expiration = float(oauth_token_expiration_str) if oauth_token_expiration_str else 0
5258

5359
if not all([databricks_instance, self.databricks_token]):
5460
raise Exception("Addon is not configured. Navigate to addon's configuration page to configure the addon.")
@@ -98,6 +104,48 @@ def get_requests_retry_session(self):
98104
session.mount("https://", adapter)
99105
return session
100106

107+
def should_refresh_oauth_token(self):
108+
"""
109+
Check if OAuth token should be refreshed proactively.
110+
111+
:return: Boolean - True if token expires within 5 minutes
112+
"""
113+
if not hasattr(self, 'oauth_token_expiration'):
114+
return False
115+
116+
import time
117+
current_time = time.time()
118+
time_until_expiry = self.oauth_token_expiration - current_time
119+
120+
# Refresh if token expires within 5 minutes (300 seconds)
121+
return time_until_expiry < 300
122+
123+
def _refresh_oauth_token(self):
124+
"""Refresh OAuth M2M access token and update session headers."""
125+
databricks_configs = utils.get_databricks_configs(self.session_key, self.account_name)
126+
proxy_config = databricks_configs.get("proxy_uri")
127+
128+
result = utils.get_oauth_access_token(
129+
self.session_key,
130+
self.account_name,
131+
self.databricks_instance_url.replace("https://", ""),
132+
self.oauth_client_id,
133+
self.oauth_client_secret,
134+
proxy_config,
135+
retry=const.RETRIES,
136+
conf_update=True
137+
)
138+
139+
if isinstance(result, tuple) and result[1] == False:
140+
raise Exception(result[0])
141+
142+
access_token, expires_in = result
143+
self.databricks_token = access_token
144+
import time
145+
self.oauth_token_expiration = time.time() + expires_in
146+
self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token)
147+
self.session.headers.update(self.request_headers)
148+
101149
def databricks_api(self, method, endpoint, data=None, args=None):
102150
"""
103151
Common method to hit the API of Databricks instance.
@@ -108,6 +156,11 @@ def databricks_api(self, method, endpoint, data=None, args=None):
108156
:param args: Arguments to be add into the url
109157
:return: response in the form of dictionary
110158
"""
159+
# Proactive OAuth token refresh
160+
if self.auth_type == "OAUTH_M2M" and self.should_refresh_oauth_token():
161+
_LOGGER.info("OAuth token expiring soon, refreshing proactively.")
162+
self._refresh_oauth_token()
163+
111164
run_again = True
112165
request_url = "{}{}".format(self.databricks_instance_url, endpoint)
113166
try:
@@ -141,6 +194,11 @@ def databricks_api(self, method, endpoint, data=None, args=None):
141194
self.databricks_token = db_token
142195
self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token)
143196
self.session.headers.update(self.request_headers)
197+
elif status_code == 403 and self.auth_type == "OAUTH_M2M" and run_again:
198+
response = None
199+
run_again = False
200+
_LOGGER.info("Refreshing OAuth M2M token.")
201+
self._refresh_oauth_token()
144202
elif status_code != 200:
145203
response.raise_for_status()
146204
else:

app/bin/databricks_common_utils.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,137 @@ def save_databricks_aad_access_token(account_name, session_key, access_token, cl
110110
raise Exception("Exception while saving AAD access token.")
111111

112112

113+
def save_databricks_oauth_access_token(account_name, session_key, access_token, expires_in, client_secret):
114+
"""
115+
Method to store new OAuth access token with expiration timestamp.
116+
117+
:param account_name: Account name
118+
:param session_key: Splunk session key
119+
:param access_token: OAuth access token
120+
:param expires_in: Token lifetime in seconds
121+
:param client_secret: OAuth client secret
122+
:return: None
123+
"""
124+
import time
125+
new_creds = {
126+
"name": account_name,
127+
"oauth_client_secret": client_secret,
128+
"oauth_access_token": access_token,
129+
"oauth_token_expiration": str(time.time() + expires_in),
130+
"update_token": True
131+
}
132+
try:
133+
_LOGGER.info("Saving databricks OAuth access token.")
134+
rest.simpleRequest(
135+
"/databricks_get_credentials",
136+
sessionKey=session_key,
137+
postargs=new_creds,
138+
raiseAllErrors=True,
139+
)
140+
_LOGGER.info("Saved OAuth access token successfully.")
141+
except Exception as e:
142+
_LOGGER.error("Exception while saving OAuth access token: {}".format(str(e)))
143+
_LOGGER.debug(traceback.format_exc())
144+
raise Exception("Exception while saving OAuth access token.")
145+
146+
147+
def get_oauth_access_token(
148+
session_key,
149+
account_name,
150+
databricks_instance,
151+
oauth_client_id,
152+
oauth_client_secret,
153+
proxy_settings=None,
154+
retry=1,
155+
conf_update=False,
156+
):
157+
"""
158+
Method to acquire OAuth M2M access token for Databricks service principal.
159+
160+
:param session_key: Splunk session key
161+
:param account_name: Account name for configuration storage
162+
:param databricks_instance: Databricks workspace instance URL
163+
:param oauth_client_id: OAuth client ID from service principal
164+
:param oauth_client_secret: OAuth client secret from service principal
165+
:param proxy_settings: Proxy configuration dict
166+
:param retry: Number of retry attempts
167+
:param conf_update: If True, store token in configuration
168+
:return: tuple (access_token, expires_in) or (error_message, False)
169+
"""
170+
import time
171+
from requests.auth import HTTPBasicAuth
172+
173+
token_url = "https://{}/oidc/v1/token".format(databricks_instance)
174+
headers = {
175+
"Content-Type": "application/x-www-form-urlencoded",
176+
"User-Agent": "{}".format(const.USER_AGENT_CONST),
177+
}
178+
_LOGGER.debug("Request made to the Databricks from Splunk user: {}".format(get_current_user(session_key)))
179+
data_dict = {"grant_type": "client_credentials", "scope": "all-apis"}
180+
data_encoded = urlencode(data_dict)
181+
182+
# Handle proxy settings for OAuth M2M
183+
# Note: "use_for_oauth" means "use proxy ONLY for AAD token generation"
184+
# Since OAuth M2M endpoint is on the Databricks instance (not AAD),
185+
# we should skip proxy when use_for_oauth is true
186+
if proxy_settings:
187+
if is_true(proxy_settings.get("use_for_oauth")):
188+
_LOGGER.info(
189+
"Skipping the usage of proxy for OAuth M2M as 'Use Proxy for OAuth' parameter is checked."
190+
)
191+
proxy_settings_copy = None
192+
else:
193+
proxy_settings_copy = proxy_settings.copy()
194+
proxy_settings_copy.pop("use_for_oauth", None)
195+
else:
196+
proxy_settings_copy = None
197+
198+
while retry:
199+
try:
200+
resp = requests.post(
201+
token_url,
202+
headers=headers,
203+
data=data_encoded,
204+
auth=HTTPBasicAuth(oauth_client_id, oauth_client_secret),
205+
proxies=proxy_settings_copy,
206+
verify=const.VERIFY_SSL,
207+
timeout=const.TIMEOUT
208+
)
209+
resp.raise_for_status()
210+
response = resp.json()
211+
oauth_access_token = response.get("access_token")
212+
expires_in = response.get("expires_in", 3600)
213+
if conf_update:
214+
save_databricks_oauth_access_token(
215+
account_name, session_key, oauth_access_token, expires_in, oauth_client_secret
216+
)
217+
return oauth_access_token, expires_in
218+
except Exception as e:
219+
retry -= 1
220+
if "resp" in locals():
221+
error_code = resp.json().get("error")
222+
if error_code and error_code in list(const.ERROR_CODE.keys()):
223+
msg = const.ERROR_CODE[error_code]
224+
elif str(resp.status_code) in list(const.ERROR_CODE.keys()):
225+
msg = const.ERROR_CODE[str(resp.status_code)]
226+
elif resp.status_code not in (200, 201):
227+
msg = (
228+
"Response status: {}. Unable to validate OAuth credentials. "
229+
"Check logs for more details.".format(str(resp.status_code))
230+
)
231+
else:
232+
msg = (
233+
"Unable to request Databricks instance. "
234+
"Please validate the provided Databricks and "
235+
"Proxy configurations or check the network connectivity."
236+
)
237+
_LOGGER.error("Error while trying to generate OAuth access token: {}".format(str(e)))
238+
_LOGGER.debug(traceback.format_exc())
239+
_LOGGER.error(msg)
240+
if retry == 0:
241+
return msg, False
242+
243+
113244
def get_proxy_clear_password(session_key):
114245
"""
115246
Get clear password from splunk passwords.conf.

0 commit comments

Comments
 (0)