← requests  /  tests/testserver/server.py

1
import select
2
import socket
3
import ssl
4
import threading
5
6
7
def consume_socket_content(sock, timeout=0.5):
8
    chunks = 65536
9
    content = b""
10
11
    while True:
12
        more_to_read = select.select([sock], [], [], timeout)[0]
13
        if not more_to_read:
14
            break
15
16
        new_content = sock.recv(chunks)
17
        if not new_content:
18
            break
19
20
        content += new_content
21
22
    return content
23
24
25
class Server(threading.Thread):
26
    """Dummy server using for unit testing"""
27
28
    WAIT_EVENT_TIMEOUT = 5
29
30
    def __init__(
31
        self,
32
        handler=None,
33
        host="localhost",
34
        port=0,
35
        requests_to_handle=1,
36
        wait_to_close_event=None,
37
    ):
38
        super().__init__()
39
40
        self.handler = handler or consume_socket_content
41
        self.handler_results = []
42
43
        self.host = host
44
        self.port = port
45
        self.requests_to_handle = requests_to_handle
46
47
        self.wait_to_close_event = wait_to_close_event
48
        self.ready_event = threading.Event()
49
        self.stop_event = threading.Event()
50
51
    @classmethod
52
    def text_response_server(cls, text, request_timeout=0.5, **kwargs):
53
        def text_response_handler(sock):
54
            request_content = consume_socket_content(sock, timeout=request_timeout)
55
            sock.send(text.encode("utf-8"))
56
57
            return request_content
58
59
        return Server(text_response_handler, **kwargs)
60
61
    @classmethod
62
    def basic_response_server(cls, **kwargs):
63
        return cls.text_response_server(
64
            "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n\r\n", **kwargs
65
        )
66
67
    def run(self):
68
        try:
69
            self.server_sock = self._create_socket_and_bind()
70
            # in case self.port = 0
71
            self.port = self.server_sock.getsockname()[1]
72
            self.ready_event.set()
73
            self._handle_requests()
74
75
            if self.wait_to_close_event:
76
                self.wait_to_close_event.wait(self.WAIT_EVENT_TIMEOUT)
77
        finally:
78
            self.ready_event.set()  # just in case of exception
79
            self._close_server_sock_ignore_errors()
80
            self.stop_event.set()
81
82
    def _create_socket_and_bind(self):
83
        sock = socket.socket()
84
        sock.bind((self.host, self.port))
85
        sock.listen()
86
        return sock
87
88
    def _close_server_sock_ignore_errors(self):
89
        try:
90
            self.server_sock.close()
91
        except OSError:
92
            pass
93
94
    def _handle_requests(self):
95
        for _ in range(self.requests_to_handle):
96
            sock = self._accept_connection()
97
            if not sock:
98
                break
99
100
            handler_result = self.handler(sock)
101
102
            self.handler_results.append(handler_result)
103
            sock.close()
104
105
    def _accept_connection(self):
106
        try:
107
            ready, _, _ = select.select(
108
                [self.server_sock], [], [], self.WAIT_EVENT_TIMEOUT
109
            )
110
            if not ready:
111
                return None
112
113
            return self.server_sock.accept()[0]
114
        except OSError:
115
            return None
116
117
    def __enter__(self):
118
        self.start()
119
        if not self.ready_event.wait(self.WAIT_EVENT_TIMEOUT):
120
            raise RuntimeError("Timeout waiting for server to be ready.")
121
        return self.host, self.port
122
123
    def __exit__(self, exc_type, exc_value, traceback):
124
        if exc_type is None:
125
            self.stop_event.wait(self.WAIT_EVENT_TIMEOUT)
126
        else:
127
            if self.wait_to_close_event:
128
                # avoid server from waiting for event timeouts
129
                # if an exception is found in the main thread
130
                self.wait_to_close_event.set()
131
132
        # ensure server thread doesn't get stuck waiting for connections
133
        self._close_server_sock_ignore_errors()
134
        self.join()
135
        return False  # allow exceptions to propagate
136
137
138
class TLSServer(Server):
139
    def __init__(
140
        self,
141
        *,
142
        handler=None,
143
        host="localhost",
144
        port=0,
145
        requests_to_handle=1,
146
        wait_to_close_event=None,
147
        cert_chain=None,
148
        keyfile=None,
149
        mutual_tls=False,
150
        cacert=None,
151
    ):
152
        super().__init__(
153
            handler=handler,
154
            host=host,
155
            port=port,
156
            requests_to_handle=requests_to_handle,
157
            wait_to_close_event=wait_to_close_event,
158
        )
159
        self.cert_chain = cert_chain
160
        self.keyfile = keyfile
161
        self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
162
        self.ssl_context.load_cert_chain(self.cert_chain, keyfile=self.keyfile)
163
        self.mutual_tls = mutual_tls
164
        self.cacert = cacert
165
        if mutual_tls:
166
            # For simplicity, we're going to assume that the client cert is
167
            # issued by the same CA as our Server certificate
168
            self.ssl_context.verify_mode = ssl.CERT_OPTIONAL
169
            self.ssl_context.load_verify_locations(self.cacert)
170
171
    def _create_socket_and_bind(self):
172
        sock = socket.socket()
173
        sock = self.ssl_context.wrap_socket(sock, server_side=True)
174
        sock.bind((self.host, self.port))
175
        sock.listen()
176
        return sock
177