Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7/X] Pipeline user facing build() function #52

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions python/ray/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def remote(self, *args, **kwargs):
**cls_options,
)

def _bind(self, *args, **kwargs):
def bind(self, *args, **kwargs):
"""
**Experimental**

Expand Down Expand Up @@ -1041,7 +1041,7 @@ def _remote(

return actor_handle

def _bind(self, *args, **kwargs):
def bind(self, *args, **kwargs):
"""
**Experimental**

Expand Down
15 changes: 13 additions & 2 deletions python/ray/experimental/dag/class_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,18 @@ def get_import_path(self) -> str:

def to_json(self, encoder_cls) -> Dict[str, Any]:
json_dict = super().to_json_base(encoder_cls, ClassNode.__name__)
json_dict["import_path"] = self.get_import_path()
import_path = self.get_import_path()
error_message = (
"Class used in DAG should not be in-line defined when exporting"
"import path for deployment. Please ensure it has fully "
"qualified name with valid __module__ and __qualname__ for "
"import path, with no __main__ or <locals>. \n"
f"Current import path: {import_path}"
)
assert "__main__" not in import_path, error_message
assert "<locals>" not in import_path, error_message

json_dict["import_path"] = import_path
return json_dict

@classmethod
Expand All @@ -117,7 +128,7 @@ def __init__(self, actor: ClassNode, method_name: str):
self._method_name = method_name
self._options = {}

def _bind(self, *args, **kwargs):
def bind(self, *args, **kwargs):
other_args_to_resolve = {
PARENT_CLASS_NODE_KEY: self._actor,
PREV_CLASS_METHOD_CALL_KEY: self._actor._last_call,
Expand Down
10 changes: 5 additions & 5 deletions python/ray/experimental/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_stable_uuid(self) -> str:

def execute(self, *args, **kwargs) -> Union[ray.ObjectRef, ray.actor.ActorHandle]:
"""Execute this DAG using the Ray default executor."""
return self._apply_recursive(lambda node: node._execute_impl(*args, **kwargs))
return self.apply_recursive(lambda node: node._execute_impl(*args, **kwargs))

def _get_toplevel_child_nodes(self) -> Set["DAGNode"]:
"""Return the set of nodes specified as top-level args.
Expand Down Expand Up @@ -135,7 +135,7 @@ def _apply_and_replace_all_child_nodes(
"""Apply and replace all immediate child nodes using a given function.

This is a shallow replacement only. To recursively transform nodes in
the DAG, use ``_apply_recursive()``.
the DAG, use ``apply_recursive()``.

Args:
fn: Callable that will be applied once to each child of this node.
Expand Down Expand Up @@ -168,7 +168,7 @@ def _apply_and_replace_all_child_nodes(
new_args, new_kwargs, self.get_options(), new_other_args_to_resolve
)

def _apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T:
def apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T:
"""Apply callable on each node in this DAG in a bottom-up tree walk.

Args:
Expand Down Expand Up @@ -203,11 +203,11 @@ def __call__(self, node):

return fn(
self._apply_and_replace_all_child_nodes(
lambda node: node._apply_recursive(fn)
lambda node: node.apply_recursive(fn)
)
)

def _apply_functional(
def apply_functional(
self,
source_input_list: Any,
predictate_fn: Callable,
Expand Down
12 changes: 11 additions & 1 deletion python/ray/experimental/dag/function_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,17 @@ def get_import_path(self):

def to_json(self, encoder_cls) -> Dict[str, Any]:
json_dict = super().to_json_base(encoder_cls, FunctionNode.__name__)
json_dict["import_path"] = self.get_import_path()
import_path = self.get_import_path()
error_message = (
"Function used in DAG should not be in-line defined when exporting"
"import path for deployment. Please ensure it has fully "
"qualified name with valid __module__ and __qualname__ for "
"import path, with no __main__ or <locals>. \n"
f"Current import path: {import_path}"
)
assert "__main__" not in import_path, error_message
assert "<locals>" not in import_path, error_message
json_dict["import_path"] = import_path
return json_dict

@classmethod
Expand Down
16 changes: 16 additions & 0 deletions python/ray/experimental/dag/input_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,22 @@ def _execute_impl(self, *args, **kwargs):
def __str__(self) -> str:
return get_dag_node_str(self, f'["{self._key}"]')

def to_json(self, encoder_cls) -> Dict[str, Any]:
json_dict = super().to_json_base(encoder_cls, InputAtrributeNode.__name__)
return json_dict

@classmethod
def from_json(cls, input_json, object_hook=None):
assert input_json[DAGNODE_TYPE_KEY] == InputAtrributeNode.__name__
args_dict = super().from_json_base(input_json, object_hook=object_hook)
node = cls(
args_dict["other_args_to_resolve"]["dag_input_node"],
args_dict["other_args_to_resolve"]["key"],
args_dict["other_args_to_resolve"]["accessor_method"],
)
node._stable_uuid = input_json["uuid"]
return node


class DAGInputData:
"""If user passed multiple args and kwargs directly to dag.execute(), we
Expand Down
72 changes: 36 additions & 36 deletions python/ray/experimental/dag/tests/test_class_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def test_basic_actor_dag(shared_ray_instance):
def combine(x, y):
return x + y

a1 = Actor._bind(10)
res = a1.get._bind()
a1 = Actor.bind(10)
res = a1.get.bind()
print(res)
assert ray.get(res.execute()) == 10

a2 = Actor._bind(10)
a1.inc._bind(2)
a1.inc._bind(4)
a2.inc._bind(6)
dag = combine._bind(a1.get._bind(), a2.get._bind())
a2 = Actor.bind(10)
a1.inc.bind(2)
a1.inc.bind(4)
a2.inc.bind(6)
dag = combine.bind(a1.get.bind(), a2.get.bind())

print(dag)
assert ray.get(dag.execute()) == 32
Expand All @@ -71,9 +71,9 @@ def inc(self, x):
def get(self):
return ray.get(self.inner_actor.get.remote())

outer = OuterActor._bind(Actor._bind(10))
outer.inc._bind(2)
dag = outer.get._bind()
outer = OuterActor.bind(Actor.bind(10))
outer.inc.bind(2)
dag = outer.get.bind()
print(dag)
assert ray.get(dag.execute()) == 12

Expand All @@ -83,19 +83,19 @@ def test_class_as_function_constructor_arg(shared_ray_instance):
def f(actor_handle):
return ray.get(actor_handle.get.remote())

dag = f._bind(Actor._bind(10))
dag = f.bind(Actor.bind(10))
print(dag)
assert ray.get(dag.execute()) == 10


def test_basic_actor_dag_constructor_options(shared_ray_instance):
a1 = Actor._bind(10)
dag = a1.get._bind()
a1 = Actor.bind(10)
dag = a1.get.bind()
print(dag)
assert ray.get(dag.execute()) == 10

a1 = Actor.options(name="Actor", namespace="test", max_pending_calls=10)._bind(10)
dag = a1.get._bind()
a1 = Actor.options(name="Actor", namespace="test", max_pending_calls=10).bind(10)
dag = a1.get.bind()
print(dag)
# Ensure execution result is identical with .options() in init()
assert ray.get(dag.execute()) == 10
Expand All @@ -106,16 +106,16 @@ def test_basic_actor_dag_constructor_options(shared_ray_instance):


def test_actor_method_options(shared_ray_instance):
a1 = Actor._bind(10)
dag = a1.get.options(name="actor_method_options")._bind()
a1 = Actor.bind(10)
dag = a1.get.options(name="actor_method_options").bind()
print(dag)
assert ray.get(dag.execute()) == 10
assert dag.get_options().get("name") == "actor_method_options"


def test_basic_actor_dag_constructor_invalid_options(shared_ray_instance):
a1 = Actor.options(num_cpus=-1)._bind(10)
invalid_dag = a1.get._bind()
a1 = Actor.options(num_cpus=-1).bind(10)
invalid_dag = a1.get.bind()
with pytest.raises(ValueError, match=".*Resource quantities may not be negative.*"):
ray.get(invalid_dag.execute())

Expand All @@ -130,24 +130,24 @@ def test_actor_options_complicated(shared_ray_instance):
def combine(x, y):
return x + y

a1 = Actor.options(name="a1_v0")._bind(10)
res = a1.get.options(name="v1")._bind()
a1 = Actor.options(name="a1_v0").bind(10)
res = a1.get.options(name="v1").bind()
print(res)
assert ray.get(res.execute()) == 10
assert a1.get_options().get("name") == "a1_v0"
assert res.get_options().get("name") == "v1"

a1 = Actor.options(name="a1_v1")._bind(10) # Cannot
a2 = Actor.options(name="a2_v0")._bind(10)
a1.inc.options(name="v1")._bind(2)
a1.inc.options(name="v2")._bind(4)
a2.inc.options(name="v3")._bind(6)
dag = combine.options(name="v4")._bind(a1.get._bind(), a2.get._bind())
a1 = Actor.options(name="a1_v1").bind(10) # Cannot
a2 = Actor.options(name="a2_v0").bind(10)
a1.inc.options(name="v1").bind(2)
a1.inc.options(name="v2").bind(4)
a2.inc.options(name="v3").bind(6)
dag = combine.options(name="v4").bind(a1.get.bind(), a2.get.bind())

print(dag)
assert ray.get(dag.execute()) == 32
test_a1 = dag.get_args()[0] # call graph for a1.get._bind()
test_a2 = dag.get_args()[1] # call graph for a2.get._bind()
test_a1 = dag.get_args()[0] # call graph for a1.get.bind()
test_a2 = dag.get_args()[1] # call graph for a2.get.bind()
assert test_a2.get_options() == {} # No .options() at outer call
# refer to a2 constructor .options() call
assert (
Expand Down Expand Up @@ -198,8 +198,8 @@ def caller(handle):
assert isinstance(handle, ray.actor.ActorHandle), handle
return ray.get(handle.ping.remote())

a1 = Actor._bind()
dag = caller._bind(a1)
a1 = Actor.bind()
dag = caller.bind(a1)
print(dag)
assert ray.get(dag.execute()) == "hello"

Expand Down Expand Up @@ -227,15 +227,15 @@ def pipeline(x, m1, m2, selection):
result = m2.forward.remote(x)
return ray.get(result)

m1 = Model._bind("Even: ")
m2 = Model._bind("Odd: ")
selection = ModelSelection._bind()
m1 = Model.bind("Even: ")
m2 = Model.bind("Odd: ")
selection = ModelSelection.bind()

even_input = pipeline._bind(20, m1, m2, selection)
even_input = pipeline.bind(20, m1, m2, selection)
print(even_input)
assert ray.get(even_input.execute()) == "Even: 20"

odd_input = pipeline._bind(21, m1, m2, selection)
odd_input = pipeline.bind(21, m1, m2, selection)
print(odd_input)
assert ray.get(odd_input.execute()) == "Odd: 21"

Expand Down
48 changes: 24 additions & 24 deletions python/ray/experimental/dag/tests/test_function_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def d(x, y):
ray.get(ct.inc.remote())
return x + y

a_ref = a._bind()
b_ref = b._bind(a_ref)
c_ref = c._bind(a_ref)
d_ref = d._bind(b_ref, c_ref)
d1_ref = d._bind(d_ref, d_ref)
d2_ref = d._bind(d1_ref, d_ref)
dag = d._bind(d2_ref, d_ref)
a_ref = a.bind()
b_ref = b.bind(a_ref)
c_ref = c.bind(a_ref)
d_ref = d.bind(b_ref, c_ref)
d1_ref = d.bind(d_ref, d_ref)
d2_ref = d.bind(d1_ref, d_ref)
dag = d.bind(d2_ref, d_ref)
print(dag)

assert ray.get(dag.execute()) == 28
Expand Down Expand Up @@ -74,10 +74,10 @@ def d(x, y):
ray.get(ct.inc.remote())
return x + y

a_ref = a._bind()
b_ref = b.options(name="b", num_returns=1)._bind(a_ref)
c_ref = c.options(name="c", max_retries=3)._bind(a_ref)
dag = d.options(name="d", num_cpus=2)._bind(b_ref, c_ref)
a_ref = a.bind()
b_ref = b.options(name="b", num_returns=1).bind(a_ref)
c_ref = c.options(name="c", max_retries=3).bind(a_ref)
dag = d.options(name="d", num_cpus=2).bind(b_ref, c_ref)

print(dag)

Expand Down Expand Up @@ -106,12 +106,12 @@ def a():
def b(x):
return x * 2

a_ref = a._bind()
dag = b._bind(a_ref)
a_ref = a.bind()
dag = b.bind(a_ref)

# Ensure current DAG is executable
assert ray.get(dag.execute()) == 4
invalid_dag = b.options(num_cpus=-1)._bind(a_ref)
invalid_dag = b.options(num_cpus=-1).bind(a_ref)
with pytest.raises(ValueError, match=".*Resource quantities may not be negative.*"):
ray.get(invalid_dag.execute())

Expand All @@ -121,17 +121,17 @@ def test_node_accessors(shared_ray_instance):
def a(*a, **kw):
pass

tmp1 = a._bind()
tmp2 = a._bind()
tmp3 = a._bind()
node = a._bind(1, tmp1, x=tmp2, y={"foo": tmp3})
tmp1 = a.bind()
tmp2 = a.bind()
tmp3 = a.bind()
node = a.bind(1, tmp1, x=tmp2, y={"foo": tmp3})
assert node.get_args() == (1, tmp1)
assert node.get_kwargs() == {"x": tmp2, "y": {"foo": tmp3}}
assert node._get_toplevel_child_nodes() == {tmp1, tmp2}
assert node._get_all_child_nodes() == {tmp1, tmp2, tmp3}

tmp4 = a._bind()
tmp5 = a._bind()
tmp4 = a.bind()
tmp5 = a.bind()
replace = {tmp1: tmp4, tmp2: tmp4, tmp3: tmp5}
n2 = node._apply_and_replace_all_child_nodes(lambda x: replace[x])
assert n2._get_all_child_nodes() == {tmp4, tmp5}
Expand Down Expand Up @@ -160,10 +160,10 @@ def d(nested):
ray.get(ct.inc.remote())
return ray.get(nested["x"]) + ray.get(nested["y"])

a_ref = a._bind()
b_ref = b._bind(x=a_ref)
c_ref = c._bind(x=a_ref)
dag = d._bind({"x": b_ref, "y": c_ref})
a_ref = a.bind()
b_ref = b.bind(x=a_ref)
c_ref = c.bind(x=a_ref)
dag = d.bind({"x": b_ref, "y": c_ref})
print(dag)

assert ray.get(dag.execute()) == 7
Expand Down
Loading