← requests  /  src/requests/auth.py

1
"""
2
requests.auth
3
~~~~~~~~~~~~~
4
5
This module contains the authentication handlers for Requests.
6
"""
7
8
from __future__ import annotations
9
10
import hashlib
11
import os
12
import re
13
import threading
14
import time
15
import warnings
16
from base64 import b64encode
17
from typing import TYPE_CHECKING, Any, Final, cast, overload
18
19
from ._internal_utils import to_native_string
20
from .compat import basestring, str, urlparse
21
from .cookies import extract_cookies_to_jar
22
from .utils import parse_dict_header
23
24
if TYPE_CHECKING:
25
    from http.cookiejar import CookieJar
26
    from typing import Any
27
28
    from .models import PreparedRequest, Response
29
30
CONTENT_TYPE_FORM_URLENCODED: Final = "application/x-www-form-urlencoded"
31
CONTENT_TYPE_MULTI_PART: Final = "multipart/form-data"
32
33
34
def _basic_auth_str(username: bytes | str, password: bytes | str) -> str:
35
    """Returns a Basic Auth string."""
36
37
    # "I want us to put a big-ol' comment on top of it that
38
    # says that this behaviour is dumb but we need to preserve
39
    # it because people are relying on it."
40
    #    - Lukasa
41
    #
42
    # These are here solely to maintain backwards compatibility
43
    # for things like ints. This will be removed in 3.0.0.
44
    if not isinstance(username, basestring):  # type: ignore[reportUnnecessaryIsInstance]  # runtime guard for non-str/bytes
45
        warnings.warn(
46
            "Non-string usernames will no longer be supported in Requests "
47
            f"3.0.0. Please convert the object you've passed in ({username!r}) to "
48
            "a string or bytes object in the near future to avoid "
49
            "problems.",
50
            category=DeprecationWarning,
51
        )
52
        username = str(username)
53
54
    if not isinstance(password, basestring):  # type: ignore[reportUnnecessaryIsInstance]  # runtime guard for non-str/bytes
55
        warnings.warn(
56
            "Non-string passwords will no longer be supported in Requests "
57
            f"3.0.0. Please convert the object you've passed in ({type(password)!r}) to "
58
            "a string or bytes object in the near future to avoid "
59
            "problems.",
60
            category=DeprecationWarning,
61
        )
62
        password = str(password)
63
    # -- End Removal --
64
65
    if isinstance(username, str):
66
        username = username.encode("latin1")
67
68
    if isinstance(password, str):
69
        password = password.encode("latin1")
70
71
    authstr = "Basic " + to_native_string(
72
        b64encode(b":".join((username, password))).strip()
73
    )
74
75
    return authstr
76
77
78
class AuthBase:
79
    """Base class that all auth implementations derive from"""
80
81
    def __call__(self, r: PreparedRequest) -> PreparedRequest:
82
        raise NotImplementedError("Auth hooks must be callable.")
83
84
85
class HTTPBasicAuth(AuthBase):
86
    """Attaches HTTP Basic Authentication to the given Request object."""
87
88
    username: bytes | str
89
    password: bytes | str
90
91
    @overload
92
    def __init__(self, username: str, password: str) -> None: ...
93
    @overload
94
    def __init__(self, username: bytes, password: bytes) -> None: ...
95
96
    def __init__(self, username: bytes | str, password: bytes | str) -> None:
97
        self.username = username
98
        self.password = password
99
100
    def __eq__(self, other: object) -> bool:
101
        return all(
102
            [
103
                self.username == getattr(other, "username", None),
104
                self.password == getattr(other, "password", None),
105
            ]
106
        )
107
108
    def __ne__(self, other: Any) -> bool:
109
        return not self == other
110
111
    def __call__(self, r: PreparedRequest) -> PreparedRequest:
112
        r.headers["Authorization"] = _basic_auth_str(self.username, self.password)
113
        return r
114
115
116
class HTTPProxyAuth(HTTPBasicAuth):
117
    """Attaches HTTP Proxy Authentication to a given Request object."""
118
119
    def __call__(self, r: PreparedRequest) -> PreparedRequest:
120
        r.headers["Proxy-Authorization"] = _basic_auth_str(self.username, self.password)
121
        return r
122
123
124
class HTTPDigestAuth(AuthBase):
125
    """Attaches HTTP Digest Authentication to the given Request object."""
126
127
    username: bytes | str
128
    password: bytes | str
129
    _thread_local: threading.local
130
    last_nonce: str
131
    nonce_count: int
132
    chal: dict[str, str]
133
    pos: int | None
134
    num_401_calls: int | None
135
136
    @overload
137
    def __init__(self, username: str, password: str) -> None: ...
138
    @overload
139
    def __init__(self, username: bytes, password: bytes) -> None: ...
140
141
    def __init__(self, username: bytes | str, password: bytes | str) -> None:
142
        self.username = username
143
        self.password = password
144
        # Keep state in per-thread local storage
145
        self._thread_local = threading.local()
146
147
    def init_per_thread_state(self) -> None:
148
        # Ensure state is initialized just once per-thread
149
        if not hasattr(self._thread_local, "init"):
150
            self._thread_local.init = True
151
            self._thread_local.last_nonce = ""
152
            self._thread_local.nonce_count = 0
153
            self._thread_local.chal = {}
154
            self._thread_local.pos = None
155
            self._thread_local.num_401_calls = None
156
157
    def build_digest_header(self, method: str, url: str) -> str | None:
158
        """
159
        :rtype: str
160
        """
161
162
        realm = self._thread_local.chal["realm"]
163
        nonce = self._thread_local.chal["nonce"]
164
        qop = self._thread_local.chal.get("qop")
165
        algorithm = self._thread_local.chal.get("algorithm")
166
        opaque = self._thread_local.chal.get("opaque")
167
        hash_utf8 = None
168
169
        if algorithm is None:
170
            _algorithm = "MD5"
171
        else:
172
            _algorithm = algorithm.upper()
173
        # lambdas assume digest modules are imported at the top level
174
        if _algorithm == "MD5" or _algorithm == "MD5-SESS":
175
176
            def md5_utf8(x: str | bytes) -> str:
177
                if isinstance(x, str):
178
                    x = x.encode("utf-8")
179
                return hashlib.md5(x, usedforsecurity=False).hexdigest()
180
181
            hash_utf8 = md5_utf8
182
        elif _algorithm == "SHA":
183
184
            def sha_utf8(x: str | bytes) -> str:
185
                if isinstance(x, str):
186
                    x = x.encode("utf-8")
187
                return hashlib.sha1(x, usedforsecurity=False).hexdigest()
188
189
            hash_utf8 = sha_utf8
190
        elif _algorithm == "SHA-256":
191
192
            def sha256_utf8(x: str | bytes) -> str:
193
                if isinstance(x, str):
194
                    x = x.encode("utf-8")
195
                return hashlib.sha256(x, usedforsecurity=False).hexdigest()
196
197
            hash_utf8 = sha256_utf8
198
        elif _algorithm == "SHA-512":
199
200
            def sha512_utf8(x: str | bytes) -> str:
201
                if isinstance(x, str):
202
                    x = x.encode("utf-8")
203
                return hashlib.sha512(x, usedforsecurity=False).hexdigest()
204
205
            hash_utf8 = sha512_utf8
206
207
        if hash_utf8 is None:
208
            return None
209
210
        def KD(s: str, d: str) -> str:
211
            return hash_utf8(f"{s}:{d}")
212
213
        # XXX not implemented yet
214
        entdig = None
215
        p_parsed = urlparse(url)
216
        #: path is request-uri defined in RFC 2616 which should not be empty
217
        path = p_parsed.path or "/"
218
        if p_parsed.query:
219
            path += f"?{p_parsed.query}"
220
221
        A1 = f"{self.username}:{realm}:{self.password}"
222
        A2 = f"{method}:{path}"
223
224
        HA1 = hash_utf8(A1)
225
        HA2 = hash_utf8(A2)
226
227
        if nonce == self._thread_local.last_nonce:
228
            self._thread_local.nonce_count += 1
229
        else:
230
            self._thread_local.nonce_count = 1
231
        ncvalue = f"{self._thread_local.nonce_count:08x}"
232
        s = str(self._thread_local.nonce_count).encode("utf-8")
233
        s += nonce.encode("utf-8")
234
        s += time.ctime().encode("utf-8")
235
        s += os.urandom(8)
236
237
        cnonce = hashlib.sha1(s, usedforsecurity=False).hexdigest()[:16]
238
        if _algorithm == "MD5-SESS":
239
            HA1 = hash_utf8(f"{HA1}:{nonce}:{cnonce}")  # type: ignore[reportConstantRedefinition]  # RFC 2617 terminology
240
241
        if not qop:
242
            respdig = KD(HA1, f"{nonce}:{HA2}")
243
        elif qop == "auth" or "auth" in qop.split(","):
244
            noncebit = f"{nonce}:{ncvalue}:{cnonce}:auth:{HA2}"
245
            respdig = KD(HA1, noncebit)
246
        else:
247
            # XXX handle auth-int.
248
            return None
249
250
        self._thread_local.last_nonce = nonce
251
252
        # XXX should the partial digests be encoded too?
253
        base = (
254
            f'username="{self.username}", realm="{realm}", nonce="{nonce}", '
255
            f'uri="{path}", response="{respdig}"'
256
        )
257
        if opaque:
258
            base += f', opaque="{opaque}"'
259
        if algorithm:
260
            base += f', algorithm="{algorithm}"'
261
        if entdig:
262
            base += f', digest="{entdig}"'
263
        if qop:
264
            base += f', qop="auth", nc={ncvalue}, cnonce="{cnonce}"'
265
266
        return f"Digest {base}"
267
268
    def handle_redirect(self, r: Response, **kwargs: Any) -> None:
269
        """Reset num_401_calls counter on redirects."""
270
        if r.is_redirect:
271
            self._thread_local.num_401_calls = 1
272
273
    def handle_401(self, r: Response, **kwargs: Any) -> Response:
274
        """
275
        Takes the given response and tries digest-auth, if needed.
276
277
        :rtype: requests.Response
278
        """
279
280
        # If response is not 4xx, do not auth
281
        # See https://github.com/psf/requests/issues/3772
282
        if not 400 <= r.status_code < 500:
283
            self._thread_local.num_401_calls = 1
284
            return r
285
286
        if self._thread_local.pos is not None:
287
            # Rewind the file position indicator of the body to where
288
            # it was to resend the request.
289
            if (seek := getattr(r.request.body, "seek", None)) is not None:
290
                seek(self._thread_local.pos)
291
        s_auth = r.headers.get("www-authenticate", "")
292
293
        if "digest" in s_auth.lower() and self._thread_local.num_401_calls < 2:
294
            self._thread_local.num_401_calls += 1
295
            pat = re.compile(r"digest ", flags=re.IGNORECASE)
296
            self._thread_local.chal = parse_dict_header(pat.sub("", s_auth, count=1))
297
298
            # Consume content and release the original connection
299
            # to allow our new request to reuse the same one.
300
            r.content
301
            r.close()
302
            prep = r.request.copy()
303
            cookie_jar = cast("CookieJar", prep._cookies)  # type: ignore[reportPrivateUsage]
304
            extract_cookies_to_jar(cookie_jar, r.request, r.raw)
305
            prep.prepare_cookies(cookie_jar)
306
307
            _digest_auth = self.build_digest_header(
308
                cast(str, prep.method), cast(str, prep.url)
309
            )
310
            if _digest_auth:
311
                prep.headers["Authorization"] = _digest_auth
312
            _r = r.connection.send(prep, **kwargs)
313
            _r.history.append(r)
314
            _r.request = prep
315
316
            return _r
317
318
        self._thread_local.num_401_calls = 1
319
        return r
320
321
    def __call__(self, r: PreparedRequest) -> PreparedRequest:
322
        # Initialize per-thread state, if needed
323
        self.init_per_thread_state()
324
        # If we have a saved nonce, skip the 401
325
        if self._thread_local.last_nonce:
326
            _digest_auth = self.build_digest_header(
327
                cast(str, r.method), cast(str, r.url)
328
            )
329
            if _digest_auth:
330
                r.headers["Authorization"] = _digest_auth
331
        if (tell := getattr(r.body, "tell", None)) is not None:
332
            self._thread_local.pos = tell()
333
        else:
334
            # In the case of HTTPDigestAuth being reused and the body of
335
            # the previous request was a file-like object, pos has the
336
            # file position of the previous body. Ensure it's set to
337
            # None.
338
            self._thread_local.pos = None
339
        r.register_hook("response", self.handle_401)
340
        r.register_hook("response", self.handle_redirect)
341
        self._thread_local.num_401_calls = 1
342
343
        return r
344
345
    def __eq__(self, other: object) -> bool:
346
        return all(
347
            [
348
                self.username == getattr(other, "username", None),
349
                self.password == getattr(other, "password", None),
350
            ]
351
        )
352
353
    def __ne__(self, other: Any) -> bool:
354
        return not self == other
355