Skip to content

Commit d2052b1

Browse files
committed
Adds some async changes
We need to audit all decorators and see if they work with async or not...
1 parent 1f05f2e commit d2052b1

File tree

8 files changed

+228
-9
lines changed

8 files changed

+228
-9
lines changed

hamilton/function_modifiers/base.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import collections
33
import functools
4+
import inspect
45
import itertools
56
import logging
67
from abc import ABC
@@ -100,14 +101,20 @@ def __call__(self, fn: Callable):
100101
:param fn: Function to decorate
101102
:return: The function again, with the desired properties.
102103
"""
103-
# stop unwrapping if not a hamilton function
104-
# should only be one level of "hamilton wrapping" - and that's what we attach things to.
104+
# # stop unwrapping if not a hamilton function
105+
# # should only be one level of "hamilton wrapping" - and that's what we attach things to.
105106
self.validate(unwrap(fn, stop=lambda f: not hasattr(f, "__hamilton__")))
106107
if not hasattr(fn, "__hamilton__"):
108+
if inspect.iscoroutinefunction(fn):
107109

108-
@functools.wraps(fn)
109-
def wrapper(*args, **kwargs):
110-
return fn(*args, **kwargs)
110+
@functools.wraps(fn)
111+
async def wrapper(*args, **kwargs):
112+
return await fn(*args, **kwargs)
113+
else:
114+
115+
@functools.wraps(fn)
116+
def wrapper(*args, **kwargs):
117+
return fn(*args, **kwargs)
111118

112119
wrapper.__hamilton__ = True
113120
else:

hamilton/function_modifiers/expanders.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,65 @@ def replacement_function(
231231
# now the error will be clear enough
232232
return node_.callable(*args, **new_kwargs)
233233

234+
async def async_replacement_function(
235+
*args,
236+
upstream_dependencies=upstream_dependencies,
237+
literal_dependencies=literal_dependencies,
238+
grouped_list_dependencies=grouped_list_dependencies,
239+
grouped_dict_dependencies=grouped_dict_dependencies,
240+
former_inputs=list(node_.input_types.keys()), # noqa
241+
**kwargs,
242+
):
243+
"""This function rewrites what is passed in kwargs to the right kwarg for the function.
244+
The passed in kwargs are all the dependencies of this node. Note that we actually have the "former inputs",
245+
which are what the node declares as its dependencies. So, we just have to loop through all of them to
246+
get the "new" value. This "new" value comes from the parameterization.
247+
248+
Note that much of this code should *probably* live within the source/value/grouped functions, but
249+
it is here as we're not 100% sure about the abstraction.
250+
251+
TODO -- think about how the grouped/source/literal functions should be able to grab the values from kwargs/args.
252+
Should be easy -- they should just have something like a "resolve(**kwargs)" function that they can call.
253+
"""
254+
new_kwargs = {}
255+
for node_input in former_inputs:
256+
if node_input in upstream_dependencies:
257+
# If the node is specified by `source`, then we get the value from the kwargs
258+
new_kwargs[node_input] = kwargs[upstream_dependencies[node_input].source]
259+
elif node_input in literal_dependencies:
260+
# If the node is specified by `value`, then we get the literal value (no need for kwargs)
261+
new_kwargs[node_input] = literal_dependencies[node_input].value
262+
elif node_input in grouped_list_dependencies:
263+
# If the node is specified by `group`, then we get the list of values from the kwargs or the literal
264+
new_kwargs[node_input] = []
265+
for replacement in grouped_list_dependencies[node_input].sources:
266+
resolved_value = (
267+
kwargs[replacement.source]
268+
if replacement.get_dependency_type()
269+
== ParametrizedDependencySource.UPSTREAM
270+
else replacement.value
271+
)
272+
new_kwargs[node_input].append(resolved_value)
273+
elif node_input in grouped_dict_dependencies:
274+
# If the node is specified by `group`, then we get the dict of values from the kwargs or the literal
275+
new_kwargs[node_input] = {}
276+
for dependency, replacement in grouped_dict_dependencies[
277+
node_input
278+
].sources.items():
279+
resolved_value = (
280+
kwargs[replacement.source]
281+
if replacement.get_dependency_type()
282+
== ParametrizedDependencySource.UPSTREAM
283+
else replacement.value
284+
)
285+
new_kwargs[node_input][dependency] = resolved_value
286+
elif node_input in kwargs:
287+
new_kwargs[node_input] = kwargs[node_input]
288+
# This case is left blank for optional parameters. If we error here, we'll break
289+
# the (supported) case of optionals. We do know whether its optional but for
290+
# now the error will be clear enough
291+
return await node_.callable(*args, **new_kwargs)
292+
234293
new_input_types = {}
235294
grouped_dependencies = {
236295
**grouped_list_dependencies,
@@ -271,7 +330,9 @@ def replacement_function(
271330
name=output_node,
272331
doc_string=docstring, # TODO -- change docstring
273332
callabl=functools.partial(
274-
replacement_function,
333+
replacement_function
334+
if not inspect.iscoroutinefunction(node_.callable)
335+
else async_replacement_function,
275336
**{parameter: val.value for parameter, val in literal_dependencies.items()},
276337
),
277338
input_types=new_input_types,

hamilton/function_modifiers/validation.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
import inspect
23
from collections import defaultdict
34
from typing import Any, Callable, Collection, Dict, List, Type
45

@@ -46,6 +47,14 @@ def validation_function(validator_to_call: dq_base.DataValidator = validator, **
4647
result = list(kwargs.values())[0] # This should just have one kwarg
4748
return validator_to_call.validate(result)
4849

50+
async def async_validation_function(
51+
validator_to_call: dq_base.DataValidator = validator, **kwargs
52+
):
53+
result = list(kwargs.values())[0] # This should just have one kwarg
54+
if inspect.isawaitable(result):
55+
result = await result
56+
return validator_to_call.validate(result)
57+
4958
validator_node_name = node_.name + "_" + validator.name()
5059
validator_name_count[validator_node_name] = (
5160
validator_name_count[validator_node_name] + 1
@@ -58,7 +67,9 @@ def validation_function(validator_to_call: dq_base.DataValidator = validator, **
5867
name=validator_node_name, # TODO -- determine a good approach towards naming this
5968
typ=dq_base.ValidationResult,
6069
doc_string=validator.description(),
61-
callabl=validation_function,
70+
callabl=validation_function
71+
if not inspect.iscoroutinefunction(node_.callable)
72+
else async_validation_function,
6273
node_source=node.NodeType.STANDARD,
6374
input_types={raw_node.name: (node_.type, node.DependencyType.REQUIRED)},
6475
tags={

tests/function_modifiers/test_base.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
from inspect import unwrap
12
from typing import Collection, Dict, List
3+
from unittest.mock import Mock
24

3-
import pytest as pytest
5+
import pytest
46

57
from hamilton import node, settings
68
from hamilton.function_modifiers import InvalidDecoratorException, base
79
from hamilton.function_modifiers.base import (
810
MissingConfigParametersException,
911
NodeTransformer,
12+
NodeTransformLifecycle,
1013
TargetType,
1114
)
1215
from hamilton.node import Node
@@ -149,3 +152,86 @@ def test_add_fn_metadata():
149152
]
150153
assert len(nodes_with_fn_pointer) == len(nodes)
151154
assert all([n.originating_functions == (test_add_fn_metadata,) for n in nodes])
155+
156+
157+
class MockNodeTransformLifecycle(NodeTransformLifecycle):
158+
@classmethod
159+
def get_lifecycle_name(cls):
160+
return "mock_lifecycle"
161+
162+
@classmethod
163+
def allows_multiple(cls):
164+
return True
165+
166+
def validate(self, fn):
167+
pass
168+
169+
170+
def test_decorator_adds_attributes():
171+
mock_decorator = MockNodeTransformLifecycle()
172+
173+
def my_function(a: int) -> int:
174+
pass
175+
176+
decorated_fn = mock_decorator(my_function)
177+
178+
assert hasattr(decorated_fn, "mock_lifecycle")
179+
assert decorated_fn.mock_lifecycle == [mock_decorator]
180+
assert hasattr(decorated_fn, "__hamilton__")
181+
182+
183+
def test_decorator_allows_multiple_raises_error():
184+
class MockMultipleNodeTransformLifecycle(NodeTransformLifecycle):
185+
@classmethod
186+
def get_lifecycle_name(cls):
187+
return "mock_lifecycle"
188+
189+
@classmethod
190+
def allows_multiple(cls):
191+
return False
192+
193+
def validate(self, fn):
194+
pass
195+
196+
mock_decorator = MockMultipleNodeTransformLifecycle()
197+
mock_fn = Mock()
198+
decorated_fn = mock_decorator(mock_fn)
199+
200+
with pytest.raises(ValueError):
201+
mock_decorator(decorated_fn)
202+
203+
204+
def test_decorator_only_wraps_once():
205+
"""Tests that the decorator only wraps once."""
206+
mock_decorator = MockNodeTransformLifecycle()
207+
208+
def my_function(a: int) -> int:
209+
pass
210+
211+
decorated_fn = mock_decorator(my_function)
212+
decorated_fn = mock_decorator(decorated_fn)
213+
decorated_fn = mock_decorator(decorated_fn)
214+
215+
assert decorated_fn.__hamilton__ is True
216+
assert decorated_fn.__wrapped__ == my_function # one level of wrapping only
217+
218+
219+
def test_wrapping_and_unwrapping_logic():
220+
"""Tests unwrapping logic works as expected."""
221+
222+
def my_function(a: int) -> int:
223+
pass
224+
225+
# Wrap the function
226+
wrapped_fn = MockNodeTransformLifecycle()(my_function)
227+
# Unwrap the function
228+
unwrapped_fn = unwrap(wrapped_fn, stop=lambda f: not hasattr(f, "__hamilton__"))
229+
230+
# Ensure the function is unwrapped correctly
231+
assert unwrapped_fn == my_function
232+
assert not hasattr(unwrapped_fn, "__hamilton__")
233+
234+
wrapped_fn2 = MockNodeTransformLifecycle()(wrapped_fn)
235+
unwrapped_fn2 = unwrap(wrapped_fn2, stop=lambda f: not hasattr(f, "__hamilton__"))
236+
assert wrapped_fn2 == wrapped_fn # these should be the same
237+
assert unwrapped_fn2 == my_function # these should be the same
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .base import * # noqa: F401, F403
2+
from .base_extended import * # noqa: F401, F403
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
def a(input: int) -> int:
2+
return input * 2
3+
4+
5+
def z(input: int) -> int:
6+
return input * 3
7+
8+
9+
async def aa(input: int) -> int:
10+
return input * 4
11+
12+
13+
async def zz(aa: int) -> int:
14+
return aa * 5
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from hamilton.function_modifiers import check_output, parameterize, value
2+
3+
from tests.resources.decorator_related import base
4+
5+
b_p = parameterize(b={"input": value(1)}, c={"input": value(2)})(base.a)
6+
7+
b_p2 = parameterize(q={"input": value(4)}, r={"input": value(5)})(base.a)
8+
9+
b_p3 = check_output(
10+
range=(0, 10),
11+
)(base.aa)
12+
b_p3.__name__ = "b_p3" # required to register this as `b_p3` in the graph
13+
14+
b_p4 = parameterize(aaa={"input": value(4)}, aab={"input": value(5)})(base.aa)
15+
16+
17+
def d(b: int, c: int) -> int:
18+
return b + c
19+
20+
21+
def e(input: int, a: int) -> int:
22+
return input * 4

tests/test_end_to_end.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import pandas as pd
88
import pytest
99

10-
from hamilton import ad_hoc_utils, base, driver, settings
10+
from hamilton import ad_hoc_utils, async_driver, base, driver, settings
1111
from hamilton.base import DefaultAdapter
1212
from hamilton.data_quality.base import DataValidationError, ValidationResult
1313
from hamilton.execution import executors, grouping
1414
from hamilton.function_modifiers import source, value
1515
from hamilton.io.materialization import from_, to
1616

1717
import tests.resources.data_quality
18+
import tests.resources.decorator_related
1819
import tests.resources.dynamic_config
1920
import tests.resources.example_module
2021
import tests.resources.overrides
@@ -556,3 +557,18 @@ def test_driver_v2_inputs_can_be_none():
556557
with pytest.raises(ValueError):
557558
# validate that None doesn't cause issues
558559
dr.execute(["e"], inputs=None)
560+
561+
562+
def test_function_decorator_reuse():
563+
"""Tests we can reuse a function with multiple decorators"""
564+
dr = driver.Builder().with_modules(tests.resources.decorator_related).build()
565+
result = dr.execute(["a", "b", "c", "e", "q"], inputs={"input": 2})
566+
assert result == {"a": 4, "b": 2, "c": 4, "e": 8, "q": 8}
567+
568+
569+
@pytest.mark.asyncio
570+
async def test_function_decorator_reuse_async():
571+
"""Tests we can reuse a function with multiple decorators"""
572+
dr = await async_driver.Builder().with_modules(tests.resources.decorator_related).build()
573+
result = await dr.execute(["a", "b", "c", "e", "q", "zz", "b_p3", "aaa"], inputs={"input": 2})
574+
assert result == {"a": 4, "aaa": 16, "b": 2, "b_p3": 8, "c": 4, "e": 8, "q": 8, "zz": 40}

0 commit comments

Comments
 (0)