|
29 | 29 | import os |
30 | 30 | import ssl |
31 | 31 | import warnings |
32 | | -from typing import Any, Collection, Mapping, Optional, Union |
| 32 | +from typing import Any, Callable, Collection, Mapping, Optional, Union |
33 | 33 |
|
34 | 34 | import urllib3 |
35 | 35 |
|
@@ -146,9 +146,14 @@ def __init__( |
146 | 146 | ) |
147 | 147 |
|
148 | 148 | if http_auth is not None: |
149 | | - if isinstance(http_auth, (tuple, list)): |
150 | | - http_auth = ":".join(http_auth) |
151 | | - self.headers.update(urllib3.make_headers(basic_auth=http_auth)) |
| 149 | + if isinstance(http_auth, Callable): # type: ignore |
| 150 | + pass |
| 151 | + elif isinstance(http_auth, (tuple, list)): |
| 152 | + self.headers.update( |
| 153 | + urllib3.make_headers(basic_auth=":".join(http_auth)) |
| 154 | + ) |
| 155 | + else: |
| 156 | + self.headers.update(urllib3.make_headers(basic_auth=http_auth)) |
152 | 157 |
|
153 | 158 | # if providing an SSL context, raise error if any other SSL related flag is used |
154 | 159 | if ssl_context and ( |
@@ -285,6 +290,9 @@ async def perform_request( |
285 | 290 | if headers: |
286 | 291 | req_headers.update(headers) |
287 | 292 |
|
| 293 | + if isinstance(self._http_auth, Callable): # type: ignore |
| 294 | + req_headers.update(self._http_auth(method, str(url), body)) |
| 295 | + |
288 | 296 | if self.http_compress and body: |
289 | 297 | body = self._gzip_compress(body) |
290 | 298 | req_headers["content-encoding"] = "gzip" |
|
0 commit comments