Skip to content

Commit f97c976

Browse files
committed
Handle total timeouts
1 parent 6d8116b commit f97c976

14 files changed

+156
-31
lines changed

httpcore/_async/connection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from types import TracebackType
55
from typing import Iterable, Iterator, Optional, Type
66

7+
from httpcore._utils import OverallTimeoutHandler
8+
79
from .._backends.auto import AutoBackend
810
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
911
from .._exceptions import ConnectError, ConnectTimeout
@@ -105,6 +107,8 @@ async def _connect(self, request: Request) -> AsyncNetworkStream:
105107
sni_hostname = request.extensions.get("sni_hostname", None)
106108
timeout = timeouts.get("connect", None)
107109

110+
overall_timeout = OverallTimeoutHandler(timeouts)
111+
108112
retries_left = self._retries
109113
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
110114

@@ -115,11 +119,12 @@ async def _connect(self, request: Request) -> AsyncNetworkStream:
115119
"host": self._origin.host.decode("ascii"),
116120
"port": self._origin.port,
117121
"local_address": self._local_address,
118-
"timeout": timeout,
122+
"timeout": overall_timeout.get_minimum_timeout(timeout),
119123
"socket_options": self._socket_options,
120124
}
121125
async with Trace("connect_tcp", logger, request, kwargs) as trace:
122-
stream = await self._network_backend.connect_tcp(**kwargs)
126+
with overall_timeout:
127+
stream = await self._network_backend.connect_tcp(**kwargs)
123128
trace.return_value = stream
124129
else:
125130
kwargs = {

httpcore/_async/connection_pool.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from types import TracebackType
44
from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type
55

6+
from httpcore._utils import OverallTimeoutHandler
7+
68
from .._backends.auto import AutoBackend
79
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
810
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
@@ -174,6 +176,7 @@ async def handle_async_request(self, request: Request) -> Response:
174176

175177
timeouts = request.extensions.get("timeout", {})
176178
timeout = timeouts.get("pool", None)
179+
overall_timeout = OverallTimeoutHandler(timeouts)
177180

178181
with self._optional_thread_lock:
179182
# Add the incoming request to our request queue.
@@ -188,8 +191,11 @@ async def handle_async_request(self, request: Request) -> Response:
188191
closing = self._assign_requests_to_connections()
189192
await self._close_connections(closing)
190193

191-
# Wait until this request has an assigned connection.
192-
connection = await pool_request.wait_for_connection(timeout=timeout)
194+
with overall_timeout:
195+
# Wait until this request has an assigned connection.
196+
connection = await pool_request.wait_for_connection(
197+
timeout=overall_timeout.get_minimum_timeout(timeout)
198+
)
193199

194200
try:
195201
# Send the request on the assigned connection.

httpcore/_async/http11.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import h11
1818

19+
from httpcore._utils import OverallTimeoutHandler
20+
1921
from .._backends.base import AsyncNetworkStream
2022
from .._exceptions import (
2123
ConnectionNotAvailable,
@@ -147,25 +149,37 @@ async def handle_async_request(self, request: Request) -> Response:
147149
async def _send_request_headers(self, request: Request) -> None:
148150
timeouts = request.extensions.get("timeout", {})
149151
timeout = timeouts.get("write", None)
152+
overall_timeout = OverallTimeoutHandler(timeouts)
150153

151154
with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
152155
event = h11.Request(
153156
method=request.method,
154157
target=request.url.target,
155158
headers=request.headers,
156159
)
157-
await self._send_event(event, timeout=timeout)
160+
with overall_timeout:
161+
await self._send_event(
162+
event, timeout=overall_timeout.get_minimum_timeout(timeout)
163+
)
158164

159165
async def _send_request_body(self, request: Request) -> None:
160166
timeouts = request.extensions.get("timeout", {})
161167
timeout = timeouts.get("write", None)
168+
overall_timeout = OverallTimeoutHandler(timeouts)
162169

163170
assert isinstance(request.stream, AsyncIterable)
164171
async for chunk in request.stream:
165172
event = h11.Data(data=chunk)
166-
await self._send_event(event, timeout=timeout)
167173

168-
await self._send_event(h11.EndOfMessage(), timeout=timeout)
174+
with overall_timeout:
175+
await self._send_event(
176+
event, timeout=overall_timeout.get_minimum_timeout(timeout)
177+
)
178+
179+
with overall_timeout:
180+
await self._send_event(
181+
h11.EndOfMessage(), timeout=overall_timeout.get_minimum_timeout(timeout)
182+
)
169183

170184
async def _send_event(
171185
self, event: h11.Event, timeout: Optional[float] = None
@@ -181,9 +195,13 @@ async def _receive_response_headers(
181195
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
182196
timeouts = request.extensions.get("timeout", {})
183197
timeout = timeouts.get("read", None)
198+
overall_timeout = OverallTimeoutHandler(timeouts)
184199

185200
while True:
186-
event = await self._receive_event(timeout=timeout)
201+
with overall_timeout:
202+
event = await self._receive_event(
203+
timeout=overall_timeout.get_minimum_timeout(timeout)
204+
)
187205
if isinstance(event, h11.Response):
188206
break
189207
if (
@@ -205,9 +223,12 @@ async def _receive_response_headers(
205223
async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
206224
timeouts = request.extensions.get("timeout", {})
207225
timeout = timeouts.get("read", None)
226+
overall_timeout = OverallTimeoutHandler(timeouts)
208227

209228
while True:
210-
event = await self._receive_event(timeout=timeout)
229+
event = await self._receive_event(
230+
timeout=overall_timeout.get_minimum_timeout(timeout)
231+
)
211232
if isinstance(event, h11.Data):
212233
yield bytes(event.data)
213234
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):

httpcore/_async/http2.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import h2.exceptions
1111
import h2.settings
1212

13+
from httpcore._utils import OverallTimeoutHandler
14+
1315
from .._backends.base import AsyncNetworkStream
1416
from .._exceptions import (
1517
ConnectionNotAvailable,
@@ -430,12 +432,16 @@ async def _read_incoming_data(
430432
) -> typing.List[h2.events.Event]:
431433
timeouts = request.extensions.get("timeout", {})
432434
timeout = timeouts.get("read", None)
435+
overall_timeout = OverallTimeoutHandler(timeouts)
433436

434437
if self._read_exception is not None:
435438
raise self._read_exception # pragma: nocover
436439

437440
try:
438-
data = await self._network_stream.read(self.READ_NUM_BYTES, timeout)
441+
with overall_timeout:
442+
data = await self._network_stream.read(
443+
self.READ_NUM_BYTES, overall_timeout.get_minimum_timeout(timeout)
444+
)
439445
if data == b"":
440446
raise RemoteProtocolError("Server disconnected")
441447
except Exception as exc:
@@ -458,6 +464,7 @@ async def _read_incoming_data(
458464
async def _write_outgoing_data(self, request: Request) -> None:
459465
timeouts = request.extensions.get("timeout", {})
460466
timeout = timeouts.get("write", None)
467+
overall_timeout = OverallTimeoutHandler(timeouts)
461468

462469
async with self._write_lock:
463470
data_to_send = self._h2_state.data_to_send()
@@ -466,7 +473,8 @@ async def _write_outgoing_data(self, request: Request) -> None:
466473
raise self._write_exception # pragma: nocover
467474

468475
try:
469-
await self._network_stream.write(data_to_send, timeout)
476+
with overall_timeout:
477+
await self._network_stream.write(data_to_send, timeout)
470478
except Exception as exc: # pragma: nocover
471479
# If we get a network error we should:
472480
#

httpcore/_async/http_proxy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from base64 import b64encode
44
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union
55

6+
from httpcore._utils import OverallTimeoutHandler
7+
68
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
79
from .._exceptions import ProxyError
810
from .._models import (
@@ -266,6 +268,7 @@ def __init__(
266268
async def handle_async_request(self, request: Request) -> Response:
267269
timeouts = request.extensions.get("timeout", {})
268270
timeout = timeouts.get("connect", None)
271+
overall_timeout = OverallTimeoutHandler(timeouts)
269272

270273
async with self._connect_lock:
271274
if not self._connected:
@@ -311,10 +314,11 @@ async def handle_async_request(self, request: Request) -> Response:
311314
kwargs = {
312315
"ssl_context": ssl_context,
313316
"server_hostname": self._remote_origin.host.decode("ascii"),
314-
"timeout": timeout,
317+
"timeout": overall_timeout.get_minimum_timeout(timeout),
315318
}
316319
async with Trace("start_tls", logger, request, kwargs) as trace:
317-
stream = await stream.start_tls(**kwargs)
320+
with overall_timeout:
321+
stream = await stream.start_tls(**kwargs)
318322
trace.return_value = stream
319323

320324
# Determine if we should be using HTTP/1.1 or HTTP/2

httpcore/_async/socks_proxy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from socksio import socks5
66

7+
from httpcore._utils import OverallTimeoutHandler
8+
79
from .._backends.auto import AutoBackend
810
from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream
911
from .._exceptions import ConnectionNotAvailable, ProxyError
@@ -218,6 +220,7 @@ async def handle_async_request(self, request: Request) -> Response:
218220
timeouts = request.extensions.get("timeout", {})
219221
sni_hostname = request.extensions.get("sni_hostname", None)
220222
timeout = timeouts.get("connect", None)
223+
overall_timeout = OverallTimeoutHandler(timeouts)
221224

222225
async with self._connect_lock:
223226
if self._connection is None:
@@ -226,10 +229,11 @@ async def handle_async_request(self, request: Request) -> Response:
226229
kwargs = {
227230
"host": self._proxy_origin.host.decode("ascii"),
228231
"port": self._proxy_origin.port,
229-
"timeout": timeout,
232+
"timeout": overall_timeout.get_minimum_timeout(timeout),
230233
}
231234
async with Trace("connect_tcp", logger, request, kwargs) as trace:
232-
stream = await self._network_backend.connect_tcp(**kwargs)
235+
with overall_timeout:
236+
stream = await self._network_backend.connect_tcp(**kwargs)
233237
trace.return_value = stream
234238

235239
# Connect to the remote host using socks5

httpcore/_sync/connection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from types import TracebackType
55
from typing import Iterable, Iterator, Optional, Type
66

7+
from httpcore._utils import OverallTimeoutHandler
8+
79
from .._backends.sync import SyncBackend
810
from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
911
from .._exceptions import ConnectError, ConnectTimeout
@@ -105,6 +107,8 @@ def _connect(self, request: Request) -> NetworkStream:
105107
sni_hostname = request.extensions.get("sni_hostname", None)
106108
timeout = timeouts.get("connect", None)
107109

110+
overall_timeout = OverallTimeoutHandler(timeouts)
111+
108112
retries_left = self._retries
109113
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
110114

@@ -115,11 +119,12 @@ def _connect(self, request: Request) -> NetworkStream:
115119
"host": self._origin.host.decode("ascii"),
116120
"port": self._origin.port,
117121
"local_address": self._local_address,
118-
"timeout": timeout,
122+
"timeout": overall_timeout.get_minimum_timeout(timeout),
119123
"socket_options": self._socket_options,
120124
}
121125
with Trace("connect_tcp", logger, request, kwargs) as trace:
122-
stream = self._network_backend.connect_tcp(**kwargs)
126+
with overall_timeout:
127+
stream = self._network_backend.connect_tcp(**kwargs)
123128
trace.return_value = stream
124129
else:
125130
kwargs = {

httpcore/_sync/connection_pool.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from types import TracebackType
44
from typing import Iterable, Iterator, Iterable, List, Optional, Type
55

6+
from httpcore._utils import OverallTimeoutHandler
7+
68
from .._backends.sync import SyncBackend
79
from .._backends.base import SOCKET_OPTION, NetworkBackend
810
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
@@ -174,6 +176,7 @@ def handle_request(self, request: Request) -> Response:
174176

175177
timeouts = request.extensions.get("timeout", {})
176178
timeout = timeouts.get("pool", None)
179+
overall_timeout = OverallTimeoutHandler(timeouts)
177180

178181
with self._optional_thread_lock:
179182
# Add the incoming request to our request queue.
@@ -188,8 +191,11 @@ def handle_request(self, request: Request) -> Response:
188191
closing = self._assign_requests_to_connections()
189192
self._close_connections(closing)
190193

191-
# Wait until this request has an assigned connection.
192-
connection = pool_request.wait_for_connection(timeout=timeout)
194+
with overall_timeout:
195+
# Wait until this request has an assigned connection.
196+
connection = pool_request.wait_for_connection(
197+
timeout=overall_timeout.get_minimum_timeout(timeout)
198+
)
193199

194200
try:
195201
# Send the request on the assigned connection.

httpcore/_sync/http11.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import h11
1818

19+
from httpcore._utils import OverallTimeoutHandler
20+
1921
from .._backends.base import NetworkStream
2022
from .._exceptions import (
2123
ConnectionNotAvailable,
@@ -147,25 +149,37 @@ def handle_request(self, request: Request) -> Response:
147149
def _send_request_headers(self, request: Request) -> None:
148150
timeouts = request.extensions.get("timeout", {})
149151
timeout = timeouts.get("write", None)
152+
overall_timeout = OverallTimeoutHandler(timeouts)
150153

151154
with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
152155
event = h11.Request(
153156
method=request.method,
154157
target=request.url.target,
155158
headers=request.headers,
156159
)
157-
self._send_event(event, timeout=timeout)
160+
with overall_timeout:
161+
self._send_event(
162+
event, timeout=overall_timeout.get_minimum_timeout(timeout)
163+
)
158164

159165
def _send_request_body(self, request: Request) -> None:
160166
timeouts = request.extensions.get("timeout", {})
161167
timeout = timeouts.get("write", None)
168+
overall_timeout = OverallTimeoutHandler(timeouts)
162169

163170
assert isinstance(request.stream, Iterable)
164171
for chunk in request.stream:
165172
event = h11.Data(data=chunk)
166-
self._send_event(event, timeout=timeout)
167173

168-
self._send_event(h11.EndOfMessage(), timeout=timeout)
174+
with overall_timeout:
175+
self._send_event(
176+
event, timeout=overall_timeout.get_minimum_timeout(timeout)
177+
)
178+
179+
with overall_timeout:
180+
self._send_event(
181+
h11.EndOfMessage(), timeout=overall_timeout.get_minimum_timeout(timeout)
182+
)
169183

170184
def _send_event(
171185
self, event: h11.Event, timeout: Optional[float] = None
@@ -181,9 +195,13 @@ def _receive_response_headers(
181195
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
182196
timeouts = request.extensions.get("timeout", {})
183197
timeout = timeouts.get("read", None)
198+
overall_timeout = OverallTimeoutHandler(timeouts)
184199

185200
while True:
186-
event = self._receive_event(timeout=timeout)
201+
with overall_timeout:
202+
event = self._receive_event(
203+
timeout=overall_timeout.get_minimum_timeout(timeout)
204+
)
187205
if isinstance(event, h11.Response):
188206
break
189207
if (
@@ -205,9 +223,12 @@ def _receive_response_headers(
205223
def _receive_response_body(self, request: Request) -> Iterator[bytes]:
206224
timeouts = request.extensions.get("timeout", {})
207225
timeout = timeouts.get("read", None)
226+
overall_timeout = OverallTimeoutHandler(timeouts)
208227

209228
while True:
210-
event = self._receive_event(timeout=timeout)
229+
event = self._receive_event(
230+
timeout=overall_timeout.get_minimum_timeout(timeout)
231+
)
211232
if isinstance(event, h11.Data):
212233
yield bytes(event.data)
213234
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):

0 commit comments

Comments
 (0)