Advertisement
Guest User

Untitled

a guest
Nov 4th, 2016
599
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 16.29 KB | None | 0 0
  1. /*
  2.   WiFiClientSecure.cpp - Variant of WiFiClient with TLS support
  3.   Copyright (c) 2015 Ivan Grokhotkov. All rights reserved.
  4.   This file is part of the esp8266 core for Arduino environment.
  5.  
  6.  
  7.   This library is free software; you can redistribute it and/or
  8.   modify it under the terms of the GNU Lesser General Public
  9.   License as published by the Free Software Foundation; either
  10.   version 2.1 of the License, or (at your option) any later version.
  11.  
  12.   This library is distributed in the hope that it will be useful,
  13.   but WITHOUT ANY WARRANTY; without even the implied warranty of
  14.   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
  15.   Lesser General Public License for more details.
  16.  
  17.   You should have received a copy of the GNU Lesser General Public
  18.   License along with this library; if not, write to the Free Software
  19.   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
  20.  
  21. */
  22.  
  23. #define LWIP_INTERNAL
  24.  
  25. extern "C"
  26. {
  27. #include "osapi.h"
  28. #include "ets_sys.h"
  29. }
  30. #include <errno.h>
  31. #include "debug.h"
  32. #include "ESP8266WiFi.h"
  33. #include "WiFiClientSecure.h"
  34. #include "WiFiClient.h"
  35. #include "lwip/opt.h"
  36. #include "lwip/ip.h"
  37. #include "lwip/tcp.h"
  38. #include "lwip/inet.h"
  39. #include "lwip/netif.h"
  40. #include "include/ClientContext.h"
  41. #include "c_types.h"
  42.  
  43. #ifdef DEBUG_ESP_SSL
  44. #define DEBUG_SSL
  45. #endif
  46.  
  47. #ifdef DEBUG_SSL
  48. #define SSL_DEBUG_OPTS SSL_DISPLAY_STATES
  49. #else
  50. #define SSL_DEBUG_OPTS 0
  51. #endif
  52.  
  53. class SSLContext
  54. {
  55. public:
  56.     SSLContext()
  57.     {
  58.         if (_ssl_ctx_refcnt == 0) {
  59.             _ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
  60.         }
  61.         ++_ssl_ctx_refcnt;
  62.     }
  63.  
  64.     ~SSLContext()
  65.     {
  66.         if (_ssl) {
  67.             ssl_free(_ssl);
  68.             _ssl = nullptr;
  69.         }
  70.  
  71.         --_ssl_ctx_refcnt;
  72.         if (_ssl_ctx_refcnt == 0) {
  73.             ssl_ctx_free(_ssl_ctx);
  74.         }
  75.  
  76.         s_io_ctx = nullptr;
  77.     }
  78.  
  79.     void ref()
  80.     {
  81.         ++_refcnt;
  82.     }
  83.  
  84.     void unref()
  85.     {
  86.         if (--_refcnt == 0) {
  87.             delete this;
  88.         }
  89.     }
  90.  
  91.     void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms)
  92.     {
  93.         s_io_ctx = ctx;
  94.         _ssl = ssl_client_new(_ssl_ctx, 0, nullptr, 0, hostName);
  95.         uint32_t t = millis();
  96.  
  97.         while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) {
  98.             uint8_t* data;
  99.             int rc = ssl_read(_ssl, &data);
  100.             if (rc < SSL_OK) {
  101.                 break;
  102.             }
  103.         }
  104.     }
  105.  
  106.     void stop()
  107.     {
  108.         s_io_ctx = nullptr;
  109.     }
  110.  
  111.     bool connected()
  112.     {
  113.         return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK;
  114.     }
  115.  
  116.     int read(uint8_t* dst, size_t size)
  117.     {
  118.         if (!_available) {
  119.             if (!_readAll()) {
  120.                 return 0;
  121.             }
  122.         }
  123.         size_t will_copy = (_available < size) ? _available : size;
  124.         memcpy(dst, _read_ptr, will_copy);
  125.         _read_ptr += will_copy;
  126.         _available -= will_copy;
  127.         if (_available == 0) {
  128.             _read_ptr = nullptr;
  129.         }
  130.         return will_copy;
  131.     }
  132.  
  133.     int read()
  134.     {
  135.         if (!_available) {
  136.             if (!_readAll()) {
  137.                 return -1;
  138.             }
  139.         }
  140.         int result = _read_ptr[0];
  141.         ++_read_ptr;
  142.         --_available;
  143.         if (_available == 0) {
  144.             _read_ptr = nullptr;
  145.         }
  146.         return result;
  147.     }
  148.  
  149.     int peek()
  150.     {
  151.         if (!_available) {
  152.             if (!_readAll()) {
  153.                 return -1;
  154.             }
  155.         }
  156.         return _read_ptr[0];
  157.     }
  158.  
  159.     size_t peekBytes(char *dst, size_t size)
  160.     {
  161.         if (!_available) {
  162.             if (!_readAll()) {
  163.                 return -1;
  164.             }
  165.         }
  166.  
  167.         size_t will_copy = (_available < size) ? _available : size;
  168.         memcpy(dst, _read_ptr, will_copy);
  169.         return will_copy;
  170.     }
  171.  
  172.     int available()
  173.     {
  174.         auto cb = _available;
  175.         if (cb == 0) {
  176.             cb = _readAll();
  177.         } else {
  178.             optimistic_yield(100);
  179.         }
  180.         return cb;
  181.     }
  182.  
  183.     bool loadObject(int type, Stream& stream, size_t size)
  184.     {
  185.         std::unique_ptr<uint8_t[]> buf(new uint8_t[size]);
  186.         if (!buf.get()) {
  187.             DEBUGV("loadObject: failed to allocate memory\n");
  188.             return false;
  189.         }
  190.  
  191.         size_t cb = stream.readBytes(buf.get(), size);
  192.         if (cb != size) {
  193.             DEBUGV("loadObject: reading %u bytes, got %u\n", size, cb);
  194.             return false;
  195.         }
  196.  
  197.         return loadObject(type, buf.get(), size);
  198.     }
  199.  
  200.     bool loadObject(int type, const uint8_t* data, size_t size)
  201.     {
  202.         int rc = ssl_obj_memory_load(_ssl_ctx, type, data, static_cast<int>(size), nullptr);
  203.         if (rc != SSL_OK) {
  204.             DEBUGV("loadObject: ssl_obj_memory_load returned %d\n", rc);
  205.             return false;
  206.         }
  207.         return true;
  208.     }
  209.  
  210.     operator SSL*()
  211.     {
  212.         return _ssl;
  213.     }
  214.  
  215.     static ClientContext* getIOContext(int fd)
  216.     {
  217.         return s_io_ctx;
  218.     }
  219.  
  220. protected:
  221.     int _readAll()
  222.     {
  223.         if (!_ssl) {
  224.             return 0;
  225.         }
  226.  
  227.         optimistic_yield(100);
  228.  
  229.         uint8_t* data;
  230.         int rc = ssl_read(_ssl, &data);
  231.         if (rc <= 0) {
  232.             if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) {
  233.                 ssl_free(_ssl);
  234.                 _ssl = nullptr;
  235.             }
  236.             return 0;
  237.         }
  238.         DEBUGV(":wcs ra %d", rc);
  239.         _read_ptr = data;
  240.         _available = rc;
  241.         return _available;
  242.     }
  243.  
  244.     static SSL_CTX* _ssl_ctx;
  245.     static int _ssl_ctx_refcnt;
  246.     SSL* _ssl = nullptr;
  247.     int _refcnt = 0;
  248.     const uint8_t* _read_ptr = nullptr;
  249.     size_t _available = 0;
  250.     static ClientContext* s_io_ctx;
  251. };
  252.  
  253. SSL_CTX* SSLContext::_ssl_ctx = nullptr;
  254. int SSLContext::_ssl_ctx_refcnt = 0;
  255. ClientContext* SSLContext::s_io_ctx = nullptr;
  256.  
  257. WiFiClientSecure::WiFiClientSecure()
  258. {
  259. }
  260.  
  261. WiFiClientSecure::~WiFiClientSecure()
  262. {
  263.     if (_ssl) {
  264.         _ssl->unref();
  265.     }
  266. }
  267.  
  268. WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other)
  269.     : WiFiClient(static_cast<const WiFiClient&>(other))
  270. {
  271.     _ssl = other._ssl;
  272.     if (_ssl) {
  273.         _ssl->ref();
  274.     }
  275. }
  276.  
  277. WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs)
  278. {
  279.     (WiFiClient&) *this = rhs;
  280.     _ssl = rhs._ssl;
  281.     if (_ssl) {
  282.         _ssl->ref();
  283.     }
  284.     return *this;
  285. }
  286.  
  287. int WiFiClientSecure::connect(IPAddress ip, uint16_t port)
  288. {
  289.     if (!WiFiClient::connect(ip, port)) {
  290.         return 0;
  291.     }
  292.  
  293.     return _connectSSL(nullptr);
  294. }
  295.  
  296. int WiFiClientSecure::connect(const char* name, uint16_t port)
  297. {
  298.     IPAddress remote_addr;
  299.     if (!WiFi.hostByName(name, remote_addr)) {
  300.         return 0;
  301.     }
  302.     if (!WiFiClient::connect(remote_addr, port)) {
  303.         return 0;
  304.     }
  305.     return _connectSSL(name);
  306. }
  307.  
  308. int WiFiClientSecure::_connectSSL(const char* hostName)
  309. {
  310.  //   if (_ssl) {
  311.   //      _ssl->unref();
  312.   //      _ssl = nullptr;
  313. //    }
  314.  
  315.  //   _ssl = new SSLContext;
  316. //    _ssl->ref();
  317.     _ssl->connect(_client, hostName, 5000);
  318.  
  319.     auto status = ssl_handshake_status(*_ssl);
  320.     if (status != SSL_OK) {
  321.         _ssl->unref();
  322.         _ssl = nullptr;
  323.         return 0;
  324.     }
  325.  
  326.     return 1;
  327. }
  328.  
  329. size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
  330. {
  331.     if (!_ssl) {
  332.         return 0;
  333.     }
  334.  
  335.     int rc = ssl_write(*_ssl, buf, size);
  336.     if (rc >= 0) {
  337.         return rc;
  338.     }
  339.  
  340.     if (rc != SSL_CLOSE_NOTIFY) {
  341.         _ssl->unref();
  342.         _ssl = nullptr;
  343.     }
  344.  
  345.     return 0;
  346. }
  347.  
  348. int WiFiClientSecure::read(uint8_t *buf, size_t size)
  349. {
  350.     if (!_ssl) {
  351.         return 0;
  352.     }
  353.  
  354.     return _ssl->read(buf, size);
  355. }
  356.  
  357. int WiFiClientSecure::read()
  358. {
  359.     if (!_ssl) {
  360.         return -1;
  361.     }
  362.  
  363.     return _ssl->read();
  364. }
  365.  
  366. int WiFiClientSecure::peek()
  367. {
  368.     if (!_ssl) {
  369.         return -1;
  370.     }
  371.  
  372.     return _ssl->peek();
  373. }
  374.  
  375. size_t WiFiClientSecure::peekBytes(uint8_t *buffer, size_t length)
  376. {
  377.     size_t count = 0;
  378.  
  379.     if (!_ssl) {
  380.         return 0;
  381.     }
  382.  
  383.     _startMillis = millis();
  384.     while ((available() < (int) length) && ((millis() - _startMillis) < _timeout)) {
  385.         yield();
  386.     }
  387.  
  388.     if (!_ssl) {
  389.         return 0;
  390.     }
  391.  
  392.     if (available() < (int) length) {
  393.         count = available();
  394.     } else {
  395.         count = length;
  396.     }
  397.  
  398.     return _ssl->peekBytes((char *)buffer, count);
  399. }
  400.  
  401. int WiFiClientSecure::available()
  402. {
  403.     if (!_ssl) {
  404.         return 0;
  405.     }
  406.  
  407.     return _ssl->available();
  408. }
  409.  
  410.  
  411. /*
  412. SSL     TCP     RX data     connected
  413. null    x       x           N
  414. !null   x       Y           Y
  415. Y       Y       x           Y
  416. x       N       N           N
  417. err     x       N           N
  418. */
  419. uint8_t WiFiClientSecure::connected()
  420. {
  421.     if (_ssl) {
  422.         if (_ssl->available()) {
  423.             return true;
  424.         }
  425.         if (_client && _client->state() == ESTABLISHED && _ssl->connected()) {
  426.             return true;
  427.         }
  428.     }
  429.     return false;
  430. }
  431.  
  432. void WiFiClientSecure::stop()
  433. {
  434.     if (_ssl) {
  435.         _ssl->stop();
  436.     }
  437.     WiFiClient::stop();
  438. }
  439.  
  440. static bool parseHexNibble(char pb, uint8_t* res)
  441. {
  442.     if (pb >= '0' && pb <= '9') {
  443.         *res = (uint8_t) (pb - '0'); return true;
  444.     } else if (pb >= 'a' && pb <= 'f') {
  445.         *res = (uint8_t) (pb - 'a' + 10); return true;
  446.     } else if (pb >= 'A' && pb <= 'F') {
  447.         *res = (uint8_t) (pb - 'A' + 10); return true;
  448.     }
  449.     return false;
  450. }
  451.  
  452. // Compare a name from certificate and domain name, return true if they match
  453. static bool matchName(const String& name, const String& domainName)
  454. {
  455.     int wildcardPos = name.indexOf('*');
  456.     if (wildcardPos == -1) {
  457.         // Not a wildcard, expect an exact match
  458.         return name == domainName;
  459.     }
  460.     int firstDotPos = name.indexOf('.');
  461.     if (wildcardPos > firstDotPos) {
  462.         // Wildcard is not part of leftmost component of domain name
  463.         // Do not attempt to match (rfc6125 6.4.3.1)
  464.         return false;
  465.     }
  466.     if (wildcardPos != 0 || firstDotPos != 1) {
  467.         // Matching of wildcards such as baz*.example.com and b*z.example.com
  468.         // is optional. Maybe implement this in the future?
  469.         return false;
  470.     }
  471.     int domainNameFirstDotPos = domainName.indexOf('.');
  472.     if (domainNameFirstDotPos < 0) {
  473.         return false;
  474.     }
  475.     return domainName.substring(domainNameFirstDotPos) == name.substring(firstDotPos);
  476. }
  477.  
  478. bool WiFiClientSecure::verify(const char* fp, const char* domain_name)
  479. {
  480.     if (!_ssl) {
  481.         return false;
  482.     }
  483.  
  484.     uint8_t sha1[20];
  485.     int len = strlen(fp);
  486.     int pos = 0;
  487.     for (size_t i = 0; i < sizeof(sha1); ++i) {
  488.         while (pos < len && ((fp[pos] == ' ') || (fp[pos] == ':'))) {
  489.             ++pos;
  490.         }
  491.         if (pos > len - 2) {
  492.             DEBUGV("pos:%d len:%d fingerprint too short\r\n", pos, len);
  493.             return false;
  494.         }
  495.         uint8_t high, low;
  496.         if (!parseHexNibble(fp[pos], &high) || !parseHexNibble(fp[pos+1], &low)) {
  497.             DEBUGV("pos:%d len:%d invalid hex sequence: %c%c\r\n", pos, len, fp[pos], fp[pos+1]);
  498.             return false;
  499.         }
  500.         pos += 2;
  501.         sha1[i] = low | (high << 4);
  502.     }
  503.     if (ssl_match_fingerprint(*_ssl, sha1) != 0) {
  504.         DEBUGV("fingerprint doesn't match\r\n");
  505.         return false;
  506.     }
  507.  
  508.     return _verifyDN(domain_name);
  509. }
  510.  
  511. bool WiFiClientSecure::_verifyDN(const char* domain_name)
  512. {
  513.     DEBUGV("domain name: '%s'\r\n", (domain_name)?domain_name:"(null)");
  514.     String domain_name_str(domain_name);
  515.     domain_name_str.toLowerCase();
  516.  
  517.     const char* san = NULL;
  518.     int i = 0;
  519.     while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) {
  520.         if (matchName(String(san), domain_name_str)) {
  521.             return true;
  522.         }
  523.         DEBUGV("SAN %d: '%s', no match\r\n", i, san);
  524.         ++i;
  525.     }
  526.     const char* common_name = ssl_get_cert_dn(*_ssl, SSL_X509_CERT_COMMON_NAME);
  527.     if (common_name && matchName(String(common_name), domain_name_str)) {
  528.         return true;
  529.     }
  530.     DEBUGV("CN: '%s', no match\r\n", (common_name)?common_name:"(null)");
  531.  
  532.     return false;
  533. }
  534.  
  535. bool WiFiClientSecure::verifyCertChain(const char* domain_name)
  536. {
  537.     if (!_ssl) {
  538.         return false;
  539.     }
  540.     int rc = ssl_verify_cert(*_ssl);
  541.     if (rc != SSL_OK) {
  542.         DEBUGV("ssl_verify_cert returned %d\n", rc);
  543.         return false;
  544.     }
  545.  
  546.     return _verifyDN(domain_name);
  547. }
  548.  
  549. void WiFiClientSecure::setCertificate(const uint8_t* cert_data, size_t size)
  550. {
  551.    
  552.  
  553.     if (!_ssl) {
  554.         return;
  555.     }
  556.     _ssl->loadObject(SSL_OBJ_X509_CERT, cert_data, size);
  557. }
  558.  
  559. void WiFiClientSecure::setPrivateKey(const uint8_t* pk, size_t size)
  560. {
  561.     if (!_ssl) {
  562.         return;
  563.     }
  564.     _ssl->loadObject(SSL_OBJ_RSA_KEY, pk, size);
  565. }
  566.  
  567. void WiFiClientSecure::setCACert(const uint8_t* pk, size_t size)
  568. {
  569.     if (!_ssl) {
  570.         return;
  571.     }
  572.     _ssl->loadObject(SSL_OBJ_X509_CACERT, pk, size);
  573. }
  574.  
  575. bool WiFiClientSecure::loadCACert(Stream& stream, size_t size)
  576. {
  577.     if (!_ssl) {
  578.         return false;
  579.     }
  580.     return _ssl->loadObject(SSL_OBJ_X509_CACERT, stream, size);
  581. }
  582.  
  583. bool WiFiClientSecure::loadCertificate(Stream& stream, size_t size)
  584. {
  585.     _ssl = new SSLContext;
  586.     _ssl->ref();
  587.    
  588.     if (!_ssl) {
  589.         return false;
  590.     }
  591.     return _ssl->loadObject(SSL_OBJ_X509_CERT, stream, size);
  592. }
  593.  
  594. bool WiFiClientSecure::loadPrivateKey(Stream& stream, size_t size)
  595. {
  596.     if (!_ssl) {
  597.         return false;
  598.     }
  599.     return _ssl->loadObject(SSL_OBJ_RSA_KEY, stream, size);
  600. }
  601.  
  602. extern "C" int __ax_port_read(int fd, uint8_t* buffer, size_t count)
  603. {
  604.     ClientContext* _client = SSLContext::getIOContext(fd);
  605.     if (!_client || _client->state() != ESTABLISHED && !_client->getSize()) {
  606.         errno = EIO;
  607.         return -1;
  608.     }
  609.     size_t cb = _client->read((char*) buffer, count);
  610.     if (cb != count) {
  611.         errno = EAGAIN;
  612.     }
  613.     if (cb == 0) {
  614.         optimistic_yield(100);
  615.         return -1;
  616.     }
  617.     return cb;
  618. }
  619. extern "C" void ax_port_read() __attribute__ ((weak, alias("__ax_port_read")));
  620.  
  621. extern "C" int __ax_port_write(int fd, uint8_t* buffer, size_t count)
  622. {
  623.     ClientContext* _client = SSLContext::getIOContext(fd);
  624.     if (!_client || _client->state() != ESTABLISHED) {
  625.         errno = EIO;
  626.         return -1;
  627.     }
  628.  
  629.     size_t cb = _client->write(buffer, count);
  630.     if (cb != count) {
  631.         errno = EAGAIN;
  632.     }
  633.     return cb;
  634. }
  635. extern "C" void ax_port_write() __attribute__ ((weak, alias("__ax_port_write")));
  636.  
  637. extern "C" int __ax_get_file(const char *filename, uint8_t **buf)
  638. {
  639.     *buf = 0;
  640.     return 0;
  641. }
  642. extern "C" void ax_get_file() __attribute__ ((weak, alias("__ax_get_file")));
  643.  
  644.  
  645. #ifdef DEBUG_TLS_MEM
  646. #define DEBUG_TLS_MEM_PRINT(...) DEBUGV(__VA_ARGS__)
  647. #else
  648. #define DEBUG_TLS_MEM_PRINT(...)
  649. #endif
  650.  
  651. extern "C" void* ax_port_malloc(size_t size, const char* file, int line)
  652. {
  653.     void* result = malloc(size);
  654.     if (result == nullptr) {
  655.         DEBUG_TLS_MEM_PRINT("%s:%d malloc %d failed, left %d\r\n", file, line, size, ESP.getFreeHeap());
  656.     }
  657.     if (size >= 1024) {
  658.         DEBUG_TLS_MEM_PRINT("%s:%d malloc %d, left %d\r\n", file, line, size, ESP.getFreeHeap());
  659.     }
  660.     return result;
  661. }
  662.  
  663. extern "C" void* ax_port_calloc(size_t size, size_t count, const char* file, int line)
  664. {
  665.     void* result = ax_port_malloc(size * count, file, line);
  666.     memset(result, 0, size * count);
  667.     return result;
  668. }
  669.  
  670. extern "C" void* ax_port_realloc(void* ptr, size_t size, const char* file, int line)
  671. {
  672.     void* result = realloc(ptr, size);
  673.     if (result == nullptr) {
  674.         DEBUG_TLS_MEM_PRINT("%s:%d realloc %d failed, left %d\r\n", file, line, size, ESP.getFreeHeap());
  675.     }
  676.     if (size >= 1024) {
  677.         DEBUG_TLS_MEM_PRINT("%s:%d realloc %d, left %d\r\n", file, line, size, ESP.getFreeHeap());
  678.     }
  679.     return result;
  680. }
  681.  
  682. extern "C" void ax_port_free(void* ptr)
  683. {
  684.     free(ptr);
  685. }
  686.  
  687. extern "C" void __ax_wdt_feed()
  688. {
  689.     optimistic_yield(10000);
  690. }
  691. extern "C" void ax_wdt_feed() __attribute__ ((weak, alias("__ax_wdt_feed")));
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement