444 lines
11 KiB
C++
444 lines
11 KiB
C++
/* Copyright (c) Kuba Szczodrzyński 2022-04-30. */
|
|
|
|
#include "MbedTLSClient.h"
|
|
|
|
#include <IPAddress.h>
|
|
#include <WiFi.h>
|
|
#include <WiFiClient.h>
|
|
|
|
#ifdef __cplusplus
|
|
extern "C" {
|
|
#endif // __cplusplus
|
|
|
|
#include <mbedtls/debug.h>
|
|
#include <mbedtls/platform.h>
|
|
#include <mbedtls/sha256.h>
|
|
#include <mbedtls/ssl.h>
|
|
|
|
#ifdef __cplusplus
|
|
} // extern "C"
|
|
#endif
|
|
|
|
MbedTLSClient::MbedTLSClient() : WiFiClient() {}
|
|
|
|
MbedTLSClient::MbedTLSClient(int sock) : WiFiClient(sock) {}
|
|
|
|
void MbedTLSClient::stop() {
|
|
WiFiClient::stop();
|
|
LT_V_SSL("Closing SSL connection");
|
|
|
|
if (_sslCfg.ca_chain) {
|
|
mbedtls_x509_crt_free(&_caCert);
|
|
}
|
|
if (_sslCfg.key_cert) {
|
|
mbedtls_x509_crt_free(&_clientCert);
|
|
mbedtls_pk_free(&_clientKey);
|
|
}
|
|
mbedtls_ssl_free(&_sslCtx);
|
|
mbedtls_ssl_config_free(&_sslCfg);
|
|
}
|
|
|
|
void MbedTLSClient::init() {
|
|
// Realtek AmbZ: init platform here to ensure HW crypto is initialized in ssl_init
|
|
mbedtls_platform_set_calloc_free(calloc, free);
|
|
mbedtls_ssl_init(&_sslCtx);
|
|
mbedtls_ssl_config_init(&_sslCfg);
|
|
}
|
|
|
|
int MbedTLSClient::connect(IPAddress ip, uint16_t port, int32_t timeout) {
|
|
return connect(ipToString(ip).c_str(), port, timeout) == 0;
|
|
}
|
|
|
|
int MbedTLSClient::connect(const char *host, uint16_t port, int32_t timeout) {
|
|
if (_pskIdentStr && _pskStr)
|
|
return connect(host, port, NULL, NULL, NULL, _pskIdentStr, _pskStr, _alpnProtocols) == 0;
|
|
return connect(host, port, _caCertStr, _clientCertStr, _clientKeyStr, NULL, NULL, _alpnProtocols) == 0;
|
|
}
|
|
|
|
int MbedTLSClient::connect(
|
|
IPAddress ip, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey
|
|
) {
|
|
return connect(ipToString(ip).c_str(), port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0;
|
|
}
|
|
|
|
int MbedTLSClient::connect(
|
|
const char *host, uint16_t port, const char *rootCABuf, const char *clientCert, const char *clientKey
|
|
) {
|
|
return connect(host, port, rootCABuf, clientCert, clientKey, NULL, NULL, _alpnProtocols) == 0;
|
|
}
|
|
|
|
int MbedTLSClient::connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psk) {
|
|
return connect(ipToString(ip).c_str(), port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0;
|
|
}
|
|
|
|
int MbedTLSClient::connect(const char *host, uint16_t port, const char *pskIdent, const char *psk) {
|
|
return connect(host, port, NULL, NULL, NULL, pskIdent, psk, _alpnProtocols) == 0;
|
|
}
|
|
|
|
static int ssl_random(void *data, unsigned char *output, size_t len) {
|
|
int *buf = (int *)output;
|
|
size_t i;
|
|
for (i = 0; len >= sizeof(int); len -= sizeof(int)) {
|
|
buf[i++] = rand();
|
|
}
|
|
if (len) {
|
|
int rem = rand();
|
|
unsigned char *pRem = (unsigned char *)&rem;
|
|
memcpy(output + i * sizeof(int), pRem, len);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
void debug_cb(void *ctx, int level, const char *file, int line, const char *str) {
|
|
LT_I("%04d: |%d| %s", line, level, str);
|
|
}
|
|
|
|
int MbedTLSClient::connect(
|
|
const char *host,
|
|
uint16_t port,
|
|
const char *rootCABuf,
|
|
const char *clientCert,
|
|
const char *clientKey,
|
|
const char *pskIdent,
|
|
const char *psk,
|
|
const char **alpnProtocols
|
|
) {
|
|
LT_D_SSL("Free heap before TLS: TODO");
|
|
|
|
if (!rootCABuf && !pskIdent && !psk && !_insecure && !_useRootCA)
|
|
return -1;
|
|
|
|
IPAddress addr = WiFi.hostByName(host);
|
|
if (!(uint32_t)addr)
|
|
return -1;
|
|
|
|
int ret = WiFiClient::connect(addr, port, _timeout);
|
|
if (ret < 0) {
|
|
LT_E("SSL socket failed");
|
|
return ret;
|
|
}
|
|
|
|
char *uid = "lt-ssl"; // TODO
|
|
|
|
LT_V_SSL("Init SSL");
|
|
init();
|
|
|
|
// mbedtls_debug_set_threshold(4);
|
|
// mbedtls_ssl_conf_dbg(&_sslCfg, debug_cb, NULL);
|
|
|
|
ret = mbedtls_ssl_config_defaults(
|
|
&_sslCfg,
|
|
MBEDTLS_SSL_IS_CLIENT,
|
|
MBEDTLS_SSL_TRANSPORT_STREAM,
|
|
MBEDTLS_SSL_PRESET_DEFAULT
|
|
);
|
|
LT_RET_NZ(ret);
|
|
|
|
#ifdef MBEDTLS_SSL_ALPN
|
|
if (alpnProtocols) {
|
|
ret = mbedtls_ssl_conf_alpn_protocols(&_sslCfg, alpnProtocols);
|
|
LT_RET_NZ(ret);
|
|
}
|
|
#endif
|
|
|
|
if (_insecure) {
|
|
mbedtls_ssl_conf_authmode(&_sslCfg, MBEDTLS_SSL_VERIFY_NONE);
|
|
} else if (rootCABuf) {
|
|
mbedtls_x509_crt_init(&_caCert);
|
|
mbedtls_ssl_conf_authmode(&_sslCfg, MBEDTLS_SSL_VERIFY_REQUIRED);
|
|
ret = mbedtls_x509_crt_parse(&_caCert, (const unsigned char *)rootCABuf, strlen(rootCABuf) + 1);
|
|
mbedtls_ssl_conf_ca_chain(&_sslCfg, &_caCert, NULL);
|
|
if (ret < 0) {
|
|
mbedtls_x509_crt_free(&_caCert);
|
|
LT_RET(ret);
|
|
}
|
|
} else if (_useRootCA) {
|
|
return -1; // not implemented
|
|
} else if (pskIdent && psk) {
|
|
#ifdef MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED
|
|
uint16_t len = strlen(psk);
|
|
if ((len & 1) != 0 || len > 2 * MBEDTLS_PSK_MAX_LEN) {
|
|
LT_E("PSK length invalid");
|
|
return -1;
|
|
}
|
|
unsigned char pskBin[MBEDTLS_PSK_MAX_LEN] = {};
|
|
for (uint8_t i = 0; i < len; i++) {
|
|
uint8_t c = psk[i];
|
|
c |= 0b00100000; // make lowercase
|
|
c -= '0' * (c >= '0' && c <= '9');
|
|
c -= ('a' - 10) * (c >= 'a' && c <= 'z');
|
|
if (c > 0xf)
|
|
return -1;
|
|
pskBin[i / 2] |= c << (4 * ((i & 1) ^ 1));
|
|
}
|
|
ret = mbedtls_ssl_conf_psk(&_sslCfg, pskBin, len / 2, (const unsigned char *)pskIdent, strlen(pskIdent));
|
|
LT_RET_NZ(ret);
|
|
#else
|
|
return -1;
|
|
#endif
|
|
} else {
|
|
return -1;
|
|
}
|
|
|
|
if (!_insecure && clientCert && clientKey) {
|
|
mbedtls_x509_crt_init(&_clientCert);
|
|
mbedtls_pk_init(&_clientKey);
|
|
LT_V_SSL("Loading client cert");
|
|
ret = mbedtls_x509_crt_parse(&_clientCert, (const unsigned char *)clientCert, strlen(clientCert) + 1);
|
|
if (ret < 0) {
|
|
mbedtls_x509_crt_free(&_clientCert);
|
|
LT_RET(ret);
|
|
}
|
|
LT_V_SSL("Loading private key");
|
|
ret = mbedtls_pk_parse_key(&_clientKey, (const unsigned char *)clientKey, strlen(clientKey) + 1, NULL, 0);
|
|
if (ret < 0) {
|
|
mbedtls_x509_crt_free(&_clientCert);
|
|
LT_RET(ret);
|
|
}
|
|
mbedtls_ssl_conf_own_cert(&_sslCfg, &_clientCert, &_clientKey);
|
|
}
|
|
|
|
LT_V_SSL("Setting TLS hostname");
|
|
ret = mbedtls_ssl_set_hostname(&_sslCtx, host);
|
|
LT_RET_NZ(ret);
|
|
|
|
mbedtls_ssl_conf_rng(&_sslCfg, ssl_random, NULL);
|
|
ret = mbedtls_ssl_setup(&_sslCtx, &_sslCfg);
|
|
LT_RET_NZ(ret);
|
|
|
|
_sockTls = fd();
|
|
mbedtls_ssl_set_bio(&_sslCtx, &_sockTls, mbedtls_net_send, mbedtls_net_recv, NULL);
|
|
|
|
LT_V_SSL("SSL handshake");
|
|
if (_handshakeTimeout == 0)
|
|
_handshakeTimeout = _timeout * 1000;
|
|
unsigned long start = millis();
|
|
while (ret = mbedtls_ssl_handshake(&_sslCtx)) {
|
|
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
|
|
LT_RET(ret);
|
|
}
|
|
if ((millis() - start) > _handshakeTimeout) {
|
|
LT_E("SSL handshake timeout");
|
|
return -1;
|
|
}
|
|
delay(2);
|
|
}
|
|
|
|
if (clientCert && clientKey) {
|
|
LT_D_SSL(
|
|
"Protocol %s, ciphersuite %s",
|
|
mbedtls_ssl_get_version(&_sslCtx),
|
|
mbedtls_ssl_get_ciphersuite(&_sslCtx)
|
|
);
|
|
ret = mbedtls_ssl_get_record_expansion(&_sslCtx);
|
|
if (ret >= 0)
|
|
LT_D_SSL("Record expansion: %d", ret);
|
|
else {
|
|
LT_W("Record expansion unknown");
|
|
}
|
|
}
|
|
|
|
LT_V_SSL("Verifying certificate");
|
|
ret = mbedtls_ssl_get_verify_result(&_sslCtx);
|
|
if (ret) {
|
|
char buf[512];
|
|
memset(buf, 0, sizeof(buf));
|
|
mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", ret);
|
|
LT_E("Failed to verify peer certificate! Verification info: %s", buf);
|
|
return ret;
|
|
}
|
|
|
|
if (rootCABuf)
|
|
mbedtls_x509_crt_free(&_caCert);
|
|
if (clientCert)
|
|
mbedtls_x509_crt_free(&_clientCert);
|
|
if (clientKey != NULL)
|
|
mbedtls_pk_free(&_clientKey);
|
|
return 0; // OK
|
|
}
|
|
|
|
size_t MbedTLSClient::write(const uint8_t *buf, size_t size) {
|
|
int ret = -1;
|
|
while ((ret = mbedtls_ssl_write(&_sslCtx, buf, size)) <= 0) {
|
|
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) {
|
|
LT_RET(ret);
|
|
}
|
|
delay(2);
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
int MbedTLSClient::available() {
|
|
bool peeked = _peeked >= 0;
|
|
if (!connected())
|
|
return peeked;
|
|
|
|
int ret = mbedtls_ssl_read(&_sslCtx, NULL, 0);
|
|
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE && ret < 0) {
|
|
stop();
|
|
return peeked ? peeked : ret;
|
|
}
|
|
return mbedtls_ssl_get_bytes_avail(&_sslCtx) + peeked;
|
|
}
|
|
|
|
int MbedTLSClient::read(uint8_t *buf, size_t size) {
|
|
bool peeked = false;
|
|
int toRead = available();
|
|
if ((!buf && size) || toRead <= 0)
|
|
return -1;
|
|
if (!size)
|
|
return 0;
|
|
if (_peeked >= 0) {
|
|
buf[0] = _peeked;
|
|
_peeked = -1;
|
|
size--;
|
|
toRead--;
|
|
if (!size || !toRead)
|
|
return 1;
|
|
buf++;
|
|
peeked = true;
|
|
}
|
|
|
|
int ret = mbedtls_ssl_read(&_sslCtx, buf, size);
|
|
if (ret < 0) {
|
|
stop();
|
|
return peeked ? peeked : ret;
|
|
}
|
|
return ret + peeked;
|
|
}
|
|
|
|
int MbedTLSClient::peek() {
|
|
if (_peeked >= 0)
|
|
return _peeked;
|
|
_peeked = timedRead();
|
|
return _peeked;
|
|
}
|
|
|
|
void MbedTLSClient::flush() {}
|
|
|
|
int MbedTLSClient::lastError(char *buf, const size_t size) {
|
|
return 0; // TODO (?)
|
|
}
|
|
|
|
void MbedTLSClient::setInsecure() {
|
|
_caCertStr = NULL;
|
|
_clientCertStr = NULL;
|
|
_clientKeyStr = NULL;
|
|
_pskIdentStr = NULL;
|
|
_pskStr = NULL;
|
|
_insecure = true;
|
|
}
|
|
|
|
void MbedTLSClient::setPreSharedKey(const char *pskIdent, const char *psk) {
|
|
_pskIdentStr = pskIdent;
|
|
_pskStr = psk;
|
|
}
|
|
|
|
void MbedTLSClient::setCACert(const char *rootCA) {
|
|
_caCertStr = rootCA;
|
|
}
|
|
|
|
void MbedTLSClient::setCertificate(const char *clientCA) {
|
|
_clientCertStr = clientCA;
|
|
}
|
|
|
|
void MbedTLSClient::setPrivateKey(const char *privateKey) {
|
|
_clientKeyStr = privateKey;
|
|
}
|
|
|
|
char *streamToStr(Stream &stream, size_t size) {
|
|
char *buf = (char *)malloc(size + 1);
|
|
if (!buf)
|
|
return NULL;
|
|
if (size != stream.readBytes(buf, size)) {
|
|
free(buf);
|
|
return NULL;
|
|
}
|
|
buf[size] = '\0';
|
|
return buf;
|
|
}
|
|
|
|
bool MbedTLSClient::loadCACert(Stream &stream, size_t size) {
|
|
char *str = streamToStr(stream, size);
|
|
if (str) {
|
|
_caCertStr = str;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool MbedTLSClient::loadCertificate(Stream &stream, size_t size) {
|
|
char *str = streamToStr(stream, size);
|
|
if (str) {
|
|
_clientCertStr = str;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool MbedTLSClient::loadPrivateKey(Stream &stream, size_t size) {
|
|
char *str = streamToStr(stream, size);
|
|
if (str) {
|
|
_clientKeyStr = str;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool MbedTLSClient::verify(const char *fingerprint, const char *domainName) {
|
|
uint8_t fpLocal[32] = {};
|
|
uint16_t len = strlen(fingerprint);
|
|
uint8_t byte = 0;
|
|
for (uint8_t i = 0; i < len; i++) {
|
|
uint8_t c = fingerprint[i];
|
|
while ((c == ' ' || c == ':') && i < len) {
|
|
c = fingerprint[++i];
|
|
}
|
|
c |= 0b00100000; // make lowercase
|
|
c -= '0' * (c >= '0' && c <= '9');
|
|
c -= ('a' - 10) * (c >= 'a' && c <= 'z');
|
|
if (c > 0xf)
|
|
return -1;
|
|
fpLocal[byte / 2] |= c << (4 * ((byte & 1) ^ 1));
|
|
byte++;
|
|
if (byte >= 64)
|
|
break;
|
|
}
|
|
|
|
uint8_t fpRemote[32];
|
|
if (!getFingerprintSHA256(fpRemote))
|
|
return false;
|
|
|
|
if (memcmp(fpLocal, fpRemote, 32)) {
|
|
LT_D_SSL("Fingerprints don't match");
|
|
return false;
|
|
}
|
|
|
|
if (!domainName)
|
|
return true;
|
|
// TODO domain name verification
|
|
return true;
|
|
}
|
|
|
|
void MbedTLSClient::setHandshakeTimeout(unsigned long handshakeTimeout) {
|
|
_handshakeTimeout = handshakeTimeout * 1000;
|
|
}
|
|
|
|
void MbedTLSClient::setAlpnProtocols(const char **alpnProtocols) {
|
|
_alpnProtocols = alpnProtocols;
|
|
}
|
|
|
|
bool MbedTLSClient::getFingerprintSHA256(uint8_t result[32]) {
|
|
const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&_sslCtx);
|
|
if (!cert) {
|
|
LT_E("Failed to get peer certificate");
|
|
return false;
|
|
}
|
|
mbedtls_sha256_context shaCtx;
|
|
mbedtls_sha256_init(&shaCtx);
|
|
mbedtls_sha256_starts(&shaCtx, false);
|
|
mbedtls_sha256_update(&shaCtx, cert->raw.p, cert->raw.len);
|
|
mbedtls_sha256_finish(&shaCtx, result);
|
|
return true;
|
|
}
|