diff --git a/google/ads/googleads/client.py b/google/ads/googleads/client.py index 2b2ce06a7..8a8b70894 100644 --- a/google/ads/googleads/client.py +++ b/google/ads/googleads/client.py @@ -17,9 +17,12 @@ import logging.config from google.api_core.gapic_v1.client_info import ClientInfo -import grpc.experimental +import grpc from proto.enums import ProtoEnumMeta +from google.protobuf.message import Message as ProtobufMessageType +from proto import Message as ProtoPlusMessageType + from google.ads.googleads import config, oauth2, util from google.ads.googleads.interceptors import ( MetadataInterceptor, @@ -27,6 +30,9 @@ LoggingInterceptor, ) +from types import ModuleType +from typing import Any, Dict, List, Tuple, Union + _logger = logging.getLogger(__name__) _SERVICE_CLIENT_TEMPLATE = "{}Client" @@ -60,19 +66,18 @@ class _EnumGetter: class instances when accessed. """ - def __init__(self, client): + def __init__(self, client: "GoogleAdsClient") -> None: """Initializer for the _EnumGetter class. Args: - version: a str indicating the version of the Google Ads API to be - used. + client: An instance of the GoogleAdsClient class. """ - self._client = client - self._version = client.version or _DEFAULT_VERSION - self._enums = None - self._use_proto_plus = client.use_proto_plus + self._client: "GoogleAdsClient" = client + self._version: str = client.version or _DEFAULT_VERSION + self._enums: Union[Tuple[str], None] = None + self._use_proto_plus: bool = client.use_proto_plus - def __dir__(self): + def __dir__(self) -> Tuple[str]: """Overrides behavior when dir() is called on instances of this class. It's useful to use dir() to see a list of available attributes. Since @@ -86,7 +91,7 @@ def __dir__(self): return self._enums - def __getattr__(self, name): + def __getattr__(self, name: str) -> Union[ProtoPlusMessageType, ProtobufMessageType]: """Dynamically loads the given enum class instance. Args: @@ -95,14 +100,14 @@ def __getattr__(self, name): Returns: An instance of the enum proto message class. """ - if not name in self.__dir__(): + if name not in self.__dir__(): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) try: enum_class = self._client.get_type(name) - if self._use_proto_plus == True: + if self._use_proto_plus: for attr in dir(enum_class): attr_val = getattr(enum_class, attr) if isinstance(attr_val, ProtoEnumMeta): @@ -114,7 +119,7 @@ def __getattr__(self, name): f"'{type(self).__name__}' object has no attribute '{name}'" ) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: """Returns self serialized as a dict. Since this class overrides __getattr__ we define this method to help @@ -126,7 +131,7 @@ def __getstate__(self): """ return self.__dict__.copy() - def __setstate__(self, d): + def __setstate__(self, d: Dict[str, Any]) -> None: """Deserializes self with the given dictionary. Since this class overrides __getattr__ we define this method to help @@ -143,7 +148,11 @@ class GoogleAdsClient: """Google Ads client used to configure settings and fetch services.""" @classmethod - def copy_from(cls, destination, origin): + def copy_from( + cls, + destination: Union[ProtoPlusMessageType, ProtobufMessageType], + origin: Union[ProtoPlusMessageType, ProtobufMessageType] + ) -> Union[ProtoPlusMessageType, ProtobufMessageType]: """Copies protobuf and proto-plus messages into one-another. This method consolidates the CopyFrom logic of protobuf and proto-plus @@ -157,7 +166,7 @@ def copy_from(cls, destination, origin): return util.proto_copy_from(destination, origin) @classmethod - def _get_client_kwargs(cls, config_data): + def _get_client_kwargs(cls, config_data: Dict[str, Any]) -> Dict[str, Any]: """Converts configuration dict into kwargs required by the client. Args: @@ -185,7 +194,7 @@ def _get_client_kwargs(cls, config_data): } @classmethod - def _get_api_services_by_version(cls, version): + def _get_api_services_by_version(cls, version: str) -> ModuleType: """Returns a module with all services and types for a given API version. Args: @@ -207,7 +216,9 @@ def _get_api_services_by_version(cls, version): return version_module @classmethod - def load_from_env(cls, version=None): + def load_from_env( + cls, version: Union[str, None] = None + ) -> "GoogleAdsClient": """Creates a GoogleAdsClient with data stored in the env variables. Args: @@ -220,12 +231,14 @@ def load_from_env(cls, version=None): Raises: ValueError: If the configuration lacks a required field. """ - config_data = config.load_from_env() - kwargs = cls._get_client_kwargs(config_data) + config_data: Dict[str, Any] = config.load_from_env() + kwargs: Dict[str, Any] = cls._get_client_kwargs(config_data) return cls(**dict(version=version, **kwargs)) @classmethod - def load_from_string(cls, yaml_str, version=None): + def load_from_string( + cls, yaml_str: str, version: Union[str, None] = None + ) -> "GoogleAdsClient": """Creates a GoogleAdsClient with data stored in the YAML string. Args: @@ -240,12 +253,14 @@ def load_from_string(cls, yaml_str, version=None): Raises: ValueError: If the configuration lacks a required field. """ - config_data = config.parse_yaml_document_to_dict(yaml_str) - kwargs = cls._get_client_kwargs(config_data) + config_data: Dict[str, Any] = config.parse_yaml_document_to_dict(yaml_str) + kwargs: Dict[str, Any] = cls._get_client_kwargs(config_data) return cls(**dict(version=version, **kwargs)) @classmethod - def load_from_dict(cls, config_dict, version=None): + def load_from_dict( + cls, config_dict: Dict[str, Any], version: Union[str, None] = None + ) -> "GoogleAdsClient": """Creates a GoogleAdsClient with data stored in the config_dict. Args: @@ -260,12 +275,14 @@ def load_from_dict(cls, config_dict, version=None): Raises: ValueError: If the configuration lacks a required field. """ - config_data = config.load_from_dict(config_dict) - kwargs = cls._get_client_kwargs(config_data) + config_data: Dict[str, Any] = config.load_from_dict(config_dict) + kwargs: Dict[str, Any] = cls._get_client_kwargs(config_data) return cls(**dict(version=version, **kwargs)) @classmethod - def load_from_storage(cls, path=None, version=None): + def load_from_storage( + cls, path: Union[str, None] = None, version: Union[str, None] = None + ) -> "GoogleAdsClient": """Creates a GoogleAdsClient with data stored in the specified file. Args: @@ -282,22 +299,22 @@ def load_from_storage(cls, path=None, version=None): IOError: If the configuration file can't be loaded. ValueError: If the configuration file lacks a required field. """ - config_data = config.load_from_yaml_file(path) - kwargs = cls._get_client_kwargs(config_data) + config_data: Dict[str, Any] = config.load_from_yaml_file(path) + kwargs: Dict[str, Any] = cls._get_client_kwargs(config_data) return cls(**dict(version=version, **kwargs)) def __init__( self, - credentials, - developer_token, - endpoint=None, - login_customer_id=None, - logging_config=None, - linked_customer_id=None, - version=None, - http_proxy=None, - use_proto_plus=False, - use_cloud_org_for_api_access=None, + credentials: Dict[str, Any], + developer_token: str, + endpoint: Union[str, None] = None, + login_customer_id: Union[str, None] = None, + logging_config: Union[Dict[str, Any], None] = None, + linked_customer_id: Union[str, None] = None, + version: Union[str, None] = None, + http_proxy: Union[str, None] = None, + use_proto_plus: bool = False, + use_cloud_org_for_api_access: Union[str, None] = None, ): """Initializer for the GoogleAdsClient. @@ -321,22 +338,29 @@ def __init__( if logging_config: logging.config.dictConfig(logging_config) - self.credentials = credentials - self.developer_token = developer_token - self.endpoint = endpoint - self.login_customer_id = login_customer_id - self.linked_customer_id = linked_customer_id - self.version = version - self.http_proxy = http_proxy - self.use_proto_plus = use_proto_plus - self.use_cloud_org_for_api_access = use_cloud_org_for_api_access - self.enums = _EnumGetter(self) + self.credentials: Dict[str, Any] = credentials + self.developer_token: str = developer_token + self.endpoint: Union[str, None] = endpoint + self.login_customer_id: Union[str, None] = login_customer_id + self.linked_customer_id: Union[str, None] = linked_customer_id + self.version: Union[str, None] = version + self.http_proxy: Union[str, None] = http_proxy + self.use_proto_plus: bool = use_proto_plus + self.use_cloud_org_for_api_access: Union[str, None] = ( + use_cloud_org_for_api_access + ) + self.enums: _EnumGetter = _EnumGetter(self) # If given, write the http_proxy channel option for GRPC to use if http_proxy: _GRPC_CHANNEL_OPTIONS.append(("grpc.http_proxy", http_proxy)) - def get_service(self, name, version=_DEFAULT_VERSION, interceptors=None): + def get_service( + self, + name: str, + version: str = _DEFAULT_VERSION, + interceptors: Union[list, None] = None, + ) -> Any: """Returns a service client instance for the specified service_name. Args: @@ -359,13 +383,15 @@ def get_service(self, name, version=_DEFAULT_VERSION, interceptors=None): # override any version specified as an argument. version = self.version if self.version else version # api_module = self._get_api_services_by_version(version) - services_path = f"google.ads.googleads.{version}.services.services" - snaked = util.convert_upper_case_to_snake_case(name) + services_path: str = ( + f"google.ads.googleads.{version}.services.services" + ) + snaked: str = util.convert_upper_case_to_snake_case(name) interceptors = interceptors or [] try: - service_module = import_module(f"{services_path}.{snaked}") - service_client_class = util.get_nested_attr( + service_module: Any = import_module(f"{services_path}.{snaked}") + service_client_class: Any = util.get_nested_attr( service_module, _SERVICE_CLIENT_TEMPLATE.format(name) ) except (AttributeError, ModuleNotFoundError): @@ -374,21 +400,21 @@ def get_service(self, name, version=_DEFAULT_VERSION, interceptors=None): "Ads API {}.".format(name, version) ) - service_transport_class = service_client_class.get_transport_class() + service_transport_class: Any = service_client_class.get_transport_class() - endpoint = ( + endpoint: str = ( self.endpoint if self.endpoint else service_client_class.DEFAULT_ENDPOINT ) - channel = service_transport_class.create_channel( + channel: grpc.Channel = service_transport_class.create_channel( host=endpoint, credentials=self.credentials, options=_GRPC_CHANNEL_OPTIONS, ) - interceptors = interceptors + [ + interceptors: List[Union[grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor]] = interceptors + [ MetadataInterceptor( self.developer_token, self.login_customer_id, @@ -396,18 +422,20 @@ def get_service(self, name, version=_DEFAULT_VERSION, interceptors=None): self.use_cloud_org_for_api_access, ), LoggingInterceptor(_logger, version, endpoint), - ExceptionInterceptor(version, use_proto_plus=self.use_proto_plus), + ExceptionInterceptor( + version, use_proto_plus=self.use_proto_plus + ), ] - channel = grpc.intercept_channel(channel, *interceptors) + channel: grpc.Channel = grpc.intercept_channel(channel, *interceptors) - service_transport = service_transport_class( + service_transport: Any = service_transport_class( channel=channel, client_info=_CLIENT_INFO ) return service_client_class(transport=service_transport) - def get_type(self, name, version=_DEFAULT_VERSION): + def get_type(self, name: str, version: str = _DEFAULT_VERSION) -> Union[ProtoPlusMessageType, ProtobufMessageType]: """Returns the specified common, enum, error, or resource type. Args: @@ -441,22 +469,22 @@ def get_type(self, name, version=_DEFAULT_VERSION): # If version is specified when the instance is created, # override any version specified as an argument. - version = self.version if self.version else version - type_classes = self._get_api_services_by_version(version) + version: str = self.version if self.version else version + type_classes: ModuleType = self._get_api_services_by_version(version) - for type in _MESSAGE_TYPES: - if type == "services": - path = f"{type}.types.{name}" + for type_name in _MESSAGE_TYPES: + if type_name == "services": + path: str = f"{type_name}.types.{name}" else: - path = f"{type}.{name}" + path: str = f"{type_name}.{name}" try: - message_class = util.get_nested_attr(type_classes, path) + message_class: Union[ProtoPlusMessageType, ProtobufMessageType] = util.get_nested_attr(type_classes, path) # type: ignore[no-untyped-call] - if self.use_proto_plus == True: + if self.use_proto_plus: return message_class() else: - return util.convert_proto_plus_to_protobuf(message_class()) + return util.convert_proto_plus_to_protobuf(message_class()) # type: ignore[no-untyped-call] except AttributeError: pass diff --git a/google/ads/googleads/config.py b/google/ads/googleads/config.py index 8041ea0b5..db72dae24 100644 --- a/google/ads/googleads/config.py +++ b/google/ads/googleads/config.py @@ -18,6 +18,7 @@ import logging.config import os import re +from typing import Any, Callable, List, Tuple, TypeVar, Union import yaml @@ -45,7 +46,8 @@ "path_to_private_key_file", "delegated_account", ) -_KEYS_ENV_VARIABLES_MAP = { + +_KEYS_ENV_VARIABLES_MAP: dict[str, str] = { key: _ENV_PREFIX + key.upper() for key in _REQUIRED_KEYS + _OPTIONAL_KEYS @@ -56,8 +58,10 @@ + _SECONDARY_OAUTH2_SERVICE_ACCOUNT_KEYS } +F = TypeVar("F", bound=Callable[..., Any]) + -def _config_validation_decorator(func): +def _config_validation_decorator(func: F) -> F: """A decorator used to easily run validations on configs loaded into dicts. Add this decorator to any method that returns the config as a dict. @@ -67,15 +71,15 @@ def _config_validation_decorator(func): """ @functools.wraps(func) - def validation_wrapper(*args, **kwargs): - config_dict = func(*args, **kwargs) + def validation_wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]: + config_dict: dict[str, Any] = func(*args, **kwargs) validate_dict(config_dict) return config_dict return validation_wrapper -def _config_parser_decorator(func): +def _config_parser_decorator(func: F) -> F: """A decorator used to easily parse config values. Since configs can be loaded from different locations such as env vars or @@ -85,20 +89,22 @@ def _config_parser_decorator(func): """ @functools.wraps(func) - def parser_wrapper(*args, **kwargs): - config_dict = func(*args, **kwargs) - parsed_config = convert_login_customer_id_to_str(config_dict) - parsed_config = convert_linked_customer_id_to_str(parsed_config) + def parser_wrapper(*args: Any, **kwargs: Any) -> Any: + config_dict: dict[str, Any] = func(*args, **kwargs) + parsed_config: dict[str, Any] = convert_login_customer_id_to_str( + config_dict + ) + parsed_config: dict[str, Any] = convert_linked_customer_id_to_str(parsed_config) - config_keys = parsed_config.keys() + config_keys: List[str] = parsed_config.keys() if "logging" in config_keys: - logging_config = parsed_config["logging"] + logging_config: dict[str, Any] = parsed_config["logging"] # If the logging config is a dict then it is already in the format # that needs to be returned by this method. if type(logging_config) is not dict: try: - parsed_config["logging"] = json.loads(logging_config) + parsed_config["logging"]: dict[str, Any] = json.loads(logging_config) # The logger is configured here in case deprecation warnings # need to be logged further down in this method. The logger # is otherwise configured by the GoogleAdsClient class. @@ -150,15 +156,15 @@ def parser_wrapper(*args, **kwargs): # variable we need to manually change it to the bool False because # the string "False" is truthy and can easily be incorrectly # converted to the boolean True. - value = parsed_config.get("use_proto_plus", False) - parsed_config["use_proto_plus"] = disambiguate_string_bool(value) + value: Union[str, bool] = parsed_config.get("use_proto_plus", False) + parsed_config["use_proto_plus"]: bool = disambiguate_string_bool(value) return parsed_config return parser_wrapper -def validate_dict(config_data): +def validate_dict(config_data: dict[str, Any]) -> None: """Validates the given configuration dict. Validations that are performed include: @@ -172,7 +178,7 @@ def validate_dict(config_data): Raises: ValueError: If the dict does not contain all required config keys. """ - if not "use_proto_plus" in config_data.keys(): + if "use_proto_plus" not in config_data.keys(): raise ValueError( "The client library configuration is missing the required " '"use_proto_plus" key. Please set this option to either "True" ' @@ -188,13 +194,13 @@ def validate_dict(config_data): ) if "login_customer_id" in config_data: - validate_login_customer_id(config_data["login_customer_id"]) + validate_login_customer_id(str(config_data["login_customer_id"])) if "linked_customer_id" in config_data: - validate_linked_customer_id(config_data["linked_customer_id"]) + validate_linked_customer_id(str(config_data["linked_customer_id"])) -def _validate_customer_id(customer_id, id_type): +def _validate_customer_id(customer_id: Union[str, None], id_type: str) -> None: """Validates a customer ID. Args: @@ -208,7 +214,7 @@ def _validate_customer_id(customer_id, id_type): """ if customer_id is not None: # Checks that the string is comprised only of 10 digits. - pattern = re.compile(r"^\d{10}", re.ASCII) + pattern: re.Pattern = re.compile(r"^\d{10}", re.ASCII) if not pattern.fullmatch(customer_id): raise ValueError( f"The specified {id_type} customer ID is invalid. It must be a " @@ -216,7 +222,7 @@ def _validate_customer_id(customer_id, id_type): ) -def validate_login_customer_id(login_customer_id): +def validate_login_customer_id(login_customer_id: Union[str, None]) -> None: """Validates a login customer ID. Args: login_customer_id: a str from config indicating a login customer ID. @@ -227,7 +233,7 @@ def validate_login_customer_id(login_customer_id): _validate_customer_id(login_customer_id, "login") -def validate_linked_customer_id(linked_customer_id): +def validate_linked_customer_id(linked_customer_id: Union[str, None]) -> None: """Validates a linked customer ID. Args: linked_customer_id: a str from config indicating a linked customer ID. @@ -240,7 +246,7 @@ def validate_linked_customer_id(linked_customer_id): @_config_validation_decorator @_config_parser_decorator -def load_from_yaml_file(path=None): +def load_from_yaml_file(path: Union[str, None] = None) -> dict[str, Any]: """Loads configuration data from a YAML file and returns it as a dict. Args: @@ -258,27 +264,27 @@ def load_from_yaml_file(path=None): # If no path is specified then we check for the environment variable # that may define the path. If that is not defined then we use the # default path. - path_from_env_var = os.environ.get( + path_from_env_var: str = os.environ.get( _ENV_PREFIX + _CONFIG_FILE_PATH_KEY[0].upper() ) - path = ( + path: Tuple[str] = ( path_from_env_var if path_from_env_var else os.path.join(os.path.expanduser("~"), "google-ads.yaml") ) if not os.path.isabs(path): - path = os.path.expanduser(path) + path: str = os.path.expanduser(path) with open(path, "rb") as handle: - yaml_doc = handle.read() + yaml_doc: bytes = handle.read() return parse_yaml_document_to_dict(yaml_doc) @_config_validation_decorator @_config_parser_decorator -def load_from_dict(config_dict): +def load_from_dict(config_dict: dict[str, Any]) -> dict[str, Any]: """Check if the argument is dictionary or not. If successful it calls the parsing decorator, followed by validation decorator. This validates the keys used in the config_dict, before returning to its caller. @@ -302,7 +308,7 @@ def load_from_dict(config_dict): @_config_validation_decorator @_config_parser_decorator -def parse_yaml_document_to_dict(yaml_doc): +def parse_yaml_document_to_dict(yaml_doc: Union[str, bytes]) -> dict[str, Any]: """Parses a YAML document to a dict. Args: @@ -320,7 +326,7 @@ def parse_yaml_document_to_dict(yaml_doc): @_config_validation_decorator @_config_parser_decorator -def load_from_env(): +def load_from_env() -> dict[str, Any]: """Loads configuration data from the environment and returns it as a dict. Returns: @@ -329,7 +335,7 @@ def load_from_env(): Raises: ValueError: If the configuration """ - config_data = { + config_data: dict[str, Any] = { key: os.environ.get(env_variable) for key, env_variable in _KEYS_ENV_VARIABLES_MAP.items() if env_variable in os.environ @@ -343,7 +349,7 @@ def load_from_env(): return config_data -def get_oauth2_installed_app_keys(): +def get_oauth2_installed_app_keys() -> tuple[str, ...]: """A getter that returns the required OAuth2 installed application keys. Returns: @@ -352,7 +358,7 @@ def get_oauth2_installed_app_keys(): return _OAUTH2_INSTALLED_APP_KEYS -def get_oauth2_required_service_account_keys(): +def get_oauth2_required_service_account_keys() -> tuple[str, ...]: """A getter that returns the required OAuth2 service account keys. Returns: @@ -361,7 +367,9 @@ def get_oauth2_required_service_account_keys(): return _OAUTH2_REQUIRED_SERVICE_ACCOUNT_KEYS -def convert_login_customer_id_to_str(config_data): +def convert_login_customer_id_to_str( + config_data: dict[str, Any] +) -> dict[str, Any]: """Parses a config dict's login_customer_id attr value to a str. Like many values from YAML it's possible for login_customer_id to @@ -374,15 +382,17 @@ def convert_login_customer_id_to_str(config_data): Returns: The same config dict object with a mutated login_customer_id attr. """ - login_customer_id = config_data.get("login_customer_id") + login_customer_id: str = config_data.get("login_customer_id") if login_customer_id: - config_data["login_customer_id"] = str(login_customer_id) + config_data["login_customer_id"]: str = str(login_customer_id) return config_data -def convert_linked_customer_id_to_str(config_data): +def convert_linked_customer_id_to_str( + config_data: dict[str, Any] +) -> dict[str, Any]: """Parses a config dict's linked_customer_id attr value to a str. Like many values from YAML it's possible for linked_customer_id to @@ -395,15 +405,15 @@ def convert_linked_customer_id_to_str(config_data): Returns: The same config dict object with a mutated linked_customer_id attr. """ - linked_customer_id = config_data.get("linked_customer_id") + linked_customer_id: str = config_data.get("linked_customer_id") if linked_customer_id: - config_data["linked_customer_id"] = str(linked_customer_id) + config_data["linked_customer_id"]: str = str(linked_customer_id) return config_data -def disambiguate_string_bool(value): +def disambiguate_string_bool(value: Union[str, bool]) -> bool: """Converts a stringified boolean to its bool representation. Args: diff --git a/google/ads/googleads/errors.py b/google/ads/googleads/errors.py index 89463828d..5542e51aa 100644 --- a/google/ads/googleads/errors.py +++ b/google/ads/googleads/errors.py @@ -13,11 +13,20 @@ # limitations under the License. """Errors used by the Google Ads API library.""" +import grpc +from proto import Message as ProtobufMessageType + class GoogleAdsException(Exception): """Exception thrown in response to an API error from GoogleAds servers.""" - def __init__(self, error, call, failure, request_id): + def __init__( + self, + error: grpc.RpcError, + call: grpc.Call, + failure: ProtobufMessageType, + request_id: str, + ) -> None: """Initializer. Args: @@ -27,7 +36,7 @@ def __init__(self, error, call, failure, request_id): GoogleAds API call failed. request_id: a str request ID associated with the GoogleAds API call. """ - self.error = error - self.call = call - self.failure = failure - self.request_id = request_id + self.error: grpc.RpcError = error + self.call: grpc.Call = call + self.failure: ProtobufMessageType = failure + self.request_id: str = request_id diff --git a/google/ads/googleads/oauth2.py b/google/ads/googleads/oauth2.py index 0ed119f2e..575904984 100644 --- a/google/ads/googleads/oauth2.py +++ b/google/ads/googleads/oauth2.py @@ -23,11 +23,17 @@ from google.ads.googleads import config -_SERVICE_ACCOUNT_SCOPES = ["https://www.googleapis.com/auth/adwords"] -_DEFAULT_TOKEN_URI = "https://accounts.google.com/o/oauth2/token" +from typing import Any, Callable, TypeVar, Union +_SERVICE_ACCOUNT_SCOPES: list[str] = [ + "https://www.googleapis.com/auth/adwords" +] +_DEFAULT_TOKEN_URI: str = "https://accounts.google.com/o/oauth2/token" -def _initialize_credentials_decorator(func): +F = TypeVar("F", bound=Callable[..., Any]) + + +def _initialize_credentials_decorator(func: F) -> F: """A decorator used to easily initialize credentials objects. Returns: @@ -35,13 +41,13 @@ def _initialize_credentials_decorator(func): """ @functools.wraps(func) - def initialize_credentials_wrapper(*args, **kwargs): - credentials = func(*args, **kwargs) + def initialize_credentials_wrapper(*args: Any, **kwargs: Any) -> Any: + credentials: Union[InstalledAppCredentials, ServiceAccountCreds] = func(*args, **kwargs) # If the configs contain an http_proxy, refresh credentials through the # proxy URI - proxy = kwargs.get("http_proxy") + proxy: Union[str, None] = kwargs.get("http_proxy") if proxy: - session = Session() + session: Session = Session() session.proxies.update({"http": proxy, "https": proxy}) credentials.refresh(Request(session=session)) else: @@ -53,18 +59,20 @@ def initialize_credentials_wrapper(*args, **kwargs): @_initialize_credentials_decorator def get_installed_app_credentials( - client_id, - client_secret, - refresh_token, - http_proxy=None, - token_uri=_DEFAULT_TOKEN_URI, -): + client_id: str, + client_secret: str, + refresh_token: str, + http_proxy: Union[str, None] = None, + token_uri: str = _DEFAULT_TOKEN_URI, +) -> InstalledAppCredentials: """Creates and returns an instance of oauth2.credentials.Credentials. Args: client_id: A str of the oauth2 client_id from configuration. client_secret: A str of the oauth2 client_secret from configuration. refresh_token: A str of the oauth2 refresh_token from configuration. + http_proxy: An optional str of the http proxy. + token_uri: An optional str of the token URI. Returns: An instance of oauth2.credentials.Credentials @@ -80,13 +88,17 @@ def get_installed_app_credentials( @_initialize_credentials_decorator def get_service_account_credentials( - json_key_file_path, subject, http_proxy=None, scopes=_SERVICE_ACCOUNT_SCOPES -): + json_key_file_path: str, + subject: str, + http_proxy: Union[str, None] = None, + scopes: list[str] = _SERVICE_ACCOUNT_SCOPES, +) -> ServiceAccountCreds: """Creates and returns an instance of oauth2.service_account.Credentials. Args: json_key_file_path: A str of the path to the private key file location. subject: A str of the email address of the delegated account. + http_proxy: An optional str of the http proxy. scopes: A list of additional scopes. Returns: @@ -97,7 +109,9 @@ def get_service_account_credentials( ) -def get_credentials(config_data): +def get_credentials( + config_data: dict[str, Any] +) -> Union[InstalledAppCredentials, ServiceAccountCreds]: """Decides which type of credentials to return based on the given config. Args: @@ -106,25 +120,27 @@ def get_credentials(config_data): Returns: An initialized credentials instance. """ - required_installed_app_keys = config.get_oauth2_installed_app_keys() - required_service_account_keys = ( - config.get_oauth2_required_service_account_keys() - ) + required_installed_app_keys: tuple[ + str, ... + ] = config.get_oauth2_installed_app_keys() + required_service_account_keys: tuple[ + str, ... + ] = config.get_oauth2_required_service_account_keys() if all(key in config_data for key in required_installed_app_keys): # Using the Installed App Flow return get_installed_app_credentials( - config_data.get("client_id"), - config_data.get("client_secret"), - config_data.get("refresh_token"), - http_proxy=config_data.get("http_proxy"), + config_data.get("client_id"), # type: ignore[arg-type] + config_data.get("client_secret"), # type: ignore[arg-type] + config_data.get("refresh_token"), # type: ignore[arg-type] + http_proxy=config_data.get("http_proxy"), # type: ignore[arg-type] ) elif all(key in config_data for key in required_service_account_keys): # Using the Service Account Flow return get_service_account_credentials( - config_data.get("json_key_file_path"), - config_data.get("impersonated_email"), - http_proxy=config_data.get("http_proxy"), + config_data.get("json_key_file_path"), # type: ignore[arg-type] + config_data.get("impersonated_email"), # type: ignore[arg-type] + http_proxy=config_data.get("http_proxy"), # type: ignore[arg-type] ) else: raise ValueError( diff --git a/google/ads/googleads/util.py b/google/ads/googleads/util.py index 32008153f..3068a6723 100644 --- a/google/ads/googleads/util.py +++ b/google/ads/googleads/util.py @@ -19,11 +19,25 @@ from google.protobuf.message import Message as ProtobufMessageType import proto +from typing import Any, List, TypeVar, overload, Union + # This regex matches characters preceded by start of line or an underscore. _RE_FIND_CHARS_TO_UPPERCASE = re.compile(r"(?:_|^)([a-z])") +M = TypeVar("M", bound=ProtobufMessageType) + + +@overload +def get_nested_attr(obj: Any, attr: str) -> Any: + ... + + +@overload +def get_nested_attr(obj: Any, attr: str, default: Any) -> Any: + ... -def get_nested_attr(obj, attr, *args): + +def get_nested_attr(obj: Any, attr: str, *args: Any) -> Any: """Gets the value of a nested attribute from an object. Args: @@ -34,13 +48,15 @@ def get_nested_attr(obj, attr, *args): The object attribute value or the given *args if the attr isn't present. """ - def _getattr(obj, attr): + def _getattr(obj: Any, attr: str) -> Any: return getattr(obj, attr, *args) return functools.reduce(_getattr, [obj] + attr.split(".")) -def set_nested_message_field(message, field_path, value): +def set_nested_message_field( + message: M, field_path: Union[str, list[str]], value: Any +) -> None: """Sets the value of a nested attribute on a protobuf message instance. This method uses "setattr" to update a given field on a protobuf message @@ -71,18 +87,20 @@ def set_nested_message_field(message, field_path, value): names or a list of strings. value: The value that the nested attribute should be set to. """ - if type(field_path) == str: - field_path = field_path.split(".") + if isinstance(field_path, str): + field_path_list: List[str] = field_path.split(".") + else: + field_path_list: Union[str, list[str]] = field_path - if len(field_path) == 1: - setattr(message, field_path[0], value) + if len(field_path_list) == 1: + setattr(message, field_path_list[0], value) else: set_nested_message_field( - getattr(message, field_path[0]), field_path[1:], value + getattr(message, field_path_list[0]), field_path_list[1:], value ) -def convert_upper_case_to_snake_case(string): +def convert_upper_case_to_snake_case(string: str) -> str: """Converts a string from UpperCase to snake_case. Primarily used to translate module names when retrieving them from version @@ -94,8 +112,8 @@ def convert_upper_case_to_snake_case(string): Returns: A new snake_case representation of the given string. """ - new_string = "" - index = 0 + new_string: str = "" + index: int = 0 for char in string: if index == 0: @@ -110,7 +128,7 @@ def convert_upper_case_to_snake_case(string): return new_string -def convert_snake_case_to_upper_case(string): +def convert_snake_case_to_upper_case(string: str) -> str: """Converts a string from snake_case to UpperCase. Primarily used to translate module names when retrieving them from version @@ -120,14 +138,18 @@ def convert_snake_case_to_upper_case(string): string: an arbitrary string to convert. """ - def converter(match): + def converter(match: re.Match) -> str: """Convert a string to strip underscores then uppercase it.""" return match.group().replace("_", "").upper() return _RE_FIND_CHARS_TO_UPPERCASE.sub(converter, string) -def convert_proto_plus_to_protobuf(message): +PPM = TypeVar("PPM", bound=proto.Message) +PBPM = TypeVar("PBPM", bound=ProtobufMessageType) + + +def convert_proto_plus_to_protobuf(message: Union[PPM, PBPM]) -> PBPM: """Converts a proto-plus message to its protobuf counterpart. Args: @@ -137,16 +159,16 @@ def convert_proto_plus_to_protobuf(message): The protobuf version of the proto_plus proto. """ if isinstance(message, proto.Message): - return type(message).pb(message) + return type(message).pb(message) # type: ignore elif isinstance(message, ProtobufMessageType): - return message + return message # type: ignore else: raise TypeError( f"Cannot convert type {type(message)} to protobuf protobuf." ) -def convert_protobuf_to_proto_plus(message): +def convert_protobuf_to_proto_plus(message: Union[PPM, PBPM]) -> PPM: """Converts a protobuf message to a proto-plus message. Args: @@ -156,16 +178,16 @@ def convert_protobuf_to_proto_plus(message): A proto_plus version of the protobuf proto. """ if isinstance(message, ProtobufMessageType): - return proto.Message.wrap(message) + return proto.Message.wrap(message) # type: ignore elif isinstance(message, proto.Message): - return message + return message # type: ignore else: raise TypeError( f"Cannot convert type {type(message)} to a proto_plus protobuf." ) -def proto_copy_from(destination, origin): +def proto_copy_from(destination: Union[PPM, PBPM], origin: Union[PPM, PBPM]) -> None: """Copies protobuf and proto-plus messages into one-another. This method consolidates the CopyFrom logic of protobuf and proto-plus @@ -176,10 +198,10 @@ def proto_copy_from(destination, origin): destination: The protobuf message where changes are being copied. origin: The protobuf message where changes are being copied from. """ - is_dest_proto_plus = isinstance(destination, proto.Message) - is_orig_proto_plus = isinstance(origin, proto.Message) - is_dest_protobuf = isinstance(destination, ProtobufMessageType) - is_orig_protobuf = isinstance(origin, ProtobufMessageType) + is_dest_proto_plus: bool = isinstance(destination, proto.Message) + is_orig_proto_plus: bool = isinstance(origin, proto.Message) + is_dest_protobuf: bool = isinstance(destination, ProtobufMessageType) + is_orig_protobuf: bool = isinstance(origin, ProtobufMessageType) if is_dest_proto_plus and is_orig_proto_plus: proto.Message.copy_from(destination, origin)