Files
esphome/esphome/components/socket/lwip_raw_tcp_impl.cpp

646 lines
19 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "socket.h"
#include "esphome/core/defines.h"
#ifdef USE_SOCKET_IMPL_LWIP_TCP
#include "lwip/ip.h"
#include "lwip/netif.h"
#include "lwip/opt.h"
#include "lwip/tcp.h"
#include <cerrno>
#include <cstring>
#include <array>
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
namespace esphome {
namespace socket {
static const char *const TAG = "socket.lwip";
// set to 1 to enable verbose lwip logging
#if 0
#define LWIP_LOG(msg, ...) ESP_LOGVV(TAG, "socket %p: " msg, this, ##__VA_ARGS__)
#else
#define LWIP_LOG(msg, ...)
#endif
class LWIPRawImpl : public Socket {
public:
LWIPRawImpl(sa_family_t family, struct tcp_pcb *pcb) : pcb_(pcb), family_(family) {}
~LWIPRawImpl() override {
if (pcb_ != nullptr) {
LWIP_LOG("tcp_abort(%p)", pcb_);
tcp_abort(pcb_);
pcb_ = nullptr;
}
}
void init() {
LWIP_LOG("init(%p)", pcb_);
tcp_arg(pcb_, this);
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
tcp_recv(pcb_, LWIPRawImpl::s_recv_fn);
tcp_err(pcb_, LWIPRawImpl::s_err_fn);
}
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return nullptr;
}
if (this->accepted_socket_count_ == 0) {
errno = EWOULDBLOCK;
return nullptr;
}
// Take from front for FIFO ordering
std::unique_ptr<LWIPRawImpl> sock = std::move(this->accepted_sockets_[0]);
// Shift remaining sockets forward
for (uint8_t i = 1; i < this->accepted_socket_count_; i++) {
this->accepted_sockets_[i - 1] = std::move(this->accepted_sockets_[i]);
}
this->accepted_socket_count_--;
LWIP_LOG("Connection accepted by application, queue size: %d", this->accepted_socket_count_);
if (addr != nullptr) {
sock->getpeername(addr, addrlen);
}
LWIP_LOG("accept(%p)", sock.get());
return std::unique_ptr<Socket>(std::move(sock));
}
int bind(const struct sockaddr *name, socklen_t addrlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (name == nullptr) {
errno = EINVAL;
return 0;
}
ip_addr_t ip;
in_port_t port;
#if LWIP_IPV6
if (family_ == AF_INET) {
if (addrlen < sizeof(sockaddr_in)) {
errno = EINVAL;
return -1;
}
auto *addr4 = reinterpret_cast<const sockaddr_in *>(name);
port = ntohs(addr4->sin_port);
ip.type = IPADDR_TYPE_V4;
ip.u_addr.ip4.addr = addr4->sin_addr.s_addr;
LWIP_LOG("tcp_bind(%p ip=%s port=%u)", pcb_, ip4addr_ntoa(&ip.u_addr.ip4), port);
} else if (family_ == AF_INET6) {
if (addrlen < sizeof(sockaddr_in6)) {
errno = EINVAL;
return -1;
}
auto *addr6 = reinterpret_cast<const sockaddr_in6 *>(name);
port = ntohs(addr6->sin6_port);
ip.type = IPADDR_TYPE_ANY;
memcpy(&ip.u_addr.ip6.addr, &addr6->sin6_addr.un.u8_addr, 16);
LWIP_LOG("tcp_bind(%p ip=%s port=%u)", pcb_, ip6addr_ntoa(&ip.u_addr.ip6), port);
} else {
errno = EINVAL;
return -1;
}
#else
if (family_ != AF_INET) {
errno = EINVAL;
return -1;
}
auto *addr4 = reinterpret_cast<const sockaddr_in *>(name);
port = ntohs(addr4->sin_port);
ip.addr = addr4->sin_addr.s_addr;
LWIP_LOG("tcp_bind(%p ip=%u port=%u)", pcb_, ip.addr, port);
#endif
err_t err = tcp_bind(pcb_, &ip, port);
if (err == ERR_USE) {
LWIP_LOG(" -> err ERR_USE");
errno = EADDRINUSE;
return -1;
}
if (err == ERR_VAL) {
LWIP_LOG(" -> err ERR_VAL");
errno = EINVAL;
return -1;
}
if (err != ERR_OK) {
LWIP_LOG(" -> err %d", err);
errno = EIO;
return -1;
}
return 0;
}
int close() override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
LWIP_LOG("tcp_close(%p)", pcb_);
err_t err = tcp_close(pcb_);
if (err != ERR_OK) {
LWIP_LOG(" -> err %d", err);
tcp_abort(pcb_);
pcb_ = nullptr;
errno = err == ERR_MEM ? ENOMEM : EIO;
return -1;
}
pcb_ = nullptr;
return 0;
}
int shutdown(int how) override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
bool shut_rx = false, shut_tx = false;
if (how == SHUT_RD) {
shut_rx = true;
} else if (how == SHUT_WR) {
shut_tx = true;
} else if (how == SHUT_RDWR) {
shut_rx = shut_tx = true;
} else {
errno = EINVAL;
return -1;
}
LWIP_LOG("tcp_shutdown(%p shut_rx=%d shut_tx=%d)", pcb_, shut_rx ? 1 : 0, shut_tx ? 1 : 0);
err_t err = tcp_shutdown(pcb_, shut_rx, shut_tx);
if (err != ERR_OK) {
LWIP_LOG(" -> err %d", err);
errno = err == ERR_MEM ? ENOMEM : EIO;
return -1;
}
return 0;
}
int getpeername(struct sockaddr *name, socklen_t *addrlen) override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
if (name == nullptr || addrlen == nullptr) {
errno = EINVAL;
return -1;
}
return this->ip2sockaddr_(&pcb_->local_ip, pcb_->local_port, name, addrlen);
}
std::string getpeername() override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return "";
}
char buffer[50] = {};
if (IP_IS_V4_VAL(pcb_->remote_ip)) {
inet_ntoa_r(pcb_->remote_ip, buffer, sizeof(buffer));
}
#if LWIP_IPV6
else if (IP_IS_V6_VAL(pcb_->remote_ip)) {
inet6_ntoa_r(pcb_->remote_ip, buffer, sizeof(buffer));
}
#endif
return std::string(buffer);
}
int getsockname(struct sockaddr *name, socklen_t *addrlen) override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
if (name == nullptr || addrlen == nullptr) {
errno = EINVAL;
return -1;
}
return this->ip2sockaddr_(&pcb_->local_ip, pcb_->local_port, name, addrlen);
}
std::string getsockname() override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return "";
}
char buffer[50] = {};
if (IP_IS_V4_VAL(pcb_->local_ip)) {
inet_ntoa_r(pcb_->local_ip, buffer, sizeof(buffer));
}
#if LWIP_IPV6
else if (IP_IS_V6_VAL(pcb_->local_ip)) {
inet6_ntoa_r(pcb_->local_ip, buffer, sizeof(buffer));
}
#endif
return std::string(buffer);
}
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
if (optlen == nullptr || optval == nullptr) {
errno = EINVAL;
return -1;
}
if (level == SOL_SOCKET && optname == SO_REUSEADDR) {
if (*optlen < 4) {
errno = EINVAL;
return -1;
}
// lwip doesn't seem to have this feature. Don't send an error
// to prevent warnings
*reinterpret_cast<int *>(optval) = 1;
*optlen = 4;
return 0;
}
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
if (*optlen < 4) {
errno = EINVAL;
return -1;
}
*reinterpret_cast<int *>(optval) = nodelay_;
*optlen = 4;
return 0;
}
errno = EINVAL;
return -1;
}
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
if (level == SOL_SOCKET && optname == SO_REUSEADDR) {
if (optlen != 4) {
errno = EINVAL;
return -1;
}
// lwip doesn't seem to have this feature. Don't send an error
// to prevent warnings
return 0;
}
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
if (optlen != 4) {
errno = EINVAL;
return -1;
}
int val = *reinterpret_cast<const int *>(optval);
nodelay_ = val;
return 0;
}
errno = EINVAL;
return -1;
}
int listen(int backlog) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
LWIP_LOG("tcp_listen_with_backlog(%p backlog=%d)", pcb_, backlog);
struct tcp_pcb *listen_pcb = tcp_listen_with_backlog(pcb_, backlog);
if (listen_pcb == nullptr) {
tcp_abort(pcb_);
pcb_ = nullptr;
errno = EOPNOTSUPP;
return -1;
}
// tcp_listen reallocates the pcb, replace ours
pcb_ = listen_pcb;
// set callbacks on new pcb
LWIP_LOG("tcp_arg(%p)", pcb_);
tcp_arg(pcb_, this);
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
return 0;
}
ssize_t read(void *buf, size_t len) override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
if (rx_closed_ && rx_buf_ == nullptr) {
return 0;
}
if (len == 0) {
return 0;
}
if (rx_buf_ == nullptr) {
errno = EWOULDBLOCK;
return -1;
}
size_t read = 0;
uint8_t *buf8 = reinterpret_cast<uint8_t *>(buf);
while (len && rx_buf_ != nullptr) {
size_t pb_len = rx_buf_->len;
size_t pb_left = pb_len - rx_buf_offset_;
if (pb_left == 0)
break;
size_t copysize = std::min(len, pb_left);
memcpy(buf8, reinterpret_cast<uint8_t *>(rx_buf_->payload) + rx_buf_offset_, copysize);
if (pb_left == copysize) {
// full pb copied, free it
if (rx_buf_->next == nullptr) {
// last buffer in chain
pbuf_free(rx_buf_);
rx_buf_ = nullptr;
rx_buf_offset_ = 0;
} else {
auto *old_buf = rx_buf_;
rx_buf_ = rx_buf_->next;
pbuf_ref(rx_buf_);
pbuf_free(old_buf);
rx_buf_offset_ = 0;
}
} else {
rx_buf_offset_ += copysize;
}
LWIP_LOG("tcp_recved(%p %u)", pcb_, copysize);
tcp_recved(pcb_, copysize);
buf8 += copysize;
len -= copysize;
read += copysize;
}
if (read == 0) {
errno = EWOULDBLOCK;
return -1;
}
return read;
}
ssize_t readv(const struct iovec *iov, int iovcnt) override {
ssize_t ret = 0;
for (int i = 0; i < iovcnt; i++) {
ssize_t err = read(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len);
if (err == -1) {
if (ret != 0)
// if we already read some don't return an error
break;
return err;
}
ret += err;
if ((size_t) err != iov[i].iov_len)
break;
}
return ret;
}
ssize_t internal_write(const void *buf, size_t len) {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
if (len == 0)
return 0;
if (buf == nullptr) {
errno = EINVAL;
return 0;
}
auto space = tcp_sndbuf(pcb_);
if (space == 0) {
errno = EWOULDBLOCK;
return -1;
}
size_t to_send = std::min((size_t) space, len);
LWIP_LOG("tcp_write(%p buf=%p %u)", pcb_, buf, to_send);
err_t err = tcp_write(pcb_, buf, to_send, TCP_WRITE_FLAG_COPY);
if (err == ERR_MEM) {
LWIP_LOG(" -> err ERR_MEM");
errno = EWOULDBLOCK;
return -1;
}
if (err != ERR_OK) {
LWIP_LOG(" -> err %d", err);
errno = ECONNRESET;
return -1;
}
return to_send;
}
int internal_output() {
LWIP_LOG("tcp_output(%p)", pcb_);
err_t err = tcp_output(pcb_);
if (err == ERR_ABRT) {
LWIP_LOG(" -> err ERR_ABRT");
// sometimes lwip returns ERR_ABRT for no apparent reason
// the connection works fine afterwards, and back with ESPAsyncTCP we
// indirectly also ignored this error
// FIXME: figure out where this is returned and what it means in this context
return 0;
}
if (err != ERR_OK) {
LWIP_LOG(" -> err %d", err);
errno = ECONNRESET;
return -1;
}
return 0;
}
ssize_t write(const void *buf, size_t len) override {
ssize_t written = internal_write(buf, len);
if (written == -1)
return -1;
if (written == 0)
// no need to output if nothing written
return 0;
if (nodelay_) {
int err = internal_output();
if (err == -1)
return -1;
}
return written;
}
ssize_t writev(const struct iovec *iov, int iovcnt) override {
ssize_t written = 0;
for (int i = 0; i < iovcnt; i++) {
ssize_t err = internal_write(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len);
if (err == -1) {
if (written != 0)
// if we already read some don't return an error
break;
return err;
}
written += err;
if ((size_t) err != iov[i].iov_len)
break;
}
if (written == 0)
// no need to output if nothing written
return 0;
if (nodelay_) {
int err = internal_output();
if (err == -1)
return -1;
}
return written;
}
ssize_t sendto(const void *buf, size_t len, int flags, const struct sockaddr *to, socklen_t tolen) override {
// return ::sendto(fd_, buf, len, flags, to, tolen);
errno = ENOSYS;
return -1;
}
int setblocking(bool blocking) override {
if (pcb_ == nullptr) {
errno = ECONNRESET;
return -1;
}
if (blocking) {
// blocking operation not supported
errno = EINVAL;
return -1;
}
return 0;
}
err_t accept_fn(struct tcp_pcb *newpcb, err_t err) {
LWIP_LOG("accept(newpcb=%p err=%d)", newpcb, err);
if (err != ERR_OK || newpcb == nullptr) {
// "An error code if there has been an error accepting. Only return ERR_ABRT if you have
// called tcp_abort from within the callback function!"
// https://www.nongnu.org/lwip/2_1_x/tcp_8h.html#a00517abce6856d6c82f0efebdafb734d
// nothing to do here, we just don't push it to the queue
return ERR_OK;
}
// Check if we've reached the maximum accept queue size
if (this->accepted_socket_count_ >= MAX_ACCEPTED_SOCKETS) {
LWIP_LOG("Rejecting connection, queue full (%d)", this->accepted_socket_count_);
// Abort the connection when queue is full
tcp_abort(newpcb);
// Must return ERR_ABRT since we called tcp_abort()
return ERR_ABRT;
}
auto sock = make_unique<LWIPRawImpl>(family_, newpcb);
sock->init();
this->accepted_sockets_[this->accepted_socket_count_++] = std::move(sock);
LWIP_LOG("Accepted connection, queue size: %d", this->accepted_socket_count_);
return ERR_OK;
}
void err_fn(err_t err) {
LWIP_LOG("err(err=%d)", err);
// "If a connection is aborted because of an error, the application is alerted of this event by
// the err callback."
// pcb is already freed when this callback is called
// ERR_RST: connection was reset by remote host
// ERR_ABRT: aborted through tcp_abort or TCP timer
pcb_ = nullptr;
}
err_t recv_fn(struct pbuf *pb, err_t err) {
LWIP_LOG("recv(pb=%p err=%d)", pb, err);
if (err != 0) {
// "An error code if there has been an error receiving Only return ERR_ABRT if you have
// called tcp_abort from within the callback function!"
rx_closed_ = true;
return ERR_OK;
}
if (pb == nullptr) {
rx_closed_ = true;
return ERR_OK;
}
if (rx_buf_ == nullptr) {
// no need to copy because lwIP gave control of it to us
rx_buf_ = pb;
rx_buf_offset_ = 0;
} else {
pbuf_cat(rx_buf_, pb);
}
return ERR_OK;
}
static err_t s_accept_fn(void *arg, struct tcp_pcb *newpcb, err_t err) {
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
return arg_this->accept_fn(newpcb, err);
}
static void s_err_fn(void *arg, err_t err) {
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
arg_this->err_fn(err);
}
static err_t s_recv_fn(void *arg, struct tcp_pcb *pcb, struct pbuf *pb, err_t err) {
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
return arg_this->recv_fn(pb, err);
}
protected:
int ip2sockaddr_(ip_addr_t *ip, uint16_t port, struct sockaddr *name, socklen_t *addrlen) {
if (family_ == AF_INET) {
if (*addrlen < sizeof(struct sockaddr_in)) {
errno = EINVAL;
return -1;
}
struct sockaddr_in *addr = reinterpret_cast<struct sockaddr_in *>(name);
addr->sin_family = AF_INET;
*addrlen = addr->sin_len = sizeof(struct sockaddr_in);
addr->sin_port = port;
inet_addr_from_ip4addr(&addr->sin_addr, ip_2_ip4(ip));
return 0;
}
#if LWIP_IPV6
else if (family_ == AF_INET6) {
if (*addrlen < sizeof(struct sockaddr_in6)) {
errno = EINVAL;
return -1;
}
struct sockaddr_in6 *addr = reinterpret_cast<struct sockaddr_in6 *>(name);
addr->sin6_family = AF_INET6;
*addrlen = addr->sin6_len = sizeof(struct sockaddr_in6);
addr->sin6_port = port;
// AF_INET6 sockets are bound to IPv4 as well, so we may encounter IPv4 addresses that must be converted to IPv6.
if (IP_IS_V4(ip)) {
ip_addr_t mapped;
ip4_2_ipv4_mapped_ipv6(ip_2_ip6(&mapped), ip_2_ip4(ip));
inet6_addr_from_ip6addr(&addr->sin6_addr, ip_2_ip6(&mapped));
} else {
inet6_addr_from_ip6addr(&addr->sin6_addr, ip_2_ip6(ip));
}
return 0;
}
#endif
return -1;
}
struct tcp_pcb *pcb_;
// Accept queue - holds incoming connections briefly until the event loop calls accept()
// This is NOT a connection pool - just a temporary queue between LWIP callbacks and the main loop
// 3 slots is plenty since connections are pulled out quickly by the event loop
//
// Memory analysis: std::array<3> vs original std::queue implementation:
// - std::queue uses std::deque internally which on 32-bit systems needs:
// 24 bytes (deque object) + 32+ bytes (map array) + heap allocations
// Total: ~56+ bytes minimum, plus heap fragmentation
// - std::array<3>: 12 bytes fixed (3 pointers × 4 bytes)
// Saves ~44+ bytes RAM per listening socket + avoids ALL heap allocations
// Used on ESP8266 and RP2040 (platforms using LWIP_TCP implementation)
static constexpr size_t MAX_ACCEPTED_SOCKETS = 3;
std::array<std::unique_ptr<LWIPRawImpl>, MAX_ACCEPTED_SOCKETS> accepted_sockets_;
uint8_t accepted_socket_count_ = 0; // Number of sockets currently in queue
bool rx_closed_ = false;
pbuf *rx_buf_ = nullptr;
size_t rx_buf_offset_ = 0;
// don't use lwip nodelay flag, it sometimes causes reconnect
// instead use it for determining whether to call lwip_output
bool nodelay_ = false;
sa_family_t family_ = 0;
};
std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
auto *pcb = tcp_new();
if (pcb == nullptr)
return nullptr;
auto *sock = new LWIPRawImpl((sa_family_t) domain, pcb); // NOLINT(cppcoreguidelines-owning-memory)
sock->init();
return std::unique_ptr<Socket>{sock};
}
std::unique_ptr<Socket> socket_loop_monitored(int domain, int type, int protocol) {
// LWIPRawImpl doesn't use file descriptors, so monitoring is not applicable
return socket(domain, type, protocol);
}
} // namespace socket
} // namespace esphome
#endif // USE_SOCKET_IMPL_LWIP_TCP