Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added python types for default arguments. #452

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 61 additions & 87 deletions environ/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
variables to configure your Django application.
"""


import ast
import contextlib
import itertools
import logging
import os
import re
import sys
from typing import Any, Optional, Union
import warnings
from urllib.parse import (
parse_qs,
Expand Down Expand Up @@ -205,75 +208,63 @@ def __call__(self, var, cast=None, default=NOTSET, parse_default=False):
def __contains__(self, var):
return var in self.ENVIRON

def str(self, var, default=NOTSET, multiline=False):
def str(self, var, default: Union[NoValue, str] = NOTSET, multiline=False):
"""
:rtype: str
"""
value = self.get_value(var, cast=str, default=default)
if multiline:
return re.sub(r'(\\r)?\\n', r'\n', value)
return value
return re.sub(r'(\\r)?\\n', r'\n', value) if multiline else value

def bytes(self, var, default=NOTSET, encoding='utf8'):
"""
:rtype: bytes
"""
value = self.get_value(var, cast=str, default=default)
if hasattr(value, 'encode'):
return value.encode(encoding)
return value
return value.encode(encoding) if hasattr(value, 'encode') else value

def bool(self, var, default=NOTSET):
def bool(self, var, default: Union[bool, NoValue] = NOTSET):
"""
:rtype: bool
"""
return self.get_value(var, cast=bool, default=default)

def int(self, var, default=NOTSET):
def int(self, var, default: Union[int, NoValue] =NOTSET):
"""
:rtype: int
"""
return self.get_value(var, cast=int, default=default)

def float(self, var, default=NOTSET):
def float(self, var, default: Union[float, NoValue] =NOTSET):
"""
:rtype: float
"""
return self.get_value(var, cast=float, default=default)

def json(self, var, default=NOTSET):
def json(self, var, default: Union[NoValue, str] =NOTSET):
"""
:returns: Json parsed
"""
return self.get_value(var, cast=json.loads, default=default)

def list(self, var, cast=None, default=NOTSET):
def list(self, var, cast=None, default: Union[list, NoValue] =NOTSET):
"""
:rtype: list
"""
return self.get_value(
var,
cast=list if not cast else [cast],
default=default
)
return self.get_value(var, cast=[cast] if cast else list, default=default)

def tuple(self, var, cast=None, default=NOTSET):
def tuple(self, var, cast=None, default: Union[tuple, NoValue] =NOTSET):
"""
:rtype: tuple
"""
return self.get_value(
var,
cast=tuple if not cast else (cast,),
default=default
)
return self.get_value(var, cast=(cast, ) if cast else tuple, default=default)

def dict(self, var, cast=dict, default=NOTSET):
"""
:rtype: dict
"""
return self.get_value(var, cast=cast, default=default)

def url(self, var, default=NOTSET):
def url(self, var, default: Union[str, NoValue] =NOTSET):
"""
:rtype: urllib.parse.ParseResult
"""
Expand Down Expand Up @@ -336,13 +327,13 @@ def search_url(self, var=DEFAULT_SEARCH_ENV, default=NOTSET, engine=None):
engine=engine
)

def path(self, var, default=NOTSET, **kwargs):
def path(self, var, default: str = NOTSET, **kwargs):
"""
:rtype: Path
"""
return Path(self.get_value(var, default=default), **kwargs)
return Path(str(self.get_value(var, default=default)), **kwargs)

def get_value(self, var, cast=None, default=NOTSET, parse_default=False):
def get_value(self, var, cast: Optional[str] = None, default: Any = NOTSET, parse_default=False):
"""Return value for given environment variable.

:param str var:
Expand All @@ -361,7 +352,7 @@ def get_value(self, var, cast=None, default=NOTSET, parse_default=False):
"get '%s' casted as '%s' with default '%s'",
var, cast, default)

var_name = f'{self.prefix}{var}'
var_name = f"{self.prefix}{var}"
if var_name in self.scheme:
var_info = self.scheme[var_name]

Expand All @@ -375,19 +366,16 @@ def get_value(self, var, cast=None, default=NOTSET, parse_default=False):
cast = var_info[0]

if default is self.NOTSET:
try:
with contextlib.suppress(IndexError):
default = var_info[1]
except IndexError:
pass
else:
if not cast:
cast = var_info
elif not cast:
cast = var_info

try:
value = self.ENVIRON[var_name]
value: Any = self.ENVIRON[var_name]
except KeyError as exc:
if default is self.NOTSET:
error_msg = f'Set the {var} environment variable'
error_msg = f"Set the {var} environment variable"
raise ImproperlyConfigured(error_msg) from exc

value = default
Expand All @@ -403,10 +391,13 @@ def get_value(self, var, cast=None, default=NOTSET, parse_default=False):
value = value.replace(escape, prefix)

# Smart casting
if self.smart_cast:
if cast is None and default is not None and \
not isinstance(default, NoValue):
cast = type(default)
if (
self.smart_cast
and cast is None
and default is not None
and not isinstance(default, NoValue)
):
cast = type(default)

value = None if default is None and value == '' else value

Expand Down Expand Up @@ -457,7 +448,7 @@ def parse_value(cls, value, cast):
elif cast is tuple:
val = value.strip('(').strip(')').split(',')
# pylint: disable=consider-using-generator
value = tuple([x for x in val if x])
value = tuple(x for x in val if x)
elif cast is float:
# clean string
float_str = re.sub(r'[^\d,.-]', '', value)
Expand All @@ -467,7 +458,7 @@ def parse_value(cls, value, cast):
if len(parts) == 1:
float_str = parts[0]
else:
float_str = f"{''.join(parts[0:-1])}.{parts[-1]}"
float_str = f"{''.join(parts[:-1])}.{parts[-1]}"
value = float(float_str)
else:
value = cast(value)
Expand Down Expand Up @@ -581,21 +572,17 @@ def db_url_config(cls, url, engine=None):
config_options = {}
for k, v in parse_qs(url.query).items():
if k.upper() in cls._DB_BASE_OPTIONS:
config.update({k.upper(): _cast(v[0])})
config[k.upper()] = _cast(v[0])
else:
config_options.update({k: _cast_int(v[0])})
config_options[k] = _cast_int(v[0])
config['OPTIONS'] = config_options

if engine:
config['ENGINE'] = engine
else:
config['ENGINE'] = url.scheme

config['ENGINE'] = engine or url.scheme
if config['ENGINE'] in cls.DB_SCHEMES:
config['ENGINE'] = cls.DB_SCHEMES[config['ENGINE']]

if not config.get('ENGINE', False):
warnings.warn(f'Engine not recognized from url: {config}')
warnings.warn(f"Engine not recognized from url: {config}")
return {}

return config
Expand Down Expand Up @@ -630,9 +617,7 @@ def cache_url_config(cls, url, backend=None):

# Add the drive to LOCATION
if url.scheme == 'filecache':
config.update({
'LOCATION': url.netloc + url.path,
})
config['LOCATION'] = url.netloc + url.path

# urlparse('pymemcache://127.0.0.1:11211')
# => netloc='127.0.0.1:11211', path=''
Expand All @@ -643,21 +628,11 @@ def cache_url_config(cls, url, backend=None):
# urlparse('memcache:///tmp/memcached.sock')
# => netloc='', path='/tmp/memcached.sock'
if not url.netloc and url.scheme in ['memcache', 'pymemcache']:
config.update({
'LOCATION': 'unix:' + url.path,
})
config['LOCATION'] = f'unix:{url.path}'
elif url.scheme.startswith('redis'):
if url.hostname:
scheme = url.scheme.replace('cache', '')
else:
scheme = 'unix'
locations = [scheme + '://' + loc + url.path
for loc in url.netloc.split(',')]
if len(locations) == 1:
config['LOCATION'] = locations[0]
else:
config['LOCATION'] = locations

scheme = url.scheme.replace('cache', '') if url.hostname else 'unix'
locations = [f'{scheme}://{loc}{url.path}' for loc in url.netloc.split(',')]
config['LOCATION'] = locations[0] if len(locations) == 1 else locations
if url.query:
config_options = {}
for k, v in parse_qs(url.query).items():
Expand Down Expand Up @@ -687,7 +662,7 @@ def email_url_config(cls, url, backend=None):

config = {}

url = urlparse(url) if not isinstance(url, cls.URL_CLASS) else url
url = url if isinstance(url, cls.URL_CLASS) else urlparse(url)

# Remove query strings
path = url.path[1:]
Expand Down Expand Up @@ -738,9 +713,7 @@ def search_url_config(cls, url, engine=None):
:rtype: dict
"""

config = {}

url = urlparse(url) if not isinstance(url, cls.URL_CLASS) else url
url = url if isinstance(url, cls.URL_CLASS) else urlparse(url)

# Remove query strings.
path = url.path[1:]
Expand All @@ -756,7 +729,7 @@ def search_url_config(cls, url, engine=None):
params = parse_qs(url.query)
if 'EXCLUDED_INDEXES' in params:
config['EXCLUDED_INDEXES'] \
= params['EXCLUDED_INDEXES'][0].split(',')
= params['EXCLUDED_INDEXES'][0].split(',')
if 'INCLUDE_SPELLING' in params:
config['INCLUDE_SPELLING'] = cls.parse_value(
params['INCLUDE_SPELLING'][0],
Expand All @@ -770,9 +743,11 @@ def search_url_config(cls, url, engine=None):

if url.scheme == 'simple':
return config
if url.scheme in ['solr'] + cls.ELASTICSEARCH_FAMILY:
if 'KWARGS' in params:
config['KWARGS'] = params['KWARGS'][0]
if (
url.scheme in ['solr'] + cls.ELASTICSEARCH_FAMILY
and 'KWARGS' in params
):
config['KWARGS'] = params['KWARGS'][0]

# remove trailing slash
if path.endswith('/'):
Expand Down Expand Up @@ -804,7 +779,7 @@ def search_url_config(cls, url, engine=None):
config['INDEX_NAME'] = index
return config

config['PATH'] = '/' + path
config['PATH'] = f'/{path}'

if url.scheme == 'whoosh':
if 'STORAGE' in params:
Expand Down Expand Up @@ -882,21 +857,18 @@ def read_env(cls, env_file=None, overwrite=False, encoding='utf8',
def _keep_escaped_format_characters(match):
"""Keep escaped newline/tabs in quoted strings"""
escaped_char = match.group(1)
if escaped_char in 'rnt':
return '\\' + escaped_char
return escaped_char
return '\\' + escaped_char if escaped_char in 'rnt' else escaped_char

for line in content.splitlines():
m1 = re.match(r'\A(?:export )?([A-Za-z_0-9]+)=(.*)\Z', line)
if m1:
key, val = m1.group(1), m1.group(2)
key, val = m1[1], m1[2]
m2 = re.match(r"\A'(.*)'\Z", val)
if m2:
val = m2.group(1)
val = m2[1]
m3 = re.match(r'\A"(.*)"\Z', val)
if m3:
val = re.sub(r'\\(.)', _keep_escaped_format_characters,
m3.group(1))
val = re.sub(r'\\(.)', _keep_escaped_format_characters, m3[1])
overrides[key] = str(val)
elif not line or line.startswith('#'):
# ignore warnings for empty line-breaks or comments
Expand Down Expand Up @@ -992,9 +964,11 @@ def __ne__(self, other):
return not self.__eq__(other)

def __add__(self, other):
if not isinstance(other, Path):
return Path(self.__root__, other)
return Path(self.__root__, other.__root__)
return (
Path(self.__root__, other.__root__)
if isinstance(other, Path)
else Path(self.__root__, other)
)

def __sub__(self, other):
if isinstance(other, int):
Expand All @@ -1019,7 +993,7 @@ def __contains__(self, item):
return item.__root__.startswith(base_path)

def __repr__(self):
return f'<Path:{self.__root__}>'
return f"<Path:{self.__root__}>"

def __str__(self):
return self.__root__
Expand Down