[dashboard] Add ESPHOME_TRUSTED_DOMAINS support to events WebSocket (#12479)

This commit is contained in:
J. Nick Koston
2025-12-14 13:30:55 -06:00
committed by GitHub
parent cfc0d8bdfc
commit 780a407b10
2 changed files with 105 additions and 14 deletions

View File

@@ -164,8 +164,24 @@ def websocket_method(name):
return wrap
class CheckOriginMixin:
"""Mixin to handle WebSocket origin checks for reverse proxy setups."""
def check_origin(self, origin: str) -> bool:
if "ESPHOME_TRUSTED_DOMAINS" not in os.environ:
return super().check_origin(origin)
trusted_domains = [
s.strip() for s in os.environ["ESPHOME_TRUSTED_DOMAINS"].split(",")
]
url = urlparse(origin)
if url.hostname in trusted_domains:
return True
_LOGGER.info("check_origin %s, domain is not trusted", origin)
return False
@websocket_class
class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
class EsphomeCommandWebSocket(CheckOriginMixin, tornado.websocket.WebSocketHandler):
"""Base class for ESPHome websocket commands."""
def __init__(
@@ -183,18 +199,6 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
# use Popen() with a reading thread instead
self._use_popen = os.name == "nt"
def check_origin(self, origin):
if "ESPHOME_TRUSTED_DOMAINS" not in os.environ:
return super().check_origin(origin)
trusted_domains = [
s.strip() for s in os.environ["ESPHOME_TRUSTED_DOMAINS"].split(",")
]
url = urlparse(origin)
if url.hostname in trusted_domains:
return True
_LOGGER.info("check_origin %s, domain is not trusted", origin)
return False
def open(self, *args: str, **kwargs: str) -> None:
"""Handle new WebSocket connection."""
# Ensure messages from the subprocess are sent immediately
@@ -601,7 +605,7 @@ DASHBOARD_SUBSCRIBER = DashboardSubscriber()
@websocket_class
class DashboardEventsWebSocket(tornado.websocket.WebSocketHandler):
class DashboardEventsWebSocket(CheckOriginMixin, tornado.websocket.WebSocketHandler):
"""WebSocket handler for real-time dashboard events."""
_event_listeners: list[Callable[[], None]] | None = None

View File

@@ -1567,3 +1567,90 @@ async def test_dashboard_yaml_loading_with_packages_and_secrets(
# If we get here, secret resolution worked!
assert "esphome" in config
assert config["esphome"]["name"] == "test-download-secrets"
@pytest.mark.asyncio
async def test_websocket_check_origin_default_same_origin(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket uses default same-origin check when ESPHOME_TRUSTED_DOMAINS not set."""
# Ensure ESPHOME_TRUSTED_DOMAINS is not set
env = os.environ.copy()
env.pop("ESPHOME_TRUSTED_DOMAINS", None)
with patch.dict(os.environ, env, clear=True):
from tornado.httpclient import HTTPRequest
url = f"ws://127.0.0.1:{dashboard.port}/events"
# Same origin should work (default Tornado behavior)
request = HTTPRequest(
url, headers={"Origin": f"http://127.0.0.1:{dashboard.port}"}
)
ws = await websocket_connect(request)
try:
msg = await ws.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "initial_state"
finally:
ws.close()
@pytest.mark.asyncio
async def test_websocket_check_origin_trusted_domain(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket accepts connections from trusted domains."""
with patch.dict(os.environ, {"ESPHOME_TRUSTED_DOMAINS": "trusted.example.com"}):
from tornado.httpclient import HTTPRequest
url = f"ws://127.0.0.1:{dashboard.port}/events"
request = HTTPRequest(url, headers={"Origin": "https://trusted.example.com"})
ws = await websocket_connect(request)
try:
# Should receive initial state
msg = await ws.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "initial_state"
finally:
ws.close()
@pytest.mark.asyncio
async def test_websocket_check_origin_untrusted_domain(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket rejects connections from untrusted domains."""
with patch.dict(os.environ, {"ESPHOME_TRUSTED_DOMAINS": "trusted.example.com"}):
from tornado.httpclient import HTTPRequest
url = f"ws://127.0.0.1:{dashboard.port}/events"
request = HTTPRequest(url, headers={"Origin": "https://untrusted.example.com"})
with pytest.raises(HTTPClientError) as exc_info:
await websocket_connect(request)
# Should get HTTP 403 Forbidden due to origin check failure
assert exc_info.value.code == 403
@pytest.mark.asyncio
async def test_websocket_check_origin_multiple_trusted_domains(
dashboard: DashboardTestHelper,
) -> None:
"""Test WebSocket accepts connections from multiple trusted domains."""
with patch.dict(
os.environ,
{"ESPHOME_TRUSTED_DOMAINS": "first.example.com, second.example.com"},
):
from tornado.httpclient import HTTPRequest
url = f"ws://127.0.0.1:{dashboard.port}/events"
# Test second domain in list (with space after comma)
request = HTTPRequest(url, headers={"Origin": "https://second.example.com"})
ws = await websocket_connect(request)
try:
msg = await ws.read_message()
assert msg is not None
data = json.loads(msg)
assert data["event"] == "initial_state"
finally:
ws.close()