mirror of
https://github.com/fpagliughi/sockpp.git
synced 2026-01-12 00:04:45 +08:00
Merge
This commit is contained in:
@@ -69,6 +69,8 @@ class connector : public stream_socket
|
||||
connector(const connector&) =delete;
|
||||
connector& operator=(const connector&) =delete;
|
||||
|
||||
bool recreate(const sock_address& addr);
|
||||
|
||||
public:
|
||||
/**
|
||||
* Creates an unconnected connector.
|
||||
@@ -80,6 +82,14 @@ public:
|
||||
* @param addr The remote server address.
|
||||
*/
|
||||
connector(const sock_address& addr) { connect(addr); }
|
||||
/**
|
||||
* Creates the connector and attempts to connect to the specified
|
||||
* address, with a timeout.
|
||||
* If the operation times out, the \ref last_error will be set to ETIMEOUT.
|
||||
* @param addr The remote server address.
|
||||
* @param t The duration after which to give up. Zero means never.
|
||||
*/
|
||||
connector(const sock_address& addr, std::chrono::milliseconds t) { connect(addr, t); }
|
||||
/**
|
||||
* Move constructor.
|
||||
* Creates a connector by moving the other connector to this one.
|
||||
@@ -112,6 +122,16 @@ public:
|
||||
* @return @em true on success, @em false on error
|
||||
*/
|
||||
bool connect(const sock_address& addr);
|
||||
/**
|
||||
* Attempts to connect to the specified server, with a timeout.
|
||||
* If the socket is currently connected, this will close the current
|
||||
* connection and open the new one.
|
||||
* If the operation times out, the \ref last_error will be set to ETIMEOUT.
|
||||
* @param addr The remote server address.
|
||||
* @param timeout The duration after which to give up. Zero means never.
|
||||
* @return @em true on success, @em false on error
|
||||
*/
|
||||
bool connect(const sock_address& addr, std::chrono::microseconds timeout);
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
@@ -182,6 +202,18 @@ public:
|
||||
* @return @em true on success, @em false on error
|
||||
*/
|
||||
bool connect(const addr_t& addr) { return base::connect(addr); }
|
||||
/**
|
||||
* Attempts to connect to the specified server, with a timeout.
|
||||
* If the socket is currently connected, this will close the current
|
||||
* connection and open the new one.
|
||||
* If the operation times out, the \ref last_error will be set to ETIMEOUT.
|
||||
* @param addr The remote server address.
|
||||
* @param timeout The duration after which to give up. Zero means never.
|
||||
* @return @em true on success, @em false on error
|
||||
*/
|
||||
bool connect(const addr_t& addr, std::chrono::microseconds timeout) {
|
||||
return base::connect(addr, timeout);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
120
include/sockpp/mbedtls_context.h
Normal file
120
include/sockpp/mbedtls_context.h
Normal file
@@ -0,0 +1,120 @@
|
||||
/**
|
||||
* @file mbedtls_socket.h
|
||||
*
|
||||
* TLS (SSL) socket implementation using mbedTLS.
|
||||
*
|
||||
* @author Jens Alfke
|
||||
* @author Couchbase, Inc.
|
||||
* @author www.couchbase.com
|
||||
*
|
||||
* @date August 2019
|
||||
*/
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// This file is part of the "sockpp" C++ socket library.
|
||||
//
|
||||
// Copyright (c) 2014-2019 Frank Pagliughi
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from this
|
||||
// software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
|
||||
// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
|
||||
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
|
||||
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
#ifndef __sockpp_mbedtls_socket_h
|
||||
#define __sockpp_mbedtls_socket_h
|
||||
|
||||
#include "sockpp/tls_context.h"
|
||||
#include "sockpp/tls_socket.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
struct mbedtls_pk_context;
|
||||
struct mbedtls_ssl_config;
|
||||
struct mbedtls_x509_crt;
|
||||
|
||||
namespace sockpp {
|
||||
|
||||
/**
|
||||
* A concrete implementation of \ref tls_context, using the mbedTLS library.
|
||||
* You probably don't want to use this class directly, unless you want to instantiate a
|
||||
* custom context so you can have different contexts for different sockets.
|
||||
*/
|
||||
class mbedtls_context : public tls_context {
|
||||
public:
|
||||
mbedtls_context(role_t = CLIENT);
|
||||
~mbedtls_context() override;
|
||||
|
||||
void set_root_certs(const std::string &certData) override;
|
||||
void require_peer_cert(role_t, bool) override;
|
||||
void allow_only_certificate(const std::string &certData) override;
|
||||
|
||||
void allow_only_certificate(mbedtls_x509_crt *certificate);
|
||||
|
||||
/**
|
||||
* Sets the identity certificate and private key using mbedTLS objects.
|
||||
*/
|
||||
void set_identity(mbedtls_x509_crt *certificate,
|
||||
mbedtls_pk_context *private_key);
|
||||
|
||||
void set_identity(const std::string &certificate_data,
|
||||
const std::string &private_key_data) override;
|
||||
|
||||
std::unique_ptr<tls_socket> wrap_socket(std::unique_ptr<stream_socket> socket,
|
||||
role_t,
|
||||
const std::string &peer_name) override;
|
||||
|
||||
role_t role();
|
||||
|
||||
static mbedtls_x509_crt* get_system_root_certs();
|
||||
|
||||
using Logger = std::function<void(int level, const char *filename, int line, const char *message)>;
|
||||
void set_logger(int threshold, Logger);
|
||||
|
||||
private:
|
||||
struct cert;
|
||||
struct key;
|
||||
|
||||
int verify_callback(mbedtls_x509_crt *crt, int depth, uint32_t *flags);
|
||||
static std::unique_ptr<cert> parse_cert(const std::string &cert_data, bool partialOk);
|
||||
|
||||
std::unique_ptr<mbedtls_ssl_config> ssl_config_;
|
||||
std::unique_ptr<cert> root_certs_;
|
||||
std::unique_ptr<cert> pinned_cert_;
|
||||
|
||||
std::unique_ptr<cert> identity_cert_;
|
||||
std::unique_ptr<key> identity_key_;
|
||||
Logger logger_;
|
||||
|
||||
static cert *s_system_root_certs;
|
||||
|
||||
friend class mbedtls_socket;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -78,11 +78,7 @@
|
||||
#ifndef _SSIZE_T_DEFINED
|
||||
#define _SSIZE_T_DEFINED
|
||||
#undef ssize_t
|
||||
#ifdef _WIN64
|
||||
using ssize_t = int64_t;
|
||||
#else
|
||||
using ssize_t = int;
|
||||
#endif // _WIN64
|
||||
using ssize_t = SSIZE_T;
|
||||
#endif // _SSIZE_T_DEFINED
|
||||
|
||||
#ifndef _SUSECONDS_T
|
||||
|
||||
@@ -415,7 +415,7 @@ public:
|
||||
* @param on Whether to turn non-blocking mode on or off.
|
||||
* @return @em true on success, @em false on failure.
|
||||
*/
|
||||
bool set_non_blocking(bool on=true);
|
||||
virtual bool set_non_blocking(bool on=true);
|
||||
/**
|
||||
* Gets a string describing the specified error.
|
||||
* This is typically the returned message from the system strerror().
|
||||
@@ -445,14 +445,16 @@ public:
|
||||
* @li SHUT_RDWR (2) Further reads and writes disallowed.
|
||||
* @return @em true on success, @em false on error.
|
||||
*/
|
||||
bool shutdown(int how=SHUT_RDWR);
|
||||
virtual bool shutdown(int how=SHUT_RDWR);
|
||||
/**
|
||||
* Closes the socket.
|
||||
* After closing the socket, the handle is @em invalid, and can not be
|
||||
* used again until reassigned.
|
||||
* @return @em true if the sock is closed, @em false on error.
|
||||
*/
|
||||
bool close();
|
||||
virtual bool close();
|
||||
|
||||
friend struct ioresult;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -54,6 +54,26 @@ namespace sockpp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Result of a thread-safe read or write
|
||||
* (\ref read_r, \ref read_n_r, \ref write_r, \ref write_n_r)
|
||||
*/
|
||||
struct ioresult {
|
||||
size_t count = 0; ///< Byte count, or 0 on error or EOF
|
||||
int error = 0; ///< errno value (0 if no error or EOF)
|
||||
|
||||
ioresult() = default;
|
||||
|
||||
ioresult(size_t c, int e) :count(c), error(e) { }
|
||||
|
||||
explicit inline ioresult(ssize_t n) {
|
||||
if (n >= 0)
|
||||
count = size_t(n);
|
||||
else
|
||||
error = socket::get_last_error();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Base class for streaming sockets, such as TCP and Unix Domain.
|
||||
* This is the streaming connection between two peers. It looks like a
|
||||
@@ -230,6 +250,12 @@ public:
|
||||
bool write_timeout(const std::chrono::duration<Rep,Period>& to) {
|
||||
return write_timeout(std::chrono::duration_cast<std::chrono::microseconds>(to));
|
||||
}
|
||||
|
||||
virtual ioresult read_r(void *buf, size_t n);
|
||||
virtual ioresult read_n_r(void *buf, size_t n);
|
||||
virtual ioresult write_r(const void *buf, size_t n);
|
||||
virtual ioresult write_n_r(const void *buf, size_t n);
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
158
include/sockpp/tls_context.h
Normal file
158
include/sockpp/tls_context.h
Normal file
@@ -0,0 +1,158 @@
|
||||
/**
|
||||
* @file tls_context.h
|
||||
*
|
||||
* Context object for TLS (SSL) sockets.
|
||||
*
|
||||
* @author Jens Alfke
|
||||
* @author Couchbase, Inc.
|
||||
* @author www.couchbase.com
|
||||
*
|
||||
* @date August 2019
|
||||
*/
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// This file is part of the "sockpp" C++ socket library.
|
||||
//
|
||||
// Copyright (c) 2014-2019 Frank Pagliughi
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from this
|
||||
// software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
|
||||
// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
|
||||
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
|
||||
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
#ifndef __sockpp_tls_context_h
|
||||
#define __sockpp_tls_context_h
|
||||
|
||||
#include "sockpp/platform.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace sockpp {
|
||||
class connector;
|
||||
class stream_socket;
|
||||
class tls_socket;
|
||||
|
||||
/**
|
||||
* Context / configuration for TLS (SSL) connections; also acts as a factory for
|
||||
* \ref tls_socket objects.
|
||||
*
|
||||
* A single context can be shared by any number of \ref tls_socket instances.
|
||||
* A context must remain in scope as long as any socket using it remains in scope.
|
||||
*/
|
||||
class tls_context
|
||||
{
|
||||
public:
|
||||
enum role_t {
|
||||
CLIENT = 0,
|
||||
SERVER = 1
|
||||
};
|
||||
/**
|
||||
* A singleton context that can be used if you don't need any per-connection
|
||||
* configuration.
|
||||
*/
|
||||
static tls_context& default_context();
|
||||
|
||||
virtual ~tls_context() =default;
|
||||
|
||||
/**
|
||||
* Tells whether the context is initialized and valid. Check this after constructing
|
||||
* an instance and do not use if not valid.
|
||||
* @return Zero if valid, a nonzero error code if initialization failed.
|
||||
* The code may be a POSIX code, or one specific to the TLS library.
|
||||
*/
|
||||
int status() const {
|
||||
return status_;
|
||||
}
|
||||
|
||||
operator bool() const {
|
||||
return status_ == 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Overrides the set of trusted root certificates used for validation.
|
||||
*/
|
||||
virtual void set_root_certs(const std::string &certData) =0;
|
||||
|
||||
/**
|
||||
* Configures whether the peer is required to present a valid certificate, for a connection
|
||||
* using the given role.
|
||||
* * For the CLIENT role the default is true; if you change to false, you take
|
||||
* responsibility for validating the server certificate yourself!
|
||||
* * For the SERVER role the default is false; you can change it to true to require
|
||||
* client certificate authentication.
|
||||
* @param role The role you are configuring this setting for
|
||||
* @param require Pass true to require a valid peer certificate, false to not require.
|
||||
*/
|
||||
virtual void require_peer_cert(role_t role, bool require) =0;
|
||||
|
||||
/**
|
||||
* Requires that the peer have the exact certificate given.
|
||||
* This is known as "cert-pinning". It's more secure, but requires that the client
|
||||
* update its copy of the certificate whenever the server updates it.
|
||||
* @param certData The X.509 certificate in DER or PEM form; or an empty string for
|
||||
* no pinning (the default).
|
||||
*/
|
||||
virtual void allow_only_certificate(const std::string &certData) =0;
|
||||
|
||||
virtual void set_identity(const std::string &certificate_data,
|
||||
const std::string &private_key_data) =0;
|
||||
|
||||
/**
|
||||
* Creates a new \ref tls_socket instance that wraps the given connector socket.
|
||||
* The \ref tls_socket takes ownership of the base socket and will close it when
|
||||
* it's closed.
|
||||
* When this method returns, the TLS handshake will already have completed;
|
||||
* be sure to check the stream's status, since the handshake may have failed.
|
||||
* @param socket The underlying connector socket that TLS will use for I/O.
|
||||
* @param role CLIENT or SERVER mode.
|
||||
* @param peer_name The peer's canonical hostname, or other distinguished name,
|
||||
* to be used for certificate validation.
|
||||
* @return A new \ref tls_socket to use for secure I/O.
|
||||
*/
|
||||
virtual std::unique_ptr<tls_socket> wrap_socket(std::unique_ptr<stream_socket> socket,
|
||||
role_t role,
|
||||
const std::string &peer_name) =0;
|
||||
|
||||
protected:
|
||||
tls_context() =default;
|
||||
|
||||
/**
|
||||
* Sets the error status of the context. Call this if initialization fails.
|
||||
*/
|
||||
void set_status(int s) const {
|
||||
status_ = s;
|
||||
}
|
||||
|
||||
private:
|
||||
tls_context(const tls_context&) =delete;
|
||||
|
||||
mutable int status_ =0;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
158
include/sockpp/tls_socket.h
Normal file
158
include/sockpp/tls_socket.h
Normal file
@@ -0,0 +1,158 @@
|
||||
/**
|
||||
* @file tls_socket.h
|
||||
*
|
||||
* TLS (SSL) sockets.
|
||||
*
|
||||
* @author Jens Alfke
|
||||
* @author Couchbase, Inc.
|
||||
* @author www.couchbase.com
|
||||
*
|
||||
* @date August 2019
|
||||
*/
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// This file is part of the "sockpp" C++ socket library.
|
||||
//
|
||||
// Copyright (c) 2014-2019 Frank Pagliughi
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from this
|
||||
// software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
|
||||
// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
|
||||
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
|
||||
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
#ifndef __sockpp_tls_socket_h
|
||||
#define __sockpp_tls_socket_h
|
||||
|
||||
#include "sockpp/stream_socket.h"
|
||||
#include "sockpp/tls_context.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace sockpp {
|
||||
|
||||
/**
|
||||
* Abstract base class of TLS (SSL) sockets.
|
||||
* Instances are created by the \ref wrap_socket factory method of \ref tls_context.
|
||||
* First you create/open a regular \ref stream_socket, then immediately wrap it with TLS.
|
||||
*/
|
||||
class tls_socket : public stream_socket
|
||||
{
|
||||
/** The base class */
|
||||
using base = stream_socket;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Returns the verification status of the peer's certificate.
|
||||
* If the certificate is valid, returns zero.
|
||||
* Otherwise, returns a nonzero value whose interpretation depends on the actual
|
||||
* TLS library in use.
|
||||
* (For example, if using mbedTLS the return value is a bitwise combination of
|
||||
* \c MBEDTLS_X509_BADCERT_XXX and \c MBEDTLS_X509_BADCRL_XXX flags
|
||||
* defined in <mbedtls/x509.h>, or -1 if the TLS handshake failed earlier.)
|
||||
*/
|
||||
virtual uint32_t peer_certificate_status() =0;
|
||||
|
||||
/**
|
||||
* Returns an error message describing any problem with the peer's certificate.
|
||||
*/
|
||||
virtual std::string peer_certificate_status_message() =0;
|
||||
|
||||
/**
|
||||
* Returns the peer's X.509 certificate data, in binary DER format.
|
||||
*/
|
||||
virtual std::string peer_certificate() =0;
|
||||
|
||||
/**
|
||||
* Move assignment.
|
||||
* @param rhs The other socket to move into this one.
|
||||
* @return A reference to this object.
|
||||
*/
|
||||
tls_socket& operator=(tls_socket&& rhs) {
|
||||
base::operator=(std::move(rhs));
|
||||
stream_ = std::move(rhs.stream_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// I/O primitives must be reimplemented in subclasses:
|
||||
|
||||
virtual ssize_t read(void *buf, size_t n) override = 0;
|
||||
virtual ioresult read_r(void *buf, size_t n) override = 0;
|
||||
virtual bool read_timeout(const std::chrono::microseconds& to) override = 0;
|
||||
virtual ssize_t write(const void *buf, size_t n) override = 0;
|
||||
virtual ioresult write_r(const void *buf, size_t n) override = 0;
|
||||
virtual bool write_timeout(const std::chrono::microseconds& to) override = 0;
|
||||
|
||||
virtual ssize_t write(const std::vector<iovec> &ranges) override {
|
||||
return ranges.empty() ? 0 : write(ranges[0].iov_base, ranges[0].iov_len);
|
||||
}
|
||||
|
||||
virtual bool set_non_blocking(bool on) override = 0;
|
||||
|
||||
virtual bool close() override {
|
||||
bool ok = true;
|
||||
if (stream_) {
|
||||
ok = stream_->close();
|
||||
if (!ok && !last_error())
|
||||
clear(stream_->last_error());
|
||||
stream_.reset();
|
||||
}
|
||||
release();
|
||||
return ok;
|
||||
}
|
||||
|
||||
virtual ~tls_socket() {
|
||||
close();
|
||||
}
|
||||
|
||||
|
||||
protected:
|
||||
static constexpr socket_t PLACEHOLDER_SOCKET = -999;
|
||||
|
||||
tls_socket(std::unique_ptr<stream_socket> stream)
|
||||
:base(PLACEHOLDER_SOCKET)
|
||||
,stream_(move(stream))
|
||||
{ }
|
||||
|
||||
/**
|
||||
* Creates a TLS socket by copying the socket handle from the
|
||||
* specified socket object and transfers ownership of the socket.
|
||||
*/
|
||||
tls_socket(tls_socket&& sock) : tls_socket(std::move(sock.stream_)) {}
|
||||
|
||||
/**
|
||||
* The underlying socket stream that this socket wraps.
|
||||
* The TLS code reads and writes this stream.
|
||||
*/
|
||||
stream_socket &stream() {return *stream_;}
|
||||
|
||||
private:
|
||||
std::unique_ptr<stream_socket> stream_;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -73,14 +73,18 @@ bool acceptor::open(const sock_address& addr, int queSize /*=DFLT_QUE_SIZE*/)
|
||||
|
||||
reset(h);
|
||||
|
||||
#if !defined(_WIN32)
|
||||
// TODO: This should be an option
|
||||
if (domain == AF_INET || domain == AF_INET6) {
|
||||
int reuse = 1;
|
||||
if (!set_option(SOL_SOCKET, SO_REUSEADDR, reuse))
|
||||
return close_on_err();
|
||||
}
|
||||
#ifdef WIN32
|
||||
const int reuseSocket = SO_REUSEADDR;
|
||||
#else
|
||||
const int reuseSocket = SO_REUSEPORT;
|
||||
#endif
|
||||
|
||||
// TODO: This should be an option
|
||||
if (domain == AF_INET || domain == AF_INET6) {
|
||||
int reuse = 1;
|
||||
if (!set_option(SOL_SOCKET, reuseSocket, reuse))
|
||||
return close_on_err();
|
||||
}
|
||||
|
||||
if (!bind(addr) || !listen(queSize))
|
||||
return close_on_err();
|
||||
|
||||
@@ -35,28 +35,94 @@
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
#include "sockpp/connector.h"
|
||||
#include <cerrno>
|
||||
|
||||
namespace sockpp {
|
||||
|
||||
#ifdef _WIN32
|
||||
// Winsock calls return non-POSIX error codes
|
||||
#define ERR_IN_PROGRESS WSAEINPROGRESS
|
||||
#define ERR_TIMED_OUT WSAETIMEDOUT
|
||||
#define ERR_WOULD_BLOCK WSAEWOULDBLOCK
|
||||
#else
|
||||
#define ERR_IN_PROGRESS EINPROGRESS
|
||||
#define ERR_TIMED_OUT ETIMEDOUT
|
||||
#define ERR_WOULD_BLOCK EWOULDBLOCK
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool connector::recreate(const sock_address& addr)
|
||||
{
|
||||
sa_family_t domain = addr.family();
|
||||
socket_t h = create_handle(domain);
|
||||
|
||||
if (!check_socket_bool(h))
|
||||
return false;
|
||||
|
||||
// This will close the old connection, if any.
|
||||
reset(h);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool connector::connect(const sock_address& addr)
|
||||
{
|
||||
sa_family_t domain = addr.family();
|
||||
socket_t h = create_handle(domain);
|
||||
|
||||
if (!check_ret_bool(h))
|
||||
if (!recreate(addr))
|
||||
return false;
|
||||
|
||||
// This will close the old connection, if any.
|
||||
reset(h);
|
||||
|
||||
if (!check_ret_bool(::connect(h, addr.sockaddr_ptr(), addr.size())))
|
||||
if (!check_ret_bool(::connect(handle(), addr.sockaddr_ptr(), addr.size())))
|
||||
return close_on_err();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool connector::connect(const sock_address& addr, std::chrono::microseconds timeout)
|
||||
{
|
||||
if (timeout.count() <= 0)
|
||||
return connect(addr);
|
||||
|
||||
if (!recreate(addr))
|
||||
return false;
|
||||
|
||||
set_non_blocking(true);
|
||||
if (!check_ret_bool(::connect(handle(), addr.sockaddr_ptr(), addr.size()))) {
|
||||
if (last_error() == ERR_IN_PROGRESS || last_error() == ERR_WOULD_BLOCK) {
|
||||
// Non-blocking connect -- call `select` to wait until the timeout:
|
||||
// Note: Windows returns errors in exceptset so check it too, the
|
||||
// logic afterwords doesn't change
|
||||
fd_set readset;
|
||||
FD_ZERO(&readset);
|
||||
FD_SET(handle(), &readset);
|
||||
fd_set writeset = readset;
|
||||
fd_set exceptset = readset;
|
||||
timeval tv = to_timeval(timeout);
|
||||
int n = check_ret(::select(handle()+1, &readset, &writeset, &exceptset, &tv));
|
||||
|
||||
if (n > 0) {
|
||||
// Got a socket event, but it might be an error, so check:
|
||||
int err;
|
||||
if (get_option(SOL_SOCKET, SO_ERROR, &err))
|
||||
clear(err);
|
||||
} else if (n == 0) {
|
||||
clear(ERR_TIMED_OUT);
|
||||
}
|
||||
}
|
||||
|
||||
if (last_error() != 0) {
|
||||
close();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
set_non_blocking(false);
|
||||
return true;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// end namespace sockpp
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ std::string sys_error::error_str(int err)
|
||||
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
|
||||
buf, sizeof(buf), NULL);
|
||||
#else
|
||||
#ifdef _GNU_SOURCE
|
||||
#if defined(__GLIBC__)
|
||||
auto s = strerror_r(err, buf, sizeof(buf));
|
||||
return s ? std::string(s) : std::string();
|
||||
#else
|
||||
|
||||
710
src/mbedtls_context.cpp
Normal file
710
src/mbedtls_context.cpp
Normal file
@@ -0,0 +1,710 @@
|
||||
// mbedtls_context.cpp
|
||||
//
|
||||
// --------------------------------------------------------------------------
|
||||
// This file is part of the "sockpp" C++ socket library.
|
||||
//
|
||||
// Copyright (c) 2014-2017 Frank Pagliughi
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// 3. Neither the name of the copyright holder nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from this
|
||||
// software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
|
||||
// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
|
||||
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
|
||||
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
#include "sockpp/mbedtls_context.h"
|
||||
#include "sockpp/connector.h"
|
||||
#include "sockpp/exception.h"
|
||||
#include <mbedtls/ctr_drbg.h>
|
||||
#include <mbedtls/debug.h>
|
||||
#include <mbedtls/entropy.h>
|
||||
#include <mbedtls/error.h>
|
||||
#include <mbedtls/net_sockets.h>
|
||||
#include <mbedtls/ssl.h>
|
||||
#include <mutex>
|
||||
#include <chrono>
|
||||
#include <cassert>
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <fcntl.h>
|
||||
#include <TargetConditionals.h>
|
||||
#ifdef TARGET_OS_OSX
|
||||
// For macOS read_system_root_certs():
|
||||
#include <Security/Security.h>
|
||||
#endif
|
||||
#elif !defined(_WIN32)
|
||||
// For Unix read_system_root_certs():
|
||||
#include <dirent.h>
|
||||
#include <fcntl.h>
|
||||
#include <fnmatch.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <sys/stat.h>
|
||||
#else
|
||||
#include <wincrypt.h>
|
||||
#include <sstream>
|
||||
|
||||
#pragma comment (lib, "crypt32.lib")
|
||||
#pragma comment (lib, "cryptui.lib")
|
||||
#endif
|
||||
|
||||
|
||||
// TODO: Better logging(?)
|
||||
#define log(FMT,...) fprintf(stderr, "TLS: " FMT "\n", ## __VA_ARGS__)
|
||||
|
||||
|
||||
namespace sockpp {
|
||||
using namespace std;
|
||||
|
||||
|
||||
static std::string read_system_root_certs();
|
||||
|
||||
|
||||
static int log_mbed_ret(int ret, const char *fn) {
|
||||
if (ret != 0) {
|
||||
char msg[100];
|
||||
mbedtls_strerror(ret, msg, sizeof(msg));
|
||||
log("mbedtls error -0x%04X from %s: %s", -ret, fn, msg);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
// Simple RAII helper for mbedTLS cert struct
|
||||
struct mbedtls_context::cert : public mbedtls_x509_crt
|
||||
{
|
||||
cert() {mbedtls_x509_crt_init(this);}
|
||||
~cert() {mbedtls_x509_crt_free(this);}
|
||||
};
|
||||
|
||||
|
||||
// Simple RAII helper for mbedTLS cert struct
|
||||
struct mbedtls_context::key : public mbedtls_pk_context
|
||||
{
|
||||
key() {mbedtls_pk_init(this);}
|
||||
~key() {mbedtls_pk_free(this);}
|
||||
};
|
||||
|
||||
|
||||
#pragma mark - SOCKET:
|
||||
|
||||
|
||||
/** Concrete implementation of tls_socket using mbedTLS. */
|
||||
class mbedtls_socket : public tls_socket {
|
||||
private:
|
||||
mbedtls_context& context_;
|
||||
mbedtls_ssl_context ssl_;
|
||||
chrono::microseconds read_timeout_ {0L};
|
||||
bool open_ = false;
|
||||
|
||||
public:
|
||||
|
||||
mbedtls_socket(unique_ptr<stream_socket> base,
|
||||
mbedtls_context &context,
|
||||
const string &hostname)
|
||||
:tls_socket(move(base))
|
||||
,context_(context)
|
||||
{
|
||||
mbedtls_ssl_init(&ssl_);
|
||||
if (context.status() != 0) {
|
||||
clear(context.status());
|
||||
return;
|
||||
}
|
||||
|
||||
if (check_mbed_setup(mbedtls_ssl_setup(&ssl_, context_.ssl_config_.get()),
|
||||
"mbedtls_ssl_setup"))
|
||||
return;
|
||||
if (!hostname.empty() && check_mbed_setup(mbedtls_ssl_set_hostname(&ssl_, hostname.c_str()),
|
||||
"mbedtls_ssl_set_hostname"))
|
||||
return;
|
||||
|
||||
#if defined(_WIN32)
|
||||
// Winsock does not allow us to tell if a socket is nonblocking, so assume it isn't
|
||||
bool blocking = true;
|
||||
#else
|
||||
int flags = fcntl(stream().handle(), F_GETFL, 0);
|
||||
bool blocking = (flags < 0 || (flags & O_NONBLOCK) == 0);
|
||||
#endif
|
||||
setup_bio(blocking);
|
||||
|
||||
// Run the TLS handshake:
|
||||
int status;
|
||||
do {
|
||||
open_ = true; // temporarily, so BIO methods won't fail
|
||||
status = mbedtls_ssl_handshake(&ssl_);
|
||||
open_ = false;
|
||||
} while (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE
|
||||
|| status == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS);
|
||||
if (check_mbed_setup(status, "mbedtls_ssl_handshake") != 0)
|
||||
return;
|
||||
|
||||
uint32_t verify_flags = mbedtls_ssl_get_verify_result(&ssl_);
|
||||
if (verify_flags != 0 && verify_flags != uint32_t(-1)
|
||||
&& !(verify_flags & MBEDTLS_X509_BADCERT_SKIP_VERIFY)) {
|
||||
char vrfy_buf[512];
|
||||
mbedtls_x509_crt_verify_info(vrfy_buf, sizeof( vrfy_buf ), "", verify_flags);
|
||||
log("Cert verify failed: %s", vrfy_buf );
|
||||
reset();
|
||||
clear(MBEDTLS_ERR_X509_CERT_VERIFY_FAILED);
|
||||
return;
|
||||
}
|
||||
open_ = true;
|
||||
}
|
||||
|
||||
|
||||
void setup_bio(bool nonblocking) {
|
||||
mbedtls_ssl_send_t *f_send = [](void *ctx, const uint8_t *buf, size_t len) {
|
||||
return ((mbedtls_socket*)ctx)->bio_send(buf, len); };
|
||||
mbedtls_ssl_recv_t *f_recv = nullptr;
|
||||
mbedtls_ssl_recv_timeout_t *f_recv_timeout = nullptr;
|
||||
if (nonblocking)
|
||||
f_recv = [](void *ctx, uint8_t *buf, size_t len) {
|
||||
return ((mbedtls_socket*)ctx)->bio_recv(buf, len); };
|
||||
else
|
||||
f_recv_timeout = [](void *ctx, uint8_t *buf, size_t len, uint32_t timeout) {
|
||||
return ((mbedtls_socket*)ctx)->bio_recv_timeout(buf, len, timeout); };
|
||||
mbedtls_ssl_set_bio(&ssl_, this, f_send, f_recv, f_recv_timeout);
|
||||
}
|
||||
|
||||
|
||||
~mbedtls_socket() {
|
||||
close();
|
||||
mbedtls_ssl_free(&ssl_);
|
||||
reset(); // remove bogus file descriptor so base class won't call close() on it
|
||||
}
|
||||
|
||||
|
||||
virtual bool close() override {
|
||||
if (open_) {
|
||||
mbedtls_ssl_close_notify(&ssl_);
|
||||
open_ = false;
|
||||
}
|
||||
return tls_socket::close();
|
||||
}
|
||||
|
||||
|
||||
// -------- certificate / trust API
|
||||
|
||||
|
||||
uint32_t peer_certificate_status() override {
|
||||
return mbedtls_ssl_get_verify_result(&ssl_);
|
||||
}
|
||||
|
||||
|
||||
string peer_certificate_status_message() override {
|
||||
uint32_t verify_flags = mbedtls_ssl_get_verify_result(&ssl_);
|
||||
if (verify_flags == 0 || verify_flags == UINT32_MAX)
|
||||
return "";
|
||||
char message[512];
|
||||
mbedtls_x509_crt_verify_info(message, sizeof( message ), "",
|
||||
verify_flags & ~MBEDTLS_X509_BADCERT_OTHER);
|
||||
size_t len = strlen(message);
|
||||
if (len > 0 && message[len] == '\0')
|
||||
--len;
|
||||
|
||||
string result(message, len);
|
||||
if (verify_flags & MBEDTLS_X509_BADCERT_OTHER) { // flag set by verify_callback()
|
||||
if (!result.empty())
|
||||
result = "\n" + result;
|
||||
result = "The certificate does not match the known pinned certificate" + result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
string peer_certificate() override {
|
||||
auto cert = mbedtls_ssl_get_peer_cert(&ssl_);
|
||||
if (!cert)
|
||||
return "";
|
||||
return string((const char*)cert->raw.p, cert->raw.len);
|
||||
}
|
||||
|
||||
|
||||
// -------- stream_socket I/O
|
||||
|
||||
|
||||
ssize_t read(void *buf, size_t length) override {
|
||||
return check_mbed_io( mbedtls_ssl_read(&ssl_, (uint8_t*)buf, length) );
|
||||
}
|
||||
|
||||
|
||||
ioresult read_r(void *buf, size_t length) override {
|
||||
return ioresult_from_mbed( mbedtls_ssl_read(&ssl_, (uint8_t*)buf, length) );
|
||||
}
|
||||
|
||||
|
||||
bool read_timeout(const chrono::microseconds& to) override {
|
||||
bool ok = stream().read_timeout(to);
|
||||
if (ok)
|
||||
read_timeout_ = to;
|
||||
return ok;
|
||||
}
|
||||
|
||||
|
||||
ssize_t write(const void *buf, size_t length) override {
|
||||
if (length == 0)
|
||||
return 0;
|
||||
return check_mbed_io( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) );
|
||||
}
|
||||
|
||||
|
||||
ioresult write_r(const void *buf, size_t length) override {
|
||||
if (length == 0)
|
||||
return {};
|
||||
return ioresult_from_mbed( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) );
|
||||
}
|
||||
|
||||
|
||||
bool write_timeout(const chrono::microseconds& to) override {
|
||||
return stream().write_timeout(to);
|
||||
}
|
||||
|
||||
|
||||
bool set_non_blocking(bool nonblocking) override {
|
||||
bool ok = stream().set_non_blocking(nonblocking);
|
||||
if (ok)
|
||||
setup_bio(nonblocking);
|
||||
return ok;
|
||||
}
|
||||
|
||||
|
||||
// -------- mbedTLS BIO callbacks
|
||||
|
||||
|
||||
int bio_send(const void* buf, size_t length) {
|
||||
if (!open_)
|
||||
return MBEDTLS_ERR_NET_CONN_RESET;
|
||||
return bio_return_value<false>(stream().write_r(buf, length));
|
||||
}
|
||||
|
||||
|
||||
int bio_recv(void* buf, size_t length) {
|
||||
if (!open_)
|
||||
return MBEDTLS_ERR_NET_CONN_RESET;
|
||||
return bio_return_value<true>(stream().read_r(buf, length));
|
||||
}
|
||||
|
||||
|
||||
int bio_recv_timeout(void* buf, size_t length, uint32_t timeout) {
|
||||
if (!open_)
|
||||
return MBEDTLS_ERR_NET_CONN_RESET;
|
||||
if (timeout > 0)
|
||||
stream().read_timeout(chrono::milliseconds(timeout));
|
||||
|
||||
int n = bio_recv(buf, length);
|
||||
|
||||
if (timeout > 0)
|
||||
stream().read_timeout(read_timeout_);
|
||||
return (int)n;
|
||||
}
|
||||
|
||||
|
||||
// -------- error handling
|
||||
|
||||
|
||||
// Translates mbedTLS error code to POSIX (errno)
|
||||
static int translate_mbed_err(int mbedErr) {
|
||||
switch (mbedErr) {
|
||||
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
|
||||
return 0;
|
||||
case MBEDTLS_ERR_SSL_WANT_READ:
|
||||
case MBEDTLS_ERR_SSL_WANT_WRITE:
|
||||
log(">>> mbedtls_socket returning EWOULDBLOCK");
|
||||
return EWOULDBLOCK;
|
||||
case MBEDTLS_ERR_NET_CONN_RESET:
|
||||
return ECONNRESET;
|
||||
case MBEDTLS_ERR_NET_RECV_FAILED:
|
||||
case MBEDTLS_ERR_NET_SEND_FAILED:
|
||||
return EIO;
|
||||
default:
|
||||
return mbedErr;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Handles an mbedTLS error return value during setup, closing me on error
|
||||
int check_mbed_setup(int ret, const char *fn) {
|
||||
if (ret != 0) {
|
||||
log_mbed_ret(ret, fn);
|
||||
reset(); // marks me as closed/invalid
|
||||
clear(translate_mbed_err(ret)); // sets last_error
|
||||
stream().close();
|
||||
open_ = false;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
// Handles an mbedTLS read/write return value, storing any error in last_error
|
||||
inline ssize_t check_mbed_io(int mbedResult) {
|
||||
if (mbedResult < 0) {
|
||||
clear(translate_mbed_err(mbedResult)); // sets last_error
|
||||
return -1;
|
||||
}
|
||||
return mbedResult;
|
||||
}
|
||||
|
||||
|
||||
// Handles an mbedTLS read/write return value, converting it to an ioresult.
|
||||
static inline ioresult ioresult_from_mbed(int mbedResult) {
|
||||
if (mbedResult < 0)
|
||||
return ioresult(0, translate_mbed_err(mbedResult));
|
||||
else
|
||||
return ioresult(mbedResult, 0);
|
||||
}
|
||||
|
||||
|
||||
// Translates ioresult to an mbedTLS error code to return from a BIO function.
|
||||
template <bool reading>
|
||||
static int bio_return_value(ioresult result) {
|
||||
if (result.error == 0)
|
||||
return (int)result.count;
|
||||
switch (result.error) {
|
||||
case EPIPE:
|
||||
case ECONNRESET:
|
||||
return MBEDTLS_ERR_NET_CONN_RESET;
|
||||
case EINTR:
|
||||
case EWOULDBLOCK:
|
||||
#if defined(EAGAIN) && EAGAIN != EWOULDBLOCK // these are usually synonyms
|
||||
case EAGAIN:
|
||||
#endif
|
||||
log(">>> BIO returning MBEDTLS_ERR_SSL_WANT_%s", reading ?"READ":"WRITE");
|
||||
return reading ? MBEDTLS_ERR_SSL_WANT_READ
|
||||
: MBEDTLS_ERR_SSL_WANT_WRITE;
|
||||
default:
|
||||
return reading ? MBEDTLS_ERR_NET_RECV_FAILED
|
||||
: MBEDTLS_ERR_NET_SEND_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
||||
#pragma mark - CONTEXT:
|
||||
|
||||
|
||||
static tls_context *s_default_context = nullptr;
|
||||
|
||||
mbedtls_context::cert *mbedtls_context::s_system_root_certs;
|
||||
|
||||
|
||||
tls_context& tls_context::default_context() {
|
||||
if (!s_default_context)
|
||||
s_default_context = new mbedtls_context();
|
||||
return *s_default_context;
|
||||
}
|
||||
|
||||
|
||||
// Returns a shared mbedTLS random-number generator context.
|
||||
static mbedtls_ctr_drbg_context* get_drbg_context() {
|
||||
static const char* k_entropy_personalization = "sockpp";
|
||||
static mbedtls_entropy_context s_entropy;
|
||||
static mbedtls_ctr_drbg_context s_random_ctx;
|
||||
|
||||
static once_flag once;
|
||||
call_once(once, []() {
|
||||
mbedtls_entropy_init( &s_entropy );
|
||||
mbedtls_ctr_drbg_init( &s_random_ctx );
|
||||
int ret = mbedtls_ctr_drbg_seed(&s_random_ctx, mbedtls_entropy_func, &s_entropy,
|
||||
(const uint8_t *)k_entropy_personalization,
|
||||
strlen(k_entropy_personalization));
|
||||
if (ret != 0) {
|
||||
log_mbed_ret(ret, "mbedtls_ctr_drbg_seed");
|
||||
throw sys_error(ret); //FIXME: Not an errno; use different exception?
|
||||
}
|
||||
});
|
||||
return &s_random_ctx;
|
||||
}
|
||||
|
||||
|
||||
unique_ptr<mbedtls_context::cert> mbedtls_context::parse_cert(const std::string &cert_data, bool partialOk) {
|
||||
unique_ptr<cert> c(new cert);
|
||||
mbedtls_x509_crt_init(c.get());
|
||||
int ret = mbedtls_x509_crt_parse(c.get(),
|
||||
(const uint8_t*)cert_data.data(), cert_data.size() + 1);
|
||||
if (ret != 0) {
|
||||
if(ret < 0 || !partialOk) {
|
||||
log_mbed_ret(ret, "mbedtls_x509_crt_parse");
|
||||
if(ret > 0) {
|
||||
ret = MBEDTLS_ERR_X509_CERT_VERIFY_FAILED;
|
||||
}
|
||||
|
||||
throw sys_error(ret);
|
||||
}
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
|
||||
void mbedtls_context::set_root_certs(const std::string &cert_data) {
|
||||
root_certs_ = parse_cert(cert_data, true);
|
||||
mbedtls_ssl_conf_ca_chain(ssl_config_.get(), root_certs_.get(), nullptr);
|
||||
}
|
||||
|
||||
|
||||
// Returns the set of system trusted root CA certs.
|
||||
mbedtls_x509_crt* mbedtls_context::get_system_root_certs() {
|
||||
static once_flag once;
|
||||
call_once(once, []() {
|
||||
// One-time initialization:
|
||||
string certsPEM = read_system_root_certs();
|
||||
if (!certsPEM.empty())
|
||||
s_system_root_certs = parse_cert(certsPEM, true).release();
|
||||
});
|
||||
return s_system_root_certs;
|
||||
}
|
||||
|
||||
|
||||
mbedtls_context::mbedtls_context(role_t r)
|
||||
:ssl_config_(new mbedtls_ssl_config)
|
||||
{
|
||||
mbedtls_ssl_config_init(ssl_config_.get());
|
||||
mbedtls_ssl_conf_rng(ssl_config_.get(), mbedtls_ctr_drbg_random, get_drbg_context());
|
||||
int endpoint = (r == CLIENT) ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER;
|
||||
set_status(mbedtls_ssl_config_defaults(ssl_config_.get(),
|
||||
endpoint,
|
||||
MBEDTLS_SSL_TRANSPORT_STREAM,
|
||||
MBEDTLS_SSL_PRESET_DEFAULT));
|
||||
if (status() != 0)
|
||||
return;
|
||||
|
||||
auto roots = get_system_root_certs();
|
||||
if (roots)
|
||||
mbedtls_ssl_conf_ca_chain(ssl_config_.get(), roots, nullptr);
|
||||
}
|
||||
|
||||
|
||||
mbedtls_context::~mbedtls_context() {
|
||||
mbedtls_ssl_config_free(ssl_config_.get());
|
||||
}
|
||||
|
||||
|
||||
void mbedtls_context::set_logger(int threshold, Logger logger) {
|
||||
if (!logger_) {
|
||||
mbedtls_ssl_conf_dbg(ssl_config_.get(), [](void *ctx, int level, const char *file, int line,
|
||||
const char *msg) {
|
||||
auto &logger = ((mbedtls_context*)ctx)->logger_;
|
||||
if (logger)
|
||||
logger(level, file, line, msg);
|
||||
}, this);
|
||||
}
|
||||
logger_ = logger;
|
||||
mbedtls_debug_set_threshold(threshold);
|
||||
}
|
||||
|
||||
|
||||
void mbedtls_context::require_peer_cert(role_t forRole, bool require) {
|
||||
if (forRole != role())
|
||||
return;
|
||||
int authMode = (require ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_OPTIONAL);
|
||||
mbedtls_ssl_conf_authmode(ssl_config_.get(), authMode);
|
||||
}
|
||||
|
||||
|
||||
void mbedtls_context::allow_only_certificate(const std::string &cert_data) {
|
||||
pinned_cert_.reset();
|
||||
if (cert_data.empty()) {
|
||||
mbedtls_ssl_conf_verify(ssl_config_.get(), nullptr, nullptr);
|
||||
} else {
|
||||
pinned_cert_ = parse_cert(cert_data, false);
|
||||
// Install a custom verification callback:
|
||||
mbedtls_ssl_conf_verify(
|
||||
ssl_config_.get(),
|
||||
[](void *ctx, mbedtls_x509_crt *crt, int depth, uint32_t *flags) {
|
||||
return ((mbedtls_context*)ctx)->verify_callback(crt,depth,flags);
|
||||
},
|
||||
this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void mbedtls_context::allow_only_certificate(mbedtls_x509_crt *certificate) {
|
||||
pinned_cert_.reset();
|
||||
if (certificate) {
|
||||
string cert_data((const char*)certificate->raw.p, certificate->raw.len);
|
||||
allow_only_certificate(cert_data);
|
||||
} else {
|
||||
mbedtls_ssl_conf_verify(ssl_config_.get(), nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// callback from mbedTLS cert validation (see above)
|
||||
int mbedtls_context::verify_callback(mbedtls_x509_crt *crt, int depth, uint32_t *flags) {
|
||||
if (depth == 0) {
|
||||
if (crt->raw.len == pinned_cert_->raw.len
|
||||
&& 0 == memcmp(crt->raw.p, pinned_cert_->raw.p, crt->raw.len)) {
|
||||
// The cert matches our pinned cert, so mark it as trusted.
|
||||
// (It might still be invalid if it's expired or revoked...)
|
||||
*flags &= ~(MBEDTLS_X509_BADCERT_NOT_TRUSTED | MBEDTLS_X509_BADCERT_CN_MISMATCH);
|
||||
} else {
|
||||
// If cert doesn't match pinned cert, mark it as untrusted.
|
||||
*flags |= MBEDTLS_X509_BADCERT_OTHER;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
void mbedtls_context::set_identity(const std::string &certificate_data,
|
||||
const std::string &private_key_data)
|
||||
{
|
||||
auto ident_cert = parse_cert(certificate_data, false);
|
||||
|
||||
unique_ptr<key> ident_key(new key);
|
||||
int err = mbedtls_pk_parse_key(ident_key.get(),
|
||||
(const uint8_t*) private_key_data.data(),
|
||||
private_key_data.size(), NULL, 0);
|
||||
if( err != 0 ) {
|
||||
log_mbed_ret(err, "mbedtls_pk_parse_key");
|
||||
throw sys_error(err);
|
||||
}
|
||||
|
||||
set_identity(ident_cert.get(), ident_key.get());
|
||||
identity_cert_ = move(ident_cert);
|
||||
identity_key_ = move(ident_key);
|
||||
}
|
||||
|
||||
|
||||
void mbedtls_context::set_identity(mbedtls_x509_crt *certificate,
|
||||
mbedtls_pk_context *private_key)
|
||||
{
|
||||
mbedtls_ssl_conf_own_cert(ssl_config_.get(), certificate, private_key);
|
||||
}
|
||||
|
||||
|
||||
mbedtls_context::role_t mbedtls_context::role() {
|
||||
return (ssl_config_->endpoint == MBEDTLS_SSL_IS_CLIENT) ? CLIENT : SERVER;
|
||||
}
|
||||
|
||||
|
||||
unique_ptr<tls_socket> mbedtls_context::wrap_socket(std::unique_ptr<stream_socket> socket,
|
||||
role_t socketRole,
|
||||
const std::string &peer_name)
|
||||
{
|
||||
assert(socketRole == role());
|
||||
return unique_ptr<tls_socket>(new mbedtls_socket(move(socket), *this, peer_name));
|
||||
}
|
||||
|
||||
|
||||
#pragma mark - PLATFORM SPECIFIC:
|
||||
|
||||
|
||||
// mbedTLS does not have built-in support for reading the OS's trusted root certs.
|
||||
|
||||
#ifdef __APPLE__
|
||||
// Read system root CA certs on macOS.
|
||||
// (Sadly, SecTrustCopyAnchorCertificates() is not available on iOS)
|
||||
static string read_system_root_certs() {
|
||||
#if TARGET_OS_OSX
|
||||
CFArrayRef roots;
|
||||
OSStatus err = SecTrustCopyAnchorCertificates(&roots);
|
||||
if (err)
|
||||
return {};
|
||||
CFDataRef pemData = nullptr;
|
||||
err = SecItemExport(roots, kSecFormatPEMSequence, kSecItemPemArmour, nullptr, &pemData);
|
||||
CFRelease(roots);
|
||||
if (err)
|
||||
return {};
|
||||
string pem((const char*)CFDataGetBytePtr(pemData), CFDataGetLength(pemData));
|
||||
CFRelease(pemData);
|
||||
return pem;
|
||||
#else
|
||||
// fallback -- no certs
|
||||
return "";
|
||||
#endif
|
||||
}
|
||||
|
||||
#elif defined(_WIN32)
|
||||
// Windows:
|
||||
static string read_system_root_certs() {
|
||||
PCCERT_CONTEXT pContext = nullptr;
|
||||
HCERTSTORE hStore = CertOpenSystemStore(NULL, "ROOT");
|
||||
if(hStore == nullptr) {
|
||||
return "";
|
||||
}
|
||||
|
||||
stringstream certs;
|
||||
while ((pContext = CertEnumCertificatesInStore(hStore, pContext))) {
|
||||
certs.write((const char*)pContext->pbCertEncoded, pContext->cbCertEncoded);
|
||||
}
|
||||
|
||||
CertCloseStore(hStore, CERT_CLOSE_STORE_FORCE_FLAG);
|
||||
return certs.str();
|
||||
}
|
||||
|
||||
#else
|
||||
// Read system root CA certs on Linux using OpenSSL's cert directory
|
||||
static string read_system_root_certs() {
|
||||
static constexpr const char* CERTS_DIR = "/etc/ssl/certs/";
|
||||
static constexpr const char* CERTS_FILE = "ca-certificates.crt";
|
||||
|
||||
stringstream certs;
|
||||
char buf[1024];
|
||||
// Subroutine to append a file to the `certs` stream:
|
||||
auto read_file = [&](const string &file) {
|
||||
ifstream in(file);
|
||||
char last_char = '\n';
|
||||
while (in) {
|
||||
in.read(buf, sizeof(buf));
|
||||
auto n = in.gcount();
|
||||
if (n > 0) {
|
||||
certs.write(buf, n);
|
||||
last_char = buf[n-1];
|
||||
}
|
||||
}
|
||||
if (last_char != '\n')
|
||||
certs << '\n';
|
||||
};
|
||||
|
||||
struct stat s;
|
||||
if (stat(CERTS_DIR, &s) == 0 && S_ISDIR(s.st_mode)) {
|
||||
string certs_file = string(CERTS_DIR) + CERTS_FILE;
|
||||
if (stat(certs_file.c_str(), &s) == 0) {
|
||||
// If there is a file containing all the certs, just read it:
|
||||
read_file(certs_file);
|
||||
} else {
|
||||
// Otherwise concatenate all the certs found in the dir:
|
||||
auto dir = opendir(CERTS_DIR);
|
||||
if (dir) {
|
||||
struct dirent *ent;
|
||||
while (nullptr != (ent = readdir(dir))) {
|
||||
if (fnmatch("?*.pem", ent->d_name, FNM_PERIOD) == 0
|
||||
|| fnmatch("?*.crt", ent->d_name, FNM_PERIOD) == 0)
|
||||
read_file(string(CERTS_DIR) + ent->d_name);
|
||||
}
|
||||
closedir(dir);
|
||||
}
|
||||
}
|
||||
}
|
||||
return certs.str();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
}
|
||||
@@ -294,7 +294,11 @@ std::string socket::error_str(int err)
|
||||
|
||||
bool socket::shutdown(int how /*=SHUT_RDWR*/)
|
||||
{
|
||||
return check_ret_bool(::shutdown(handle_, how));
|
||||
if(handle_ != INVALID_SOCKET) {
|
||||
return check_ret_bool(::shutdown(release(), how));
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
@@ -69,6 +69,16 @@ ssize_t stream_socket::read(void *buf, size_t n)
|
||||
#endif
|
||||
}
|
||||
|
||||
ioresult stream_socket::read_r(void *buf, size_t n)
|
||||
{
|
||||
#if defined(_WIN32)
|
||||
return ioresult(::recv(handle(), reinterpret_cast<char*>(buf),
|
||||
int(n), 0));
|
||||
#else
|
||||
return ioresult(::recv(handle(), buf, n, 0));
|
||||
#endif
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Attempts to read the requested number of bytes by repeatedly calling
|
||||
// read() until it has the data or an error occurs.
|
||||
@@ -94,7 +104,6 @@ ssize_t stream_socket::read_n(void *buf, size_t n)
|
||||
return (nr == 0 && nx < 0) ? nx : ssize_t(nr);
|
||||
}
|
||||
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
ssize_t stream_socket::read(const std::vector<iovec>& ranges)
|
||||
@@ -122,6 +131,23 @@ ssize_t stream_socket::read(const std::vector<iovec>& ranges)
|
||||
#endif
|
||||
}
|
||||
|
||||
ioresult stream_socket::read_n_r(void *buf, size_t n)
|
||||
{
|
||||
ioresult result;
|
||||
uint8_t *b = reinterpret_cast<uint8_t*>(buf);
|
||||
|
||||
while (result.count < n) {
|
||||
ioresult r = read_r(b + result.count, n - result.count);
|
||||
if (r.count == 0) {
|
||||
result.error = r.error;
|
||||
break;
|
||||
}
|
||||
result.count += r.count;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
bool stream_socket::read_timeout(const microseconds& to)
|
||||
@@ -147,6 +173,16 @@ ssize_t stream_socket::write(const void *buf, size_t n)
|
||||
#endif
|
||||
}
|
||||
|
||||
ioresult stream_socket::write_r(const void *buf, size_t n)
|
||||
{
|
||||
#if defined(_WIN32)
|
||||
return ioresult(::send(handle(), reinterpret_cast<const char*>(buf),
|
||||
int(n) , 0));
|
||||
#else
|
||||
return ioresult(::send(handle(), buf, n , 0));
|
||||
#endif
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Attempts to write the entire buffer by repeatedly calling write() until
|
||||
// either all of the data is sent or an error occurs.
|
||||
@@ -171,32 +207,53 @@ ssize_t stream_socket::write_n(const void *buf, size_t n)
|
||||
return (nw == 0 && nx < 0) ? nx : ssize_t(nw);
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
ssize_t stream_socket::write(const std::vector<iovec>& ranges)
|
||||
ioresult stream_socket::write_n_r(const void *buf, size_t n)
|
||||
{
|
||||
if (ranges.empty())
|
||||
return 0;
|
||||
ioresult result;
|
||||
const uint8_t *b = reinterpret_cast<const uint8_t*>(buf);
|
||||
|
||||
#if !defined(_WIN32)
|
||||
return check_ret(::writev(handle(), ranges.data(), int(ranges.size())));
|
||||
#else
|
||||
std::vector<WSABUF> bufs;
|
||||
for (const auto& iovec : ranges) {
|
||||
bufs.push_back({
|
||||
static_cast<ULONG>(iovec.iov_len),
|
||||
static_cast<CHAR*>(iovec.iov_base)
|
||||
});
|
||||
}
|
||||
while (result.count < n) {
|
||||
ioresult r = write_r(b + result.count, n - result.count);
|
||||
if (r.count == 0) {
|
||||
result.error = r.error;
|
||||
break;
|
||||
}
|
||||
result.count += r.count;
|
||||
}
|
||||
|
||||
DWORD nwritten = 0,
|
||||
nmsg = DWORD(bufs.size());
|
||||
|
||||
auto ret = check_ret(::WSASend(handle(), bufs.data(), nmsg, &nwritten, 0, nullptr, nullptr));
|
||||
return ssize_t(ret == SOCKET_ERROR ? ret : nwritten);
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
|
||||
ssize_t stream_socket::write(const std::vector<iovec> &ranges)
|
||||
{
|
||||
#if !defined(_WIN32)
|
||||
msghdr msg = {};
|
||||
msg.msg_iov = const_cast<iovec*>(ranges.data());
|
||||
msg.msg_iovlen = int(ranges.size());
|
||||
if (msg.msg_iovlen == 0)
|
||||
return 0;
|
||||
return check_ret(sendmsg(handle(), &msg, 0));
|
||||
#else
|
||||
if(ranges.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<WSABUF> buffers;
|
||||
for(const auto& iovec : ranges) {
|
||||
buffers.push_back({
|
||||
static_cast<ULONG>(iovec.iov_len),
|
||||
static_cast<CHAR FAR *>(iovec.iov_base)
|
||||
});
|
||||
}
|
||||
|
||||
DWORD written = 0;
|
||||
ssize_t ret = check_ret(WSASend(handle(), buffers.data(), buffers.size(), &written, 0, nullptr, nullptr));
|
||||
return ret == SOCKET_ERROR ? ret : written;
|
||||
#endif
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
@@ -212,6 +269,8 @@ bool stream_socket::write_timeout(const microseconds& to)
|
||||
return set_option(SOL_SOCKET, SO_SNDTIMEO, tv);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// end namespace sockpp
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user