diff --git a/esphome/cpp_generator.py b/esphome/cpp_generator.py index 6f1af01a5..4f91696ca 100644 --- a/esphome/cpp_generator.py +++ b/esphome/cpp_generator.py @@ -19,11 +19,21 @@ from esphome.core import ( TimePeriodNanoseconds, TimePeriodSeconds, ) +from esphome.coroutine import CoroPriority, coroutine_with_priority from esphome.helpers import cpp_string_escape, indent_all_but_first_and_last from esphome.types import Expression, SafeExpType, TemplateArgsType from esphome.util import OrderedDict from esphome.yaml_util import ESPHomeDataBase +# Keys for lambda deduplication storage in CORE.data +_KEY_LAMBDA_DEDUP = "lambda_dedup" +_KEY_LAMBDA_DEDUP_DECLARATIONS = "lambda_dedup_declarations" + +# Regex patterns for static variable detection (compiled once) +_RE_CPP_SINGLE_LINE_COMMENT = re.compile(r"//.*?$", re.MULTILINE) +_RE_CPP_MULTI_LINE_COMMENT = re.compile(r"/\*.*?\*/", re.DOTALL) +_RE_STATIC_VARIABLE = re.compile(r"\bstatic\s+(?!cast|assert|pointer_cast)\w+\s+\w+") + class RawExpression(Expression): __slots__ = ("text",) @@ -188,7 +198,7 @@ class LambdaExpression(Expression): def __init__( self, parts, parameters, capture: str = "=", return_type=None, source=None - ): + ) -> None: self.parts = parts if not isinstance(parameters, ParameterListExpression): parameters = ParameterListExpression(*parameters) @@ -197,16 +207,21 @@ class LambdaExpression(Expression): self.capture = capture self.return_type = safe_exp(return_type) if return_type is not None else None - def __str__(self): + def format_body(self) -> str: + """Format the lambda body with source directive and content.""" + body = "" + if self.source is not None: + body += f"{self.source.as_line_directive}\n" + body += self.content + return body + + def __str__(self) -> str: # Stateless lambdas (empty capture) implicitly convert to function pointers # when assigned to function pointer types - no unary + needed cpp = f"[{self.capture}]({self.parameters})" if self.return_type is not None: cpp += f" -> {self.return_type}" - cpp += " {\n" - if self.source is not None: - cpp += f"{self.source.as_line_directive}\n" - cpp += f"{self.content}\n}}" + cpp += f" {{\n{self.format_body()}\n}}" return indent_all_but_first_and_last(cpp) @property @@ -214,6 +229,37 @@ class LambdaExpression(Expression): return "".join(str(part) for part in self.parts) +class SharedFunctionLambdaExpression(LambdaExpression): + """A lambda expression that references a shared deduplicated function. + + This class wraps a function pointer but maintains the LambdaExpression + interface so calling code works unchanged. + """ + + __slots__ = ("_func_name",) + + def __init__( + self, + func_name: str, + parameters: TemplateArgsType, + return_type: SafeExpType | None = None, + ) -> None: + # Initialize parent with empty parts since we're just a function reference + super().__init__( + [], parameters, capture="", return_type=return_type, source=None + ) + self._func_name = func_name + + def __str__(self) -> str: + # Just return the function name - it's already a function pointer + return self._func_name + + @property + def content(self) -> str: + # No content, just a function reference + return "" + + # pylint: disable=abstract-method class Literal(Expression, metaclass=abc.ABCMeta): __slots__ = () @@ -583,6 +629,25 @@ def add_global(expression: SafeExpType | Statement, prepend: bool = False): CORE.add_global(expression, prepend) +@coroutine_with_priority(CoroPriority.FINAL) +async def flush_lambda_dedup_declarations() -> None: + """Flush all deferred lambda deduplication declarations to global scope. + + This is a coroutine that runs with FINAL priority (after all components) + to ensure all referenced variables are declared before the shared + lambda functions that use them. + """ + if _KEY_LAMBDA_DEDUP_DECLARATIONS not in CORE.data: + return + + declarations = CORE.data[_KEY_LAMBDA_DEDUP_DECLARATIONS] + for func_declaration in declarations: + add_global(RawStatement(func_declaration)) + + # Clear the list so we don't add them again + CORE.data[_KEY_LAMBDA_DEDUP_DECLARATIONS] = [] + + def add_library(name: str, version: str | None, repository: str | None = None): """Add a library to the codegen library storage. @@ -656,6 +721,93 @@ async def get_variable_with_full_id(id_: ID) -> tuple[ID, "MockObj"]: return await CORE.get_variable_with_full_id(id_) +def _has_static_variables(code: str) -> bool: + """Check if code contains static variable definitions. + + Static variables in lambdas should not be deduplicated because each lambda + instance should have its own static variable state. + + Args: + code: The lambda body code to check + + Returns: + True if code contains static variable definitions + """ + # Remove C++ comments to avoid false positives + # Remove single-line comments (// ...) + code_no_comments = _RE_CPP_SINGLE_LINE_COMMENT.sub("", code) + # Remove multi-line comments (/* ... */) + code_no_comments = _RE_CPP_MULTI_LINE_COMMENT.sub("", code_no_comments) + + # Match: static + # But not: static_cast, static_assert, static_pointer_cast + return bool(_RE_STATIC_VARIABLE.search(code_no_comments)) + + +def _get_shared_lambda_name(lambda_expr: LambdaExpression) -> str | None: + """Get the shared function name for a lambda expression. + + If an identical lambda was already generated, returns the existing shared + function name. Otherwise, creates a new shared function and returns its name. + + Lambdas with static variables are not deduplicated to preserve their + independent state. + + Args: + lambda_expr: The lambda expression to deduplicate + + Returns: + The name of the shared function for this lambda (either existing or newly created), + or None if the lambda should not be deduplicated (e.g., contains static variables) + """ + # Create a unique key from the lambda content, parameters, and return type + content = lambda_expr.content + + # Don't deduplicate lambdas with static variables - each instance needs its own state + if _has_static_variables(content): + return None + param_str = str(lambda_expr.parameters) + return_str = ( + str(lambda_expr.return_type) if lambda_expr.return_type is not None else "void" + ) + + # Use tuple of (content, params, return_type) as key + lambda_key = (content, param_str, return_str) + + # Initialize deduplication storage in CORE.data if not exists + if _KEY_LAMBDA_DEDUP not in CORE.data: + CORE.data[_KEY_LAMBDA_DEDUP] = {} + # Register the flush job to run after all components (FINAL priority) + # This ensures all variables are declared before shared lambda functions + CORE.add_job(flush_lambda_dedup_declarations) + + lambda_cache = CORE.data[_KEY_LAMBDA_DEDUP] + + # Check if we've seen this lambda before + if lambda_key in lambda_cache: + # Return name of existing shared function + return lambda_cache[lambda_key] + + # First occurrence - create a shared function + # Use the cache size as the function number + func_name = f"shared_lambda_{len(lambda_cache)}" + + # Build the function declaration using lambda's body formatting + func_declaration = ( + f"{return_str} {func_name}({param_str}) {{\n{lambda_expr.format_body()}\n}}" + ) + + # Store the declaration to be added later (after all variable declarations) + # We can't add it immediately because it might reference variables not yet declared + CORE.data.setdefault(_KEY_LAMBDA_DEDUP_DECLARATIONS, []).append(func_declaration) + + # Store in cache + lambda_cache[lambda_key] = func_name + + # Return the function name (this is the first occurrence, but we still generate shared function) + return func_name + + async def process_lambda( value: Lambda, parameters: TemplateArgsType, @@ -713,6 +865,19 @@ async def process_lambda( location.line += value.content_offset else: location = None + + # Lambda deduplication: Only deduplicate stateless lambdas (empty capture). + # Stateful lambdas cannot be shared as they capture different contexts. + # Lambdas with static variables are also not deduplicated to preserve independent state. + if capture == "": + lambda_expr = LambdaExpression( + parts, parameters, capture, return_type, location + ) + func_name = _get_shared_lambda_name(lambda_expr) + if func_name is not None: + # Return a shared function reference instead of inline lambda + return SharedFunctionLambdaExpression(func_name, parameters, return_type) + return LambdaExpression(parts, parameters, capture, return_type, location) diff --git a/tests/component_tests/text/test_text.py b/tests/component_tests/text/test_text.py index bfc3131f6..56dee205b 100644 --- a/tests/component_tests/text/test_text.py +++ b/tests/component_tests/text/test_text.py @@ -1,4 +1,6 @@ -"""Tests for the binary sensor component.""" +"""Tests for the text component.""" + +from esphome.core import CORE def test_text_is_setup(generate_main): @@ -56,15 +58,22 @@ def test_text_config_value_mode_set(generate_main): assert "it_3->traits.set_mode(text::TEXT_MODE_PASSWORD);" in main_cpp -def test_text_config_lamda_is_set(generate_main): +def test_text_config_lambda_is_set(generate_main) -> None: """ - Test if lambda is set for lambda mode (optimized with stateless lambda) + Test if lambda is set for lambda mode (optimized with stateless lambda and deduplication) """ # Given # When main_cpp = generate_main("tests/component_tests/text/test_text.yaml") + # Get both global and main sections to find the shared lambda definition + full_cpp = CORE.cpp_global_section + main_cpp + # Then - assert "it_4->set_template([]() -> esphome::optional {" in main_cpp - assert 'return std::string{"Hello"};' in main_cpp + # Lambda is deduplicated into a shared function (reference in main section) + assert "it_4->set_template(shared_lambda_" in main_cpp + # Lambda body should be in the code somewhere + assert 'return std::string{"Hello"};' in full_cpp + # Verify the shared lambda function is defined (in global section) + assert "esphome::optional shared_lambda_" in full_cpp diff --git a/tests/unit_tests/test_lambda_dedup.py b/tests/unit_tests/test_lambda_dedup.py new file mode 100644 index 000000000..bbf5f02e6 --- /dev/null +++ b/tests/unit_tests/test_lambda_dedup.py @@ -0,0 +1,286 @@ +"""Tests for lambda deduplication in cpp_generator.""" + +from esphome import cpp_generator as cg +from esphome.core import CORE + + +def test_deduplicate_identical_lambdas() -> None: + """Test that identical stateless lambdas are deduplicated.""" + # Create two identical lambda expressions + lambda1 = cg.LambdaExpression( + parts=["return 42;"], + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + lambda2 = cg.LambdaExpression( + parts=["return 42;"], + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + # Try to deduplicate them + func_name1 = cg._get_shared_lambda_name(lambda1) + func_name2 = cg._get_shared_lambda_name(lambda2) + + # Both should get the same function name (deduplication happened) + assert func_name1 == func_name2 + assert func_name1 == "shared_lambda_0" + + +def test_different_lambdas_not_deduplicated() -> None: + """Test that different lambdas get different function names.""" + lambda1 = cg.LambdaExpression( + parts=["return 42;"], + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + lambda2 = cg.LambdaExpression( + parts=["return 24;"], # Different content + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + func_name1 = cg._get_shared_lambda_name(lambda1) + func_name2 = cg._get_shared_lambda_name(lambda2) + + # Different lambdas should get different function names + assert func_name1 != func_name2 + assert func_name1 == "shared_lambda_0" + assert func_name2 == "shared_lambda_1" + + +def test_different_return_types_not_deduplicated() -> None: + """Test that lambdas with different return types are not deduplicated.""" + lambda1 = cg.LambdaExpression( + parts=["return 42;"], + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + lambda2 = cg.LambdaExpression( + parts=["return 42;"], # Same content + parameters=[], + capture="", + return_type=cg.RawExpression("float"), # Different return type + ) + + func_name1 = cg._get_shared_lambda_name(lambda1) + func_name2 = cg._get_shared_lambda_name(lambda2) + + # Different return types = different functions + assert func_name1 != func_name2 + + +def test_different_parameters_not_deduplicated() -> None: + """Test that lambdas with different parameters are not deduplicated.""" + lambda1 = cg.LambdaExpression( + parts=["return x;"], + parameters=[("int", "x")], + capture="", + return_type=cg.RawExpression("int"), + ) + + lambda2 = cg.LambdaExpression( + parts=["return x;"], # Same content + parameters=[("float", "x")], # Different parameter type + capture="", + return_type=cg.RawExpression("int"), + ) + + func_name1 = cg._get_shared_lambda_name(lambda1) + func_name2 = cg._get_shared_lambda_name(lambda2) + + # Different parameters = different functions + assert func_name1 != func_name2 + + +def test_flush_lambda_dedup_declarations() -> None: + """Test that deferred declarations are properly stored for later flushing.""" + # Create a lambda which will create a deferred declaration + lambda1 = cg.LambdaExpression( + parts=["return 42;"], + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + cg._get_shared_lambda_name(lambda1) + + # Check that declaration was stored + assert cg._KEY_LAMBDA_DEDUP_DECLARATIONS in CORE.data + assert len(CORE.data[cg._KEY_LAMBDA_DEDUP_DECLARATIONS]) == 1 + + # Verify the declaration content is correct + declaration = CORE.data[cg._KEY_LAMBDA_DEDUP_DECLARATIONS][0] + assert "shared_lambda_0" in declaration + assert "return 42;" in declaration + + # Note: The actual flushing happens via CORE.add_job with FINAL priority + # during real code generation, so we don't test that here + + +def test_shared_function_lambda_expression() -> None: + """Test SharedFunctionLambdaExpression behaves correctly.""" + shared_lambda = cg.SharedFunctionLambdaExpression( + func_name="shared_lambda_0", + parameters=[], + return_type=cg.RawExpression("int"), + ) + + # Should output just the function name + assert str(shared_lambda) == "shared_lambda_0" + + # Should have empty capture (stateless) + assert shared_lambda.capture == "" + + # Should have empty content (just a reference) + assert shared_lambda.content == "" + + +def test_lambda_deduplication_counter() -> None: + """Test that lambda counter increments correctly.""" + # Create 3 different lambdas + for i in range(3): + lambda_expr = cg.LambdaExpression( + parts=[f"return {i};"], + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + func_name = cg._get_shared_lambda_name(lambda_expr) + assert func_name == f"shared_lambda_{i}" + + +def test_lambda_format_body() -> None: + """Test that format_body correctly formats lambda body with source.""" + # Without source + lambda1 = cg.LambdaExpression( + parts=["return 42;"], + parameters=[], + capture="", + return_type=None, + source=None, + ) + assert lambda1.format_body() == "return 42;" + + # With source would need a proper source object, skip for now + + +def test_stateful_lambdas_not_deduplicated() -> None: + """Test that stateful lambdas (non-empty capture) are not deduplicated.""" + # _get_shared_lambda_name is only called for stateless lambdas (capture == "") + # Stateful lambdas bypass deduplication entirely in process_lambda + + # Verify that a stateful lambda would NOT get deduplicated + # by checking it's not in the stateless dedup cache + stateful_lambda = cg.LambdaExpression( + parts=["return x + y;"], + parameters=[], + capture="=", # Non-empty capture means stateful + return_type=cg.RawExpression("int"), + ) + + # Stateful lambdas should NOT be passed to _get_shared_lambda_name + # This is enforced by the `if capture == ""` check in process_lambda + # We verify the lambda has a non-empty capture + assert stateful_lambda.capture != "" + assert stateful_lambda.capture == "=" + + +def test_static_variable_detection() -> None: + """Test detection of static variables in lambda code.""" + # Should detect static variables + assert cg._has_static_variables("static int counter = 0;") + assert cg._has_static_variables("static bool flag = false; return flag;") + assert cg._has_static_variables(" static float value = 1.0; ") + + # Should NOT detect static_cast, static_assert, etc. (with underscores) + assert not cg._has_static_variables("return static_cast(value);") + assert not cg._has_static_variables("static_assert(sizeof(int) == 4);") + assert not cg._has_static_variables("auto ptr = static_pointer_cast(bar);") + + # Edge case: 'cast', 'assert', 'pointer_cast' are NOT C++ keywords + # Someone could use them as type names, but we should NOT flag them + # because they're not actually static variables with state + # NOTE: These are valid C++ but extremely unlikely in ESPHome lambdas + assert not cg._has_static_variables("static cast obj;") # 'cast' as type name + assert not cg._has_static_variables("static assert value;") # 'assert' as type name + assert not cg._has_static_variables( + "static pointer_cast ptr;" + ) # 'pointer_cast' as type + + # Should NOT detect in comments + assert not cg._has_static_variables("// static int x = 0;\nreturn 42;") + assert not cg._has_static_variables("/* static int y = 0; */ return 42;") + + # Should detect even with comments elsewhere + assert cg._has_static_variables("// comment\nstatic int x = 0;\nreturn x;") + + # Should NOT detect non-static code + assert not cg._has_static_variables("int counter = 0; return counter++;") + assert not cg._has_static_variables("return 42;") + + # Should handle newlines between static and type/variable + assert cg._has_static_variables("static int\nfoo = 0;") + assert cg._has_static_variables("static\nint\nbar = 0;") + assert cg._has_static_variables( + "static int \n foo = 0;" + ) # Mixed spaces/newlines + + +def test_lambdas_with_static_not_deduplicated() -> None: + """Test that lambdas with static variables are not deduplicated.""" + # Two identical lambdas with static variables + lambda1 = cg.LambdaExpression( + parts=["static int counter = 0; return counter++;"], + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + lambda2 = cg.LambdaExpression( + parts=["static int counter = 0; return counter++;"], + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + # Should return None (not deduplicated) + func_name1 = cg._get_shared_lambda_name(lambda1) + func_name2 = cg._get_shared_lambda_name(lambda2) + + assert func_name1 is None + assert func_name2 is None + + +def test_lambdas_without_static_still_deduplicated() -> None: + """Test that lambdas without static variables are still deduplicated.""" + # Two identical lambdas WITHOUT static variables + lambda1 = cg.LambdaExpression( + parts=["int counter = 0; return counter++;"], # No static + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + lambda2 = cg.LambdaExpression( + parts=["int counter = 0; return counter++;"], # No static + parameters=[], + capture="", + return_type=cg.RawExpression("int"), + ) + + # Should be deduplicated (same function name) + func_name1 = cg._get_shared_lambda_name(lambda1) + func_name2 = cg._get_shared_lambda_name(lambda2) + + assert func_name1 is not None + assert func_name2 is not None + assert func_name1 == func_name2