Skip to content

Improve run template UX #3602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/book/user-guide/best-practices/quick-wins.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ micro-setup (under 5 minutes) and any tips or gotchas to anticipate.
| [Slack/Discord alerts](#id-4-instant-alerter-notifications-for-successesfailures) | Instant notifications for pipeline events | Stay informed without checking dashboards |
| [Cron scheduling](#id-5-schedule-the-pipeline-on-a-cron) | Run pipelines automatically on schedule | Promote notebooks to production workflows |
| [Warm pools/resources](#id-6-kill-cold-starts-with-sagemaker-warm-pools--vertex-persistent-resources) | Eliminate cold starts in cloud environments | Reduce iteration time from minutes to seconds |
| [Secret management](#id-7-centralise-secrets-tokens-db-creds-s3-keys) | Centralize credentials and tokens | Keep sensitive data out of code |
| [Secret management](#id-7-centralize-secrets-tokens-db-creds-s3-keys) | Centralize credentials and tokens | Keep sensitive data out of code |
| [Local smoke tests](#id-8-run-smoke-tests-locally-before-going-to-the-cloud) | Faster iteration on Docker before cloud | Quick feedback without cloud waiting times |
| [Organize with tags](#id-9-organize-with-tags) | Classify and filter ML assets | Find and relate your ML assets with ease |
| [Git repo hooks](#id-10-hook-your-git-repo-to-every-run) | Track code state with every run | Perfect reproducibility and faster builds |
Expand Down Expand Up @@ -291,7 +291,7 @@ zenml stack update your_stack_name -s vertex_persistent

Learn more: [AWS SageMaker Orchestrator](https://docs.zenml.io/stacks/stack-components/orchestrators/sagemaker), [Google Cloud Vertex AI Step Operator](https://docs.zenml.io/stacks/stack-components/step-operators/vertex)

## 7 Centralise secrets (tokens, DB creds, S3 keys)
## 7 Centralize secrets (tokens, DB creds, S3 keys)

**Why** -- eliminate hardcoded credentials from your code and gain centralized control over sensitive information. Secrets management prevents exposing sensitive information in version control, enables secure credential rotation, and simplifies access management across environments.

Expand Down
53 changes: 16 additions & 37 deletions src/zenml/config/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from zenml import __version__
from zenml.config.base_settings import BaseSettings, ConfigurationLevel
from zenml.config.pipeline_configurations import PipelineConfiguration
from zenml.config.pipeline_run_configuration import PipelineRunConfiguration
from zenml.config.pipeline_spec import PipelineSpec
from zenml.config.settings_resolver import SettingsResolver
Expand All @@ -43,7 +44,6 @@
from zenml.utils import pydantic_utils, settings_utils

if TYPE_CHECKING:
from zenml.config.source import Source
from zenml.pipelines.pipeline_definition import Pipeline
from zenml.stack import Stack, StackComponent
from zenml.steps.step_invocation import StepInvocation
Expand Down Expand Up @@ -112,21 +112,12 @@ def compile(
with pipeline.__suppress_configure_warnings__():
pipeline.configure(settings=pipeline_settings, merge=False)

settings_to_passdown = {
key: settings
for key, settings in pipeline_settings.items()
if ConfigurationLevel.STEP in settings.LEVEL
}

steps = {
invocation_id: self._compile_step_invocation(
invocation=invocation,
pipeline_settings=settings_to_passdown,
pipeline_extra=pipeline.configuration.extra,
stack=stack,
step_config=run_configuration.steps.get(invocation_id),
pipeline_failure_hook_source=pipeline.configuration.failure_hook_source,
pipeline_success_hook_source=pipeline.configuration.success_hook_source,
pipeline_configuration=pipeline.configuration,
)
for invocation_id, invocation in self._get_sorted_invocations(
pipeline=pipeline
Expand Down Expand Up @@ -447,24 +438,17 @@ def _get_step_spec(
def _compile_step_invocation(
self,
invocation: "StepInvocation",
pipeline_settings: Dict[str, "BaseSettings"],
pipeline_extra: Dict[str, Any],
stack: "Stack",
step_config: Optional["StepConfigurationUpdate"],
pipeline_failure_hook_source: Optional["Source"] = None,
pipeline_success_hook_source: Optional["Source"] = None,
pipeline_configuration: "PipelineConfiguration",
) -> Step:
"""Compiles a ZenML step.

Args:
invocation: The step invocation to compile.
pipeline_settings: settings configured on the
pipeline of the step.
pipeline_extra: Extra values configured on the pipeline of the step.
stack: The stack on which the pipeline will be run.
step_config: Run configuration for the step.
pipeline_failure_hook_source: Source for the failure hook.
pipeline_success_hook_source: Source for the success hook.
pipeline_configuration: Configuration for the pipeline.

Returns:
The compiled step.
Expand All @@ -488,32 +472,27 @@ def _compile_step_invocation(
configuration_level=ConfigurationLevel.STEP,
stack=stack,
)
step_extra = step.configuration.extra
step_on_failure_hook_source = step.configuration.failure_hook_source
step_on_success_hook_source = step.configuration.success_hook_source

step.configure(
settings=pipeline_settings,
extra=pipeline_extra,
on_failure=pipeline_failure_hook_source,
on_success=pipeline_success_hook_source,
merge=False,
)
step.configure(
settings=step_settings,
extra=step_extra,
on_failure=step_on_failure_hook_source,
on_success=step_on_success_hook_source,
merge=True,
merge=False,
)

parameters_to_ignore = (
set(step_config.parameters) if step_config else set()
)
complete_step_configuration = invocation.finalize(
step_configuration_overrides = invocation.finalize(
parameters_to_ignore=parameters_to_ignore
)
return Step(spec=step_spec, config=complete_step_configuration)
full_step_config = (
step_configuration_overrides.apply_pipeline_configuration(
pipeline_configuration=pipeline_configuration
)
)
return Step(
spec=step_spec,
config=full_step_config,
step_config_overrides=step_configuration_overrides,
)

def _get_sorted_invocations(
self,
Expand Down
19 changes: 13 additions & 6 deletions src/zenml/config/pipeline_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,30 @@ class PipelineConfigurationUpdate(StrictBaseModel):
retry: Optional[StepRetryConfig] = None
substitutions: Dict[str, str] = {}

def _get_full_substitutions(
self, start_time: Optional[datetime]
def finalize_substitutions(
self, start_time: Optional[datetime] = None, inplace: bool = False
) -> Dict[str, str]:
"""Returns the full substitutions dict.

Args:
start_time: Start time of the pipeline run.
inplace: Whether to update the substitutions in place.

Returns:
The full substitutions dict including date and time.
"""
if start_time is None:
start_time = utc_now()
ret = self.substitutions.copy()
ret.setdefault("date", start_time.strftime("%Y_%m_%d"))
ret.setdefault("time", start_time.strftime("%H_%M_%S_%f"))
return ret

if inplace:
dict_ = self.substitutions
else:
dict_ = self.substitutions.copy()

dict_.setdefault("date", start_time.strftime("%Y_%m_%d"))
dict_.setdefault("time", start_time.strftime("%H_%M_%S_%f"))

return dict_


class PipelineConfiguration(PipelineConfigurationUpdate):
Expand Down
103 changes: 89 additions & 14 deletions src/zenml/config/step_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# permissions and limitations under the License.
"""Pipeline configuration classes."""

from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -46,7 +45,7 @@
from zenml.model.lazy_load import ModelVersionDataLazyLoader
from zenml.model.model import Model
from zenml.utils import deprecation_utils
from zenml.utils.pydantic_utils import before_validator_handler
from zenml.utils.pydantic_utils import before_validator_handler, update_model

if TYPE_CHECKING:
from zenml.config import DockerSettings, ResourceSettings
Expand Down Expand Up @@ -241,23 +240,43 @@ def docker_settings(self) -> "DockerSettings":
model_or_dict = model_or_dict.model_dump()
return DockerSettings.model_validate(model_or_dict)

def _get_full_substitutions(
self,
pipeline_config: "PipelineConfiguration",
start_time: Optional[datetime],
) -> Dict[str, str]:
"""Get the full set of substitutions for this step configuration.
def apply_pipeline_configuration(
self, pipeline_configuration: "PipelineConfiguration"
) -> "StepConfiguration":
"""Apply the pipeline configuration to this step configuration.

Args:
pipeline_config: The pipeline configuration.
start_time: The start time of the pipeline run.
pipeline_configuration: The pipeline configuration to apply.

Returns:
The full set of substitutions for this step configuration.
The updated step configuration.
"""
ret = pipeline_config._get_full_substitutions(start_time)
ret.update(self.substitutions)
return ret
pipeline_values = pipeline_configuration.model_dump(
include={
"settings",
"extra",
"failure_hook_source",
"success_hook_source",
"substitutions",
},
exclude_none=True,
)
if pipeline_values:
original_values = self.model_dump(
include={
"settings",
"extra",
"failure_hook_source",
"success_hook_source",
"substitutions",
},
exclude_none=True,
)

updated_config = self.model_copy(update=pipeline_values, deep=True)
return update_model(updated_config, original_values)
else:
return self.model_copy(deep=True)


class InputSpec(StrictBaseModel):
Expand Down Expand Up @@ -311,3 +330,59 @@ class Step(StrictBaseModel):

spec: StepSpec
config: StepConfiguration
step_config_overrides: StepConfiguration

@model_validator(mode="before")
@classmethod
@before_validator_handler
def _add_step_config_overrides_if_missing(cls, data: Any) -> Any:
"""Add step config overrides if missing.

This is to ensure backwards compatibility with data stored in the DB
before the `step_config_overrides` field was added. In that case, only
the `config` field, which contains the merged pipeline and step configs,
existed. We have no way to figure out which of those values were defined
on the step vs the pipeline level, so we just use the entire `config`
object as the `step_config_overrides`.

Args:
data: The values dict used to instantiate the model.

Returns:
The values dict with the step config overrides added if missing.
"""
if "step_config_overrides" not in data:
data["step_config_overrides"] = data["config"]

return data

@classmethod
def from_dict(
cls,
data: Dict[str, Any],
pipeline_configuration: "PipelineConfiguration",
) -> "Step":
"""Create a step from a dictionary.

This method can create a step from data stored without the merged
`config` attribute, by merging the `step_config_overrides` with the
pipeline configuration.

Args:
data: The dictionary to create the `Step` object from.
pipeline_configuration: The pipeline configuration to apply to the
step configuration.

Returns:
The instantiated object.
"""
if "config" not in data:
config = StepConfiguration.model_validate(
data["step_config_overrides"]
)
config = config.apply_pipeline_configuration(
pipeline_configuration
)
data["config"] = config

return cls.model_validate(data)
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
)
ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING = "ZENML_PREVENT_CLIENT_SIDE_CACHING"
ENV_ZENML_DISABLE_CREDENTIALS_DISK_CACHING = "DISABLE_CREDENTIALS_DISK_CACHING"
ENV_ZENML_RUNNER_PARENT_IMAGE = "ZENML_RUNNER_PARENT_IMAGE"
ENV_ZENML_RUNNER_IMAGE_DISABLE_UV = "ZENML_RUNNER_IMAGE_DISABLE_UV"
ENV_ZENML_RUNNER_POD_TIMEOUT = "ZENML_RUNNER_POD_TIMEOUT"
ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY = (
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]:
start_time = utc_now()
run_name = string_utils.format_name_template(
name_template=self._deployment.run_name_template,
substitutions=self._deployment.pipeline_configuration._get_full_substitutions(
start_time
substitutions=self._deployment.pipeline_configuration.finalize_substitutions(
start_time=start_time,
),
)

Expand Down
5 changes: 3 additions & 2 deletions src/zenml/pipelines/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def create_placeholder_run(
run_request = PipelineRunRequest(
name=string_utils.format_name_template(
name_template=deployment.run_name_template,
substitutions=deployment.pipeline_configuration._get_full_substitutions(
start_time
substitutions=deployment.pipeline_configuration.finalize_substitutions(
start_time=start_time,
),
),
# We set the start time on the placeholder run already to
Expand Down Expand Up @@ -113,6 +113,7 @@ def get_placeholder_run(
size=1,
deployment_id=deployment_id,
status=ExecutionStatus.INITIALIZING,
hydrate=True,
)
if len(runs.items) == 0:
return None
Expand Down
9 changes: 6 additions & 3 deletions src/zenml/stack/stack_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,13 @@ def get_settings(
else container.pipeline_configuration.settings
)

# Use the current config as a base
settings_dict = self.config.model_dump()

if key in all_settings:
return self.settings_class.model_validate(dict(all_settings[key]))
else:
return self.settings_class()
settings_dict.update(dict(all_settings[key]))

return self.settings_class.model_validate(settings_dict)

def connector_has_expired(self) -> bool:
"""Checks whether the connector linked to this stack component has expired.
Expand Down
Loading
Loading