Skip to content

Commit c961e47

Browse files
committed
WIP
1 parent bdaf385 commit c961e47

File tree

9 files changed

+125
-68
lines changed

9 files changed

+125
-68
lines changed

src/zenml/config/compiler.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,17 @@ def compile(
149149
deployment = PipelineDeploymentBase(
150150
run_name_template=run_name,
151151
pipeline_configuration=pipeline.configuration,
152-
step_configurations=steps,
152+
step_configurations={
153+
k: v.model_copy(
154+
update={
155+
"config": v.config.apply_pipeline_configuration(
156+
pipeline.configuration
157+
)
158+
}
159+
)
160+
for k, v in steps.items()
161+
},
162+
raw_step_configurations=steps,
153163
client_environment=get_run_environment_dict(),
154164
client_version=client_version,
155165
server_version=server_version,
@@ -469,23 +479,23 @@ def _compile_step_invocation(
469479
configuration_level=ConfigurationLevel.STEP,
470480
stack=stack,
471481
)
472-
step_extra = step.configuration.extra
473-
step_on_failure_hook_source = step.configuration.failure_hook_source
474-
step_on_success_hook_source = step.configuration.success_hook_source
475-
476-
step.configure(
477-
settings=pipeline_settings,
478-
extra=pipeline_extra,
479-
on_failure=pipeline_failure_hook_source,
480-
on_success=pipeline_success_hook_source,
481-
merge=False,
482-
)
482+
# step_extra = step.configuration.extra
483+
# step_on_failure_hook_source = step.configuration.failure_hook_source
484+
# step_on_success_hook_source = step.configuration.success_hook_source
485+
486+
# step.configure(
487+
# settings=pipeline_settings,
488+
# extra=pipeline_extra,
489+
# on_failure=pipeline_failure_hook_source,
490+
# on_success=pipeline_success_hook_source,
491+
# merge=False,
492+
# )
483493
step.configure(
484494
settings=step_settings,
485-
extra=step_extra,
486-
on_failure=step_on_failure_hook_source,
487-
on_success=step_on_success_hook_source,
488-
merge=True,
495+
# extra=step_extra,
496+
# on_failure=step_on_failure_hook_source,
497+
# on_success=step_on_success_hook_source,
498+
merge=False,
489499
)
490500

491501
parameters_to_ignore = (

src/zenml/config/step_configurations.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from zenml.model.lazy_load import ModelVersionDataLazyLoader
4747
from zenml.model.model import Model
4848
from zenml.utils import deprecation_utils
49-
from zenml.utils.pydantic_utils import before_validator_handler
49+
from zenml.utils.pydantic_utils import before_validator_handler, update_model
5050

5151
if TYPE_CHECKING:
5252
from zenml.config import DockerSettings, ResourceSettings
@@ -259,6 +259,39 @@ def _get_full_substitutions(
259259
ret.update(self.substitutions)
260260
return ret
261261

262+
def apply_pipeline_configuration(
263+
self, pipeline_configuration: "PipelineConfiguration"
264+
) -> "StepConfiguration":
265+
"""Apply the pipeline configuration to this step configuration.
266+
267+
Args:
268+
pipeline_configuration: The pipeline configuration to apply.
269+
"""
270+
pipeline_values = pipeline_configuration.model_dump(
271+
include={
272+
"settings",
273+
"extra",
274+
"failure_hook_source",
275+
"success_hook_source",
276+
},
277+
exclude_none=True,
278+
)
279+
if pipeline_values:
280+
original_values = self.model_dump(
281+
include={
282+
"settings",
283+
"extra",
284+
"failure_hook_source",
285+
"success_hook_source",
286+
},
287+
exclude_none=True,
288+
)
289+
290+
updated_config = self.model_copy(update=pipeline_values, deep=True)
291+
return update_model(updated_config, original_values)
292+
else:
293+
return self.model_copy(deep=True)
294+
262295

263296
class InputSpec(StrictBaseModel):
264297
"""Step input specification."""

src/zenml/models/v2/core/pipeline_deployment.py

+15
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class PipelineDeploymentBase(BaseZenModel):
6060
step_configurations: Dict[str, Step] = Field(
6161
default={}, title="The step configurations for this deployment."
6262
)
63+
raw_step_configurations: Dict[str, Step] = Field(
64+
default={}, title="The raw step configurations for this deployment."
65+
)
6366
client_environment: Dict[str, str] = Field(
6467
default={}, title="The client environment for this deployment."
6568
)
@@ -143,6 +146,9 @@ class PipelineDeploymentResponseMetadata(ProjectScopedResponseMetadata):
143146
step_configurations: Dict[str, Step] = Field(
144147
default={}, title="The step configurations for this deployment."
145148
)
149+
raw_step_configurations: Dict[str, Step] = Field(
150+
default={}, title="The raw step configurations for this deployment."
151+
)
146152
client_environment: Dict[str, str] = Field(
147153
default={}, title="The client environment for this deployment."
148154
)
@@ -241,6 +247,15 @@ def step_configurations(self) -> Dict[str, Step]:
241247
"""
242248
return self.get_metadata().step_configurations
243249

250+
@property
251+
def raw_step_configurations(self) -> Dict[str, Step]:
252+
"""The `raw_step_configurations` property.
253+
254+
Returns:
255+
the value of the property.
256+
"""
257+
return self.get_metadata().raw_step_configurations
258+
244259
@property
245260
def client_environment(self) -> Dict[str, str]:
246261
"""The `client_environment` property.

src/zenml/pipelines/run_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def get_placeholder_run(
113113
size=1,
114114
deployment_id=deployment_id,
115115
status=ExecutionStatus.INITIALIZING,
116+
hydrate=True,
116117
)
117118
if len(runs.items) == 0:
118119
return None

src/zenml/zen_server/template_execution/utils.py

+27-46
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Utility functions to run a pipeline from the server."""
22

3-
import copy
43
import hashlib
54
import sys
65
from typing import Any, Dict, List, Optional
@@ -16,7 +15,7 @@
1615
from zenml.config.pipeline_run_configuration import (
1716
PipelineRunConfiguration,
1817
)
19-
from zenml.config.step_configurations import Step, StepConfiguration
18+
from zenml.config.step_configurations import Step, StepConfigurationUpdate
2019
from zenml.constants import (
2120
ENV_ZENML_ACTIVE_PROJECT_ID,
2221
ENV_ZENML_ACTIVE_STACK_ID,
@@ -43,7 +42,7 @@
4342
validate_stack_is_runnable_from_server,
4443
)
4544
from zenml.stack.flavor import Flavor
46-
from zenml.utils import dict_utils, requirements_utils, settings_utils
45+
from zenml.utils import pydantic_utils, requirements_utils, settings_utils
4746
from zenml.zen_server.auth import AuthContext, generate_access_token
4847
from zenml.zen_server.template_execution.runner_entrypoint_configuration import (
4948
RunnerEntrypointConfiguration,
@@ -106,7 +105,6 @@ def run_template(
106105
deployment_request = deployment_request_from_template(
107106
template=template,
108107
config=run_config or PipelineRunConfiguration(),
109-
user_id=auth_context.user.id,
110108
)
111109

112110
ensure_async_orchestrator(deployment=deployment_request, stack=stack)
@@ -345,14 +343,12 @@ def generate_dockerfile(
345343
def deployment_request_from_template(
346344
template: RunTemplateResponse,
347345
config: PipelineRunConfiguration,
348-
user_id: UUID,
349346
) -> "PipelineDeploymentRequest":
350347
"""Generate a deployment request from a template.
351348
352349
Args:
353350
template: The template from which to create the deployment request.
354351
config: The run configuration.
355-
user_id: ID of the user that is trying to run the template.
356352
357353
Raises:
358354
ValueError: If there are missing/extra step parameters in the run
@@ -363,49 +359,34 @@ def deployment_request_from_template(
363359
"""
364360
deployment = template.source_deployment
365361
assert deployment
366-
pipeline_configuration = PipelineConfiguration(
367-
**config.model_dump(
368-
include=set(PipelineConfiguration.model_fields),
369-
exclude={"name", "parameters"},
370-
),
371-
name=deployment.pipeline_configuration.name,
372-
parameters=deployment.pipeline_configuration.parameters,
373-
)
374362

375-
step_config_dict_base = pipeline_configuration.model_dump(
376-
exclude={"name", "parameters", "tags", "enable_pipeline_logs"}
363+
pipeline_update = config.model_dump(
364+
include=set(PipelineConfiguration.model_fields),
365+
exclude={"name", "parameters"},
366+
# TODO: Make sure all unset values are actually passed as unset
367+
exclude_unset=True,
368+
)
369+
pipeline_configuration = pydantic_utils.update_model(
370+
deployment.pipeline_configuration, pipeline_update
377371
)
378-
steps = {}
379-
for invocation_id, step in deployment.step_configurations.items():
380-
step_config_dict = {
381-
**copy.deepcopy(step_config_dict_base),
382-
**step.config.model_dump(
383-
# TODO: Maybe we need to make some of these configurable via
384-
# yaml as well, e.g. the lazy loaders?
385-
include={
386-
"name",
387-
"caching_parameters",
388-
"external_input_artifacts",
389-
"model_artifacts_or_metadata",
390-
"client_lazy_loaders",
391-
"substitutions",
392-
"outputs",
393-
}
394-
),
395-
}
396-
397-
required_parameters = set(step.config.parameters)
398-
configured_parameters = set()
399372

400-
if update := config.steps.get(invocation_id):
401-
update_dict = update.model_dump()
373+
steps = {}
374+
for invocation_id, step in deployment.raw_step_configurations.items():
375+
step_update = config.steps.get(
376+
invocation_id, StepConfigurationUpdate()
377+
).model_dump(
402378
# Get rid of deprecated name to prevent overriding the step name
403379
# with `None`.
404-
update_dict.pop("name", None)
405-
configured_parameters = set(update.parameters)
406-
step_config_dict = dict_utils.recursive_update(
407-
step_config_dict, update=update_dict
408-
)
380+
exclude={"name"},
381+
exclude_unset=True,
382+
)
383+
step_config = pydantic_utils.update_model(step.config, step_update)
384+
# step_config = step_config.apply_pipeline_configuration(
385+
# pipeline_configuration
386+
# )
387+
388+
required_parameters = set(step.config.parameters)
389+
configured_parameters = set(step_config.parameters)
409390

410391
unknown_parameters = configured_parameters - required_parameters
411392
if unknown_parameters:
@@ -421,7 +402,6 @@ def deployment_request_from_template(
421402
f"parameters for step {invocation_id}: {missing_parameters}."
422403
)
423404

424-
step_config = StepConfiguration.model_validate(step_config_dict)
425405
steps[invocation_id] = Step(spec=step.spec, config=step_config)
426406

427407
code_reference_request = None
@@ -440,7 +420,8 @@ def deployment_request_from_template(
440420
run_name_template=config.run_name
441421
or get_default_run_name(pipeline_name=pipeline_configuration.name),
442422
pipeline_configuration=pipeline_configuration,
443-
step_configurations=steps,
423+
# step_configurations=steps,
424+
raw_step_configurations=steps,
444425
client_environment={},
445426
client_version=zenml_version,
446427
server_version=zenml_version,

src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def from_request(
197197
run_name_template=request.run_name_template,
198198
pipeline_configuration=request.pipeline_configuration.model_dump_json(),
199199
step_configurations=json.dumps(
200-
request.step_configurations,
200+
request.raw_step_configurations,
201201
sort_keys=False,
202202
default=pydantic_encoder,
203203
),
@@ -248,7 +248,17 @@ def to_model(
248248
project=self.project.to_model(),
249249
run_name_template=self.run_name_template,
250250
pipeline_configuration=pipeline_configuration,
251-
step_configurations=step_configurations,
251+
raw_step_configurations=step_configurations,
252+
step_configurations={
253+
k: v.model_copy(
254+
update={
255+
"config": v.config.apply_pipeline_configuration(
256+
pipeline_configuration
257+
)
258+
}
259+
)
260+
for k, v in step_configurations.items()
261+
},
252262
client_environment=json.loads(self.client_environment),
253263
client_version=self.client_version,
254264
server_version=self.server_version,

src/zenml/zen_stores/schemas/step_run_schemas.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ def to_model(
254254

255255
full_step_config = None
256256
if self.deployment is not None:
257+
pipeline_configuration = PipelineConfiguration.model_validate_json(
258+
self.deployment.pipeline_configuration
259+
)
257260
step_configuration = json.loads(
258261
self.deployment.step_configurations
259262
)
@@ -269,9 +272,14 @@ def to_model(
269272
self.pipeline_run.start_time,
270273
)
271274
)
275+
merged_config = (
276+
full_step_config.config.apply_pipeline_configuration(
277+
pipeline_configuration
278+
)
279+
)
272280
full_step_config = full_step_config.model_copy(
273281
update={
274-
"config": full_step_config.config.model_copy(
282+
"config": merged_config.model_copy(
275283
update={"substitutions": new_substitutions}
276284
)
277285
}

src/zenml/zen_stores/template_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def generate_config_template(
101101
exclude={"name", "outputs"},
102102
exclude_none=True,
103103
)
104-
for name, step in deployment_model.step_configurations.items()
104+
for name, step in deployment_model.raw_step_configurations.items()
105105
}
106106

107107
for config in steps_configs.values():

tests/integration/functional/zen_server/template_execution/test_template_execution_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,4 @@ def test_creating_deployment_request_from_template(
6969
deployment_request_from_template(
7070
template=template_response,
7171
config=PipelineRunConfiguration(),
72-
user_id=deployment.user.id,
7372
)

0 commit comments

Comments
 (0)