From 323ae6d35b067b0dfc6381c51141771b6c63e024 Mon Sep 17 00:00:00 2001 From: Allen Greaves Date: Tue, 18 Nov 2025 22:53:46 -0800 Subject: [PATCH 1/6] feat(cloud): support workload identity auth for azure --- .../backend-operator/templates/_helpers.tpl | 2 + .../templates/backend-listener-rbac.yaml | 6 ++ .../templates/backend-test-runner-rbac.yaml | 6 ++ .../templates/backend-worker-rbac.yaml | 6 ++ .../charts/backend-operator/values.yaml | 33 +++++-- .../charts/router/templates/_helpers.tpl | 5 +- .../router/templates/router-service.yaml | 4 +- .../router/templates/service-account.yaml | 9 +- deployments/charts/router/values.yaml | 24 ++++- .../charts/service/templates/_helpers.tpl | 16 +++- .../service/templates/agent-service.yaml | 3 +- .../charts/service/templates/api-service.yaml | 3 +- .../templates/delayed-job-monitor.yaml | 3 +- .../service/templates/logger-service.yaml | 3 +- .../service/templates/service-account.yaml | 12 ++- .../charts/service/templates/worker.yaml | 3 +- deployments/charts/service/values.yaml | 52 +++++++++- src/cli/credential.py | 2 +- src/lib/data/dataset/common.py | 2 +- src/lib/data/dataset/manager.py | 6 +- src/lib/data/dataset/migrating.py | 4 +- src/lib/data/dataset/updating.py | 4 +- src/lib/data/dataset/uploading.py | 6 +- src/lib/data/storage/backends/azure.py | 53 +++++++++-- src/lib/data/storage/backends/backends.py | 95 ++++++++++++++----- src/lib/data/storage/backends/common.py | 12 +-- src/lib/data/storage/client.py | 2 +- src/lib/data/storage/common.py | 11 ++- src/lib/data/storage/mux.py | 2 +- src/lib/utils/credentials.py | 22 ++++- src/utils/connectors/postgres.py | 4 +- 31 files changed, 323 insertions(+), 92 deletions(-) diff --git a/deployments/charts/backend-operator/templates/_helpers.tpl b/deployments/charts/backend-operator/templates/_helpers.tpl index 14dbbf1e..5147b5b0 100644 --- a/deployments/charts/backend-operator/templates/_helpers.tpl +++ b/deployments/charts/backend-operator/templates/_helpers.tpl @@ -73,6 +73,8 @@ Create a common service account name based on component {{- $name := or .root.Values.global.name .root.Release.Name -}} {{- if .serviceConfig.serviceAccount -}} {{- printf "%s-%s" $name .serviceConfig.serviceAccount | trunc 63 | trimSuffix "-" -}} +{{- else if .root.Values.serviceAccount.name -}} +{{- .root.Values.serviceAccount.name | trunc 63 | trimSuffix "-" -}} {{- else -}} {{- printf "%s-%s" $name .component | trunc 63 | trimSuffix "-" -}} {{- end -}} diff --git a/deployments/charts/backend-operator/templates/backend-listener-rbac.yaml b/deployments/charts/backend-operator/templates/backend-listener-rbac.yaml index 83fdbb19..346a4503 100644 --- a/deployments/charts/backend-operator/templates/backend-listener-rbac.yaml +++ b/deployments/charts/backend-operator/templates/backend-listener-rbac.yaml @@ -18,12 +18,18 @@ {{$name := or .Values.global.name .Release.Name }} # Backend Listener Service Account +{{- if .Values.serviceAccount.create }} apiVersion: v1 kind: ServiceAccount metadata: namespace: {{ .Values.global.agentNamespace }} name: {{ include "backend-operator.listener.serviceAccountName" . }} + {{- if .Values.serviceAccount.annotations }} + annotations: + {{- toYaml .Values.serviceAccount.annotations | nindent 4 }} + {{- end }} --- +{{- end }} # Role for Backend Listener (Namespace-scoped events) kind: Role apiVersion: rbac.authorization.k8s.io/v1 diff --git a/deployments/charts/backend-operator/templates/backend-test-runner-rbac.yaml b/deployments/charts/backend-operator/templates/backend-test-runner-rbac.yaml index 3725f943..a05b8546 100644 --- a/deployments/charts/backend-operator/templates/backend-test-runner-rbac.yaml +++ b/deployments/charts/backend-operator/templates/backend-test-runner-rbac.yaml @@ -19,6 +19,7 @@ # Backend Test Runner Service Account {{- if .Values.global.backendTestNamespace }} +{{- if .Values.serviceAccount.create }} apiVersion: v1 kind: ServiceAccount metadata: @@ -26,7 +27,12 @@ metadata: namespace: {{ .Values.global.backendTestNamespace }} labels: app: {{ $name }}-test-runner + {{- if .Values.serviceAccount.annotations }} + annotations: + {{- toYaml .Values.serviceAccount.annotations | nindent 4 }} + {{- end }} --- +{{- end }} # Role for Backend Test Runner kind: Role apiVersion: rbac.authorization.k8s.io/v1 diff --git a/deployments/charts/backend-operator/templates/backend-worker-rbac.yaml b/deployments/charts/backend-operator/templates/backend-worker-rbac.yaml index 618868ee..cca84c63 100644 --- a/deployments/charts/backend-operator/templates/backend-worker-rbac.yaml +++ b/deployments/charts/backend-operator/templates/backend-worker-rbac.yaml @@ -18,12 +18,18 @@ {{$name := or .Values.global.name .Release.Name }} # Backend Worker Service Account +{{- if .Values.serviceAccount.create }} apiVersion: v1 kind: ServiceAccount metadata: namespace: {{ .Values.global.agentNamespace }} name: {{ include "backend-operator.worker.serviceAccountName" . }} + {{- if .Values.serviceAccount.annotations }} + annotations: + {{- toYaml .Values.serviceAccount.annotations | nindent 4 }} + {{- end }} --- +{{- end }} # ClusterRole for Backend Worker (Cluster-scoped resources) kind: ClusterRole apiVersion: rbac.authorization.k8s.io/v1 diff --git a/deployments/charts/backend-operator/values.yaml b/deployments/charts/backend-operator/values.yaml index 90753590..b24711bd 100644 --- a/deployments/charts/backend-operator/values.yaml +++ b/deployments/charts/backend-operator/values.yaml @@ -219,6 +219,24 @@ global: value: 50 description: "Schedule last. Preemptible." +## Service account configuration shared by backend listener and worker pods +## +serviceAccount: + ## Create the ServiceAccounts defined by this chart. Set to false to bind + ## pre-provisioned workload identity ServiceAccounts + ## + create: true + + ## ServiceAccount name to use when creating or for an already provisioned + ## ServiceAccount + ## + name: "" + + ## Extra annotations applied to ServiceAccounts (e.g. + ## `azure.workload.identity/client-id`) + ## + annotations: {} + ## Configuration for individual backend operators ## services: @@ -246,9 +264,10 @@ services: ## initContainers: [] - ## Kubernetes service account name for the backend listener + ## Kubernetes service account name for the backend listener. Leave empty to + ## use the chart default (`backend-listener`). ## - serviceAccount: backend-listener + serviceAccount: "" ## Maximum number of unacknowledged websocket messages ## @@ -356,9 +375,10 @@ services: ## initContainers: [] - ## Kubernetes service account name for the backend worker + ## Kubernetes service account name for the backend worker. Leave empty to + ## use the chart default (`backend-worker`). ## - serviceAccount: backend-worker + serviceAccount: "" ## How often to write progress during task processing loops ## @@ -612,9 +632,10 @@ backendTestRunner: ## readOnlyRootFilesystem: true - ## Kubernetes service account name for test runner pods + ## Kubernetes service account name for test runner pods. Leave empty to use + ## the chart default (`test-runner`). ## - serviceAccount: test-runner + serviceAccount: "" ## Whether to automount service account token ## diff --git a/deployments/charts/router/templates/_helpers.tpl b/deployments/charts/router/templates/_helpers.tpl index 72dd28c8..c7c90064 100644 --- a/deployments/charts/router/templates/_helpers.tpl +++ b/deployments/charts/router/templates/_helpers.tpl @@ -70,9 +70,10 @@ app.kubernetes.io/instance: {{ .Release.Name }} Create the name of the service account to use */}} {{- define "router.serviceAccountName" -}} +{{- $defaultName := "router" -}} {{- if .Values.serviceAccount.create }} -{{- default (include "router.fullname" .) .Values.serviceAccount.name }} +{{- default $defaultName .Values.serviceAccount.name }} {{- else }} -{{- default "default" .Values.serviceAccount.name }} +{{- required "serviceAccount.name must be provided when serviceAccount.create is false" .Values.serviceAccount.name }} {{- end }} {{- end }} diff --git a/deployments/charts/router/templates/router-service.yaml b/deployments/charts/router/templates/router-service.yaml index 83db1c4a..df335cc1 100644 --- a/deployments/charts/router/templates/router-service.yaml +++ b/deployments/charts/router/templates/router-service.yaml @@ -14,6 +14,8 @@ # # SPDX-License-Identifier: Apache-2.0 +{{- $routerServiceAccount := default (include "router.serviceAccountName" .) .Values.services.service.serviceAccountName }} + apiVersion: apps/v1 kind: Deployment metadata: @@ -65,7 +67,7 @@ spec: {{- end }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} - serviceAccountName: {{ .Values.services.service.serviceAccountName | default "default" }} + serviceAccountName: {{ $routerServiceAccount }} {{- with .Values.services.service.initContainers }} initContainers: {{- toYaml . | nindent 6 }} diff --git a/deployments/charts/router/templates/service-account.yaml b/deployments/charts/router/templates/service-account.yaml index 4fe0a371..05e54f68 100644 --- a/deployments/charts/router/templates/service-account.yaml +++ b/deployments/charts/router/templates/service-account.yaml @@ -13,7 +13,14 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 +{{- $routerServiceAccount := default (include "router.serviceAccountName" .) .Values.services.service.serviceAccountName }} +{{- if .Values.serviceAccount.create }} apiVersion: v1 kind: ServiceAccount metadata: - name: router + name: {{ $routerServiceAccount }} + {{- if .Values.serviceAccount.annotations }} + annotations: + {{- toYaml .Values.serviceAccount.annotations | nindent 4 }} + {{- end }} +{{- end }} diff --git a/deployments/charts/router/values.yaml b/deployments/charts/router/values.yaml index 5b272d66..7a45c396 100644 --- a/deployments/charts/router/values.yaml +++ b/deployments/charts/router/values.yaml @@ -51,6 +51,24 @@ global: ## k8sLogLevel: WARNING +## Service account configuration for router pods +## +serviceAccount: + ## Create the ServiceAccounts defined by this chart. Set to false to bind + ## pre-provisioned workload identity ServiceAccounts + ## + create: true + + ## ServiceAccount name to use when creating or for an already provisioned + ## ServiceAccount + ## + name: "" + + ## Additional ServiceAccount annotations (e.g. + ## `azure.workload.identity/client-id`). + ## + annotations: {} + ## Configuration for individual Osmo services ## services: @@ -115,9 +133,11 @@ services: # - "--debug" # - "--config=/path/to/config" - ## Kubernetes service account name for the router service + ## Kubernetes service account name for the router service. Leave empty to + ## reuse `serviceAccount.name` (when provided); if both are empty the chart + ## falls back to the built-in `router` ServiceAccount name. ## - serviceAccountName: router + serviceAccountName: "" ## Host aliases for custom DNS resolution within the router pods ## diff --git a/deployments/charts/service/templates/_helpers.tpl b/deployments/charts/service/templates/_helpers.tpl index 4a053481..19ce7571 100644 --- a/deployments/charts/service/templates/_helpers.tpl +++ b/deployments/charts/service/templates/_helpers.tpl @@ -70,10 +70,11 @@ app.kubernetes.io/instance: {{ .Release.Name }} Create the name of the service account to use */}} {{- define "osmo.serviceAccountName" -}} +{{- $defaultName := include "osmo.fullname" . }} {{- if .Values.serviceAccount.create }} -{{- default (include "osmo.fullname" .) .Values.serviceAccount.name }} +{{- default $defaultName .Values.serviceAccount.name }} {{- else }} -{{- default "default" .Values.serviceAccount.name }} +{{- required "serviceAccount.name must be provided when serviceAccount.create is false" .Values.serviceAccount.name }} {{- end }} {{- end }} @@ -95,7 +96,7 @@ Service account name helper {{- if .serviceAccountName }} {{- .serviceAccountName }} {{- else }} -{{- .Values.global.serviceAccountName }} +{{- include "osmo.serviceAccountName" .root }} {{- end }} {{- end }} @@ -108,6 +109,15 @@ Extra annotations helper {{- end }} {{- end }} +{{/* +Extra labels helper +*/}} +{{- define "osmo.extra-labels" -}} +{{- if .extraPodLabels }} +{{- toYaml .extraPodLabels }} +{{- end }} +{{- end }} + {{/* Extra environment variables helper */}} diff --git a/deployments/charts/service/templates/agent-service.yaml b/deployments/charts/service/templates/agent-service.yaml index b5d756a5..06b98e9f 100644 --- a/deployments/charts/service/templates/agent-service.yaml +++ b/deployments/charts/service/templates/agent-service.yaml @@ -31,6 +31,7 @@ spec: metadata: labels: app: {{ .Values.services.agent.serviceName }} + {{- include "osmo.extra-labels" .Values.services.agent | nindent 8 }} annotations: {{- include "osmo.extra-annotations" .Values.services.agent | nindent 8 }} {{- if .Values.sidecars.otel.enabled }} @@ -64,7 +65,7 @@ spec: {{ toYaml .Values.services.agent.tolerations | nindent 8 }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} - serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.agent.serviceAccountName "Values" .Values) }} + serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.agent.serviceAccountName "root" .) }} {{- with .Values.services.agent.initContainers }} initContainers: {{- toYaml . | nindent 6 }} diff --git a/deployments/charts/service/templates/api-service.yaml b/deployments/charts/service/templates/api-service.yaml index 4498912d..1be46e0f 100644 --- a/deployments/charts/service/templates/api-service.yaml +++ b/deployments/charts/service/templates/api-service.yaml @@ -30,6 +30,7 @@ spec: metadata: labels: app: {{ .Values.services.service.serviceName }} + {{- include "osmo.extra-labels" .Values.services.service | nindent 8 }} annotations: {{- include "osmo.extra-annotations" .Values.services.service | nindent 8 }} {{- if .Values.sidecars.otel.enabled }} @@ -63,7 +64,7 @@ spec: {{ toYaml .Values.services.service.tolerations | nindent 8 }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} - serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.service.serviceAccountName "Values" .Values) }} + serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.service.serviceAccountName "root" .) }} {{- with .Values.services.service.initContainers }} initContainers: {{- toYaml . | nindent 6 }} diff --git a/deployments/charts/service/templates/delayed-job-monitor.yaml b/deployments/charts/service/templates/delayed-job-monitor.yaml index 98c7cc14..70d720e1 100644 --- a/deployments/charts/service/templates/delayed-job-monitor.yaml +++ b/deployments/charts/service/templates/delayed-job-monitor.yaml @@ -28,6 +28,7 @@ spec: metadata: labels: app: {{ .Values.services.delayedJobMonitor.serviceName }} + {{- include "osmo.extra-labels" .Values.services.delayedJobMonitor | nindent 8 }} annotations: {{- include "osmo.extra-annotations" .Values.services.delayedJobMonitor | nindent 8 }} {{- if .Values.sidecars.otel.enabled }} @@ -45,7 +46,7 @@ spec: {{ toYaml .Values.services.delayedJobMonitor.tolerations | nindent 8 }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} - serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.delayedJobMonitor.serviceAccountName "Values" .Values) }} + serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.delayedJobMonitor.serviceAccountName "root" .) }} {{- with .Values.services.delayedJobMonitor.initContainers }} initContainers: {{- toYaml . | nindent 6 }} diff --git a/deployments/charts/service/templates/logger-service.yaml b/deployments/charts/service/templates/logger-service.yaml index fa0e795e..df73171c 100644 --- a/deployments/charts/service/templates/logger-service.yaml +++ b/deployments/charts/service/templates/logger-service.yaml @@ -31,6 +31,7 @@ spec: metadata: labels: app: {{ .Values.services.logger.serviceName }} + {{- include "osmo.extra-labels" .Values.services.logger | nindent 8 }} annotations: {{- include "osmo.extra-annotations" .Values.services.logger | nindent 8 }} {{- if .Values.sidecars.otel.enabled }} @@ -64,7 +65,7 @@ spec: {{ toYaml .Values.services.logger.tolerations | nindent 8 }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} - serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.logger.serviceAccountName "Values" .Values) }} + serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.logger.serviceAccountName "root" .) }} {{- with .Values.services.logger.initContainers }} initContainers: {{- toYaml . | nindent 6 }} diff --git a/deployments/charts/service/templates/service-account.yaml b/deployments/charts/service/templates/service-account.yaml index b2192b2c..f94945e4 100644 --- a/deployments/charts/service/templates/service-account.yaml +++ b/deployments/charts/service/templates/service-account.yaml @@ -13,11 +13,19 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 +{{- if .Values.serviceAccount.create }} apiVersion: v1 kind: ServiceAccount metadata: - name: {{ .Values.global.serviceAccountName }} - {{- if and .Values.sidecars.logAgent.cloudwatch .Values.sidecars.logAgent.cloudwatch.enabled }} + name: {{ include "osmo.serviceAccountName" . }} + {{- $hasAnnotations := or .Values.serviceAccount.annotations (and .Values.sidecars.logAgent.cloudwatch .Values.sidecars.logAgent.cloudwatch.enabled) }} + {{- if $hasAnnotations }} annotations: + {{- with .Values.serviceAccount.annotations }} + {{- toYaml . | nindent 4 }} + {{- end }} + {{- if and .Values.sidecars.logAgent.cloudwatch .Values.sidecars.logAgent.cloudwatch.enabled }} eks.amazonaws.com/role-arn: {{ .Values.sidecars.logAgent.cloudwatch.role }} + {{- end }} {{- end }} +{{- end }} diff --git a/deployments/charts/service/templates/worker.yaml b/deployments/charts/service/templates/worker.yaml index cfe53833..1281b690 100644 --- a/deployments/charts/service/templates/worker.yaml +++ b/deployments/charts/service/templates/worker.yaml @@ -27,6 +27,7 @@ spec: metadata: labels: app: {{ .Values.services.worker.serviceName }} + {{- include "osmo.extra-labels" .Values.services.worker | nindent 8 }} annotations: {{- include "osmo.extra-annotations" .Values.services.worker | nindent 8 }} {{- if .Values.sidecars.otel.enabled }} @@ -56,7 +57,7 @@ spec: {{ toYaml .Values.services.worker.tolerations | nindent 8 }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} - serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.worker.serviceAccountName "Values" .Values) }} + serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.worker.serviceAccountName "root" .) }} {{- with .Values.services.worker.initContainers }} initContainers: {{- toYaml . | nindent 6 }} diff --git a/deployments/charts/service/values.yaml b/deployments/charts/service/values.yaml index 52df5838..2efe6873 100644 --- a/deployments/charts/service/values.yaml +++ b/deployments/charts/service/values.yaml @@ -36,10 +36,6 @@ global: # node-type: cpu # zone: us-west-2a - ## Service account name for the OSMO services - ## - serviceAccountName: osmo - ## Logging configuration for all Osmo services ## logs: @@ -55,6 +51,27 @@ global: ## k8sLogLevel: WARNING +## Service account configuration for core Osmo services +## +serviceAccount: + ## Create the chart-managed ServiceAccount. Set to false to re-use + ## an existing ServiceAccount that already has workload identity bindings. + ## When create is false you must set `serviceAccount.name` to the + ## pre-provisioned ServiceAccount name. + ## + create: true + + ## ServiceAccount name to use when you need a specific ServiceAccount. When + ## `serviceAccount.create=false` this value is required; otherwise it overrides + ## the default name when provided. + ## + name: "" + + ## Extra ServiceAccount annotations such as + ## `azure.workload.identity/client-id` for AKS workload identity. + ## + annotations: {} + ## Configuration for individual Osmo services ## services: @@ -335,6 +352,12 @@ services: ## extraPodAnnotations: {} + ## Extra pod labels for the delayed job monitor service. + ## Useful for workload identity labels such as + ## `azure.workload.identity/use: "true"`. + ## + extraPodLabels: {} + ## Extra environment variables for the delayed job monitor service container ## extraEnv: [] @@ -439,6 +462,11 @@ services: ## extraPodAnnotations: {} + ## Extra pod labels for the worker service pods (e.g. AKS + ## `azure.workload.identity/use` selectors). + ## + extraPodLabels: {} + ## Extra environment variables for the worker service container ## extraEnv: [] @@ -663,6 +691,11 @@ services: ## extraPodAnnotations: {} + ## Extra pod labels for the API service pods. Use this to apply + ## workload identity labels like `azure.workload.identity/use`. + ## + extraPodLabels: {} + ## Extra environment variables for the API service container ## extraEnv: [] @@ -775,6 +808,11 @@ services: ## extraPodAnnotations: {} + ## Extra pod labels for the logger service pods (e.g. workload + ## identity labels required by AKS). + ## + extraPodLabels: {} + ## Extra environment variables for the logger service container ## extraEnv: [] @@ -884,6 +922,12 @@ services: ## extraPodAnnotations: {} + ## Extra pod labels for the agent service pods. + ## Useful for projecting workload identity labels such as + ## `azure.workload.identity/use`. + ## + extraPodLabels: {} + ## Extra environment variables for the agent service container ## extraEnv: [] diff --git a/src/cli/credential.py b/src/cli/credential.py index 922e2733..d32f5fd1 100644 --- a/src/cli/credential.py +++ b/src/cli/credential.py @@ -45,7 +45,7 @@ def _save_config(data_cred: credentials.DataCredential): config['auth']['data'][data_cred.endpoint] = { 'access_key_id': data_cred.access_key_id, - 'access_key': data_cred.access_key.get_secret_value(), + 'access_key': data_cred.get_access_key_value(), 'region': data_cred.region} with open(password_file, 'w', encoding='utf-8') as file: yaml.dump(config, file) diff --git a/src/lib/data/dataset/common.py b/src/lib/data/dataset/common.py index 510294f1..6c6bf51f 100644 --- a/src/lib/data/dataset/common.py +++ b/src/lib/data/dataset/common.py @@ -308,7 +308,7 @@ def _validate_source_path( user_credentials = client_configs.get_credentials(path_components.profile) path_components.data_auth( user_credentials.access_key_id, - user_credentials.access_key.get_secret_value(), + user_credentials.get_access_key_value(), user_credentials.region, storage.AccessType.READ, ) diff --git a/src/lib/data/dataset/manager.py b/src/lib/data/dataset/manager.py index d507b09c..9b89816e 100644 --- a/src/lib/data/dataset/manager.py +++ b/src/lib/data/dataset/manager.py @@ -262,7 +262,7 @@ def upload_start( credentials = client_configs.get_credentials(path_components.profile) path_components.data_auth( credentials.access_key_id, - credentials.access_key.get_secret_value(), + credentials.get_access_key_value(), credentials.region, storage.AccessType.WRITE, ) @@ -404,7 +404,7 @@ def _update_dataset_start( # Validate delete access path_components.data_auth( credentials.access_key_id, - credentials.access_key.get_secret_value(), + credentials.get_access_key_value(), credentials.region, storage.AccessType.DELETE, ) @@ -412,7 +412,7 @@ def _update_dataset_start( # Validate write access path_components.data_auth( credentials.access_key_id, - credentials.access_key.get_secret_value(), + credentials.get_access_key_value(), credentials.region, storage.AccessType.WRITE, ) diff --git a/src/lib/data/dataset/migrating.py b/src/lib/data/dataset/migrating.py index 26704c92..ba38f0a9 100644 --- a/src/lib/data/dataset/migrating.py +++ b/src/lib/data/dataset/migrating.py @@ -224,12 +224,12 @@ def migrate( destination_creds = client_configs.get_credentials(destination_backend.profile) destination_region = destination_backend.region( destination_creds.access_key_id, - destination_creds.access_key.get_secret_value(), + destination_creds.get_access_key_value(), ) client_factory = destination_backend.client_factory( access_key_id=destination_creds.access_key_id, - access_key=destination_creds.access_key.get_secret_value(), + access_key=destination_creds.get_access_key_value(), region=destination_region, ) diff --git a/src/lib/data/dataset/updating.py b/src/lib/data/dataset/updating.py index 5318697b..31353da0 100644 --- a/src/lib/data/dataset/updating.py +++ b/src/lib/data/dataset/updating.py @@ -227,12 +227,12 @@ def update( destination_creds = client_configs.get_credentials(destination.profile) destination_region = destination.region( destination_creds.access_key_id, - destination_creds.access_key.get_secret_value(), + destination_creds.get_access_key_value(), ) client_factory = destination.client_factory( access_key_id=destination_creds.access_key_id, - access_key=destination_creds.access_key.get_secret_value(), + access_key=destination_creds.get_access_key_value(), region=destination_region, request_headers=request_headers, ) diff --git a/src/lib/data/dataset/uploading.py b/src/lib/data/dataset/uploading.py index 24d8f7e6..bb45918e 100644 --- a/src/lib/data/dataset/uploading.py +++ b/src/lib/data/dataset/uploading.py @@ -147,7 +147,7 @@ def dataset_upload_remote_file_entry_generator( url_base = storage_backend.parse_uri_to_link( storage_backend.region( data_creds.access_key_id, - data_creds.access_key.get_secret_value(), + data_creds.get_access_key_value(), ), ) @@ -384,12 +384,12 @@ def upload( destination_creds = client_configs.get_credentials(destination.profile) destination_region = destination.region( destination_creds.access_key_id, - destination_creds.access_key.get_secret_value(), + destination_creds.get_access_key_value(), ) client_factory = destination.client_factory( access_key_id=destination_creds.access_key_id, - access_key=destination_creds.access_key.get_secret_value(), + access_key=destination_creds.get_access_key_value(), region=destination_region, request_headers=request_headers, ) diff --git a/src/lib/data/storage/backends/azure.py b/src/lib/data/storage/backends/azure.py index 58547a88..47df4b7e 100644 --- a/src/lib/data/storage/backends/azure.py +++ b/src/lib/data/storage/backends/azure.py @@ -27,6 +27,7 @@ from typing_extensions import assert_never, override from azure.core import exceptions +from azure.identity import DefaultAzureCredential from azure.storage import blob from ..core import client, provider @@ -271,12 +272,35 @@ def __next__(self) -> bytes: def create_client( - connection_string: str, + connection_string: str | None = None, + *, + account_url: str | None = None, ) -> blob.BlobServiceClient: """ - Creates a new Azure Blob Storage client. + Creates a new Azure Blob Storage client using the provided credential mode. + + Args: + connection_string: The Azure storage connection string. When provided, takes precedence. + account_url: The storage account blob service URL used for token credentials. + + Returns: + An initialized `BlobServiceClient` configured with the requested credential strategy. + + Raises: + client.OSMODataStorageClientError: If no valid credential configuration is supplied. """ - return blob.BlobServiceClient.from_connection_string(conn_str=connection_string) + if connection_string: + return blob.BlobServiceClient.from_connection_string(conn_str=connection_string) + + if account_url: + return blob.BlobServiceClient( + account_url=account_url, + credential=DefaultAzureCredential(), + ) + + raise client.OSMODataStorageClientError( + 'Azure Blob credential configuration requires either a connection string or token credentials.', + ) class AzureBlobStorageClient(client.StorageClient): @@ -306,7 +330,7 @@ def close(self) -> None: pass finally: super().close() - + @override def get_object_info( self, @@ -713,16 +737,23 @@ def _get_sas_url_for_copy(source_blob_client: blob.BlobClient) -> str: This is necessary to authorize the copy operation. """ - assert hasattr(source_blob_client.credential, 'account_key') + key_start_time = common.current_time().replace(tzinfo=datetime.timezone.utc) + key_expiry_time = key_start_time + _get_copy_sas_expiry_time() + + delegation_key = self._azure_client.get_user_delegation_key( + key_start_time=key_start_time, + key_expiry_time=key_expiry_time, + ) sas_token = blob.generate_blob_sas( account_name=source_blob_client.account_name, container_name=source_blob_client.container_name, blob_name=source_blob_client.blob_name, - account_key=source_blob_client.credential.account_key, permission=blob.BlobSasPermissions(read=True), - expiry=common.current_time() + _get_copy_sas_expiry_time(), + expiry=key_expiry_time, + user_delegation_key=delegation_key, ) + return f'{source_blob_client.url}?{sas_token}' def _call_api() -> client.CopyResponse: @@ -822,10 +853,14 @@ class AzureBlobStorageClientFactory(provider.StorageClientFactory): Factory for the AzureBlobStorageClient. """ - connection_string: str + connection_string: str | None = None + account_url: str | None = None @override def create(self) -> AzureBlobStorageClient: return AzureBlobStorageClient( - lambda: create_client(self.connection_string), + lambda: create_client( + self.connection_string, + account_url=self.account_url, + ), ) diff --git a/src/lib/data/storage/backends/backends.py b/src/lib/data/storage/backends/backends.py index e8b3ffa8..0413c3eb 100644 --- a/src/lib/data/storage/backends/backends.py +++ b/src/lib/data/storage/backends/backends.py @@ -132,14 +132,19 @@ def _get_extra_headers( @override def client_factory( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, request_headers: List[header.RequestHeaders] | None = None, **kwargs: Any, ) -> s3.S3StorageClientFactory: """ Returns a factory for creating storage clients. """ + if not access_key_id or not access_key: + raise osmo_errors.OSMOCredentialError( + 'Access key ID and secret access key are required for S3-compatible backends.' + ) + region = kwargs.get('region', None) or self.region(access_key_id, access_key) return s3.S3StorageClientFactory( # pylint: disable=unexpected-keyword-arg @@ -242,8 +247,8 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, region: str | None = None, access_type: common.AccessType | None = None, ): @@ -254,6 +259,11 @@ def data_auth( if _skip_data_auth(): return + if not access_key_id or not access_key: + raise osmo_errors.OSMOCredentialError( + 'Access key ID and secret access key are required for Swift backend.' + ) + if ':' in access_key_id: namespace = access_key_id.split(':')[1] else: @@ -297,8 +307,8 @@ def _validate_auth(): @override def region( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, ) -> str: """ Infer the region of the bucket via provided credentials. @@ -308,6 +318,11 @@ def region( if self._region is not None: return self._region + if not access_key_id or not access_key: + raise osmo_errors.OSMOCredentialError( + 'Access key ID and secret access key are required for Swift backend.' + ) + s3_client = s3.create_client( access_key_id=access_key_id, access_key=access_key, @@ -407,8 +422,8 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, region: str | None = None, access_type: common.AccessType | None = None, ): @@ -418,6 +433,11 @@ def data_auth( if _skip_data_auth(): return + if not access_key_id or not access_key: + raise osmo_errors.OSMOCredentialError( + 'Access key ID and secret access key are required for S3 backend.' + ) + action = [] if access_type == common.AccessType.READ: action.append('s3:GetObject') @@ -475,8 +495,8 @@ def _validate_auth(): @override def region( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, ) -> str: """ Infer the region of the bucket via provided credentials. @@ -486,6 +506,11 @@ def region( if self._region is not None: return self._region + if not access_key_id or not access_key: + raise osmo_errors.OSMOCredentialError( + 'Access key ID and secret access key are required for S3 backend.' + ) + s3_client = s3.create_client( access_key_id=access_key_id, access_key=access_key, @@ -587,8 +612,8 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, region: str | None = None, access_type: common.AccessType | None = None, ): @@ -598,6 +623,11 @@ def data_auth( if _skip_data_auth(): return + if not access_key_id or not access_key: + raise osmo_errors.OSMOCredentialError( + 'Access key ID and secret access key are required for GS backend.' + ) + if region is None: region = self.region(access_key_id, access_key) @@ -634,8 +664,8 @@ def _validate_auth(): @override def region( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, ) -> str: # pylint: disable=unused-argument """ @@ -726,8 +756,8 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, region: str | None = None, access_type: common.AccessType | None = None, ): @@ -738,6 +768,11 @@ def data_auth( if _skip_data_auth(): return + if not access_key_id or not access_key: + raise osmo_errors.OSMOCredentialError( + 'Access key ID and secret access key are required for TOS backend.' + ) + if region is None: # If region is not provided, we need to extract it from the netloc region = self.region('', '') @@ -773,8 +808,8 @@ def _validate_auth(): @override def region( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, ) -> str: # pylint: disable=unused-argument # netloc = tos-s3-. @@ -866,8 +901,8 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, region: str | None = None, access_type: common.AccessType | None = None, ): @@ -878,8 +913,13 @@ def data_auth( if _skip_data_auth(): return + connection_string = access_key.strip() if access_key else None + def _validate_auth(): - with azure.create_client(access_key) as service_client: + with azure.create_client( + connection_string=connection_string, + account_url=self.endpoint, + ) as service_client: if self.container: with service_client.get_container_client(self.container) as container_client: container_client.get_container_properties() @@ -900,8 +940,8 @@ def _validate_auth(): @override def region( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, ) -> str: # pylint: disable=unused-argument # Azure Blob Storage does not encode region in the URLs, we will simply @@ -916,8 +956,8 @@ def default_region(self) -> str: @override def client_factory( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, request_headers: List[header.RequestHeaders] | None = None, **kwargs: Any, ) -> azure.AzureBlobStorageClientFactory: @@ -925,8 +965,11 @@ def client_factory( """ Returns a factory for creating storage clients. """ + connection_string = access_key.strip() if access_key else None + return azure.AzureBlobStorageClientFactory( # pylint: disable=unexpected-keyword-arg - connection_string=access_key, + connection_string=connection_string, + account_url=self.endpoint, ) diff --git a/src/lib/data/storage/backends/common.py b/src/lib/data/storage/backends/common.py index d10f2ea2..97bdb391 100644 --- a/src/lib/data/storage/backends/common.py +++ b/src/lib/data/storage/backends/common.py @@ -184,8 +184,8 @@ def parse_uri_to_link(self, region: str) -> str: @abc.abstractmethod def data_auth( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, region: str | None = None, access_type: AccessType | None = None, ): @@ -197,8 +197,8 @@ def data_auth( @abc.abstractmethod def region( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, ) -> str: """ Infer the region of the bucket via provided credentials. @@ -232,8 +232,8 @@ def to_storage_path( @abc.abstractmethod def client_factory( self, - access_key_id: str, - access_key: str, + access_key_id: str | None, + access_key: str | None, request_headers: List[header.RequestHeaders] | None = None, **kwargs: Any, ) -> provider.StorageClientFactory: diff --git a/src/lib/data/storage/client.py b/src/lib/data/storage/client.py index b6dd0642..b10b26b1 100644 --- a/src/lib/data/storage/client.py +++ b/src/lib/data/storage/client.py @@ -294,7 +294,7 @@ def storage_auth(self) -> common.StorageAuth: """ return common.StorageAuth( user=self.data_credential.access_key_id, - key=self.data_credential.access_key.get_secret_value() + key=self.data_credential.get_access_key_value(), ) def _validate_remote_path( diff --git a/src/lib/data/storage/common.py b/src/lib/data/storage/common.py index c8000d1c..4caedd74 100644 --- a/src/lib/data/storage/common.py +++ b/src/lib/data/storage/common.py @@ -48,15 +48,18 @@ class StorageAuth: """ A class for storing Data Storage Authentication Details. + + Both user and key are optional to support workload identity authentication + where credentials are obtained from the environment rather than explicit keys. """ - user: str = pydantic.Field( - ..., + user: str | None = pydantic.Field( + default=None, description='The user of the storage authentication.', ) - key: str = pydantic.Field( - ..., + key: str | None = pydantic.Field( + default=None, description='The key of the storage authentication.', ) diff --git a/src/lib/data/storage/mux.py b/src/lib/data/storage/mux.py index f5a5f3de..02f40532 100644 --- a/src/lib/data/storage/mux.py +++ b/src/lib/data/storage/mux.py @@ -134,7 +134,7 @@ def bind(self, storage_profile: str) -> provider.StorageClientProvider: ) client_factory = storage_backend.client_factory( access_key_id=data_cred.access_key_id, - access_key=data_cred.access_key.get_secret_value(), + access_key=data_cred.get_access_key_value(), region=data_cred.region, request_headers=self._client_factory.request_headers, **self._client_factory.kwargs, diff --git a/src/lib/utils/credentials.py b/src/lib/utils/credentials.py index 912eb2c8..f39e3751 100644 --- a/src/lib/utils/credentials.py +++ b/src/lib/utils/credentials.py @@ -39,11 +39,22 @@ class RegistryCredential(pydantic.BaseModel, extra=pydantic.Extra.forbid): class BasicDataCredential(pydantic.BaseModel, extra=pydantic.Extra.forbid): """ Authentication information for a data service without endpoint and region info. """ - access_key_id: str = pydantic.Field( + access_key_id: str | None = pydantic.Field( + default=None, description='The authentication key for the data service') - access_key: pydantic.SecretStr = pydantic.Field( + access_key: pydantic.SecretStr | None = pydantic.Field( + default=None, description='The authentication secret for the data service') + def get_access_key_value(self) -> str | None: + """ + Safely returns the access key secret value, or None if not set. + + This supports workload identity authentication where credentials + are obtained from the environment rather than explicit keys. + """ + return self.access_key.get_secret_value() if self.access_key else None + class DataCredential(BasicDataCredential, extra=pydantic.Extra.forbid): """ @@ -76,8 +87,8 @@ class DecryptedDataCredential(BasicDataCredential, extra=pydantic.Extra.ignore): Basic data cred with access_key decrypted. """ - access_key: str = pydantic.Field( # type: ignore[assignment] - ..., + access_key: str | None = pydantic.Field( # type: ignore[assignment] + default=None, description='The authentication secret for the data service', ) @@ -89,5 +100,6 @@ class DecryptedDataCredential(BasicDataCredential, extra=pydantic.Extra.ignore): def decrypt(base_cred: DataCredential) -> DecryptedDataCredential: cred_dict = base_cred.dict() - cred_dict['access_key'] = cred_dict['access_key'].get_secret_value() + if cred_dict.get('access_key'): + cred_dict['access_key'] = cred_dict['access_key'].get_secret_value() return DecryptedDataCredential(**cred_dict) diff --git a/src/utils/connectors/postgres.py b/src/utils/connectors/postgres.py index 637eb9ce..9e8486a0 100644 --- a/src/utils/connectors/postgres.py +++ b/src/utils/connectors/postgres.py @@ -1228,7 +1228,7 @@ def get_data_cred(self, user: str, profile: str) -> credentials.DecryptedDataCre return credentials.DecryptedDataCredential( region=bucket.region, access_key_id=bucket.default_credential.access_key_id, - access_key=bucket.default_credential.access_key.get_secret_value(), + access_key=bucket.default_credential.get_access_key_value(), endpoint=bucket_info.profile ) break @@ -1255,7 +1255,7 @@ def get_all_data_creds(self, user: str) -> Dict[str, credentials.DecryptedDataCr user_creds[bucket_info.profile] = credentials.DecryptedDataCredential( region=bucket.region, access_key_id=bucket.default_credential.access_key_id, - access_key=bucket.default_credential.access_key.get_secret_value(), + access_key=bucket.default_credential.get_access_key_value(), endpoint=bucket_info.profile ) return user_creds From 053e0479a2eb20b438d62eb49441983b17178025 Mon Sep 17 00:00:00 2001 From: Allen Greaves Date: Thu, 11 Dec 2025 11:39:48 -0800 Subject: [PATCH 2/6] fix(charts): make imagePullSecret optional, update hook-azure to include msal modules for azure --- .../router/templates/router-service.yaml | 2 + .../service/templates/agent-service.yaml | 2 + .../charts/service/templates/api-service.yaml | 2 + .../templates/delayed-job-monitor.yaml | 2 + .../service/templates/logger-service.yaml | 2 + .../charts/service/templates/worker.yaml | 2 + deployments/charts/web-ui/templates/ui.yaml | 2 + .../data/storage/extra_hooks/hook-azure.py | 55 ++++++++++++------- 8 files changed, 50 insertions(+), 19 deletions(-) diff --git a/deployments/charts/router/templates/router-service.yaml b/deployments/charts/router/templates/router-service.yaml index df335cc1..4d97de9a 100644 --- a/deployments/charts/router/templates/router-service.yaml +++ b/deployments/charts/router/templates/router-service.yaml @@ -65,8 +65,10 @@ spec: tolerations: {{- toYaml . | nindent 8 }} {{- end }} + {{- if .Values.global.imagePullSecret }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} + {{- end }} serviceAccountName: {{ $routerServiceAccount }} {{- with .Values.services.service.initContainers }} initContainers: diff --git a/deployments/charts/service/templates/agent-service.yaml b/deployments/charts/service/templates/agent-service.yaml index 06b98e9f..eb39022f 100644 --- a/deployments/charts/service/templates/agent-service.yaml +++ b/deployments/charts/service/templates/agent-service.yaml @@ -63,8 +63,10 @@ spec: {{- end}} tolerations: {{ toYaml .Values.services.agent.tolerations | nindent 8 }} + {{- if .Values.global.imagePullSecret }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} + {{- end }} serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.agent.serviceAccountName "root" .) }} {{- with .Values.services.agent.initContainers }} initContainers: diff --git a/deployments/charts/service/templates/api-service.yaml b/deployments/charts/service/templates/api-service.yaml index 1be46e0f..d78c39f8 100644 --- a/deployments/charts/service/templates/api-service.yaml +++ b/deployments/charts/service/templates/api-service.yaml @@ -62,8 +62,10 @@ spec: {{- end}} tolerations: {{ toYaml .Values.services.service.tolerations | nindent 8 }} + {{- if .Values.global.imagePullSecret }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} + {{- end }} serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.service.serviceAccountName "root" .) }} {{- with .Values.services.service.initContainers }} initContainers: diff --git a/deployments/charts/service/templates/delayed-job-monitor.yaml b/deployments/charts/service/templates/delayed-job-monitor.yaml index 70d720e1..c29473e3 100644 --- a/deployments/charts/service/templates/delayed-job-monitor.yaml +++ b/deployments/charts/service/templates/delayed-job-monitor.yaml @@ -44,8 +44,10 @@ spec: {{- end}} tolerations: {{ toYaml .Values.services.delayedJobMonitor.tolerations | nindent 8 }} + {{- if .Values.global.imagePullSecret }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} + {{- end }} serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.delayedJobMonitor.serviceAccountName "root" .) }} {{- with .Values.services.delayedJobMonitor.initContainers }} initContainers: diff --git a/deployments/charts/service/templates/logger-service.yaml b/deployments/charts/service/templates/logger-service.yaml index df73171c..9326c984 100644 --- a/deployments/charts/service/templates/logger-service.yaml +++ b/deployments/charts/service/templates/logger-service.yaml @@ -63,8 +63,10 @@ spec: {{- end}} tolerations: {{ toYaml .Values.services.logger.tolerations | nindent 8 }} + {{- if .Values.global.imagePullSecret }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} + {{- end }} serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.logger.serviceAccountName "root" .) }} {{- with .Values.services.logger.initContainers }} initContainers: diff --git a/deployments/charts/service/templates/worker.yaml b/deployments/charts/service/templates/worker.yaml index 1281b690..14cb79b3 100644 --- a/deployments/charts/service/templates/worker.yaml +++ b/deployments/charts/service/templates/worker.yaml @@ -55,8 +55,10 @@ spec: {{- end}} tolerations: {{ toYaml .Values.services.worker.tolerations | nindent 8 }} + {{- if .Values.global.imagePullSecret }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} + {{- end }} serviceAccountName: {{ include "osmo.service-account-name" (dict "serviceAccountName" .Values.services.worker.serviceAccountName "root" .) }} {{- with .Values.services.worker.initContainers }} initContainers: diff --git a/deployments/charts/web-ui/templates/ui.yaml b/deployments/charts/web-ui/templates/ui.yaml index c26c8780..04e5becc 100644 --- a/deployments/charts/web-ui/templates/ui.yaml +++ b/deployments/charts/web-ui/templates/ui.yaml @@ -60,8 +60,10 @@ spec: {{- end}} tolerations: {{ toYaml .Values.services.ui.tolerations | nindent 8 }} + {{- if .Values.global.imagePullSecret }} imagePullSecrets: - name: {{ .Values.global.imagePullSecret }} + {{- end }} {{- if .Values.services.ui.initContainers }} initContainers: {{- toYaml .Values.services.ui.initContainers | nindent 6 }} diff --git a/src/lib/data/storage/extra_hooks/hook-azure.py b/src/lib/data/storage/extra_hooks/hook-azure.py index 66fb40bd..98ee538f 100644 --- a/src/lib/data/storage/extra_hooks/hook-azure.py +++ b/src/lib/data/storage/extra_hooks/hook-azure.py @@ -18,30 +18,47 @@ SPDX-License-Identifier: Apache-2.0 """ -from PyInstaller.utils import hooks # type: ignore +from PyInstaller.utils.hooks import collect_all # type: ignore -# Collect entry points -datas_set = set() -hiddenimports_set = set() +# Suppress warnings for modules that may not be installed in all environments +warn_on_missing_hiddenimports = False -data_files = ( - 'azure', - 'azure.storage', +# Initialize collections +datas = [] +binaries = [] +hiddenimports = [] + +# Use collect_all for comprehensive collection of azure packages and dependencies +# This handles data files, binaries, and hidden imports in one call +packages_to_collect = [ + 'azure.core', + 'azure.identity', + 'azure.storage.blob', + 'msal', + 'msal_extensions', 'isodate', -) +] -for data_file in data_files: - datas_set.update(hooks.collect_data_files(data_file, include_py_files=True)) +for package in packages_to_collect: + try: + pkg_datas, pkg_binaries, pkg_hiddenimports = collect_all(package) + datas.extend(pkg_datas) + binaries.extend(pkg_binaries) + hiddenimports.extend(pkg_hiddenimports) + except Exception: + # Package may not be installed; skip silently + pass -hiddenimports_files = ( +# Explicit hidden imports for modules not auto-discovered by collect_all +hiddenimports.extend([ + # Cryptography - needed by azure SDK 'cryptography.hazmat.primitives.ciphers.aead', 'cryptography.hazmat.primitives.padding', - 'wsgiref', -) - -# Add hidden imports -for hiddenimport_file in hiddenimports_files: - hiddenimports_set.update(hooks.collect_submodules(hiddenimport_file)) + # wsgiref stdlib - needed by azure.storage.blob._shared.policies + 'wsgiref.handlers', +]) -datas = list(datas_set) -hiddenimports = list(hiddenimports_set) +# Deduplicate +datas = list(set(datas)) +binaries = list(set(binaries)) +hiddenimports = list(set(hiddenimports)) From 24d93997811146257c7fa489ce8c817c76e14660 Mon Sep 17 00:00:00 2001 From: Allen Greaves Date: Thu, 11 Dec 2025 12:48:35 -0800 Subject: [PATCH 3/6] chore(credentials): add unit tests for BasicDataCredential class --- src/lib/utils/tests/BUILD | 8 +++ src/lib/utils/tests/test_credentials.py | 82 +++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 src/lib/utils/tests/test_credentials.py diff --git a/src/lib/utils/tests/BUILD b/src/lib/utils/tests/BUILD index 5bf0729c..8b4f6d88 100644 --- a/src/lib/utils/tests/BUILD +++ b/src/lib/utils/tests/BUILD @@ -42,3 +42,11 @@ osmo_py_test( "//src/lib/utils:jinja_sandbox", ] ) + +osmo_py_test( + name = "test_credentials", + srcs = ["test_credentials.py"], + deps = [ + "//src/lib/utils:credentials", + ] +) diff --git a/src/lib/utils/tests/test_credentials.py b/src/lib/utils/tests/test_credentials.py new file mode 100644 index 00000000..90cb8edc --- /dev/null +++ b/src/lib/utils/tests/test_credentials.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for the credentials module. +""" + +import unittest + +import pydantic + +from src.lib.utils import credentials + + +class TestBasicDataCredential(unittest.TestCase): + """ + Tests for BasicDataCredential class. + """ + + def test_get_access_key_value_returns_value_when_set(self): + """Test that get_access_key_value returns the secret value when access_key is set.""" + # Arrange + cred = credentials.BasicDataCredential( + access_key_id='test-key-id', + access_key=pydantic.SecretStr('test-secret'), + ) + + # Act + result = cred.get_access_key_value() + + # Assert + self.assertEqual(result, 'test-secret') + + def test_get_access_key_value_returns_none_when_not_set(self): + """Test that get_access_key_value returns None when access_key is None.""" + # Arrange + cred = credentials.BasicDataCredential() + + # Act + result = cred.get_access_key_value() + + # Assert + self.assertIsNone(result) + + def test_optional_fields_default_to_none(self): + """Test that both access_key_id and access_key default to None.""" + # Arrange & Act + cred = credentials.BasicDataCredential() + + # Assert + self.assertIsNone(cred.access_key_id) + self.assertIsNone(cred.access_key) + + def test_fields_accept_explicit_values(self): + """Test that fields accept and store explicit values correctly.""" + # Arrange & Act + cred = credentials.BasicDataCredential( + access_key_id='my-key-id', + access_key=pydantic.SecretStr('my-secret'), + ) + + # Assert + self.assertEqual(cred.access_key_id, 'my-key-id') + self.assertIsNotNone(cred.access_key) + self.assertEqual(cred.get_access_key_value(), 'my-secret') + + +if __name__ == '__main__': + unittest.main() From 0522944bf26c52765bcf717de70755c5e96dbcee Mon Sep 17 00:00:00 2001 From: Allen Greaves Date: Thu, 11 Dec 2025 12:51:20 -0800 Subject: [PATCH 4/6] chore(azure): add unit tests for Azure Blob Storage backend --- .../storage/backends/tests/test_backends.py | 119 +++++++++++++++++- 1 file changed, 117 insertions(+), 2 deletions(-) diff --git a/src/lib/data/storage/backends/tests/test_backends.py b/src/lib/data/storage/backends/tests/test_backends.py index 5cffeffd..9564e6ce 100644 --- a/src/lib/data/storage/backends/tests/test_backends.py +++ b/src/lib/data/storage/backends/tests/test_backends.py @@ -22,8 +22,8 @@ from typing import cast from unittest import mock -from src.lib.data.storage.backends import backends, s3 -from src.lib.data.storage.core import header +from src.lib.data.storage.backends import azure, backends, s3 +from src.lib.data.storage.core import client, header class TestBackends(unittest.TestCase): @@ -145,5 +145,120 @@ def test_path_backend_contains_sub_path(self): self.assertTrue(storage_backend_1 not in storage_backend_2) +class TestAzureBackend(unittest.TestCase): + """ + Tests for Azure Blob Storage backend. + """ + + @mock.patch('src.lib.data.storage.backends.azure.blob.BlobServiceClient') + def test_azure_create_client_with_connection_string(self, mock_blob_client): + """Test that create_client uses from_connection_string when connection_string provided.""" + # Arrange + connection_string = 'DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key' + + # Act + azure.create_client(connection_string=connection_string) + + # Assert + mock_blob_client.from_connection_string.assert_called_once_with( + conn_str=connection_string + ) + + @mock.patch('src.lib.data.storage.backends.azure.DefaultAzureCredential') + @mock.patch('src.lib.data.storage.backends.azure.blob.BlobServiceClient') + def test_azure_create_client_with_workload_identity(self, mock_blob_client, mock_credential): + """Test that create_client uses DefaultAzureCredential when only account_url provided.""" + # Arrange + mock_cred_instance = mock.Mock() + mock_credential.return_value = mock_cred_instance + account_url = 'https://myaccount.blob.core.windows.net' + + # Act + azure.create_client(account_url=account_url) + + # Assert + mock_credential.assert_called_once() + mock_blob_client.assert_called_once_with( + account_url=account_url, + credential=mock_cred_instance, + ) + + def test_azure_create_client_no_credentials_raises_error(self): + """Test that create_client raises error when no credentials provided.""" + # Act & Assert + with self.assertRaises(client.OSMODataStorageClientError): + azure.create_client() + + def test_azure_backend_client_factory_with_connection_string(self): + """Test that Azure backend client_factory passes connection string.""" + # Arrange + azure_backend = cast( + backends.AzureBlobStorageBackend, + backends.construct_storage_backend(uri='azure://mystorageaccount/mycontainer/path'), + ) + connection_string = 'DefaultEndpointsProtocol=https;AccountName=test;AccountKey=key' + + # Act + factory = azure_backend.client_factory( + access_key_id=None, + access_key=connection_string, + ) + + # Assert + self.assertEqual(factory.connection_string, connection_string) + self.assertEqual(factory.account_url, azure_backend.endpoint) + + def test_azure_backend_client_factory_with_workload_identity(self): + """Test that Azure backend client_factory works without credentials (workload identity).""" + # Arrange + azure_backend = cast( + backends.AzureBlobStorageBackend, + backends.construct_storage_backend(uri='azure://mystorageaccount/mycontainer/path'), + ) + + # Act + factory = azure_backend.client_factory( + access_key_id=None, + access_key=None, + ) + + # Assert + self.assertIsNone(factory.connection_string) + self.assertEqual(factory.account_url, azure_backend.endpoint) + + @mock.patch('src.lib.data.storage.backends.azure.create_client') + def test_azure_backend_data_auth_with_workload_identity(self, mock_create_client): + """Test that Azure data_auth works with workload identity (no connection string).""" + # Arrange + mock_service_client = mock.Mock() + mock_create_client.return_value.__enter__ = mock.Mock(return_value=mock_service_client) + mock_create_client.return_value.__exit__ = mock.Mock(return_value=False) + + mock_container_client = mock.Mock() + mock_service_client.get_container_client.return_value.__enter__ = mock.Mock( + return_value=mock_container_client + ) + mock_service_client.get_container_client.return_value.__exit__ = mock.Mock( + return_value=False + ) + + azure_backend = cast( + backends.AzureBlobStorageBackend, + backends.construct_storage_backend(uri='azure://mystorageaccount/mycontainer/path'), + ) + + # Act + azure_backend.data_auth( + access_key_id=None, + access_key=None, + ) + + # Assert + mock_create_client.assert_called_once_with( + connection_string=None, + account_url=azure_backend.endpoint, + ) + + if __name__ == '__main__': unittest.main() From 09e6d3595a0b4e7aa0f32e673f5885c644bc57ae Mon Sep 17 00:00:00 2001 From: Allen Greaves Date: Thu, 11 Dec 2025 12:53:14 -0800 Subject: [PATCH 5/6] chore(backends): add credential validation tests for S3-compatible backends --- src/lib/data/storage/backends/tests/BUILD | 1 + .../storage/backends/tests/test_backends.py | 83 +++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/src/lib/data/storage/backends/tests/BUILD b/src/lib/data/storage/backends/tests/BUILD index 646363f4..0362182a 100644 --- a/src/lib/data/storage/backends/tests/BUILD +++ b/src/lib/data/storage/backends/tests/BUILD @@ -24,5 +24,6 @@ osmo_py_test( deps = [ "//src/lib/data/storage/backends", "//src/lib/data/storage/core", + "//src/lib/utils:osmo_errors", ], ) diff --git a/src/lib/data/storage/backends/tests/test_backends.py b/src/lib/data/storage/backends/tests/test_backends.py index 9564e6ce..a579f6eb 100644 --- a/src/lib/data/storage/backends/tests/test_backends.py +++ b/src/lib/data/storage/backends/tests/test_backends.py @@ -24,6 +24,7 @@ from src.lib.data.storage.backends import azure, backends, s3 from src.lib.data.storage.core import client, header +from src.lib.utils import osmo_errors class TestBackends(unittest.TestCase): @@ -144,6 +145,88 @@ def test_path_backend_contains_sub_path(self): self.assertTrue(storage_backend_2 in storage_backend_1) self.assertTrue(storage_backend_1 not in storage_backend_2) + def test_s3_backend_requires_access_key_id(self): + """Test that S3 backend raises error when access_key_id is None.""" + # Arrange + s3_backend = cast( + backends.S3Backend, + backends.construct_storage_backend(uri='s3://test-bucket/path'), + ) + + # Act & Assert + with self.assertRaises(osmo_errors.OSMOCredentialError) as context: + s3_backend.client_factory( + access_key_id=None, + access_key='test-secret', + ) + self.assertIn('Access key ID', str(context.exception)) + + def test_s3_backend_requires_access_key(self): + """Test that S3 backend raises error when access_key is None.""" + # Arrange + s3_backend = cast( + backends.S3Backend, + backends.construct_storage_backend(uri='s3://test-bucket/path'), + ) + + # Act & Assert + with self.assertRaises(osmo_errors.OSMOCredentialError) as context: + s3_backend.client_factory( + access_key_id='test-key-id', + access_key=None, + ) + self.assertIn('secret access key', str(context.exception)) + + def test_swift_backend_requires_credentials(self): + """Test that Swift backend raises error when credentials are None.""" + # Arrange + swift_backend = cast( + backends.SwiftBackend, + backends.construct_storage_backend( + uri='swift://swift.example.com/AUTH_namespace/container/path' + ), + ) + + # Act & Assert + with self.assertRaises(osmo_errors.OSMOCredentialError) as context: + swift_backend.client_factory( + access_key_id=None, + access_key=None, + ) + self.assertIn('Access key ID', str(context.exception)) + + def test_gs_backend_requires_credentials(self): + """Test that GS backend raises error when credentials are None.""" + # Arrange + gs_backend = cast( + backends.GSBackend, + backends.construct_storage_backend(uri='gs://test-bucket/path'), + ) + + # Act & Assert + with self.assertRaises(osmo_errors.OSMOCredentialError) as context: + gs_backend.client_factory( + access_key_id=None, + access_key=None, + ) + self.assertIn('Access key ID', str(context.exception)) + + def test_tos_backend_requires_credentials(self): + """Test that TOS backend raises error when credentials are None.""" + # Arrange + tos_backend = cast( + backends.TOSBackend, + backends.construct_storage_backend(uri='tos://tos-s3-us-east-1.example.com/bucket/path'), + ) + + # Act & Assert + with self.assertRaises(osmo_errors.OSMOCredentialError) as context: + tos_backend.client_factory( + access_key_id=None, + access_key=None, + ) + self.assertIn('Access key ID', str(context.exception)) + class TestAzureBackend(unittest.TestCase): """ From f2d996cac958ce76894b2c077467e2702aed8ced Mon Sep 17 00:00:00 2001 From: Allen Greaves Date: Thu, 11 Dec 2025 20:54:15 -0800 Subject: [PATCH 6/6] feat(client_configs): implement workload identity credential handling and add unit tests --- src/cli/credential.py | 3 +- src/lib/data/storage/backends/backends.py | 16 ++- src/lib/data/storage/backends/common.py | 14 +++ src/lib/utils/client_configs.py | 37 ++++--- src/lib/utils/credentials.py | 9 ++ src/lib/utils/tests/BUILD | 10 ++ src/lib/utils/tests/test_client_configs.py | 110 +++++++++++++++++++++ src/service/core/workflow/objects.py | 16 +-- src/utils/connectors/postgres.py | 68 ++++++++----- src/utils/job/task.py | 8 ++ 10 files changed, 247 insertions(+), 44 deletions(-) create mode 100644 src/lib/utils/tests/test_client_configs.py diff --git a/src/cli/credential.py b/src/cli/credential.py index d32f5fd1..741bb90e 100644 --- a/src/cli/credential.py +++ b/src/cli/credential.py @@ -211,7 +211,8 @@ def setup_parser(parser: argparse._SubParsersAction): '+-----------------+---------------------------+---------------------------------------+\n' '| REGISTRY | auth | registry, username |\n' '+-----------------+---------------------------+---------------------------------------+\n' - '| DATA | access_key_id, access_key | endpoint, region (default: us-east-1) |\n' + '| DATA | endpoint | access_key_id, access_key, |\n' + '| | | region (default: us-east-1) |\n' '+-----------------+---------------------------+---------------------------------------+\n' '| GENERIC | | |\n' '+-----------------+---------------------------+---------------------------------------+\n' diff --git a/src/lib/data/storage/backends/backends.py b/src/lib/data/storage/backends/backends.py index 0413c3eb..2af138b7 100644 --- a/src/lib/data/storage/backends/backends.py +++ b/src/lib/data/storage/backends/backends.py @@ -935,7 +935,10 @@ def _validate_auth(): try: client.execute_api(_validate_auth, azure.AzureErrorHandler()) except client.OSMODataStorageClientError as err: - raise osmo_errors.OSMOCredentialError(f'Data auth validation error: {err}') + auth_method = 'connection string' if connection_string else 'DefaultAzureCredential' + raise osmo_errors.OSMOCredentialError( + f'Data auth validation error using {auth_method}: {err}' + ) @override def region( @@ -953,6 +956,17 @@ def region( def default_region(self) -> str: return constants.DEFAULT_AZURE_REGION + @override + @property + def supports_environment_auth(self) -> bool: + """ + Azure Blob Storage supports environment-based authentication. + + Returns: + bool: Always True for Azure Blob Storage. + """ + return True + @override def client_factory( self, diff --git a/src/lib/data/storage/backends/common.py b/src/lib/data/storage/backends/common.py index 97bdb391..50c6ba79 100644 --- a/src/lib/data/storage/backends/common.py +++ b/src/lib/data/storage/backends/common.py @@ -241,3 +241,17 @@ def client_factory( Returns a factory for creating storage clients. """ pass + + @property + def supports_environment_auth(self) -> bool: + """ + Returns whether this storage backend supports environment-based authentication. + + When True, the backend can authenticate without explicit access keys, + using credentials from the runtime environment (e.g., managed identities, + service account tokens, or credential chains). + + Returns: + bool: True if environment-based auth is supported, False otherwise. + """ + return False diff --git a/src/lib/utils/client_configs.py b/src/lib/utils/client_configs.py index cafe87d3..95fd7797 100644 --- a/src/lib/utils/client_configs.py +++ b/src/lib/utils/client_configs.py @@ -23,6 +23,7 @@ import yaml from . import cache, common, credentials, osmo_errors +from ..data import constants def get_client_config_dir() -> str: @@ -60,25 +61,35 @@ def get_cache_config() -> Optional[cache.CacheConfig]: @functools.lru_cache() def get_credentials(url: str) -> credentials.DataCredential: + """ + Get credentials for a storage profile. + + For storage backends that support environment-based authentication (e.g., + Azure DefaultAzureCredential, AWS IAM roles), the config may contain entries + with None access keys. This allows the SDK to use its credential chain + (CLI credentials, managed identity, workload identity, etc.). + """ osmo_directory = get_client_config_dir() password_file = osmo_directory + '/config.yaml' if os.path.isfile(password_file): with open(password_file, 'r', encoding='utf-8') as file: configs = yaml.safe_load(file.read()) - if url in configs['auth']['data']: - data_cred_dict = configs['auth']['data'][url] - data_cred = credentials.DataCredential( - access_key_id=data_cred_dict['access_key_id'], - access_key=data_cred_dict['access_key'], - endpoint=url, - region=data_cred_dict['region'], - ) - return data_cred - raise osmo_errors.OSMOError(f'Credential not set for {url}. Please set credentials using: \n' + - 'osmo credential set my_cred --type DATA ' + - '--payload access_key_id=your_s3_username access_key=your_s3_key' + - ' endpoint=your_endpoint region=endpoint_region') + if configs and 'auth' in configs and 'data' in configs['auth']: + if url in configs['auth']['data']: + data_cred_dict = configs['auth']['data'][url] + return credentials.DataCredential( + access_key_id=data_cred_dict.get('access_key_id'), + access_key=data_cred_dict.get('access_key'), + endpoint=url, + region=data_cred_dict.get('region', constants.DEFAULT_BOTO3_REGION), + ) + + raise osmo_errors.OSMOError( + f'Credential not set for {url}. Please set credentials using:\n' + 'osmo credential set my_cred --type DATA ' + '--payload access_key_id= access_key=' + ' endpoint= region=') def get_client_state_dir() -> str: diff --git a/src/lib/utils/credentials.py b/src/lib/utils/credentials.py index f39e3751..0fb6ac0a 100644 --- a/src/lib/utils/credentials.py +++ b/src/lib/utils/credentials.py @@ -46,6 +46,15 @@ class BasicDataCredential(pydantic.BaseModel, extra=pydantic.Extra.forbid): default=None, description='The authentication secret for the data service') + def get_access_key_id_value(self) -> str | None: + """ + Safely returns the access key ID value, or None if not set. + + This supports workload identity authentication where credentials + are obtained from the environment rather than explicit keys. + """ + return self.access_key_id if self.access_key_id else None + def get_access_key_value(self) -> str | None: """ Safely returns the access key secret value, or None if not set. diff --git a/src/lib/utils/tests/BUILD b/src/lib/utils/tests/BUILD index 8b4f6d88..2d1def98 100644 --- a/src/lib/utils/tests/BUILD +++ b/src/lib/utils/tests/BUILD @@ -50,3 +50,13 @@ osmo_py_test( "//src/lib/utils:credentials", ] ) + +osmo_py_test( + name = "test_client_configs", + srcs = ["test_client_configs.py"], + deps = [ + "//src/lib/utils:client_configs", + "//src/lib/utils:credentials", + "//src/lib/utils:osmo_errors", + ] +) diff --git a/src/lib/utils/tests/test_client_configs.py b/src/lib/utils/tests/test_client_configs.py new file mode 100644 index 00000000..25c32e1b --- /dev/null +++ b/src/lib/utils/tests/test_client_configs.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Tests for client_configs module. +""" + +import os +import tempfile +import unittest +from unittest import mock + +import yaml + +from src.lib.utils import client_configs, credentials, osmo_errors + + +class TestGetCredentials(unittest.TestCase): + """Tests for get_credentials function.""" + + def setUp(self): + """Clear LRU cache before each test.""" + client_configs.get_credentials.cache_clear() + + def test_get_credentials_with_none_access_keys(self): + """Test that credentials with None access keys are returned for environment-based auth.""" + with tempfile.TemporaryDirectory() as tmpdir: + config_file = os.path.join(tmpdir, 'config.yaml') + config = { + 'auth': { + 'data': { + 'azure://mystorageaccount': { + 'access_key_id': None, + 'access_key': None, + 'region': 'eastus', + } + } + } + } + with open(config_file, 'w', encoding='utf-8') as f: + yaml.dump(config, f) + + with mock.patch.object(client_configs, 'get_client_config_dir', return_value=tmpdir): + client_configs.get_credentials.cache_clear() + url = 'azure://mystorageaccount' + cred = client_configs.get_credentials(url) + + self.assertIsInstance(cred, credentials.DataCredential) + self.assertIsNone(cred.access_key_id) + self.assertIsNone(cred.get_access_key_value()) + self.assertEqual(cred.endpoint, url) + self.assertEqual(cred.region, 'eastus') + + def test_get_credentials_with_explicit_keys(self): + """Test that credentials with explicit access keys are returned correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + config_file = os.path.join(tmpdir, 'config.yaml') + config = { + 'auth': { + 'data': { + 's3://mybucket': { + 'access_key_id': 'my_key_id', + 'access_key': 'my_secret_key', + 'region': 'us-west-2', + } + } + } + } + with open(config_file, 'w', encoding='utf-8') as f: + yaml.dump(config, f) + + with mock.patch.object(client_configs, 'get_client_config_dir', return_value=tmpdir): + client_configs.get_credentials.cache_clear() + url = 's3://mybucket' + cred = client_configs.get_credentials(url) + + self.assertIsInstance(cred, credentials.DataCredential) + self.assertEqual(cred.access_key_id, 'my_key_id') + self.assertEqual(cred.get_access_key_value(), 'my_secret_key') + self.assertEqual(cred.endpoint, url) + self.assertEqual(cred.region, 'us-west-2') + + def test_get_credentials_missing_raises_error(self): + """Test that missing credentials raise error with helpful message.""" + with tempfile.TemporaryDirectory() as tmpdir: + with mock.patch.object(client_configs, 'get_client_config_dir', return_value=tmpdir): + client_configs.get_credentials.cache_clear() + url = 's3://mybucket' + with self.assertRaises(osmo_errors.OSMOError) as context: + client_configs.get_credentials(url) + + self.assertIn('Credential not set', str(context.exception)) + self.assertIn('osmo credential set', str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/service/core/workflow/objects.py b/src/service/core/workflow/objects.py index 43a5ea49..e2fa5f2b 100644 --- a/src/service/core/workflow/objects.py +++ b/src/service/core/workflow/objects.py @@ -574,17 +574,21 @@ def valid_cred(self, workflow_config: connectors.WorkflowConfig): class UserDataCredential(credentials.DataCredential, extra=pydantic.Extra.forbid): - """ Authentication information for a data service. """ - access_key: str = pydantic.Field( - description='The authentication secret for the data service') # type: ignore + """ + Authentication information for a data service. + + When access_key_id and access_key are None, the storage backend will use + environment-based authentication (e.g., Azure DefaultAzureCredential, + AWS IAM roles, managed identity, workload identity). + """ @staticmethod def type() -> connectors.CredentialType: return connectors.CredentialType.DATA def to_db_row(self, user: str, postgres: connectors.PostgresConnector) -> CredentialRecord: - payload = {'access_key_id': self.access_key_id, - 'access_key': self.access_key, + payload = {'access_key_id': self.get_access_key_id_value(), + 'access_key': self.get_access_key_value(), 'region': self.region} payload = postgres.encrypt_dict(payload, user) return CredentialRecord(self.type().value, @@ -595,7 +599,7 @@ def valid_cred(self, workflow_config: connectors.WorkflowConfig): storage_info = storage.construct_storage_backend(self.endpoint, True) if storage_info.scheme in workflow_config.credential_config.disable_data_validation: return - storage_info.data_auth(self.access_key_id, self.access_key, self.region) + storage_info.data_auth(self.get_access_key_id_value(), self.get_access_key_value(), self.region) class UserCredential(pydantic.BaseModel, extra=pydantic.Extra.forbid): diff --git a/src/utils/connectors/postgres.py b/src/utils/connectors/postgres.py index 9e8486a0..41dad078 100644 --- a/src/utils/connectors/postgres.py +++ b/src/utils/connectors/postgres.py @@ -501,6 +501,8 @@ def decrypt_credential(self, db_row) -> Dict: def encrypt_dict(self, input_dict: Dict, user: str) -> Dict: result = {} for key, value in input_dict.items(): + if value is None: + continue encrypted = self.secret_manager.encrypt(value, user) result[key] = encrypted.value return result @@ -1210,6 +1212,33 @@ def func(new_encrypted: str): self.execute_commit_command(cmd, (new_encrypted,)) return func + def _get_bucket_credential( + self, + bucket: 'BucketConfig', + bucket_info: storage.StorageBackend, + ) -> credentials.DecryptedDataCredential | None: + """ + Get credential for a bucket based on its configuration. + + Returns a DecryptedDataCredential if the bucket has a default credential + or supports environment-based authentication, otherwise returns None. + """ + if bucket.default_credential: + return credentials.DecryptedDataCredential( + region=bucket.region, + access_key_id=bucket.default_credential.access_key_id, + access_key=bucket.default_credential.get_access_key_value(), + endpoint=bucket_info.profile + ) + if bucket_info.supports_environment_auth: + return credentials.DecryptedDataCredential( + region=bucket.region, + access_key_id=None, + access_key=None, + endpoint=bucket_info.profile + ) + return None + def get_data_cred(self, user: str, profile: str) -> credentials.DecryptedDataCredential: """ Fetch data credentials by profile. """ select_data_cmd = PostgresSelectCommand( @@ -1219,22 +1248,18 @@ def get_data_cred(self, user: str, profile: str) -> credentials.DecryptedDataCre row = self.execute_fetch_command(*select_data_cmd.get_args()) if row: return credentials.DecryptedDataCredential(**self.decrypt_credential(row[0])) - else: - # Check default bucket creds - for bucket in self.get_dataset_configs().buckets.values(): - bucket_info = storage.construct_storage_backend(bucket.dataset_path) - if bucket_info.profile == profile: - if bucket.default_credential: - return credentials.DecryptedDataCredential( - region=bucket.region, - access_key_id=bucket.default_credential.access_key_id, - access_key=bucket.default_credential.get_access_key_value(), - endpoint=bucket_info.profile - ) - break - - raise osmo_errors.OSMOCredentialError( - f'Could not find {profile} credential for user {user}.') + + # Check bucket credentials + for bucket in self.get_dataset_configs().buckets.values(): + bucket_info = storage.construct_storage_backend(bucket.dataset_path) + if bucket_info.profile == profile: + cred = self._get_bucket_credential(bucket, bucket_info) + if cred: + return cred + break + + raise osmo_errors.OSMOCredentialError( + f'Could not find {profile} credential for user {user}.') def get_all_data_creds(self, user: str) -> Dict[str, credentials.DecryptedDataCredential]: """ Fetch all data credentials for user. """ @@ -1251,13 +1276,10 @@ def get_all_data_creds(self, user: str) -> Dict[str, credentials.DecryptedDataCr # Add default bucket creds for bucket in self.get_dataset_configs().buckets.values(): bucket_info = storage.construct_storage_backend(bucket.dataset_path) - if bucket_info.profile not in user_creds and bucket.default_credential: - user_creds[bucket_info.profile] = credentials.DecryptedDataCredential( - region=bucket.region, - access_key_id=bucket.default_credential.access_key_id, - access_key=bucket.default_credential.get_access_key_value(), - endpoint=bucket_info.profile - ) + if bucket_info.profile not in user_creds: + cred = self._get_bucket_credential(bucket, bucket_info) + if cred: + user_creds[bucket_info.profile] = cred return user_creds def get_generic_cred(self, user: str, cred_name: str) -> Any: diff --git a/src/utils/job/task.py b/src/utils/job/task.py index d66ba109..0aad02d2 100644 --- a/src/utils/job/task.py +++ b/src/utils/job/task.py @@ -2603,6 +2603,14 @@ def fetch_creds(user: str, data_creds: Dict[str, credentials.DecryptedDataCreden if not disabled_data or backend_info.scheme not in disabled_data: raise osmo_errors.OSMOCredentialError( f'Could not find {backend_info.profile} credential for user {user}.') + # For storage backends that support environment-based authentication, + # return credentials with None keys to allow the SDK to use its credential chain + if backend_info.supports_environment_auth: + return credentials.DecryptedDataCredential( + access_key_id=None, + access_key=None, + region=backend_info.default_region, + ).dict() return {} return data_creds[backend_info.profile].dict()