@@ -231,6 +231,65 @@ def replacement_function(
231231 # now the error will be clear enough
232232 return node_ .callable (* args , ** new_kwargs )
233233
234+ async def async_replacement_function (
235+ * args ,
236+ upstream_dependencies = upstream_dependencies ,
237+ literal_dependencies = literal_dependencies ,
238+ grouped_list_dependencies = grouped_list_dependencies ,
239+ grouped_dict_dependencies = grouped_dict_dependencies ,
240+ former_inputs = list (node_ .input_types .keys ()), # noqa
241+ ** kwargs ,
242+ ):
243+ """This function rewrites what is passed in kwargs to the right kwarg for the function.
244+ The passed in kwargs are all the dependencies of this node. Note that we actually have the "former inputs",
245+ which are what the node declares as its dependencies. So, we just have to loop through all of them to
246+ get the "new" value. This "new" value comes from the parameterization.
247+
248+ Note that much of this code should *probably* live within the source/value/grouped functions, but
249+ it is here as we're not 100% sure about the abstraction.
250+
251+ TODO -- think about how the grouped/source/literal functions should be able to grab the values from kwargs/args.
252+ Should be easy -- they should just have something like a "resolve(**kwargs)" function that they can call.
253+ """
254+ new_kwargs = {}
255+ for node_input in former_inputs :
256+ if node_input in upstream_dependencies :
257+ # If the node is specified by `source`, then we get the value from the kwargs
258+ new_kwargs [node_input ] = kwargs [upstream_dependencies [node_input ].source ]
259+ elif node_input in literal_dependencies :
260+ # If the node is specified by `value`, then we get the literal value (no need for kwargs)
261+ new_kwargs [node_input ] = literal_dependencies [node_input ].value
262+ elif node_input in grouped_list_dependencies :
263+ # If the node is specified by `group`, then we get the list of values from the kwargs or the literal
264+ new_kwargs [node_input ] = []
265+ for replacement in grouped_list_dependencies [node_input ].sources :
266+ resolved_value = (
267+ kwargs [replacement .source ]
268+ if replacement .get_dependency_type ()
269+ == ParametrizedDependencySource .UPSTREAM
270+ else replacement .value
271+ )
272+ new_kwargs [node_input ].append (resolved_value )
273+ elif node_input in grouped_dict_dependencies :
274+ # If the node is specified by `group`, then we get the dict of values from the kwargs or the literal
275+ new_kwargs [node_input ] = {}
276+ for dependency , replacement in grouped_dict_dependencies [
277+ node_input
278+ ].sources .items ():
279+ resolved_value = (
280+ kwargs [replacement .source ]
281+ if replacement .get_dependency_type ()
282+ == ParametrizedDependencySource .UPSTREAM
283+ else replacement .value
284+ )
285+ new_kwargs [node_input ][dependency ] = resolved_value
286+ elif node_input in kwargs :
287+ new_kwargs [node_input ] = kwargs [node_input ]
288+ # This case is left blank for optional parameters. If we error here, we'll break
289+ # the (supported) case of optionals. We do know whether its optional but for
290+ # now the error will be clear enough
291+ return await node_ .callable (* args , ** new_kwargs )
292+
234293 new_input_types = {}
235294 grouped_dependencies = {
236295 ** grouped_list_dependencies ,
@@ -271,7 +330,9 @@ def replacement_function(
271330 name = output_node ,
272331 doc_string = docstring , # TODO -- change docstring
273332 callabl = functools .partial (
274- replacement_function ,
333+ replacement_function
334+ if not inspect .iscoroutinefunction (node_ .callable )
335+ else async_replacement_function ,
275336 ** {parameter : val .value for parameter , val in literal_dependencies .items ()},
276337 ),
277338 input_types = new_input_types ,
0 commit comments