Skip to content

Commit f995abc

Browse files
author
William Grant
committed
add qualifier to filter based on specific arg types
1 parent de7cca1 commit f995abc

File tree

1 file changed

+45
-5
lines changed

1 file changed

+45
-5
lines changed

typed_python/compiler/runtime.py

+45-5
Original file line numberDiff line numberDiff line change
@@ -393,19 +393,25 @@ def resultTypeForCall(self, funcObj, argTypes, kwargTypes):
393393
return _resultTypeCache[key]
394394

395395
@staticmethod
396-
def getNativeIRString(typedFunc: Function) -> str:
396+
def getNativeIRString(typedFunc: Function, *args, **kwargs) -> str:
397397
"""
398398
Given a function compiled with Entrypoint, return a text representation
399399
of the generated native (one layer prior to LLVM) code.
400400
401401
Args:
402402
typedFunc (Function): a decorated python function.
403+
*args (Optional): these optional args should be the Types of the functions' positional arguments
404+
**kwargs (Optional): these keyword args should be the Types of the functions' keyword arguments
403405
404406
Returns:
405407
A string for the function bodies generated (including constructors and destructors)
406408
"""
407409
converter = Runtime.singleton().llvm_compiler.converter
408-
function_name = typedFunc.__name__
410+
411+
if args or kwargs:
412+
function_name = getFullFunctionNameWithArgs(typedFunc, args, kwargs)
413+
else:
414+
function_name = typedFunc.__name__
409415
# relies on us maintaining our naming conventions (tests would break otherwise)
410416
output_str = ""
411417
for key, value in converter._function_definitions.items():
@@ -422,20 +428,25 @@ def getNativeIRString(typedFunc: Function) -> str:
422428
return output_str
423429

424430
@staticmethod
425-
def getLLVMString(typedFunc: Function) -> str:
431+
def getLLVMString(typedFunc: Function, *args, **kwargs) -> str:
426432
"""
427433
Given a function compiled with Entrypoint, return a text representation
428434
of the generated LLVM code.
429435
430436
Args:
431437
typedFunc (Function): a decorated python function.
438+
*args (Optional): these optional args should be the Types of the functions' positional arguments
439+
**kwargs (Optional): these keyword args should be the Types of the functions' keyword arguments
432440
433441
Returns:
434442
A string for the function bodies generated (including constructors and destructors)
435443
"""
436444
converter = Runtime.singleton().llvm_compiler.converter
437-
function_name = typedFunc.__name__
438-
# relies on us maintaining our naming conventions (tests would break otherwise)
445+
446+
if args or kwargs:
447+
function_name = getFullFunctionNameWithArgs(typedFunc, args, kwargs)
448+
else:
449+
function_name = typedFunc.__name__ # relies on us maintaining our naming conventions (tests would break otherwise)
439450
output_str = ""
440451
for key, value in converter._functions_by_name.items():
441452
if function_name in key:
@@ -451,6 +462,35 @@ def getLLVMString(typedFunc: Function) -> str:
451462
return output_str
452463

453464

465+
def getFullFunctionNameWithArgs(funcObj, argTypes, kwargTypes):
466+
"""
467+
Given a Function and a set of types, compile the function to generate the unique name
468+
for that function+argument combination.
469+
470+
Args:
471+
funcObj (Function): a typed_python Function.
472+
argTypes (List): a list of the position arguments for the function.
473+
kwargTypes (Dict): a key:value mapping for the functions' keywords arguments.
474+
"""
475+
assert isinstance(funcObj, typed_python._types.Function)
476+
typeWrapper = lambda t: python_to_native_converter.typedPythonTypeToTypeWrapper(t)
477+
funcObj = _types.prepareArgumentToBePassedToCompiler(funcObj)
478+
argTypes = [typeWrapper(a) for a in argTypes]
479+
kwargTypes = {k: typeWrapper(v) for k, v in kwargTypes.items()}
480+
481+
overload_index = 0
482+
overload = funcObj.overloads[overload_index]
483+
484+
ExpressionConversionContext = typed_python.compiler.expression_conversion_context.ExpressionConversionContext
485+
argumentSignature = ExpressionConversionContext.computeFunctionArgumentTypeSignature(overload, argTypes, kwargTypes)
486+
487+
if argumentSignature is not None:
488+
callTarget = Runtime().singleton().compileFunctionOverload(funcObj, overload_index, argumentSignature, argumentsAreTypes=True)
489+
else:
490+
raise ValueError('no signature found.')
491+
return callTarget.name
492+
493+
454494
def NotCompiled(pyFunc, returnTypeOverride=None):
455495
"""Decorate 'pyFunc' to prevent it from being compiled.
456496

0 commit comments

Comments
 (0)