[ota] Replace std::function callbacks with listener interface (#12167)

This commit is contained in:
J. Nick Koston
2025-12-19 11:19:07 -10:00
committed by GitHub
parent 940afdbb12
commit 988b888c63
21 changed files with 274 additions and 206 deletions

View File

@@ -5,7 +5,7 @@ import logging
from esphome import automation
import esphome.codegen as cg
from esphome.components import esp32_ble
from esphome.components import esp32_ble, ota
from esphome.components.esp32 import add_idf_sdkconfig_option
from esphome.components.esp32_ble import (
IDF_MAX_CONNECTIONS,
@@ -328,7 +328,7 @@ async def to_code(config):
# Note: CONFIG_BT_ACL_CONNECTIONS and CONFIG_BTDM_CTRL_BLE_MAX_CONN are now
# configured in esp32_ble component based on max_connections setting
cg.add_define("USE_OTA_STATE_CALLBACK") # To be notified when an OTA update starts
ota.request_ota_state_listeners() # To be notified when an OTA update starts
cg.add_define("USE_ESP32_BLE_CLIENT")
CORE.add_job(_add_ble_features)

View File

@@ -71,21 +71,24 @@ void ESP32BLETracker::setup() {
global_esp32_ble_tracker = this;
#ifdef USE_OTA
ota::get_global_ota_callback()->add_on_state_callback(
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
this->stop_scan();
#ifdef ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT
for (auto *client : this->clients_) {
client->disconnect();
}
#endif
}
});
#ifdef USE_OTA_STATE_LISTENER
ota::get_global_ota_callback()->add_global_state_listener(this);
#endif
}
#ifdef USE_OTA_STATE_LISTENER
void ESP32BLETracker::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
this->stop_scan();
#ifdef ESPHOME_ESP32_BLE_TRACKER_CLIENT_COUNT
for (auto *client : this->clients_) {
client->disconnect();
}
#endif
}
}
#endif
void ESP32BLETracker::loop() {
if (!this->parent_->is_active()) {
this->ble_was_disabled_ = true;

View File

@@ -22,6 +22,10 @@
#include "esphome/components/esp32_ble/ble_uuid.h"
#include "esphome/components/esp32_ble/ble_scan_result.h"
#ifdef USE_OTA_STATE_LISTENER
#include "esphome/components/ota/ota_backend.h"
#endif
namespace esphome::esp32_ble_tracker {
using namespace esp32_ble;
@@ -241,6 +245,9 @@ class ESP32BLETracker : public Component,
public GAPScanEventHandler,
public GATTcEventHandler,
public BLEStatusEventHandler,
#ifdef USE_OTA_STATE_LISTENER
public ota::OTAGlobalStateListener,
#endif
public Parented<ESP32BLE> {
public:
void set_scan_duration(uint32_t scan_duration) { scan_duration_ = scan_duration; }
@@ -274,6 +281,10 @@ class ESP32BLETracker : public Component,
void gap_scan_event_handler(const BLEScanResult &scan_result) override;
void ble_before_disabled_event_handler() override;
#ifdef USE_OTA_STATE_LISTENER
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override;
#endif
/// Add a listener for scanner state changes
void add_scanner_state_listener(BLEScannerStateListener *listener) {
this->scanner_state_listeners_.push_back(listener);

View File

@@ -41,10 +41,6 @@ static constexpr size_t SHA256_HEX_SIZE = 64; // SHA256 hash as hex string (32
#endif // USE_OTA_PASSWORD
void ESPHomeOTAComponent::setup() {
#ifdef USE_OTA_STATE_CALLBACK
ota::register_ota_platform(this);
#endif
this->server_ = socket::socket_ip_loop_monitored(SOCK_STREAM, 0); // monitored for incoming connections
if (this->server_ == nullptr) {
this->log_socket_error_(LOG_STR("creation"));
@@ -297,8 +293,8 @@ void ESPHomeOTAComponent::handle_data_() {
// accidentally trigger the update process.
this->log_start_(LOG_STR("update"));
this->status_set_warning();
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0);
#ifdef USE_OTA_STATE_LISTENER
this->notify_state_(ota::OTA_STARTED, 0.0f, 0);
#endif
// This will block for a few seconds as it locks flash
@@ -357,8 +353,8 @@ void ESPHomeOTAComponent::handle_data_() {
last_progress = now;
float percentage = (total * 100.0f) / ota_size;
ESP_LOGD(TAG, "Progress: %0.1f%%", percentage);
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0);
#ifdef USE_OTA_STATE_LISTENER
this->notify_state_(ota::OTA_IN_PROGRESS, percentage, 0);
#endif
// feed watchdog and give other tasks a chance to run
this->yield_and_feed_watchdog_();
@@ -387,8 +383,8 @@ void ESPHomeOTAComponent::handle_data_() {
delay(10);
ESP_LOGI(TAG, "Update complete");
this->status_clear_warning();
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, 0);
#ifdef USE_OTA_STATE_LISTENER
this->notify_state_(ota::OTA_COMPLETED, 100.0f, 0);
#endif
delay(100); // NOLINT
App.safe_reboot();
@@ -402,8 +398,8 @@ error:
}
this->status_momentary_error("err", 5000);
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#ifdef USE_OTA_STATE_LISTENER
this->notify_state_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#endif
}

View File

@@ -16,12 +16,6 @@ namespace http_request {
static const char *const TAG = "http_request.ota";
void OtaHttpRequestComponent::setup() {
#ifdef USE_OTA_STATE_CALLBACK
ota::register_ota_platform(this);
#endif
}
void OtaHttpRequestComponent::dump_config() { ESP_LOGCONFIG(TAG, "Over-The-Air updates via HTTP request"); };
void OtaHttpRequestComponent::set_md5_url(const std::string &url) {
@@ -48,24 +42,24 @@ void OtaHttpRequestComponent::flash() {
}
ESP_LOGI(TAG, "Starting update");
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0);
#ifdef USE_OTA_STATE_LISTENER
this->notify_state_(ota::OTA_STARTED, 0.0f, 0);
#endif
auto ota_status = this->do_ota_();
switch (ota_status) {
case ota::OTA_RESPONSE_OK:
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, ota_status);
#ifdef USE_OTA_STATE_LISTENER
this->notify_state_(ota::OTA_COMPLETED, 100.0f, ota_status);
#endif
delay(10);
App.safe_reboot();
break;
default:
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_ERROR, 0.0f, ota_status);
#ifdef USE_OTA_STATE_LISTENER
this->notify_state_(ota::OTA_ERROR, 0.0f, ota_status);
#endif
this->md5_computed_.clear(); // will be reset at next attempt
this->md5_expected_.clear(); // will be reset at next attempt
@@ -165,8 +159,8 @@ uint8_t OtaHttpRequestComponent::do_ota_() {
last_progress = now;
float percentage = container->get_bytes_read() * 100.0f / container->content_length;
ESP_LOGD(TAG, "Progress: %0.1f%%", percentage);
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0);
#ifdef USE_OTA_STATE_LISTENER
this->notify_state_(ota::OTA_IN_PROGRESS, percentage, 0);
#endif
}
} // while

View File

@@ -24,7 +24,6 @@ enum OtaHttpRequestError : uint8_t {
class OtaHttpRequestComponent : public ota::OTAComponent, public Parented<HttpRequestComponent> {
public:
void setup() override;
void dump_config() override;
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }

View File

@@ -1,5 +1,5 @@
import esphome.codegen as cg
from esphome.components import update
from esphome.components import ota, update
import esphome.config_validation as cv
from esphome.const import CONF_SOURCE
@@ -38,6 +38,6 @@ async def to_code(config):
cg.add(var.set_source_url(config[CONF_SOURCE]))
cg.add_define("USE_OTA_STATE_CALLBACK")
ota.request_ota_state_listeners()
await cg.register_component(var, config)

View File

@@ -20,19 +20,19 @@ static const char *const TAG = "http_request.update";
static const size_t MAX_READ_SIZE = 256;
void HttpRequestUpdate::setup() {
this->ota_parent_->add_on_state_callback([this](ota::OTAState state, float progress, uint8_t err) {
if (state == ota::OTAState::OTA_IN_PROGRESS) {
this->state_ = update::UPDATE_STATE_INSTALLING;
this->update_info_.has_progress = true;
this->update_info_.progress = progress;
this->publish_state();
} else if (state == ota::OTAState::OTA_ABORT || state == ota::OTAState::OTA_ERROR) {
this->state_ = update::UPDATE_STATE_AVAILABLE;
this->status_set_error(LOG_STR("Failed to install firmware"));
this->publish_state();
}
});
void HttpRequestUpdate::setup() { this->ota_parent_->add_state_listener(this); }
void HttpRequestUpdate::on_ota_state(ota::OTAState state, float progress, uint8_t error) {
if (state == ota::OTAState::OTA_IN_PROGRESS) {
this->state_ = update::UPDATE_STATE_INSTALLING;
this->update_info_.has_progress = true;
this->update_info_.progress = progress;
this->publish_state();
} else if (state == ota::OTAState::OTA_ABORT || state == ota::OTAState::OTA_ERROR) {
this->state_ = update::UPDATE_STATE_AVAILABLE;
this->status_set_error(LOG_STR("Failed to install firmware"));
this->publish_state();
}
}
void HttpRequestUpdate::update() {

View File

@@ -14,7 +14,7 @@
namespace esphome {
namespace http_request {
class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent {
class HttpRequestUpdate final : public update::UpdateEntity, public PollingComponent, public ota::OTAStateListener {
public:
void setup() override;
void update() override;
@@ -29,6 +29,8 @@ class HttpRequestUpdate : public update::UpdateEntity, public PollingComponent {
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }
void on_ota_state(ota::OTAState state, float progress, uint8_t error) override;
protected:
HttpRequestComponent *request_parent_;
OtaHttpRequestComponent *ota_parent_;

View File

@@ -7,7 +7,7 @@ from urllib.parse import urljoin
from esphome import automation, external_files, git
from esphome.automation import register_action, register_condition
import esphome.codegen as cg
from esphome.components import esp32, microphone, socket
from esphome.components import esp32, microphone, ota, socket
import esphome.config_validation as cv
from esphome.const import (
CONF_FILE,
@@ -452,7 +452,7 @@ async def to_code(config):
cg.add(var.set_microphone_source(mic_source))
cg.add_define("USE_MICRO_WAKE_WORD")
cg.add_define("USE_OTA_STATE_CALLBACK")
ota.request_ota_state_listeners()
esp32.add_idf_component(name="espressif/esp-tflite-micro", ref="1.3.3~1")

View File

@@ -119,18 +119,21 @@ void MicroWakeWord::setup() {
}
});
#ifdef USE_OTA
ota::get_global_ota_callback()->add_on_state_callback(
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
this->suspend_task_();
} else if (state == ota::OTA_ERROR) {
this->resume_task_();
}
});
#ifdef USE_OTA_STATE_LISTENER
ota::get_global_ota_callback()->add_global_state_listener(this);
#endif
}
#ifdef USE_OTA_STATE_LISTENER
void MicroWakeWord::on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
this->suspend_task_();
} else if (state == ota::OTA_ERROR) {
this->resume_task_();
}
}
#endif
void MicroWakeWord::inference_task(void *params) {
MicroWakeWord *this_mww = (MicroWakeWord *) params;

View File

@@ -9,8 +9,13 @@
#include "esphome/core/automation.h"
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/core/ring_buffer.h"
#ifdef USE_OTA_STATE_LISTENER
#include "esphome/components/ota/ota_backend.h"
#endif
#include <freertos/event_groups.h>
#include <frontend.h>
@@ -26,13 +31,22 @@ enum State {
STOPPED,
};
class MicroWakeWord : public Component {
class MicroWakeWord : public Component
#ifdef USE_OTA_STATE_LISTENER
,
public ota::OTAGlobalStateListener
#endif
{
public:
void setup() override;
void loop() override;
float get_setup_priority() const override;
void dump_config() override;
#ifdef USE_OTA_STATE_LISTENER
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override;
#endif
void start();
void stop();

View File

@@ -13,6 +13,8 @@ from esphome.const import (
from esphome.core import CORE, coroutine_with_priority
from esphome.coroutine import CoroPriority
OTA_STATE_LISTENER_KEY = "ota_state_listener"
CODEOWNERS = ["@esphome/core"]
AUTO_LOAD = ["md5", "safe_mode"]
@@ -86,6 +88,7 @@ BASE_OTA_SCHEMA = cv.Schema(
@coroutine_with_priority(CoroPriority.OTA_UPDATES)
async def to_code(config):
cg.add_define("USE_OTA")
CORE.add_job(final_step)
if CORE.is_rp2040 and CORE.using_arduino:
cg.add_library("Updater", None)
@@ -119,7 +122,24 @@ async def ota_to_code(var, config):
await automation.build_automation(trigger, [(cg.uint8, "x")], conf)
use_state_callback = True
if use_state_callback:
cg.add_define("USE_OTA_STATE_CALLBACK")
request_ota_state_listeners()
def request_ota_state_listeners() -> None:
"""Request that OTA state listeners be compiled in.
Components that need to be notified about OTA state changes (start, progress,
complete, error) should call this function during their code generation.
This enables the add_state_listener() API on OTAComponent.
"""
CORE.data[OTA_STATE_LISTENER_KEY] = True
@coroutine_with_priority(CoroPriority.FINAL)
async def final_step():
"""Final code generation step to configure optional OTA features."""
if CORE.data.get(OTA_STATE_LISTENER_KEY, False):
cg.add_define("USE_OTA_STATE_LISTENER")
FILTER_SOURCE_FILES = filter_source_files_from_platform(

View File

@@ -1,5 +1,5 @@
#pragma once
#ifdef USE_OTA_STATE_CALLBACK
#ifdef USE_OTA_STATE_LISTENER
#include "ota_backend.h"
#include "esphome/core/automation.h"
@@ -7,70 +7,64 @@
namespace esphome {
namespace ota {
class OTAStateChangeTrigger : public Trigger<OTAState> {
class OTAStateChangeTrigger final : public Trigger<OTAState>, public OTAStateListener {
public:
explicit OTAStateChangeTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (!parent->is_failed()) {
trigger(state);
}
});
explicit OTAStateChangeTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
void on_ota_state(OTAState state, float progress, uint8_t error) override {
if (!this->parent_->is_failed()) {
this->trigger(state);
}
}
protected:
OTAComponent *parent_;
};
class OTAStartTrigger : public Trigger<> {
template<OTAState State> class OTAStateTrigger final : public Trigger<>, public OTAStateListener {
public:
explicit OTAStartTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_STARTED && !parent->is_failed()) {
trigger();
}
});
explicit OTAStateTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
void on_ota_state(OTAState state, float progress, uint8_t error) override {
if (state == State && !this->parent_->is_failed()) {
this->trigger();
}
}
protected:
OTAComponent *parent_;
};
class OTAProgressTrigger : public Trigger<float> {
using OTAStartTrigger = OTAStateTrigger<OTA_STARTED>;
using OTAEndTrigger = OTAStateTrigger<OTA_COMPLETED>;
using OTAAbortTrigger = OTAStateTrigger<OTA_ABORT>;
class OTAProgressTrigger final : public Trigger<float>, public OTAStateListener {
public:
explicit OTAProgressTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_IN_PROGRESS && !parent->is_failed()) {
trigger(progress);
}
});
explicit OTAProgressTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
void on_ota_state(OTAState state, float progress, uint8_t error) override {
if (state == OTA_IN_PROGRESS && !this->parent_->is_failed()) {
this->trigger(progress);
}
}
protected:
OTAComponent *parent_;
};
class OTAEndTrigger : public Trigger<> {
class OTAErrorTrigger final : public Trigger<uint8_t>, public OTAStateListener {
public:
explicit OTAEndTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_COMPLETED && !parent->is_failed()) {
trigger();
}
});
}
};
explicit OTAErrorTrigger(OTAComponent *parent) : parent_(parent) { parent->add_state_listener(this); }
class OTAAbortTrigger : public Trigger<> {
public:
explicit OTAAbortTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_ABORT && !parent->is_failed()) {
trigger();
}
});
void on_ota_state(OTAState state, float progress, uint8_t error) override {
if (state == OTA_ERROR && !this->parent_->is_failed()) {
this->trigger(error);
}
}
};
class OTAErrorTrigger : public Trigger<uint8_t> {
public:
explicit OTAErrorTrigger(OTAComponent *parent) {
parent->add_on_state_callback([this, parent](OTAState state, float progress, uint8_t error) {
if (state == OTA_ERROR && !parent->is_failed()) {
trigger(error);
}
});
}
protected:
OTAComponent *parent_;
};
} // namespace ota

View File

@@ -3,7 +3,7 @@
namespace esphome {
namespace ota {
#ifdef USE_OTA_STATE_CALLBACK
#ifdef USE_OTA_STATE_LISTENER
OTAGlobalCallback *global_ota_callback{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
OTAGlobalCallback *get_global_ota_callback() {
@@ -13,7 +13,12 @@ OTAGlobalCallback *get_global_ota_callback() {
return global_ota_callback;
}
void register_ota_platform(OTAComponent *ota_caller) { get_global_ota_callback()->register_ota(ota_caller); }
void OTAComponent::notify_state_(OTAState state, float progress, uint8_t error) {
for (auto *listener : this->state_listeners_) {
listener->on_ota_state(state, progress, error);
}
get_global_ota_callback()->notify_ota_state(state, progress, error, this);
}
#endif
} // namespace ota

View File

@@ -4,8 +4,8 @@
#include "esphome/core/defines.h"
#include "esphome/core/helpers.h"
#ifdef USE_OTA_STATE_CALLBACK
#include "esphome/core/automation.h"
#ifdef USE_OTA_STATE_LISTENER
#include <vector>
#endif
namespace esphome {
@@ -60,62 +60,75 @@ class OTABackend {
virtual bool supports_compression() = 0;
};
class OTAComponent : public Component {
#ifdef USE_OTA_STATE_CALLBACK
/** Listener interface for OTA state changes.
*
* Components can implement this interface to receive OTA state updates
* without the overhead of std::function callbacks.
*/
class OTAStateListener {
public:
void add_on_state_callback(std::function<void(ota::OTAState, float, uint8_t)> &&callback) {
this->state_callback_.add(std::move(callback));
}
virtual ~OTAStateListener() = default;
virtual void on_ota_state(OTAState state, float progress, uint8_t error) = 0;
};
class OTAComponent : public Component {
#ifdef USE_OTA_STATE_LISTENER
public:
void add_state_listener(OTAStateListener *listener) { this->state_listeners_.push_back(listener); }
protected:
/** Extended callback manager with deferred call support.
void notify_state_(OTAState state, float progress, uint8_t error);
/** Notify state with deferral to main loop (for thread safety).
*
* This adds a call_deferred() method for thread-safe execution from other tasks.
* This should be used by OTA implementations that run in separate tasks
* (like web_server OTA) to ensure listeners execute in the main loop.
*/
class StateCallbackManager : public CallbackManager<void(OTAState, float, uint8_t)> {
public:
StateCallbackManager(OTAComponent *component) : component_(component) {}
void notify_state_deferred_(OTAState state, float progress, uint8_t error) {
this->defer([this, state, progress, error]() { this->notify_state_(state, progress, error); });
}
/** Call callbacks with deferral to main loop (for thread safety).
*
* This should be used by OTA implementations that run in separate tasks
* (like web_server OTA) to ensure callbacks execute in the main loop.
*/
void call_deferred(ota::OTAState state, float progress, uint8_t error) {
component_->defer([this, state, progress, error]() { this->call(state, progress, error); });
}
private:
OTAComponent *component_;
};
StateCallbackManager state_callback_{this};
std::vector<OTAStateListener *> state_listeners_;
#endif
};
#ifdef USE_OTA_STATE_CALLBACK
#ifdef USE_OTA_STATE_LISTENER
/** Listener interface for global OTA state changes (includes OTA component pointer).
*
* Used by OTAGlobalCallback to aggregate state from multiple OTA components.
*/
class OTAGlobalStateListener {
public:
virtual ~OTAGlobalStateListener() = default;
virtual void on_ota_global_state(OTAState state, float progress, uint8_t error, OTAComponent *component) = 0;
};
/** Global callback that aggregates OTA state from all OTA components.
*
* OTA components call notify_ota_state() directly with their pointer,
* which forwards the event to all registered global listeners.
*/
class OTAGlobalCallback {
public:
void register_ota(OTAComponent *ota_caller) {
ota_caller->add_on_state_callback([this, ota_caller](OTAState state, float progress, uint8_t error) {
this->state_callback_.call(state, progress, error, ota_caller);
});
}
void add_on_state_callback(std::function<void(OTAState, float, uint8_t, OTAComponent *)> &&callback) {
this->state_callback_.add(std::move(callback));
void add_global_state_listener(OTAGlobalStateListener *listener) { this->global_listeners_.push_back(listener); }
void notify_ota_state(OTAState state, float progress, uint8_t error, OTAComponent *component) {
for (auto *listener : this->global_listeners_) {
listener->on_ota_global_state(state, progress, error, component);
}
}
protected:
CallbackManager<void(OTAState, float, uint8_t, OTAComponent *)> state_callback_{};
std::vector<OTAGlobalStateListener *> global_listeners_;
};
OTAGlobalCallback *get_global_ota_callback();
void register_ota_platform(OTAComponent *ota_caller);
// OTA implementations should use:
// - state_callback_.call() when already in main loop (e.g., esphome OTA)
// - state_callback_.call_deferred() when in separate task (e.g., web_server OTA)
// This ensures proper callback execution in all contexts.
// - notify_state_() when already in main loop (e.g., esphome OTA)
// - notify_state_deferred_() when in separate task (e.g., web_server OTA)
// This ensures proper listener execution in all contexts.
#endif
std::unique_ptr<ota::OTABackend> make_ota_backend();

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from esphome import automation, external_files
import esphome.codegen as cg
from esphome.components import audio, esp32, media_player, network, psram, speaker
from esphome.components import audio, esp32, media_player, network, ota, psram, speaker
import esphome.config_validation as cv
from esphome.const import (
CONF_BUFFER_SIZE,
@@ -342,7 +342,7 @@ async def to_code(config):
var = await media_player.new_media_player(config)
await cg.register_component(var, config)
cg.add_define("USE_OTA_STATE_CALLBACK")
ota.request_ota_state_listeners()
cg.add(var.set_buffer_size(config[CONF_BUFFER_SIZE]))

View File

@@ -66,25 +66,8 @@ void SpeakerMediaPlayer::setup() {
this->set_mute_state_(false);
}
#ifdef USE_OTA
ota::get_global_ota_callback()->add_on_state_callback(
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
if (this->media_pipeline_ != nullptr) {
this->media_pipeline_->suspend_tasks();
}
if (this->announcement_pipeline_ != nullptr) {
this->announcement_pipeline_->suspend_tasks();
}
} else if (state == ota::OTA_ERROR) {
if (this->media_pipeline_ != nullptr) {
this->media_pipeline_->resume_tasks();
}
if (this->announcement_pipeline_ != nullptr) {
this->announcement_pipeline_->resume_tasks();
}
}
});
#ifdef USE_OTA_STATE_LISTENER
ota::get_global_ota_callback()->add_global_state_listener(this);
#endif
this->announcement_pipeline_ =
@@ -300,6 +283,27 @@ void SpeakerMediaPlayer::watch_media_commands_() {
}
}
#ifdef USE_OTA_STATE_LISTENER
void SpeakerMediaPlayer::on_ota_global_state(ota::OTAState state, float progress, uint8_t error,
ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
if (this->media_pipeline_ != nullptr) {
this->media_pipeline_->suspend_tasks();
}
if (this->announcement_pipeline_ != nullptr) {
this->announcement_pipeline_->suspend_tasks();
}
} else if (state == ota::OTA_ERROR) {
if (this->media_pipeline_ != nullptr) {
this->media_pipeline_->resume_tasks();
}
if (this->announcement_pipeline_ != nullptr) {
this->announcement_pipeline_->resume_tasks();
}
}
}
#endif
void SpeakerMediaPlayer::loop() {
this->watch_media_commands_();

View File

@@ -5,14 +5,18 @@
#include "audio_pipeline.h"
#include "esphome/components/audio/audio.h"
#include "esphome/components/media_player/media_player.h"
#include "esphome/components/speaker/speaker.h"
#include "esphome/core/automation.h"
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/core/preferences.h"
#ifdef USE_OTA_STATE_LISTENER
#include "esphome/components/ota/ota_backend.h"
#endif
#include <deque>
#include <freertos/FreeRTOS.h>
#include <freertos/queue.h>
@@ -39,12 +43,22 @@ struct VolumeRestoreState {
bool is_muted;
};
class SpeakerMediaPlayer : public Component, public media_player::MediaPlayer {
class SpeakerMediaPlayer : public Component,
public media_player::MediaPlayer
#ifdef USE_OTA_STATE_LISTENER
,
public ota::OTAGlobalStateListener
#endif
{
public:
float get_setup_priority() const override { return esphome::setup_priority::PROCESSOR; }
void setup() override;
void loop() override;
#ifdef USE_OTA_STATE_LISTENER
void on_ota_global_state(ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) override;
#endif
// MediaPlayer implementations
media_player::MediaPlayerTraits get_traits() override;
bool is_muted() const override { return this->is_muted_; }

View File

@@ -84,9 +84,9 @@ void OTARequestHandler::report_ota_progress_(AsyncWebServerRequest *request) {
} else {
ESP_LOGD(TAG, "OTA in progress: %" PRIu32 " bytes read", this->ota_read_length_);
}
#ifdef USE_OTA_STATE_CALLBACK
// Report progress - use call_deferred since we're in web server task
this->parent_->state_callback_.call_deferred(ota::OTA_IN_PROGRESS, percentage, 0);
#ifdef USE_OTA_STATE_LISTENER
// Report progress - use notify_state_deferred_ since we're in web server task
this->parent_->notify_state_deferred_(ota::OTA_IN_PROGRESS, percentage, 0);
#endif
this->last_ota_progress_ = now;
}
@@ -114,9 +114,9 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
// Initialize OTA on first call
this->ota_init_(filename.c_str());
#ifdef USE_OTA_STATE_CALLBACK
// Notify OTA started - use call_deferred since we're in web server task
this->parent_->state_callback_.call_deferred(ota::OTA_STARTED, 0.0f, 0);
#ifdef USE_OTA_STATE_LISTENER
// Notify OTA started - use notify_state_deferred_ since we're in web server task
this->parent_->notify_state_deferred_(ota::OTA_STARTED, 0.0f, 0);
#endif
// Platform-specific pre-initialization
@@ -134,9 +134,9 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
this->ota_backend_ = ota::make_ota_backend();
if (!this->ota_backend_) {
ESP_LOGE(TAG, "Failed to create OTA backend");
#ifdef USE_OTA_STATE_CALLBACK
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f,
static_cast<uint8_t>(ota::OTA_RESPONSE_ERROR_UNKNOWN));
#ifdef USE_OTA_STATE_LISTENER
this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f,
static_cast<uint8_t>(ota::OTA_RESPONSE_ERROR_UNKNOWN));
#endif
return;
}
@@ -148,8 +148,8 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
if (error_code != ota::OTA_RESPONSE_OK) {
ESP_LOGE(TAG, "OTA begin failed: %d", error_code);
this->ota_backend_.reset();
#ifdef USE_OTA_STATE_CALLBACK
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#ifdef USE_OTA_STATE_LISTENER
this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#endif
return;
}
@@ -166,8 +166,8 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
ESP_LOGE(TAG, "OTA write failed: %d", error_code);
this->ota_backend_->abort();
this->ota_backend_.reset();
#ifdef USE_OTA_STATE_CALLBACK
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#ifdef USE_OTA_STATE_LISTENER
this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#endif
return;
}
@@ -186,15 +186,15 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Platf
error_code = this->ota_backend_->end();
if (error_code == ota::OTA_RESPONSE_OK) {
this->ota_success_ = true;
#ifdef USE_OTA_STATE_CALLBACK
// Report completion before reboot - use call_deferred since we're in web server task
this->parent_->state_callback_.call_deferred(ota::OTA_COMPLETED, 100.0f, 0);
#ifdef USE_OTA_STATE_LISTENER
// Report completion before reboot - use notify_state_deferred_ since we're in web server task
this->parent_->notify_state_deferred_(ota::OTA_COMPLETED, 100.0f, 0);
#endif
this->schedule_ota_reboot_();
} else {
ESP_LOGE(TAG, "OTA end failed: %d", error_code);
#ifdef USE_OTA_STATE_CALLBACK
this->parent_->state_callback_.call_deferred(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#ifdef USE_OTA_STATE_LISTENER
this->parent_->notify_state_deferred_(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
#endif
}
this->ota_backend_.reset();
@@ -232,10 +232,6 @@ void WebServerOTAComponent::setup() {
// AsyncWebServer takes ownership of the handler and will delete it when the server is destroyed
base->add_handler(new OTARequestHandler(this)); // NOLINT
#ifdef USE_OTA_STATE_CALLBACK
// Register with global OTA callback system
ota::register_ota_platform(this);
#endif
}
void WebServerOTAComponent::dump_config() { ESP_LOGCONFIG(TAG, "Web Server OTA"); }

View File

@@ -146,7 +146,7 @@
#define USE_OTA_PASSWORD
#define USE_OTA_SHA256
#define ALLOW_OTA_DOWNGRADE_MD5
#define USE_OTA_STATE_CALLBACK
#define USE_OTA_STATE_LISTENER
#define USE_OTA_VERSION 2
#define USE_TIME_TIMEZONE
#define USE_WIFI