[dashboard] Add ESPHOME_TRUSTED_DOMAINS support to events WebSocket (#12479)
This commit is contained in:
@@ -164,8 +164,24 @@ def websocket_method(name):
|
||||
return wrap
|
||||
|
||||
|
||||
class CheckOriginMixin:
|
||||
"""Mixin to handle WebSocket origin checks for reverse proxy setups."""
|
||||
|
||||
def check_origin(self, origin: str) -> bool:
|
||||
if "ESPHOME_TRUSTED_DOMAINS" not in os.environ:
|
||||
return super().check_origin(origin)
|
||||
trusted_domains = [
|
||||
s.strip() for s in os.environ["ESPHOME_TRUSTED_DOMAINS"].split(",")
|
||||
]
|
||||
url = urlparse(origin)
|
||||
if url.hostname in trusted_domains:
|
||||
return True
|
||||
_LOGGER.info("check_origin %s, domain is not trusted", origin)
|
||||
return False
|
||||
|
||||
|
||||
@websocket_class
|
||||
class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
|
||||
class EsphomeCommandWebSocket(CheckOriginMixin, tornado.websocket.WebSocketHandler):
|
||||
"""Base class for ESPHome websocket commands."""
|
||||
|
||||
def __init__(
|
||||
@@ -183,18 +199,6 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
|
||||
# use Popen() with a reading thread instead
|
||||
self._use_popen = os.name == "nt"
|
||||
|
||||
def check_origin(self, origin):
|
||||
if "ESPHOME_TRUSTED_DOMAINS" not in os.environ:
|
||||
return super().check_origin(origin)
|
||||
trusted_domains = [
|
||||
s.strip() for s in os.environ["ESPHOME_TRUSTED_DOMAINS"].split(",")
|
||||
]
|
||||
url = urlparse(origin)
|
||||
if url.hostname in trusted_domains:
|
||||
return True
|
||||
_LOGGER.info("check_origin %s, domain is not trusted", origin)
|
||||
return False
|
||||
|
||||
def open(self, *args: str, **kwargs: str) -> None:
|
||||
"""Handle new WebSocket connection."""
|
||||
# Ensure messages from the subprocess are sent immediately
|
||||
@@ -601,7 +605,7 @@ DASHBOARD_SUBSCRIBER = DashboardSubscriber()
|
||||
|
||||
|
||||
@websocket_class
|
||||
class DashboardEventsWebSocket(tornado.websocket.WebSocketHandler):
|
||||
class DashboardEventsWebSocket(CheckOriginMixin, tornado.websocket.WebSocketHandler):
|
||||
"""WebSocket handler for real-time dashboard events."""
|
||||
|
||||
_event_listeners: list[Callable[[], None]] | None = None
|
||||
|
||||
@@ -1567,3 +1567,90 @@ async def test_dashboard_yaml_loading_with_packages_and_secrets(
|
||||
# If we get here, secret resolution worked!
|
||||
assert "esphome" in config
|
||||
assert config["esphome"]["name"] == "test-download-secrets"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_check_origin_default_same_origin(
|
||||
dashboard: DashboardTestHelper,
|
||||
) -> None:
|
||||
"""Test WebSocket uses default same-origin check when ESPHOME_TRUSTED_DOMAINS not set."""
|
||||
# Ensure ESPHOME_TRUSTED_DOMAINS is not set
|
||||
env = os.environ.copy()
|
||||
env.pop("ESPHOME_TRUSTED_DOMAINS", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
from tornado.httpclient import HTTPRequest
|
||||
|
||||
url = f"ws://127.0.0.1:{dashboard.port}/events"
|
||||
# Same origin should work (default Tornado behavior)
|
||||
request = HTTPRequest(
|
||||
url, headers={"Origin": f"http://127.0.0.1:{dashboard.port}"}
|
||||
)
|
||||
ws = await websocket_connect(request)
|
||||
try:
|
||||
msg = await ws.read_message()
|
||||
assert msg is not None
|
||||
data = json.loads(msg)
|
||||
assert data["event"] == "initial_state"
|
||||
finally:
|
||||
ws.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_check_origin_trusted_domain(
|
||||
dashboard: DashboardTestHelper,
|
||||
) -> None:
|
||||
"""Test WebSocket accepts connections from trusted domains."""
|
||||
with patch.dict(os.environ, {"ESPHOME_TRUSTED_DOMAINS": "trusted.example.com"}):
|
||||
from tornado.httpclient import HTTPRequest
|
||||
|
||||
url = f"ws://127.0.0.1:{dashboard.port}/events"
|
||||
request = HTTPRequest(url, headers={"Origin": "https://trusted.example.com"})
|
||||
ws = await websocket_connect(request)
|
||||
try:
|
||||
# Should receive initial state
|
||||
msg = await ws.read_message()
|
||||
assert msg is not None
|
||||
data = json.loads(msg)
|
||||
assert data["event"] == "initial_state"
|
||||
finally:
|
||||
ws.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_check_origin_untrusted_domain(
|
||||
dashboard: DashboardTestHelper,
|
||||
) -> None:
|
||||
"""Test WebSocket rejects connections from untrusted domains."""
|
||||
with patch.dict(os.environ, {"ESPHOME_TRUSTED_DOMAINS": "trusted.example.com"}):
|
||||
from tornado.httpclient import HTTPRequest
|
||||
|
||||
url = f"ws://127.0.0.1:{dashboard.port}/events"
|
||||
request = HTTPRequest(url, headers={"Origin": "https://untrusted.example.com"})
|
||||
with pytest.raises(HTTPClientError) as exc_info:
|
||||
await websocket_connect(request)
|
||||
# Should get HTTP 403 Forbidden due to origin check failure
|
||||
assert exc_info.value.code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_check_origin_multiple_trusted_domains(
|
||||
dashboard: DashboardTestHelper,
|
||||
) -> None:
|
||||
"""Test WebSocket accepts connections from multiple trusted domains."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"ESPHOME_TRUSTED_DOMAINS": "first.example.com, second.example.com"},
|
||||
):
|
||||
from tornado.httpclient import HTTPRequest
|
||||
|
||||
url = f"ws://127.0.0.1:{dashboard.port}/events"
|
||||
# Test second domain in list (with space after comma)
|
||||
request = HTTPRequest(url, headers={"Origin": "https://second.example.com"})
|
||||
ws = await websocket_connect(request)
|
||||
try:
|
||||
msg = await ws.read_message()
|
||||
assert msg is not None
|
||||
data = json.loads(msg)
|
||||
assert data["event"] == "initial_state"
|
||||
finally:
|
||||
ws.close()
|
||||
|
||||
Reference in New Issue
Block a user