Skip to content

Commit f46439d

Browse files
authored
feature/ODSC-41635/Support Data Flow Pools (#212)
1 parent 2dbc2a0 commit f46439d

File tree

4 files changed

+74
-36
lines changed

4 files changed

+74
-36
lines changed

ads/jobs/builders/infrastructure/dataflow.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ class DataFlow(Infrastructure):
391391
CONST_OCPUS = "ocpus"
392392
CONST_ID = "id"
393393
CONST_PRIVATE_ENDPOINT_ID = "private_endpoint_id"
394+
CONST_POOL_ID = "pool_id"
394395
CONST_FREEFORM_TAGS = "freeform_tags"
395396
CONST_DEFINED_TAGS = "defined_tags"
396397

@@ -411,8 +412,9 @@ class DataFlow(Infrastructure):
411412
CONST_OCPUS: CONST_OCPUS,
412413
CONST_ID: CONST_ID,
413414
CONST_PRIVATE_ENDPOINT_ID: "privateEndpointId",
415+
CONST_POOL_ID: "poolId",
414416
CONST_FREEFORM_TAGS: "freeformTags",
415-
CONST_DEFINED_TAGS: "definedTags"
417+
CONST_DEFINED_TAGS: "definedTags",
416418
}
417419

418420
def __init__(self, spec: dict = None, **kwargs):
@@ -425,8 +427,10 @@ def __init__(self, spec: dict = None, **kwargs):
425427
spec = {
426428
k: v
427429
for k, v in spec.items()
428-
if (f"with_{camel_to_snake(k)}" in self.__dir__()
429-
or (k == "defined_tags" or "freeform_tags"))
430+
if (
431+
f"with_{camel_to_snake(k)}" in self.__dir__()
432+
or (k == "defined_tags" or "freeform_tags")
433+
)
430434
and v is not None
431435
}
432436
defaults.update(spec)
@@ -809,10 +813,34 @@ def with_defined_tag(self, **kwargs) -> "DataFlow":
809813
"""
810814
return self.set_spec(self.CONST_DEFINED_TAGS, kwargs)
811815

816+
def with_pool_id(self, pool_id: str) -> "DataFlow":
817+
"""
818+
Set the Data Flow Pool Id for a Data Flow job.
819+
820+
Parameters
821+
----------
822+
pool_id: str
823+
The OCID of a Data Flow Pool.
824+
825+
Returns
826+
-------
827+
DataFlow
828+
the Data Flow instance itself
829+
"""
830+
if not hasattr(CreateApplicationDetails, "pool_id"):
831+
raise EnvironmentError(
832+
"Data Flow Pool has not been supported in the current OCI SDK installed."
833+
)
834+
return self.set_spec(self.CONST_POOL_ID, pool_id)
835+
812836
def __getattr__(self, item):
813837
if item == self.CONST_DEFINED_TAGS or item == self.CONST_FREEFORM_TAGS:
814838
return self.get_spec(item)
815-
elif f"with_{item}" in self.__dir__() and item != "defined_tag" and item != "freeform_tag":
839+
elif (
840+
f"with_{item}" in self.__dir__()
841+
and item != "defined_tag"
842+
and item != "freeform_tag"
843+
):
816844
return self.get_spec(item)
817845
raise AttributeError(f"Attribute {item} not found.")
818846

@@ -832,6 +860,11 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
832860
DataFlow
833861
a Data Flow job instance
834862
"""
863+
if self.pool_id:
864+
if not hasattr(CreateApplicationDetails, "pool_id"):
865+
raise EnvironmentError(
866+
"Data Flow Pool has not been supported in the current OCI SDK installed."
867+
)
835868
# Set default display_name if not specified - randomly generated easy to remember name
836869
if not self.name:
837870
self.name = utils.get_random_name_for_resource()

docs/source/user_guide/apachespark/dataflow-spark-magic.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Data Flow Sessions are accessible through the following conda environment:
3232

3333
* PySpark 3.2 and Data Flow 2.0 (pyspark32_p38_cpu_v2)
3434

35-
You can customize **pypspark32_p38_cpu_v1**, publish it, and use it as a runtime environment for a Data Flow Session.
35+
You can customize **pypspark32_p38_cpu_v2**, publish it, and use it as a runtime environment for a Data Flow Session.
3636

3737
Policies
3838
********

docs/source/user_guide/apachespark/dataflow.rst

+11-11
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ You could submit a notebook using ADS SDK APIs. Here is an example to submit a n
159159
"ocid1.compartment.oc1.<your compartment id>"
160160
)
161161
.with_driver_shape("VM.Standard.E4.Flex")
162-
.with_driver_shape_config(ocpus=2, memory_in_gbs=32)
163-
.with_executor_shape("VM.Standard.E4.Flex")
164-
.with_executor_shape_config(ocpus=4, memory_in_gbs=64)
162+
.with_driver_shape_config(ocpus=2, memory_in_gbs=32)
163+
.with_executor_shape("VM.Standard.E4.Flex")
164+
.with_executor_shape_config(ocpus=4, memory_in_gbs=64)
165165
.with_logs_bucket_uri("oci://mybucket@mytenancy/")
166166
.with_private_endpoint_id("ocid1.dataflowprivateendpoint.oc1.iad.<your private endpoint ocid>")
167167
.with_configuration({
@@ -231,8 +231,8 @@ create applications.
231231

232232
In the following "hello-world" example, ``DataFlow`` is populated with ``compartment_id``,
233233
``driver_shape``, ``driver_shape_config``, ``executor_shape``, ``executor_shape_config``
234-
, ``spark_version``, ``defined_tags`` and ``freeform_tags``. ``DataFlowRuntime`` is
235-
populated with ``script_uri`` and ``script_bucket``. The ``script_uri`` specifies the
234+
, ``spark_version``, ``defined_tags`` and ``freeform_tags``. ``DataFlowRuntime`` is
235+
populated with ``script_uri`` and ``script_bucket``. The ``script_uri`` specifies the
236236
path to the script. It can be local or remote (an Object Storage path). If the path
237237
is local, then ``script_bucket`` must be specified additionally because Data Flow
238238
requires a script to be available in Object Storage. ADS
@@ -270,9 +270,9 @@ accepted. In the next example, the prefix is given for ``script_bucket``.
270270
.with_compartment_id("oci.xx.<compartment_id>")
271271
.with_logs_bucket_uri("oci://mybucket@mynamespace/dflogs")
272272
.with_driver_shape("VM.Standard.E4.Flex")
273-
.with_driver_shape_config(ocpus=2, memory_in_gbs=32)
274-
.with_executor_shape("VM.Standard.E4.Flex")
275-
.with_executor_shape_config(ocpus=4, memory_in_gbs=64)
273+
.with_driver_shape_config(ocpus=2, memory_in_gbs=32)
274+
.with_executor_shape("VM.Standard.E4.Flex")
275+
.with_executor_shape_config(ocpus=4, memory_in_gbs=64)
276276
.with_spark_version("3.0.2")
277277
.with_defined_tag(
278278
**{"Oracle-Tags": {"CreatedBy": "[email protected]"}}
@@ -391,9 +391,9 @@ In the next example, ``archive_uri`` is given as an Object Storage location.
391391
.with_compartment_id("oci1.xxx.<compartment_ocid>")
392392
.with_logs_bucket_uri("oci://mybucket@mynamespace/prefix")
393393
.with_driver_shape("VM.Standard.E4.Flex")
394-
.with_driver_shape_config(ocpus=2, memory_in_gbs=32)
395-
.with_executor_shape("VM.Standard.E4.Flex")
396-
.with_executor_shape_config(ocpus=4, memory_in_gbs=64)
394+
.with_driver_shape_config(ocpus=2, memory_in_gbs=32)
395+
.with_executor_shape("VM.Standard.E4.Flex")
396+
.with_executor_shape_config(ocpus=4, memory_in_gbs=64)
397397
.with_spark_version("3.0.2")
398398
.with_configuration({
399399
"spark.driverEnv.myEnvVariable": "value1",

tests/unitary/default_setup/jobs/test_jobs_dataflow.py

+25-20
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
DataFlowRuntime,
2929
DataFlowNotebookRuntime,
3030
)
31+
from oci.data_flow.models import CreateApplicationDetails
3132

3233
logger.setLevel(logging.DEBUG)
3334

@@ -47,7 +48,13 @@
4748
language="PYTHON",
4849
logs_bucket_uri="oci://test_bucket@test_namespace/",
4950
private_endpoint_id="test_private_endpoint",
51+
pool_id="ocid1.dataflowpool.oc1..<unique_ocid>",
5052
)
53+
EXPECTED_YAML_LENGTH = 614
54+
if not hasattr(CreateApplicationDetails, "pool_id"):
55+
SAMPLE_PAYLOAD.pop("pool_id")
56+
EXPECTED_YAML_LENGTH = 567
57+
5158
random_seed = 42
5259

5360

@@ -124,7 +131,7 @@ def test_create_delete(self, mock_to_dict, mock_client):
124131
df.lifecycle_state
125132
== oci.data_flow.models.Application.LIFECYCLE_STATE_DELETED
126133
)
127-
assert len(df.to_yaml()) == 567
134+
assert len(df.to_yaml()) == EXPECTED_YAML_LENGTH
128135

129136
def test_create_df_app_with_default_display_name(
130137
self,
@@ -319,14 +326,16 @@ def df(self):
319326
).with_num_executors(
320327
2
321328
).with_private_endpoint_id(
322-
"test_private_endpoint"
329+
SAMPLE_PAYLOAD["private_endpoint_id"]
323330
).with_freeform_tag(
324331
test_freeform_tags_key="test_freeform_tags_value",
325332
).with_defined_tag(
326333
test_defined_tags_namespace={
327334
"test_defined_tags_key": "test_defined_tags_value"
328335
}
329336
)
337+
if SAMPLE_PAYLOAD.get("pool_id", None):
338+
df.with_pool_id(SAMPLE_PAYLOAD["pool_id"])
330339
return df
331340

332341
def test_create_with_builder_pattern(self, mock_to_dict, mock_client, df):
@@ -341,6 +350,8 @@ def test_create_with_builder_pattern(self, mock_to_dict, mock_client, df):
341350
"test_defined_tags_key": "test_defined_tags_value"
342351
}
343352
}
353+
if SAMPLE_PAYLOAD.get("pool_id", None):
354+
assert df.pool_id == SAMPLE_PAYLOAD["pool_id"]
344355

345356
rt = (
346357
DataFlowRuntime()
@@ -483,50 +494,44 @@ def test_to_and_from_dict(self, df):
483494
assert df3_dict["spec"]["numExecutors"] == 2
484495

485496
def test_shape_and_details(self, mock_to_dict, mock_client, df):
486-
df.with_driver_shape(
487-
"VM.Standard2.1"
488-
).with_executor_shape(
497+
df.with_driver_shape("VM.Standard2.1").with_executor_shape(
489498
"VM.Standard.E4.Flex"
490499
)
491500

492501
rt = (
493502
DataFlowRuntime()
494-
.with_script_uri(SAMPLE_PAYLOAD["file_uri"])
495-
.with_archive_uri(SAMPLE_PAYLOAD["archive_uri"])
496-
.with_custom_conda(
497-
"oci://my_bucket@my_namespace/conda_environments/cpu/PySpark 3.0 and Data Flow/5.0/pyspark30_p37_cpu_v5"
498-
)
499-
.with_overwrite(True)
503+
.with_script_uri(SAMPLE_PAYLOAD["file_uri"])
504+
.with_archive_uri(SAMPLE_PAYLOAD["archive_uri"])
505+
.with_custom_conda(
506+
"oci://my_bucket@my_namespace/conda_environments/cpu/PySpark 3.0 and Data Flow/5.0/pyspark30_p37_cpu_v5"
507+
)
508+
.with_overwrite(True)
500509
)
501510

502511
with pytest.raises(
503512
ValueError,
504-
match="`executor_shape` and `driver_shape` must be from the same shape family."
513+
match="`executor_shape` and `driver_shape` must be from the same shape family.",
505514
):
506515
with patch.object(DataFlowApp, "client", mock_client):
507516
with patch.object(DataFlowApp, "to_dict", mock_to_dict):
508517
df.create(rt)
509518

510-
df.with_driver_shape(
511-
"VM.Standard2.1"
512-
).with_driver_shape_config(
519+
df.with_driver_shape("VM.Standard2.1").with_driver_shape_config(
513520
memory_in_gbs=SAMPLE_PAYLOAD["driver_shape_config"]["memory_in_gbs"],
514521
ocpus=SAMPLE_PAYLOAD["driver_shape_config"]["ocpus"],
515-
).with_executor_shape(
516-
"VM.Standard2.16"
517-
).with_executor_shape_config(
522+
).with_executor_shape("VM.Standard2.16").with_executor_shape_config(
518523
memory_in_gbs=SAMPLE_PAYLOAD["executor_shape_config"]["memory_in_gbs"],
519524
ocpus=SAMPLE_PAYLOAD["executor_shape_config"]["ocpus"],
520525
)
521526

522527
with pytest.raises(
523528
ValueError,
524-
match="Shape config is not required for non flex shape from user end."
529+
match="Shape config is not required for non flex shape from user end.",
525530
):
526531
with patch.object(DataFlowApp, "client", mock_client):
527532
with patch.object(DataFlowApp, "to_dict", mock_to_dict):
528533
df.create(rt)
529-
534+
530535

531536
class TestDataFlowNotebookRuntime:
532537
@pytest.mark.skipif(

0 commit comments

Comments
 (0)