diff --git a/esphome/components/esp32/__init__.py b/esphome/components/esp32/__init__.py index 6bfa8b9053..4c211b2f2a 100644 --- a/esphome/components/esp32/__init__.py +++ b/esphome/components/esp32/__init__.py @@ -1265,8 +1265,8 @@ def _configure_lwip_max_sockets(conf: dict) -> None: # CONFIG_LWIP_MAX_SOCKETS is a single VFS socket pool shared by all socket # types (TCP clients, TCP listeners, and UDP). Include all three counts. - tcp_sockets, udp_sockets, tcp_listen = get_socket_counts() - total_sockets = tcp_sockets + udp_sockets + tcp_listen + sc = get_socket_counts() + total_sockets = sc.tcp + sc.udp + sc.tcp_listen # User specified their own value - respect it but warn if insufficient if user_max_sockets is not None: @@ -1287,9 +1287,9 @@ def _configure_lwip_max_sockets(conf: dict) -> None: "at least %d.", user_sockets_int, total_sockets, - tcp_sockets, - udp_sockets, - tcp_listen, + sc.tcp, + sc.udp, + sc.tcp_listen, total_sockets, ) # User's value already added via sdkconfig_options processing @@ -1300,13 +1300,19 @@ def _configure_lwip_max_sockets(conf: dict) -> None: max_sockets = max(DEFAULT_MAX_SOCKETS, total_sockets) log_level = logging.INFO if max_sockets > DEFAULT_MAX_SOCKETS else logging.DEBUG + sock_min = " (min)" if max_sockets > total_sockets else "" _LOGGER.log( log_level, - "Setting CONFIG_LWIP_MAX_SOCKETS to %d (%d TCP + %d UDP + %d TCP_LISTEN)", + "Setting CONFIG_LWIP_MAX_SOCKETS to %d%s " + "(TCP=%d [%s], UDP=%d [%s], TCP_LISTEN=%d [%s])", max_sockets, - tcp_sockets, - udp_sockets, - tcp_listen, + sock_min, + sc.tcp, + sc.tcp_details, + sc.udp, + sc.udp_details, + sc.tcp_listen, + sc.tcp_listen_details, ) add_idf_sdkconfig_option("CONFIG_LWIP_MAX_SOCKETS", max_sockets) diff --git a/esphome/components/libretiny/__init__.py b/esphome/components/libretiny/__init__.py index 0daf9733b8..2291114d9a 100644 --- a/esphome/components/libretiny/__init__.py +++ b/esphome/components/libretiny/__init__.py @@ -290,7 +290,7 @@ def _configure_lwip(config: dict) -> None: Setting ESP8266 ESP32 BK SDK RTL SDK LN SDK New ──────────────────────────────────────────────────────────────────────────── TCP_SND_BUF 2×MSS 4×MSS 10×MSS 5×MSS 7×MSS 4×MSS - TCP_WND 4×MSS 4×MSS 10×MSS 2×MSS 3×MSS 4×MSS + TCP_WND 4×MSS 4×MSS 3/10×MSS 2×MSS 3×MSS 4×MSS MEM_LIBC_MALLOC 1 1 0 0 1 1 MEMP_MEM_MALLOC 1 1 0 0 0 1 MEM_SIZE N/A* N/A* 16/32KB 5KB N/A* N/A* BK @@ -313,21 +313,22 @@ def _configure_lwip(config: dict) -> None: **** RTL/LN LT overlay overrides to flat 7. ***** Not defined in RTL SDK — lwIP opt.h defaults shown. "dynamic" = auto-calculated from component socket registrations via - socket.get_socket_counts() with minimums of 10 TCP / 8 UDP. + socket.get_socket_counts() with minimums of 8 TCP / 6 UDP. """ from esphome.components.socket import ( + MIN_TCP_LISTEN_SOCKETS, MIN_TCP_SOCKETS, MIN_UDP_SOCKETS, get_socket_counts, ) - raw_tcp, raw_udp, raw_tcp_listen = get_socket_counts() + sc = get_socket_counts() # Apply platform minimums — ensure headroom for ESPHome's needs - tcp_sockets = max(MIN_TCP_SOCKETS, raw_tcp) - udp_sockets = max(MIN_UDP_SOCKETS, raw_udp) + tcp_sockets = max(MIN_TCP_SOCKETS, sc.tcp) + udp_sockets = max(MIN_UDP_SOCKETS, sc.udp) # Listening sockets — registered by components (api, ota, web_server_base, etc.) - # Not all components register yet, so ensure a minimum of 2 (api + ota baseline). - listening_tcp = max(raw_tcp_listen, 2) + # Not all components register yet, so ensure a minimum for baseline operation. + listening_tcp = max(MIN_TCP_LISTEN_SOCKETS, sc.tcp_listen) # TCP_SND_BUF: ESPAsyncWebServer allocates malloc(tcp_sndbuf()) per # response chunk. At 10×MSS=14.6KB (BK default) this causes OOM (#14095). @@ -396,6 +397,21 @@ def _configure_lwip(config: dict) -> None: if CORE.is_bk72xx: lwip_opts.append("PBUF_POOL_SIZE=10") + tcp_min = " (min)" if tcp_sockets > sc.tcp else "" + udp_min = " (min)" if udp_sockets > sc.udp else "" + listen_min = " (min)" if listening_tcp > sc.tcp_listen else "" + _LOGGER.info( + "Configuring lwIP: TCP=%d%s [%s], UDP=%d%s [%s], TCP_LISTEN=%d%s [%s]", + tcp_sockets, + tcp_min, + sc.tcp_details, + udp_sockets, + udp_min, + sc.udp_details, + listening_tcp, + listen_min, + sc.tcp_listen_details, + ) cg.add_platformio_option("custom_options.lwip", lwip_opts) diff --git a/esphome/components/socket/__init__.py b/esphome/components/socket/__init__.py index de5c6d2dd6..d82f0c7aba 100644 --- a/esphome/components/socket/__init__.py +++ b/esphome/components/socket/__init__.py @@ -1,4 +1,5 @@ from collections.abc import Callable, MutableMapping +from dataclasses import dataclass from enum import StrEnum import logging @@ -21,12 +22,14 @@ KEY_SOCKET_CONSUMERS_TCP = "socket_consumers_tcp" KEY_SOCKET_CONSUMERS_UDP = "socket_consumers_udp" KEY_SOCKET_CONSUMERS_TCP_LISTEN = "socket_consumers_tcp_listen" -# Recommended minimum socket counts to ensure headroom. +# Recommended minimum socket counts. # Platforms should apply these (or their own) on top of get_socket_counts(). -# TCP: Typical setup: api(3) + web_server(5) = 8 registered, +2 headroom for ota-transfer/other = 10 total. -# UDP: dhcp(1) + dns(1) + mdns(2) + wake_loop(1) = 5 base, +3 headroom. -MIN_TCP_SOCKETS = 10 -MIN_UDP_SOCKETS = 8 +# These cover minimal configs (e.g. api-only without web_server). +# When web_server is present, its 5 registered sockets push past the TCP minimum. +MIN_TCP_SOCKETS = 8 +MIN_UDP_SOCKETS = 6 +# Minimum listening sockets — at least api + ota baseline. +MIN_TCP_LISTEN_SOCKETS = 2 # Wake loop threadsafe support tracking KEY_WAKE_LOOP_THREADSAFE_REQUIRED = "wake_loop_threadsafe_required" @@ -68,8 +71,27 @@ def consume_sockets( return _consume_sockets -def get_socket_counts() -> tuple[int, int, int]: - """Return (tcp_count, udp_count, tcp_listen_count) of raw registered socket needs. +def _format_consumers(consumers: dict[str, int]) -> str: + """Format consumer dict as 'name=count, ...' or 'none'.""" + if not consumers: + return "none" + return ", ".join(f"{name}={count}" for name, count in sorted(consumers.items())) + + +@dataclass(frozen=True) +class SocketCounts: + """Socket counts and component details for platform configuration.""" + + tcp: int + udp: int + tcp_listen: int + tcp_details: str + udp_details: str + tcp_listen_details: str + + +def get_socket_counts() -> SocketCounts: + """Return socket counts and component details for platform configuration. Platforms call this during code generation to configure lwIP socket limits. All components will have registered their needs by then. @@ -83,25 +105,21 @@ def get_socket_counts() -> tuple[int, int, int]: udp = sum(udp_consumers.values()) tcp_listen = sum(tcp_listen_consumers.values()) - tcp_list = ", ".join( - f"{name}={count}" for name, count in sorted(tcp_consumers.items()) - ) - udp_list = ", ".join( - f"{name}={count}" for name, count in sorted(udp_consumers.items()) - ) - tcp_listen_list = ", ".join( - f"{name}={count}" for name, count in sorted(tcp_listen_consumers.items()) - ) + tcp_details = _format_consumers(tcp_consumers) + udp_details = _format_consumers(udp_consumers) + tcp_listen_details = _format_consumers(tcp_listen_consumers) _LOGGER.debug( "Socket counts: TCP=%d (%s), UDP=%d (%s), TCP_LISTEN=%d (%s)", tcp, - tcp_list or "none", + tcp_details, udp, - udp_list or "none", + udp_details, tcp_listen, - tcp_listen_list or "none", + tcp_listen_details, + ) + return SocketCounts( + tcp, udp, tcp_listen, tcp_details, udp_details, tcp_listen_details ) - return tcp, udp, tcp_listen def require_wake_loop_threadsafe() -> None: diff --git a/esphome/platformio_api.py b/esphome/platformio_api.py index a7ab9717d3..4c71bdef6b 100644 --- a/esphome/platformio_api.py +++ b/esphome/platformio_api.py @@ -5,6 +5,7 @@ import os from pathlib import Path import re import subprocess +import time from typing import Any from esphome.const import CONF_COMPILE_PROCESS_LIMIT, CONF_ESPHOME, KEY_CORE @@ -44,31 +45,61 @@ def patch_structhash(): def patch_file_downloader(): - """Patch PlatformIO's FileDownloader to retry on PackageException errors.""" + """Patch PlatformIO's FileDownloader to retry on PackageException errors. + + PlatformIO's FileDownloader uses HTTPSession which lacks built-in retry + for 502/503 errors. We add retries with exponential backoff and close the + session between attempts to force a fresh TCP connection, which may route + to a different CDN edge node. + """ from platformio.package.download import FileDownloader from platformio.package.exception import PackageException + if getattr(FileDownloader.__init__, "_esphome_patched", False): + return + original_init = FileDownloader.__init__ def patched_init(self, *args: Any, **kwargs: Any) -> None: - max_retries = 3 + max_retries = 5 for attempt in range(max_retries): try: - return original_init(self, *args, **kwargs) + original_init(self, *args, **kwargs) + return except PackageException as e: if attempt < max_retries - 1: + # Exponential backoff: 2, 4, 8, 16 seconds + delay = 2 ** (attempt + 1) _LOGGER.warning( - "Package download failed: %s. Retrying... (attempt %d/%d)", + "Package download failed: %s. " + "Retrying in %d seconds... (attempt %d/%d)", str(e), + delay, attempt + 1, max_retries, ) + # Close the response and session to free resources + # and force a new TCP connection on retry, which may + # route to a different CDN edge node + # pylint: disable=protected-access,broad-except + try: + if ( + hasattr(self, "_http_response") + and self._http_response is not None + ): + self._http_response.close() + if hasattr(self, "_http_session"): + self._http_session.close() + except Exception: + pass + # pylint: enable=protected-access,broad-except + time.sleep(delay) else: # Final attempt - re-raise raise - return None + patched_init._esphome_patched = True # type: ignore[attr-defined] # pylint: disable=protected-access FileDownloader.__init__ = patched_init diff --git a/tests/unit_tests/test_platformio_api.py b/tests/unit_tests/test_platformio_api.py index 4d7b635e59..1686144277 100644 --- a/tests/unit_tests/test_platformio_api.py +++ b/tests/unit_tests/test_platformio_api.py @@ -6,7 +6,7 @@ import os from pathlib import Path import shutil from types import SimpleNamespace -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, call, patch import pytest @@ -673,6 +673,200 @@ def test_process_stacktrace_bad_alloc( assert state is False +def test_patch_file_downloader_succeeds_first_try() -> None: + """Test patch_file_downloader succeeds on first attempt.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + original_init = MagicMock() + + with patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type("FileDownloader", (), {"__init__": original_init}) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ): + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + FileDownloader.__init__(instance, "http://example.com/file.zip") + + original_init.assert_called_once() + + +def test_patch_file_downloader_retries_on_failure() -> None: + """Test patch_file_downloader retries with backoff on PackageException.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + call_count = 0 + + def failing_init(self, *args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise mock_exception_cls(f"502 error attempt {call_count}") + + with ( + patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type( + "FileDownloader", (), {"__init__": failing_init} + ) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ), + patch("time.sleep") as mock_sleep, + ): + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + FileDownloader.__init__(instance, "http://example.com/file.zip") + + # Should have been called 3 times (2 failures + 1 success) + assert call_count == 3 + + # Should have slept with exponential backoff: 2s, 4s + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(2) + mock_sleep.assert_any_call(4) + + +def test_patch_file_downloader_raises_after_max_retries() -> None: + """Test patch_file_downloader raises after exhausting all retries.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + + def always_failing_init(self, *args, **kwargs): + raise mock_exception_cls("502 error") + + with ( + patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type( + "FileDownloader", (), {"__init__": always_failing_init} + ) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ), + patch("time.sleep") as mock_sleep, + ): + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + with pytest.raises(mock_exception_cls, match="502 error"): + FileDownloader.__init__(instance, "http://example.com/file.zip") + + # Should have slept 4 times (before attempts 2-5), not on final attempt + assert mock_sleep.call_count == 4 + mock_sleep.assert_has_calls([call(2), call(4), call(8), call(16)]) + + +def test_patch_file_downloader_closes_session_and_response_between_retries() -> None: + """Test patch_file_downloader closes HTTP session and response between retries.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + mock_session = MagicMock() + mock_response = MagicMock() + call_count = 0 + + def failing_init_with_session(self, *args, **kwargs): + nonlocal call_count + call_count += 1 + self._http_session = mock_session + self._http_response = mock_response + if call_count < 2: + raise mock_exception_cls("502 error") + + with ( + patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type( + "FileDownloader", + (), + {"__init__": failing_init_with_session}, + ) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ), + patch("time.sleep"), + ): + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + FileDownloader.__init__(instance, "http://example.com/file.zip") + + # Both response and session should have been closed between retries + mock_response.close.assert_called_once() + mock_session.close.assert_called_once() + + +def test_patch_file_downloader_idempotent() -> None: + """Test patch_file_downloader does not stack wrappers when called multiple times.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + call_count = 0 + + def counting_init(self, *args, **kwargs): + nonlocal call_count + call_count += 1 + + with patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type("FileDownloader", (), {"__init__": counting_init}) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ): + # Patch multiple times + platformio_api.patch_file_downloader() + platformio_api.patch_file_downloader() + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + FileDownloader.__init__(instance, "http://example.com/file.zip") + + # Should only be called once, not 3 times from stacked wrappers + assert call_count == 1 + + def test_platformio_log_filter_allows_non_platformio_messages() -> None: """Test that non-platformio logger messages are allowed through.""" log_filter = platformio_api.PlatformioLogFilter()