Skip to content

Commit b0a662e

Browse files
committed
Some cleanup
1 parent c961e47 commit b0a662e

File tree

7 files changed

+51
-75
lines changed

7 files changed

+51
-75
lines changed

src/zenml/config/compiler.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from zenml import __version__
2929
from zenml.config.base_settings import BaseSettings, ConfigurationLevel
30+
from zenml.config.pipeline_configurations import PipelineConfiguration
3031
from zenml.config.pipeline_run_configuration import PipelineRunConfiguration
3132
from zenml.config.pipeline_spec import PipelineSpec
3233
from zenml.config.settings_resolver import SettingsResolver
@@ -43,7 +44,6 @@
4344
from zenml.utils import pydantic_utils, settings_utils
4445

4546
if TYPE_CHECKING:
46-
from zenml.config.source import Source
4747
from zenml.pipelines.pipeline_definition import Pipeline
4848
from zenml.stack import Stack, StackComponent
4949
from zenml.steps.step_invocation import StepInvocation
@@ -121,12 +121,9 @@ def compile(
121121
steps = {
122122
invocation_id: self._compile_step_invocation(
123123
invocation=invocation,
124-
pipeline_settings=settings_to_passdown,
125-
pipeline_extra=pipeline.configuration.extra,
126124
stack=stack,
127125
step_config=run_configuration.steps.get(invocation_id),
128-
pipeline_failure_hook_source=pipeline.configuration.failure_hook_source,
129-
pipeline_success_hook_source=pipeline.configuration.success_hook_source,
126+
pipeline_configuration=pipeline.configuration,
130127
)
131128
for invocation_id, invocation in self._get_sorted_invocations(
132129
pipeline=pipeline
@@ -438,24 +435,17 @@ def _get_step_spec(
438435
def _compile_step_invocation(
439436
self,
440437
invocation: "StepInvocation",
441-
pipeline_settings: Dict[str, "BaseSettings"],
442-
pipeline_extra: Dict[str, Any],
443438
stack: "Stack",
444439
step_config: Optional["StepConfigurationUpdate"],
445-
pipeline_failure_hook_source: Optional["Source"] = None,
446-
pipeline_success_hook_source: Optional["Source"] = None,
440+
pipeline_configuration: "PipelineConfiguration",
447441
) -> Step:
448442
"""Compiles a ZenML step.
449443
450444
Args:
451445
invocation: The step invocation to compile.
452-
pipeline_settings: settings configured on the
453-
pipeline of the step.
454-
pipeline_extra: Extra values configured on the pipeline of the step.
455446
stack: The stack on which the pipeline will be run.
456447
step_config: Run configuration for the step.
457-
pipeline_failure_hook_source: Source for the failure hook.
458-
pipeline_success_hook_source: Source for the success hook.
448+
pipeline_configuration: Configuration for the pipeline.
459449
460450
Returns:
461451
The compiled step.
@@ -479,32 +469,27 @@ def _compile_step_invocation(
479469
configuration_level=ConfigurationLevel.STEP,
480470
stack=stack,
481471
)
482-
# step_extra = step.configuration.extra
483-
# step_on_failure_hook_source = step.configuration.failure_hook_source
484-
# step_on_success_hook_source = step.configuration.success_hook_source
485-
486-
# step.configure(
487-
# settings=pipeline_settings,
488-
# extra=pipeline_extra,
489-
# on_failure=pipeline_failure_hook_source,
490-
# on_success=pipeline_success_hook_source,
491-
# merge=False,
492-
# )
493472
step.configure(
494473
settings=step_settings,
495-
# extra=step_extra,
496-
# on_failure=step_on_failure_hook_source,
497-
# on_success=step_on_success_hook_source,
498474
merge=False,
499475
)
500476

501477
parameters_to_ignore = (
502478
set(step_config.parameters) if step_config else set()
503479
)
504-
complete_step_configuration = invocation.finalize(
480+
step_configuration_overrides = invocation.finalize(
505481
parameters_to_ignore=parameters_to_ignore
506482
)
507-
return Step(spec=step_spec, config=complete_step_configuration)
483+
full_step_config = (
484+
step_configuration_overrides.apply_pipeline_configuration(
485+
pipeline_configuration=pipeline_configuration
486+
)
487+
)
488+
return Step(
489+
spec=step_spec,
490+
config=full_step_config,
491+
step_config_overrides=step_configuration_overrides,
492+
)
508493

509494
def _get_sorted_invocations(
510495
self,

src/zenml/config/step_configurations.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,21 @@ class Step(StrictBaseModel):
344344

345345
spec: StepSpec
346346
config: StepConfiguration
347+
step_config_overrides: StepConfiguration
348+
349+
@model_validator(mode="before")
350+
@classmethod
351+
@before_validator_handler
352+
def _autocomplete_step_config_overrides(cls, data: Any) -> Any:
353+
"""Autocompletes the step config overrides.
354+
355+
Args:
356+
data: The values dict used to instantiate the model.
357+
358+
Returns:
359+
The values dict with the step config overrides autocompleted.
360+
"""
361+
if "step_config_overrides" not in data:
362+
data["step_config_overrides"] = data["config"]
363+
364+
return data

src/zenml/models/v2/core/pipeline_deployment.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ class PipelineDeploymentBase(BaseZenModel):
6060
step_configurations: Dict[str, Step] = Field(
6161
default={}, title="The step configurations for this deployment."
6262
)
63-
raw_step_configurations: Dict[str, Step] = Field(
64-
default={}, title="The raw step configurations for this deployment."
65-
)
6663
client_environment: Dict[str, str] = Field(
6764
default={}, title="The client environment for this deployment."
6865
)
@@ -146,9 +143,6 @@ class PipelineDeploymentResponseMetadata(ProjectScopedResponseMetadata):
146143
step_configurations: Dict[str, Step] = Field(
147144
default={}, title="The step configurations for this deployment."
148145
)
149-
raw_step_configurations: Dict[str, Step] = Field(
150-
default={}, title="The raw step configurations for this deployment."
151-
)
152146
client_environment: Dict[str, str] = Field(
153147
default={}, title="The client environment for this deployment."
154148
)
@@ -247,15 +241,6 @@ def step_configurations(self) -> Dict[str, Step]:
247241
"""
248242
return self.get_metadata().step_configurations
249243

250-
@property
251-
def raw_step_configurations(self) -> Dict[str, Step]:
252-
"""The `raw_step_configurations` property.
253-
254-
Returns:
255-
the value of the property.
256-
"""
257-
return self.get_metadata().raw_step_configurations
258-
259244
@property
260245
def client_environment(self) -> Dict[str, str]:
261246
"""The `client_environment` property.

src/zenml/zen_server/template_execution/utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def deployment_request_from_template(
371371
)
372372

373373
steps = {}
374-
for invocation_id, step in deployment.raw_step_configurations.items():
374+
for invocation_id, step in deployment.step_configurations.items():
375375
step_update = config.steps.get(
376376
invocation_id, StepConfigurationUpdate()
377377
).model_dump(
@@ -380,10 +380,12 @@ def deployment_request_from_template(
380380
exclude={"name"},
381381
exclude_unset=True,
382382
)
383-
step_config = pydantic_utils.update_model(step.config, step_update)
384-
# step_config = step_config.apply_pipeline_configuration(
385-
# pipeline_configuration
386-
# )
383+
step_config = pydantic_utils.update_model(
384+
step.step_config_overrides, step_update
385+
)
386+
merged_step_config = step_config.apply_pipeline_configuration(
387+
pipeline_configuration
388+
)
387389

388390
required_parameters = set(step.config.parameters)
389391
configured_parameters = set(step_config.parameters)
@@ -402,7 +404,11 @@ def deployment_request_from_template(
402404
f"parameters for step {invocation_id}: {missing_parameters}."
403405
)
404406

405-
steps[invocation_id] = Step(spec=step.spec, config=step_config)
407+
steps[invocation_id] = Step(
408+
spec=step.spec,
409+
config=merged_step_config,
410+
step_config_overrides=step_config,
411+
)
406412

407413
code_reference_request = None
408414
if deployment.code_reference:
@@ -420,8 +426,7 @@ def deployment_request_from_template(
420426
run_name_template=config.run_name
421427
or get_default_run_name(pipeline_name=pipeline_configuration.name),
422428
pipeline_configuration=pipeline_configuration,
423-
# step_configurations=steps,
424-
raw_step_configurations=steps,
429+
step_configurations=steps,
425430
client_environment={},
426431
client_version=zenml_version,
427432
server_version=zenml_version,

src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def from_request(
197197
run_name_template=request.run_name_template,
198198
pipeline_configuration=request.pipeline_configuration.model_dump_json(),
199199
step_configurations=json.dumps(
200-
request.raw_step_configurations,
200+
request.step_configurations,
201201
sort_keys=False,
202202
default=pydantic_encoder,
203203
),
@@ -249,16 +249,7 @@ def to_model(
249249
run_name_template=self.run_name_template,
250250
pipeline_configuration=pipeline_configuration,
251251
raw_step_configurations=step_configurations,
252-
step_configurations={
253-
k: v.model_copy(
254-
update={
255-
"config": v.config.apply_pipeline_configuration(
256-
pipeline_configuration
257-
)
258-
}
259-
)
260-
for k, v in step_configurations.items()
261-
},
252+
step_configurations=step_configurations,
262253
client_environment=json.loads(self.client_environment),
263254
client_version=self.client_version,
264255
server_version=self.server_version,

src/zenml/zen_stores/schemas/step_run_schemas.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,6 @@ def to_model(
254254

255255
full_step_config = None
256256
if self.deployment is not None:
257-
pipeline_configuration = PipelineConfiguration.model_validate_json(
258-
self.deployment.pipeline_configuration
259-
)
260257
step_configuration = json.loads(
261258
self.deployment.step_configurations
262259
)
@@ -272,14 +269,9 @@ def to_model(
272269
self.pipeline_run.start_time,
273270
)
274271
)
275-
merged_config = (
276-
full_step_config.config.apply_pipeline_configuration(
277-
pipeline_configuration
278-
)
279-
)
280272
full_step_config = full_step_config.model_copy(
281273
update={
282-
"config": merged_config.model_copy(
274+
"config": full_step_config.config.model_copy(
283275
update={"substitutions": new_substitutions}
284276
)
285277
}

src/zenml/zen_stores/template_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ def generate_config_template(
9696
deployment_model = deployment.to_model(include_metadata=True)
9797

9898
steps_configs = {
99-
name: step.config.model_dump(
99+
name: step.step_config_overrides.model_dump(
100100
include=set(StepConfigurationUpdate.model_fields),
101101
exclude={"name", "outputs"},
102102
exclude_none=True,
103103
)
104-
for name, step in deployment_model.raw_step_configurations.items()
104+
for name, step in deployment_model.step_configurations.items()
105105
}
106106

107107
for config in steps_configs.values():

0 commit comments

Comments
 (0)