|
29 | 29 | import os
|
30 | 30 | import ssl
|
31 | 31 | import warnings
|
32 |
| -from typing import Any, Collection, Mapping, Optional, Union |
| 32 | +from typing import Any, Callable, Collection, Mapping, Optional, Union |
33 | 33 |
|
34 | 34 | import urllib3
|
35 | 35 |
|
@@ -146,9 +146,14 @@ def __init__(
|
146 | 146 | )
|
147 | 147 |
|
148 | 148 | if http_auth is not None:
|
149 |
| - if isinstance(http_auth, (tuple, list)): |
150 |
| - http_auth = ":".join(http_auth) |
151 |
| - self.headers.update(urllib3.make_headers(basic_auth=http_auth)) |
| 149 | + if isinstance(http_auth, Callable): # type: ignore |
| 150 | + pass |
| 151 | + elif isinstance(http_auth, (tuple, list)): |
| 152 | + self.headers.update( |
| 153 | + urllib3.make_headers(basic_auth=":".join(http_auth)) |
| 154 | + ) |
| 155 | + else: |
| 156 | + self.headers.update(urllib3.make_headers(basic_auth=http_auth)) |
152 | 157 |
|
153 | 158 | # if providing an SSL context, raise error if any other SSL related flag is used
|
154 | 159 | if ssl_context and (
|
@@ -285,6 +290,9 @@ async def perform_request(
|
285 | 290 | if headers:
|
286 | 291 | req_headers.update(headers)
|
287 | 292 |
|
| 293 | + if isinstance(self._http_auth, Callable): # type: ignore |
| 294 | + req_headers.update(self._http_auth(method, str(url), body)) |
| 295 | + |
288 | 296 | if self.http_compress and body:
|
289 | 297 | body = self._gzip_compress(body)
|
290 | 298 | req_headers["content-encoding"] = "gzip"
|
|
0 commit comments