diff --git a/inference/usage_tracking/utils.py b/inference/usage_tracking/utils.py index 1df3443168..12fdf58ccc 100644 --- a/inference/usage_tracking/utils.py +++ b/inference/usage_tracking/utils.py @@ -19,27 +19,30 @@ def collect_func_params( func: Callable[[Any], Any], args: Iterable[Any], kwargs: Dict[Any, Any] ) -> Dict[str, Any]: signature = get_signature(func) - - params = {} - if args: - for param, arg_value in zip(signature.parameters.keys(), args): - params[param] = arg_value - if kwargs: - params = {**params, **kwargs} - defaults = set(signature.parameters.keys()).difference(set(params.keys())) - for default_arg in defaults: - default = signature.parameters[default_arg].default - if default is inspect.Parameter.empty: - continue - params[default_arg] = default - - signature_params = set(signature.parameters) - if set(params) != signature_params: - if "kwargs" in signature_params: + parameters = signature.parameters + + # Initialize params with positional arguments + params = {param: arg_value for param, arg_value in zip(parameters.keys(), args)} + + # Update params with keyword arguments + params.update(kwargs) + + # Set default values for missing arguments + defaults = { + param: param_obj.default + for param, param_obj in parameters.items() + if param not in params and param_obj.default is not inspect.Parameter.empty + } + params.update(defaults) + + # Verify against function signature parameters + signature_keys = set(parameters.keys()) + if params.keys() != signature_keys: + if "kwargs" in signature_keys: params["kwargs"] = kwargs - if "args" in signature_params: + if "args" in signature_keys: params["args"] = args - if not set(params).issuperset(signature_params): + if not set(params).issuperset(signature_keys): logger.error("Params mismatch for %s.%s", func.__module__, func.__name__) return params