diff --git a/docs/book/user-guide/best-practices/quick-wins.md b/docs/book/user-guide/best-practices/quick-wins.md index d032f76e60d..337dae4af12 100644 --- a/docs/book/user-guide/best-practices/quick-wins.md +++ b/docs/book/user-guide/best-practices/quick-wins.md @@ -17,7 +17,7 @@ micro-setup (under 5 minutes) and any tips or gotchas to anticipate. | [Slack/Discord alerts](#id-4-instant-alerter-notifications-for-successesfailures) | Instant notifications for pipeline events | Stay informed without checking dashboards | | [Cron scheduling](#id-5-schedule-the-pipeline-on-a-cron) | Run pipelines automatically on schedule | Promote notebooks to production workflows | | [Warm pools/resources](#id-6-kill-cold-starts-with-sagemaker-warm-pools--vertex-persistent-resources) | Eliminate cold starts in cloud environments | Reduce iteration time from minutes to seconds | -| [Secret management](#id-7-centralise-secrets-tokens-db-creds-s3-keys) | Centralize credentials and tokens | Keep sensitive data out of code | +| [Secret management](#id-7-centralize-secrets-tokens-db-creds-s3-keys) | Centralize credentials and tokens | Keep sensitive data out of code | | [Local smoke tests](#id-8-run-smoke-tests-locally-before-going-to-the-cloud) | Faster iteration on Docker before cloud | Quick feedback without cloud waiting times | | [Organize with tags](#id-9-organize-with-tags) | Classify and filter ML assets | Find and relate your ML assets with ease | | [Git repo hooks](#id-10-hook-your-git-repo-to-every-run) | Track code state with every run | Perfect reproducibility and faster builds | @@ -291,7 +291,7 @@ zenml stack update your_stack_name -s vertex_persistent Learn more: [AWS SageMaker Orchestrator](https://docs.zenml.io/stacks/stack-components/orchestrators/sagemaker), [Google Cloud Vertex AI Step Operator](https://docs.zenml.io/stacks/stack-components/step-operators/vertex) -## 7 Centralise secrets (tokens, DB creds, S3 keys) +## 7 Centralize secrets (tokens, DB creds, S3 keys) **Why** -- eliminate hardcoded credentials from your code and gain centralized control over sensitive information. Secrets management prevents exposing sensitive information in version control, enables secure credential rotation, and simplifies access management across environments. diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index 7ccec3d5a11..0a3ef796c43 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -27,6 +27,7 @@ from zenml import __version__ from zenml.config.base_settings import BaseSettings, ConfigurationLevel +from zenml.config.pipeline_configurations import PipelineConfiguration from zenml.config.pipeline_run_configuration import PipelineRunConfiguration from zenml.config.pipeline_spec import PipelineSpec from zenml.config.settings_resolver import SettingsResolver @@ -43,7 +44,6 @@ from zenml.utils import pydantic_utils, settings_utils if TYPE_CHECKING: - from zenml.config.source import Source from zenml.pipelines.pipeline_definition import Pipeline from zenml.stack import Stack, StackComponent from zenml.steps.step_invocation import StepInvocation @@ -112,21 +112,12 @@ def compile( with pipeline.__suppress_configure_warnings__(): pipeline.configure(settings=pipeline_settings, merge=False) - settings_to_passdown = { - key: settings - for key, settings in pipeline_settings.items() - if ConfigurationLevel.STEP in settings.LEVEL - } - steps = { invocation_id: self._compile_step_invocation( invocation=invocation, - pipeline_settings=settings_to_passdown, - pipeline_extra=pipeline.configuration.extra, stack=stack, step_config=run_configuration.steps.get(invocation_id), - pipeline_failure_hook_source=pipeline.configuration.failure_hook_source, - pipeline_success_hook_source=pipeline.configuration.success_hook_source, + pipeline_configuration=pipeline.configuration, ) for invocation_id, invocation in self._get_sorted_invocations( pipeline=pipeline @@ -447,24 +438,17 @@ def _get_step_spec( def _compile_step_invocation( self, invocation: "StepInvocation", - pipeline_settings: Dict[str, "BaseSettings"], - pipeline_extra: Dict[str, Any], stack: "Stack", step_config: Optional["StepConfigurationUpdate"], - pipeline_failure_hook_source: Optional["Source"] = None, - pipeline_success_hook_source: Optional["Source"] = None, + pipeline_configuration: "PipelineConfiguration", ) -> Step: """Compiles a ZenML step. Args: invocation: The step invocation to compile. - pipeline_settings: settings configured on the - pipeline of the step. - pipeline_extra: Extra values configured on the pipeline of the step. stack: The stack on which the pipeline will be run. step_config: Run configuration for the step. - pipeline_failure_hook_source: Source for the failure hook. - pipeline_success_hook_source: Source for the success hook. + pipeline_configuration: Configuration for the pipeline. Returns: The compiled step. @@ -488,32 +472,27 @@ def _compile_step_invocation( configuration_level=ConfigurationLevel.STEP, stack=stack, ) - step_extra = step.configuration.extra - step_on_failure_hook_source = step.configuration.failure_hook_source - step_on_success_hook_source = step.configuration.success_hook_source - - step.configure( - settings=pipeline_settings, - extra=pipeline_extra, - on_failure=pipeline_failure_hook_source, - on_success=pipeline_success_hook_source, - merge=False, - ) step.configure( settings=step_settings, - extra=step_extra, - on_failure=step_on_failure_hook_source, - on_success=step_on_success_hook_source, - merge=True, + merge=False, ) parameters_to_ignore = ( set(step_config.parameters) if step_config else set() ) - complete_step_configuration = invocation.finalize( + step_configuration_overrides = invocation.finalize( parameters_to_ignore=parameters_to_ignore ) - return Step(spec=step_spec, config=complete_step_configuration) + full_step_config = ( + step_configuration_overrides.apply_pipeline_configuration( + pipeline_configuration=pipeline_configuration + ) + ) + return Step( + spec=step_spec, + config=full_step_config, + step_config_overrides=step_configuration_overrides, + ) def _get_sorted_invocations( self, diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index fdd3ff93021..8d7910fd93b 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -52,23 +52,30 @@ class PipelineConfigurationUpdate(StrictBaseModel): retry: Optional[StepRetryConfig] = None substitutions: Dict[str, str] = {} - def _get_full_substitutions( - self, start_time: Optional[datetime] + def finalize_substitutions( + self, start_time: Optional[datetime] = None, inplace: bool = False ) -> Dict[str, str]: """Returns the full substitutions dict. Args: start_time: Start time of the pipeline run. + inplace: Whether to update the substitutions in place. Returns: The full substitutions dict including date and time. """ if start_time is None: start_time = utc_now() - ret = self.substitutions.copy() - ret.setdefault("date", start_time.strftime("%Y_%m_%d")) - ret.setdefault("time", start_time.strftime("%H_%M_%S_%f")) - return ret + + if inplace: + dict_ = self.substitutions + else: + dict_ = self.substitutions.copy() + + dict_.setdefault("date", start_time.strftime("%Y_%m_%d")) + dict_.setdefault("time", start_time.strftime("%H_%M_%S_%f")) + + return dict_ class PipelineConfiguration(PipelineConfigurationUpdate): diff --git a/src/zenml/config/step_configurations.py b/src/zenml/config/step_configurations.py index 1f87df05b45..6470f34ae6e 100644 --- a/src/zenml/config/step_configurations.py +++ b/src/zenml/config/step_configurations.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Pipeline configuration classes.""" -from datetime import datetime from typing import ( TYPE_CHECKING, Any, @@ -46,7 +45,7 @@ from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model import Model from zenml.utils import deprecation_utils -from zenml.utils.pydantic_utils import before_validator_handler +from zenml.utils.pydantic_utils import before_validator_handler, update_model if TYPE_CHECKING: from zenml.config import DockerSettings, ResourceSettings @@ -241,23 +240,43 @@ def docker_settings(self) -> "DockerSettings": model_or_dict = model_or_dict.model_dump() return DockerSettings.model_validate(model_or_dict) - def _get_full_substitutions( - self, - pipeline_config: "PipelineConfiguration", - start_time: Optional[datetime], - ) -> Dict[str, str]: - """Get the full set of substitutions for this step configuration. + def apply_pipeline_configuration( + self, pipeline_configuration: "PipelineConfiguration" + ) -> "StepConfiguration": + """Apply the pipeline configuration to this step configuration. Args: - pipeline_config: The pipeline configuration. - start_time: The start time of the pipeline run. + pipeline_configuration: The pipeline configuration to apply. Returns: - The full set of substitutions for this step configuration. + The updated step configuration. """ - ret = pipeline_config._get_full_substitutions(start_time) - ret.update(self.substitutions) - return ret + pipeline_values = pipeline_configuration.model_dump( + include={ + "settings", + "extra", + "failure_hook_source", + "success_hook_source", + "substitutions", + }, + exclude_none=True, + ) + if pipeline_values: + original_values = self.model_dump( + include={ + "settings", + "extra", + "failure_hook_source", + "success_hook_source", + "substitutions", + }, + exclude_none=True, + ) + + updated_config = self.model_copy(update=pipeline_values, deep=True) + return update_model(updated_config, original_values) + else: + return self.model_copy(deep=True) class InputSpec(StrictBaseModel): @@ -311,3 +330,59 @@ class Step(StrictBaseModel): spec: StepSpec config: StepConfiguration + step_config_overrides: StepConfiguration + + @model_validator(mode="before") + @classmethod + @before_validator_handler + def _add_step_config_overrides_if_missing(cls, data: Any) -> Any: + """Add step config overrides if missing. + + This is to ensure backwards compatibility with data stored in the DB + before the `step_config_overrides` field was added. In that case, only + the `config` field, which contains the merged pipeline and step configs, + existed. We have no way to figure out which of those values were defined + on the step vs the pipeline level, so we just use the entire `config` + object as the `step_config_overrides`. + + Args: + data: The values dict used to instantiate the model. + + Returns: + The values dict with the step config overrides added if missing. + """ + if "step_config_overrides" not in data: + data["step_config_overrides"] = data["config"] + + return data + + @classmethod + def from_dict( + cls, + data: Dict[str, Any], + pipeline_configuration: "PipelineConfiguration", + ) -> "Step": + """Create a step from a dictionary. + + This method can create a step from data stored without the merged + `config` attribute, by merging the `step_config_overrides` with the + pipeline configuration. + + Args: + data: The dictionary to create the `Step` object from. + pipeline_configuration: The pipeline configuration to apply to the + step configuration. + + Returns: + The instantiated object. + """ + if "config" not in data: + config = StepConfiguration.model_validate( + data["step_config_overrides"] + ) + config = config.apply_pipeline_configuration( + pipeline_configuration + ) + data["config"] = config + + return cls.model_validate(data) diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 1cf36b74410..adfc80795b3 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -199,6 +199,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ) ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING = "ZENML_PREVENT_CLIENT_SIDE_CACHING" ENV_ZENML_DISABLE_CREDENTIALS_DISK_CACHING = "DISABLE_CREDENTIALS_DISK_CACHING" +ENV_ZENML_RUNNER_PARENT_IMAGE = "ZENML_RUNNER_PARENT_IMAGE" ENV_ZENML_RUNNER_IMAGE_DISABLE_UV = "ZENML_RUNNER_IMAGE_DISABLE_UV" ENV_ZENML_RUNNER_POD_TIMEOUT = "ZENML_RUNNER_POD_TIMEOUT" ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY = ( diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index c62214e2dd5..fce5e77ac27 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -308,8 +308,8 @@ def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]: start_time = utc_now() run_name = string_utils.format_name_template( name_template=self._deployment.run_name_template, - substitutions=self._deployment.pipeline_configuration._get_full_substitutions( - start_time + substitutions=self._deployment.pipeline_configuration.finalize_substitutions( + start_time=start_time, ), ) diff --git a/src/zenml/pipelines/run_utils.py b/src/zenml/pipelines/run_utils.py index 2d4f8c6ff6c..990c2ac7dd5 100644 --- a/src/zenml/pipelines/run_utils.py +++ b/src/zenml/pipelines/run_utils.py @@ -72,8 +72,8 @@ def create_placeholder_run( run_request = PipelineRunRequest( name=string_utils.format_name_template( name_template=deployment.run_name_template, - substitutions=deployment.pipeline_configuration._get_full_substitutions( - start_time + substitutions=deployment.pipeline_configuration.finalize_substitutions( + start_time=start_time, ), ), # We set the start time on the placeholder run already to @@ -113,6 +113,7 @@ def get_placeholder_run( size=1, deployment_id=deployment_id, status=ExecutionStatus.INITIALIZING, + hydrate=True, ) if len(runs.items) == 0: return None diff --git a/src/zenml/stack/stack_component.py b/src/zenml/stack/stack_component.py index 1f2f4b81241..18222ff61a3 100644 --- a/src/zenml/stack/stack_component.py +++ b/src/zenml/stack/stack_component.py @@ -529,10 +529,13 @@ def get_settings( else container.pipeline_configuration.settings ) + # Use the current config as a base + settings_dict = self.config.model_dump() + if key in all_settings: - return self.settings_class.model_validate(dict(all_settings[key])) - else: - return self.settings_class() + settings_dict.update(dict(all_settings[key])) + + return self.settings_class.model_validate(settings_dict) def connector_has_expired(self) -> bool: """Checks whether the connector linked to this stack component has expired. diff --git a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py index 97ed8e29823..ff762ddf76e 100644 --- a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py @@ -13,18 +13,16 @@ # permissions and limitations under the License. """Endpoint definitions for deployments.""" -from typing import Optional, Union +from typing import Any, Optional, Union from uuid import UUID -from fastapi import APIRouter, Depends, Security +from fastapi import APIRouter, Depends, Request, Security from zenml.constants import API, PIPELINE_DEPLOYMENTS, VERSION_1 from zenml.logging.step_logging import fetch_logs from zenml.models import ( - Page, PipelineDeploymentFilter, PipelineDeploymentRequest, - PipelineDeploymentResponse, PipelineRunFilter, ) from zenml.zen_server.auth import AuthContext, authorize @@ -45,6 +43,37 @@ zen_store, ) + +# TODO: Remove this as soon as there is only a low number of users running +# versions < 0.82.0. Also change back the return type to +# `PipelineDeploymentResponse` once we have removed the `exclude` logic. +def _should_remove_step_config_overrides( + request: Request, +) -> bool: + """Check if the step config overrides should be removed from the response. + + Args: + request: The request object. + + Returns: + If the step config overrides should be removed from the response. + """ + from packaging import version + + user_agent = request.headers.get("User-Agent", "") + + if not user_agent.startswith("zenml/"): + # This request is not coming from a ZenML client + return False + + client_version = version.parse(user_agent.removeprefix("zenml/")) + + # Versions before 0.82.0 did have `extra="forbid"` in the pydantic model + # that stores the step configurations. This means it would crash if we + # included the `step_config_overrides` in the response. + return client_version < version.parse("0.82.0") + + router = APIRouter( prefix=API + VERSION_1 + PIPELINE_DEPLOYMENTS, tags=["deployments"], @@ -66,13 +95,15 @@ ) @handle_exceptions def create_deployment( + request: Request, deployment: PipelineDeploymentRequest, project_name_or_id: Optional[Union[str, UUID]] = None, _: AuthContext = Security(authorize), -) -> PipelineDeploymentResponse: +) -> Any: """Creates a deployment. Args: + request: The request object. deployment: Deployment to create. project_name_or_id: Optional name or ID of the project. @@ -83,11 +114,21 @@ def create_deployment( project = zen_store().get_project(project_name_or_id) deployment.project = project.id - return verify_permissions_and_create_entity( + deployment_response = verify_permissions_and_create_entity( request_model=deployment, create_method=zen_store().create_deployment, ) + exclude = None + if _should_remove_step_config_overrides(request): + exclude = { + "metadata": { + "step_configurations": {"__all__": {"step_config_overrides"}} + } + } + + return deployment_response.model_dump(mode="json", exclude=exclude) + @router.get( "", @@ -103,16 +144,18 @@ def create_deployment( ) @handle_exceptions def list_deployments( + request: Request, deployment_filter_model: PipelineDeploymentFilter = Depends( make_dependable(PipelineDeploymentFilter) ), project_name_or_id: Optional[Union[str, UUID]] = None, hydrate: bool = False, _: AuthContext = Security(authorize), -) -> Page[PipelineDeploymentResponse]: +) -> Any: """Gets a list of deployments. Args: + request: The request object. deployment_filter_model: Filter model used for pagination, sorting, filtering. project_name_or_id: Optional name or ID of the project to filter by. @@ -125,13 +168,29 @@ def list_deployments( if project_name_or_id: deployment_filter_model.project = project_name_or_id - return verify_permissions_and_list_entities( + page = verify_permissions_and_list_entities( filter_model=deployment_filter_model, resource_type=ResourceType.PIPELINE_DEPLOYMENT, list_method=zen_store().list_deployments, hydrate=hydrate, ) + exclude = None + if _should_remove_step_config_overrides(request): + exclude = { + "items": { + "__all__": { + "metadata": { + "step_configurations": { + "__all__": {"step_config_overrides"} + } + } + } + } + } + + return page.model_dump(mode="json", exclude=exclude) + @router.get( "/{deployment_id}", @@ -139,13 +198,15 @@ def list_deployments( ) @handle_exceptions def get_deployment( + request: Request, deployment_id: UUID, hydrate: bool = True, _: AuthContext = Security(authorize), -) -> PipelineDeploymentResponse: +) -> Any: """Gets a specific deployment using its unique id. Args: + request: The request object. deployment_id: ID of the deployment to get. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -153,12 +214,22 @@ def get_deployment( Returns: A specific deployment object. """ - return verify_permissions_and_get_entity( + deployment = verify_permissions_and_get_entity( id=deployment_id, get_method=zen_store().get_deployment, hydrate=hydrate, ) + exclude = None + if _should_remove_step_config_overrides(request): + exclude = { + "metadata": { + "step_configurations": {"__all__": {"step_config_overrides"}} + } + } + + return deployment.model_dump(mode="json", exclude=exclude) + @router.delete( "/{deployment_id}", diff --git a/src/zenml/zen_server/template_execution/utils.py b/src/zenml/zen_server/template_execution/utils.py index f194d0e8a9e..4c3d0144d19 100644 --- a/src/zenml/zen_server/template_execution/utils.py +++ b/src/zenml/zen_server/template_execution/utils.py @@ -1,7 +1,7 @@ """Utility functions to run a pipeline from the server.""" -import copy import hashlib +import os import sys import threading from concurrent.futures import Future, ThreadPoolExecutor @@ -17,11 +17,12 @@ from zenml.config.pipeline_run_configuration import ( PipelineRunConfiguration, ) -from zenml.config.step_configurations import Step, StepConfiguration +from zenml.config.step_configurations import Step, StepConfigurationUpdate from zenml.constants import ( ENV_ZENML_ACTIVE_PROJECT_ID, ENV_ZENML_ACTIVE_STACK_ID, ENV_ZENML_RUNNER_IMAGE_DISABLE_UV, + ENV_ZENML_RUNNER_PARENT_IMAGE, ENV_ZENML_RUNNER_POD_TIMEOUT, handle_bool_env_var, handle_int_env_var, @@ -47,7 +48,7 @@ validate_stack_is_runnable_from_server, ) from zenml.stack.flavor import Flavor -from zenml.utils import dict_utils, requirements_utils, settings_utils +from zenml.utils import pydantic_utils, requirements_utils, settings_utils from zenml.zen_server.auth import AuthContext, generate_access_token from zenml.zen_server.template_execution.runner_entrypoint_configuration import ( RunnerEntrypointConfiguration, @@ -172,7 +173,6 @@ def run_template( deployment_request = deployment_request_from_template( template=template, config=run_config or PipelineRunConfiguration(), - user_id=auth_context.user.id, ) ensure_async_orchestrator(deployment=deployment_request, stack=stack) @@ -386,7 +386,10 @@ def generate_dockerfile( Returns: The Dockerfile. """ - parent_image = f"zenmldocker/zenml:{zenml_version}-py{python_version}" + parent_image = os.environ.get( + ENV_ZENML_RUNNER_PARENT_IMAGE, + f"zenmldocker/zenml:{zenml_version}-py{python_version}", + ) lines = [f"FROM {parent_image}"] if apt_packages: @@ -420,14 +423,12 @@ def generate_dockerfile( def deployment_request_from_template( template: RunTemplateResponse, config: PipelineRunConfiguration, - user_id: UUID, ) -> "PipelineDeploymentRequest": """Generate a deployment request from a template. Args: template: The template from which to create the deployment request. config: The run configuration. - user_id: ID of the user that is trying to run the template. Raises: ValueError: If there are missing/extra step parameters in the run @@ -438,49 +439,37 @@ def deployment_request_from_template( """ deployment = template.source_deployment assert deployment - pipeline_configuration = PipelineConfiguration( - **config.model_dump( - include=set(PipelineConfiguration.model_fields), - exclude={"name", "parameters"}, - ), - name=deployment.pipeline_configuration.name, - parameters=deployment.pipeline_configuration.parameters, - ) - step_config_dict_base = pipeline_configuration.model_dump( - exclude={"name", "parameters", "tags", "enable_pipeline_logs"} + pipeline_update = config.model_dump( + include=set(PipelineConfiguration.model_fields), + exclude={"name", "parameters"}, + exclude_unset=True, + exclude_none=True, + ) + pipeline_configuration = pydantic_utils.update_model( + deployment.pipeline_configuration, pipeline_update ) + steps = {} for invocation_id, step in deployment.step_configurations.items(): - step_config_dict = { - **copy.deepcopy(step_config_dict_base), - **step.config.model_dump( - # TODO: Maybe we need to make some of these configurable via - # yaml as well, e.g. the lazy loaders? - include={ - "name", - "caching_parameters", - "external_input_artifacts", - "model_artifacts_or_metadata", - "client_lazy_loaders", - "substitutions", - "outputs", - } - ), - } - - required_parameters = set(step.config.parameters) - configured_parameters = set() - - if update := config.steps.get(invocation_id): - update_dict = update.model_dump() + step_update = config.steps.get( + invocation_id, StepConfigurationUpdate() + ).model_dump( # Get rid of deprecated name to prevent overriding the step name # with `None`. - update_dict.pop("name", None) - configured_parameters = set(update.parameters) - step_config_dict = dict_utils.recursive_update( - step_config_dict, update=update_dict - ) + exclude={"name"}, + exclude_unset=True, + exclude_none=True, + ) + step_config = pydantic_utils.update_model( + step.step_config_overrides, step_update + ) + merged_step_config = step_config.apply_pipeline_configuration( + pipeline_configuration + ) + + required_parameters = set(step.config.parameters) + configured_parameters = set(step_config.parameters) unknown_parameters = configured_parameters - required_parameters if unknown_parameters: @@ -496,8 +485,11 @@ def deployment_request_from_template( f"parameters for step {invocation_id}: {missing_parameters}." ) - step_config = StepConfiguration.model_validate(step_config_dict) - steps[invocation_id] = Step(spec=step.spec, config=step_config) + steps[invocation_id] = Step( + spec=step.spec, + config=merged_step_config, + step_config_overrides=step_config, + ) code_reference_request = None if deployment.code_reference: diff --git a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py index 80d827f00f9..7abe95bec6e 100644 --- a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py @@ -32,7 +32,6 @@ PipelineDeploymentResponseBody, PipelineDeploymentResponseMetadata, ) -from zenml.utils.json_utils import pydantic_encoder from zenml.zen_stores.schemas.base_schemas import BaseSchema from zenml.zen_stores.schemas.code_repository_schemas import ( CodeReferenceSchema, @@ -188,6 +187,14 @@ def from_request( Returns: The created `PipelineDeploymentSchema`. """ + # Don't include the merged config in the step configurations, we + # reconstruct it in the `to_model` method using the pipeline + # configuration. + step_configurations = { + invocation_id: step.model_dump(mode="json", exclude={"config"}) + for invocation_id, step in request.step_configurations.items() + } + client_env = json.dumps(request.client_environment) if len(client_env) > TEXT_FIELD_MAX_LENGTH: logger.warning( @@ -208,9 +215,8 @@ def from_request( run_name_template=request.run_name_template, pipeline_configuration=request.pipeline_configuration.model_dump_json(), step_configurations=json.dumps( - request.step_configurations, + step_configurations, sort_keys=False, - default=pydantic_encoder, ), client_environment=client_env, client_version=request.client_version, @@ -254,8 +260,10 @@ def to_model( self.pipeline_configuration ) step_configurations = json.loads(self.step_configurations) - for s, c in step_configurations.items(): - step_configurations[s] = Step.model_validate(c) + for invocation_id, step in step_configurations.items(): + step_configurations[invocation_id] = Step.from_dict( + step, pipeline_configuration + ) client_environment = json.loads(self.client_environment) if not include_python_packages: diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 00465141492..afff33ab812 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -317,10 +317,6 @@ def to_model( ) config = deployment.pipeline_configuration - new_substitutions = config._get_full_substitutions(self.start_time) - config = config.model_copy( - update={"substitutions": new_substitutions} - ) client_environment = deployment.client_environment stack = deployment.stack @@ -352,6 +348,8 @@ def to_model( "pipeline_configuration." ) + config.finalize_substitutions(start_time=self.start_time, inplace=True) + body = PipelineRunResponseBody( user=self.user.to_model() if self.user else None, status=ExecutionStatus(self.status), diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 9917e2bb67f..032a024e87f 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -232,9 +232,7 @@ def to_model( The created StepRunResponse. Raises: - ValueError: In case the step run configuration can not be loaded. - RuntimeError: If the step run schema does not have a deployment_id - or a step_configuration. + ValueError: In case the step run configuration is missing. """ input_artifacts = { artifact.name: StepRunInputResponse( @@ -252,49 +250,38 @@ def to_model( artifact.artifact_version.to_model() ) - full_step_config = None + step = None if self.deployment is not None: - step_configuration = json.loads( + step_configurations = json.loads( self.deployment.step_configurations ) - if self.name in step_configuration: - full_step_config = Step.model_validate( - step_configuration[self.name] - ) - new_substitutions = ( - full_step_config.config._get_full_substitutions( - PipelineConfiguration.model_validate_json( - self.deployment.pipeline_configuration - ), - self.pipeline_run.start_time, + if self.name in step_configurations: + pipeline_configuration = ( + PipelineConfiguration.model_validate_json( + self.deployment.pipeline_configuration ) ) - full_step_config = full_step_config.model_copy( - update={ - "config": full_step_config.config.model_copy( - update={"substitutions": new_substitutions} - ) - } - ) - elif not self.step_configuration: - raise ValueError( - f"Unable to load the configuration for step `{self.name}` from the" - f"database. To solve this please delete the pipeline run that this" - f"step run belongs to. Pipeline Run ID: `{self.pipeline_run_id}`." + pipeline_configuration.finalize_substitutions( + start_time=self.pipeline_run.start_time, + inplace=True, ) - - # the step configuration moved into the deployment - the following case is to ensure - # backwards compatibility - if full_step_config is None: - if self.step_configuration: - full_step_config = Step.model_validate_json( - self.step_configuration - ) - else: - raise RuntimeError( - "Step run model creation has failed. Each step run entry " - "should either have a deployment_id or step_configuration." + step = Step.from_dict( + step_configurations[self.name], + pipeline_configuration=pipeline_configuration, ) + if not step and self.step_configuration: + # In this legacy case, we're guaranteed to have the merged + # config stored in the DB, which means we can instantiate the + # `Step` object directly without passing the pipeline + # configuration. + step = Step.model_validate_json(self.step_configuration) + elif not step: + raise ValueError( + f"Unable to load the configuration for step `{self.name}` from " + "the database. To solve this please delete the pipeline run " + "that this step run belongs to. Pipeline Run ID: " + f"`{self.pipeline_run_id}`." + ) body = StepRunResponseBody( user=self.user.to_model() if self.user else None, @@ -311,8 +298,8 @@ def to_model( if include_metadata: metadata = StepRunResponseMetadata( project=self.project.to_model(), - config=full_step_config.config, - spec=full_step_config.spec, + config=step.config, + spec=step.spec, cache_key=self.cache_key, code_hash=self.code_hash, docstring=self.docstring, diff --git a/src/zenml/zen_stores/template_utils.py b/src/zenml/zen_stores/template_utils.py index 6a08b5ac3a4..f51a9a66b35 100644 --- a/src/zenml/zen_stores/template_utils.py +++ b/src/zenml/zen_stores/template_utils.py @@ -96,22 +96,26 @@ def generate_config_template( deployment_model = deployment.to_model(include_metadata=True) steps_configs = { - name: step.config.model_dump( + name: step.step_config_overrides.model_dump( include=set(StepConfigurationUpdate.model_fields), exclude={"name", "outputs"}, + exclude_none=True, + exclude_defaults=True, ) for name, step in deployment_model.step_configurations.items() } for config in steps_configs.values(): - config["settings"].pop("docker", None) + config.get("settings", {}).pop("docker", None) pipeline_config = deployment_model.pipeline_configuration.model_dump( include=set(PipelineRunConfiguration.model_fields), exclude={"schedule", "build", "parameters"}, + exclude_none=True, + exclude_defaults=True, ) - pipeline_config["settings"].pop("docker", None) + pipeline_config.get("settings", {}).pop("docker", None) config_template = { "run_name": deployment_model.run_name_template, diff --git a/tests/integration/functional/zen_server/template_execution/test_template_execution_utils.py b/tests/integration/functional/zen_server/template_execution/test_template_execution_utils.py index 2828c1d391d..4e5d78ae27c 100644 --- a/tests/integration/functional/zen_server/template_execution/test_template_execution_utils.py +++ b/tests/integration/functional/zen_server/template_execution/test_template_execution_utils.py @@ -69,5 +69,4 @@ def test_creating_deployment_request_from_template( deployment_request_from_template( template=template_response, config=PipelineRunConfiguration(), - user_id=deployment.user.id, ) diff --git a/tests/unit/models/test_pipeline_deployment_models.py b/tests/unit/models/test_pipeline_deployment_models.py index 8e1af8e5d9a..71a71ff34d9 100644 --- a/tests/unit/models/test_pipeline_deployment_models.py +++ b/tests/unit/models/test_pipeline_deployment_models.py @@ -30,6 +30,6 @@ def test_pipeline_deployment_base_model_fails_with_long_name(): name="", run_name_template="", pipeline_configuration=PipelineConfiguration(name="aria_best_cat"), - step_configurations={"some_key": long_text}, + step_configurations={"some_key": {"config": {"name": long_text}}}, client_environment={}, ) diff --git a/tests/unit/orchestrators/test_cache_utils.py b/tests/unit/orchestrators/test_cache_utils.py index f6123c838a0..380381e8154 100644 --- a/tests/unit/orchestrators/test_cache_utils.py +++ b/tests/unit/orchestrators/test_cache_utils.py @@ -21,6 +21,7 @@ from typing_extensions import Annotated from zenml.config.compiler import Compiler +from zenml.config.pipeline_configurations import PipelineConfiguration from zenml.config.source import Source from zenml.config.step_configurations import Step from zenml.enums import ExecutionStatus, SorterOps @@ -52,12 +53,9 @@ def _compile_step(step: BaseStep) -> Step: compiler = Compiler() return compiler._compile_step_invocation( invocation=invocation, - pipeline_settings={}, - pipeline_extra={}, stack=Client().active_stack, step_config=None, - pipeline_failure_hook_source=None, - pipeline_success_hook_source=None, + pipeline_configuration=PipelineConfiguration(name=""), )