Skip to content

⚡️ Speed up function collect_func_params by 10% #1128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

41 changes: 22 additions & 19 deletions inference/usage_tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,30 @@ def collect_func_params(
func: Callable[[Any], Any], args: Iterable[Any], kwargs: Dict[Any, Any]
) -> Dict[str, Any]:
signature = get_signature(func)

params = {}
if args:
for param, arg_value in zip(signature.parameters.keys(), args):
params[param] = arg_value
if kwargs:
params = {**params, **kwargs}
defaults = set(signature.parameters.keys()).difference(set(params.keys()))
for default_arg in defaults:
default = signature.parameters[default_arg].default
if default is inspect.Parameter.empty:
continue
params[default_arg] = default

signature_params = set(signature.parameters)
if set(params) != signature_params:
if "kwargs" in signature_params:
parameters = signature.parameters

# Initialize params with positional arguments
params = {param: arg_value for param, arg_value in zip(parameters.keys(), args)}

# Update params with keyword arguments
params.update(kwargs)

# Set default values for missing arguments
defaults = {
param: param_obj.default
for param, param_obj in parameters.items()
if param not in params and param_obj.default is not inspect.Parameter.empty
}
params.update(defaults)

# Verify against function signature parameters
signature_keys = set(parameters.keys())
if params.keys() != signature_keys:
if "kwargs" in signature_keys:
params["kwargs"] = kwargs
if "args" in signature_params:
if "args" in signature_keys:
params["args"] = args
if not set(params).issuperset(signature_params):
if not set(params).issuperset(signature_keys):
logger.error("Params mismatch for %s.%s", func.__module__, func.__name__)

return params
Loading