This commit is contained in:
fpagliughi
2020-02-16 10:00:57 -05:00
13 changed files with 1382 additions and 47 deletions

View File

@@ -69,6 +69,8 @@ class connector : public stream_socket
connector(const connector&) =delete;
connector& operator=(const connector&) =delete;
bool recreate(const sock_address& addr);
public:
/**
* Creates an unconnected connector.
@@ -80,6 +82,14 @@ public:
* @param addr The remote server address.
*/
connector(const sock_address& addr) { connect(addr); }
/**
* Creates the connector and attempts to connect to the specified
* address, with a timeout.
* If the operation times out, the \ref last_error will be set to ETIMEOUT.
* @param addr The remote server address.
* @param t The duration after which to give up. Zero means never.
*/
connector(const sock_address& addr, std::chrono::milliseconds t) { connect(addr, t); }
/**
* Move constructor.
* Creates a connector by moving the other connector to this one.
@@ -112,6 +122,16 @@ public:
* @return @em true on success, @em false on error
*/
bool connect(const sock_address& addr);
/**
* Attempts to connect to the specified server, with a timeout.
* If the socket is currently connected, this will close the current
* connection and open the new one.
* If the operation times out, the \ref last_error will be set to ETIMEOUT.
* @param addr The remote server address.
* @param timeout The duration after which to give up. Zero means never.
* @return @em true on success, @em false on error
*/
bool connect(const sock_address& addr, std::chrono::microseconds timeout);
};
/////////////////////////////////////////////////////////////////////////////
@@ -182,6 +202,18 @@ public:
* @return @em true on success, @em false on error
*/
bool connect(const addr_t& addr) { return base::connect(addr); }
/**
* Attempts to connect to the specified server, with a timeout.
* If the socket is currently connected, this will close the current
* connection and open the new one.
* If the operation times out, the \ref last_error will be set to ETIMEOUT.
* @param addr The remote server address.
* @param timeout The duration after which to give up. Zero means never.
* @return @em true on success, @em false on error
*/
bool connect(const addr_t& addr, std::chrono::microseconds timeout) {
return base::connect(addr, timeout);
}
};
/////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,120 @@
/**
* @file mbedtls_socket.h
*
* TLS (SSL) socket implementation using mbedTLS.
*
* @author Jens Alfke
* @author Couchbase, Inc.
* @author www.couchbase.com
*
* @date August 2019
*/
// --------------------------------------------------------------------------
// This file is part of the "sockpp" C++ socket library.
//
// Copyright (c) 2014-2019 Frank Pagliughi
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// --------------------------------------------------------------------------
#ifndef __sockpp_mbedtls_socket_h
#define __sockpp_mbedtls_socket_h
#include "sockpp/tls_context.h"
#include "sockpp/tls_socket.h"
#include <memory>
#include <string>
#include <functional>
struct mbedtls_pk_context;
struct mbedtls_ssl_config;
struct mbedtls_x509_crt;
namespace sockpp {
/**
* A concrete implementation of \ref tls_context, using the mbedTLS library.
* You probably don't want to use this class directly, unless you want to instantiate a
* custom context so you can have different contexts for different sockets.
*/
class mbedtls_context : public tls_context {
public:
mbedtls_context(role_t = CLIENT);
~mbedtls_context() override;
void set_root_certs(const std::string &certData) override;
void require_peer_cert(role_t, bool) override;
void allow_only_certificate(const std::string &certData) override;
void allow_only_certificate(mbedtls_x509_crt *certificate);
/**
* Sets the identity certificate and private key using mbedTLS objects.
*/
void set_identity(mbedtls_x509_crt *certificate,
mbedtls_pk_context *private_key);
void set_identity(const std::string &certificate_data,
const std::string &private_key_data) override;
std::unique_ptr<tls_socket> wrap_socket(std::unique_ptr<stream_socket> socket,
role_t,
const std::string &peer_name) override;
role_t role();
static mbedtls_x509_crt* get_system_root_certs();
using Logger = std::function<void(int level, const char *filename, int line, const char *message)>;
void set_logger(int threshold, Logger);
private:
struct cert;
struct key;
int verify_callback(mbedtls_x509_crt *crt, int depth, uint32_t *flags);
static std::unique_ptr<cert> parse_cert(const std::string &cert_data, bool partialOk);
std::unique_ptr<mbedtls_ssl_config> ssl_config_;
std::unique_ptr<cert> root_certs_;
std::unique_ptr<cert> pinned_cert_;
std::unique_ptr<cert> identity_cert_;
std::unique_ptr<key> identity_key_;
Logger logger_;
static cert *s_system_root_certs;
friend class mbedtls_socket;
};
}
#endif

View File

@@ -78,11 +78,7 @@
#ifndef _SSIZE_T_DEFINED
#define _SSIZE_T_DEFINED
#undef ssize_t
#ifdef _WIN64
using ssize_t = int64_t;
#else
using ssize_t = int;
#endif // _WIN64
using ssize_t = SSIZE_T;
#endif // _SSIZE_T_DEFINED
#ifndef _SUSECONDS_T

View File

@@ -415,7 +415,7 @@ public:
* @param on Whether to turn non-blocking mode on or off.
* @return @em true on success, @em false on failure.
*/
bool set_non_blocking(bool on=true);
virtual bool set_non_blocking(bool on=true);
/**
* Gets a string describing the specified error.
* This is typically the returned message from the system strerror().
@@ -445,14 +445,16 @@ public:
* @li SHUT_RDWR (2) Further reads and writes disallowed.
* @return @em true on success, @em false on error.
*/
bool shutdown(int how=SHUT_RDWR);
virtual bool shutdown(int how=SHUT_RDWR);
/**
* Closes the socket.
* After closing the socket, the handle is @em invalid, and can not be
* used again until reassigned.
* @return @em true if the sock is closed, @em false on error.
*/
bool close();
virtual bool close();
friend struct ioresult;
};
/////////////////////////////////////////////////////////////////////////////

View File

@@ -54,6 +54,26 @@ namespace sockpp {
/////////////////////////////////////////////////////////////////////////////
/**
* Result of a thread-safe read or write
* (\ref read_r, \ref read_n_r, \ref write_r, \ref write_n_r)
*/
struct ioresult {
size_t count = 0; ///< Byte count, or 0 on error or EOF
int error = 0; ///< errno value (0 if no error or EOF)
ioresult() = default;
ioresult(size_t c, int e) :count(c), error(e) { }
explicit inline ioresult(ssize_t n) {
if (n >= 0)
count = size_t(n);
else
error = socket::get_last_error();
}
};
/**
* Base class for streaming sockets, such as TCP and Unix Domain.
* This is the streaming connection between two peers. It looks like a
@@ -230,6 +250,12 @@ public:
bool write_timeout(const std::chrono::duration<Rep,Period>& to) {
return write_timeout(std::chrono::duration_cast<std::chrono::microseconds>(to));
}
virtual ioresult read_r(void *buf, size_t n);
virtual ioresult read_n_r(void *buf, size_t n);
virtual ioresult write_r(const void *buf, size_t n);
virtual ioresult write_n_r(const void *buf, size_t n);
};
/////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,158 @@
/**
* @file tls_context.h
*
* Context object for TLS (SSL) sockets.
*
* @author Jens Alfke
* @author Couchbase, Inc.
* @author www.couchbase.com
*
* @date August 2019
*/
// --------------------------------------------------------------------------
// This file is part of the "sockpp" C++ socket library.
//
// Copyright (c) 2014-2019 Frank Pagliughi
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// --------------------------------------------------------------------------
#ifndef __sockpp_tls_context_h
#define __sockpp_tls_context_h
#include "sockpp/platform.h"
#include <memory>
#include <string>
namespace sockpp {
class connector;
class stream_socket;
class tls_socket;
/**
* Context / configuration for TLS (SSL) connections; also acts as a factory for
* \ref tls_socket objects.
*
* A single context can be shared by any number of \ref tls_socket instances.
* A context must remain in scope as long as any socket using it remains in scope.
*/
class tls_context
{
public:
enum role_t {
CLIENT = 0,
SERVER = 1
};
/**
* A singleton context that can be used if you don't need any per-connection
* configuration.
*/
static tls_context& default_context();
virtual ~tls_context() =default;
/**
* Tells whether the context is initialized and valid. Check this after constructing
* an instance and do not use if not valid.
* @return Zero if valid, a nonzero error code if initialization failed.
* The code may be a POSIX code, or one specific to the TLS library.
*/
int status() const {
return status_;
}
operator bool() const {
return status_ == 0;
}
/**
* Overrides the set of trusted root certificates used for validation.
*/
virtual void set_root_certs(const std::string &certData) =0;
/**
* Configures whether the peer is required to present a valid certificate, for a connection
* using the given role.
* * For the CLIENT role the default is true; if you change to false, you take
* responsibility for validating the server certificate yourself!
* * For the SERVER role the default is false; you can change it to true to require
* client certificate authentication.
* @param role The role you are configuring this setting for
* @param require Pass true to require a valid peer certificate, false to not require.
*/
virtual void require_peer_cert(role_t role, bool require) =0;
/**
* Requires that the peer have the exact certificate given.
* This is known as "cert-pinning". It's more secure, but requires that the client
* update its copy of the certificate whenever the server updates it.
* @param certData The X.509 certificate in DER or PEM form; or an empty string for
* no pinning (the default).
*/
virtual void allow_only_certificate(const std::string &certData) =0;
virtual void set_identity(const std::string &certificate_data,
const std::string &private_key_data) =0;
/**
* Creates a new \ref tls_socket instance that wraps the given connector socket.
* The \ref tls_socket takes ownership of the base socket and will close it when
* it's closed.
* When this method returns, the TLS handshake will already have completed;
* be sure to check the stream's status, since the handshake may have failed.
* @param socket The underlying connector socket that TLS will use for I/O.
* @param role CLIENT or SERVER mode.
* @param peer_name The peer's canonical hostname, or other distinguished name,
* to be used for certificate validation.
* @return A new \ref tls_socket to use for secure I/O.
*/
virtual std::unique_ptr<tls_socket> wrap_socket(std::unique_ptr<stream_socket> socket,
role_t role,
const std::string &peer_name) =0;
protected:
tls_context() =default;
/**
* Sets the error status of the context. Call this if initialization fails.
*/
void set_status(int s) const {
status_ = s;
}
private:
tls_context(const tls_context&) =delete;
mutable int status_ =0;
};
}
#endif

158
include/sockpp/tls_socket.h Normal file
View File

@@ -0,0 +1,158 @@
/**
* @file tls_socket.h
*
* TLS (SSL) sockets.
*
* @author Jens Alfke
* @author Couchbase, Inc.
* @author www.couchbase.com
*
* @date August 2019
*/
// --------------------------------------------------------------------------
// This file is part of the "sockpp" C++ socket library.
//
// Copyright (c) 2014-2019 Frank Pagliughi
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// --------------------------------------------------------------------------
#ifndef __sockpp_tls_socket_h
#define __sockpp_tls_socket_h
#include "sockpp/stream_socket.h"
#include "sockpp/tls_context.h"
#include <memory>
#include <string>
namespace sockpp {
/**
* Abstract base class of TLS (SSL) sockets.
* Instances are created by the \ref wrap_socket factory method of \ref tls_context.
* First you create/open a regular \ref stream_socket, then immediately wrap it with TLS.
*/
class tls_socket : public stream_socket
{
/** The base class */
using base = stream_socket;
public:
/**
* Returns the verification status of the peer's certificate.
* If the certificate is valid, returns zero.
* Otherwise, returns a nonzero value whose interpretation depends on the actual
* TLS library in use.
* (For example, if using mbedTLS the return value is a bitwise combination of
* \c MBEDTLS_X509_BADCERT_XXX and \c MBEDTLS_X509_BADCRL_XXX flags
* defined in <mbedtls/x509.h>, or -1 if the TLS handshake failed earlier.)
*/
virtual uint32_t peer_certificate_status() =0;
/**
* Returns an error message describing any problem with the peer's certificate.
*/
virtual std::string peer_certificate_status_message() =0;
/**
* Returns the peer's X.509 certificate data, in binary DER format.
*/
virtual std::string peer_certificate() =0;
/**
* Move assignment.
* @param rhs The other socket to move into this one.
* @return A reference to this object.
*/
tls_socket& operator=(tls_socket&& rhs) {
base::operator=(std::move(rhs));
stream_ = std::move(rhs.stream_);
return *this;
}
// I/O primitives must be reimplemented in subclasses:
virtual ssize_t read(void *buf, size_t n) override = 0;
virtual ioresult read_r(void *buf, size_t n) override = 0;
virtual bool read_timeout(const std::chrono::microseconds& to) override = 0;
virtual ssize_t write(const void *buf, size_t n) override = 0;
virtual ioresult write_r(const void *buf, size_t n) override = 0;
virtual bool write_timeout(const std::chrono::microseconds& to) override = 0;
virtual ssize_t write(const std::vector<iovec> &ranges) override {
return ranges.empty() ? 0 : write(ranges[0].iov_base, ranges[0].iov_len);
}
virtual bool set_non_blocking(bool on) override = 0;
virtual bool close() override {
bool ok = true;
if (stream_) {
ok = stream_->close();
if (!ok && !last_error())
clear(stream_->last_error());
stream_.reset();
}
release();
return ok;
}
virtual ~tls_socket() {
close();
}
protected:
static constexpr socket_t PLACEHOLDER_SOCKET = -999;
tls_socket(std::unique_ptr<stream_socket> stream)
:base(PLACEHOLDER_SOCKET)
,stream_(move(stream))
{ }
/**
* Creates a TLS socket by copying the socket handle from the
* specified socket object and transfers ownership of the socket.
*/
tls_socket(tls_socket&& sock) : tls_socket(std::move(sock.stream_)) {}
/**
* The underlying socket stream that this socket wraps.
* The TLS code reads and writes this stream.
*/
stream_socket &stream() {return *stream_;}
private:
std::unique_ptr<stream_socket> stream_;
};
}
#endif

View File

@@ -73,14 +73,18 @@ bool acceptor::open(const sock_address& addr, int queSize /*=DFLT_QUE_SIZE*/)
reset(h);
#if !defined(_WIN32)
// TODO: This should be an option
if (domain == AF_INET || domain == AF_INET6) {
int reuse = 1;
if (!set_option(SOL_SOCKET, SO_REUSEADDR, reuse))
return close_on_err();
}
#ifdef WIN32
const int reuseSocket = SO_REUSEADDR;
#else
const int reuseSocket = SO_REUSEPORT;
#endif
// TODO: This should be an option
if (domain == AF_INET || domain == AF_INET6) {
int reuse = 1;
if (!set_option(SOL_SOCKET, reuseSocket, reuse))
return close_on_err();
}
if (!bind(addr) || !listen(queSize))
return close_on_err();

View File

@@ -35,28 +35,94 @@
// --------------------------------------------------------------------------
#include "sockpp/connector.h"
#include <cerrno>
namespace sockpp {
#ifdef _WIN32
// Winsock calls return non-POSIX error codes
#define ERR_IN_PROGRESS WSAEINPROGRESS
#define ERR_TIMED_OUT WSAETIMEDOUT
#define ERR_WOULD_BLOCK WSAEWOULDBLOCK
#else
#define ERR_IN_PROGRESS EINPROGRESS
#define ERR_TIMED_OUT ETIMEDOUT
#define ERR_WOULD_BLOCK EWOULDBLOCK
#endif
/////////////////////////////////////////////////////////////////////////////
bool connector::recreate(const sock_address& addr)
{
sa_family_t domain = addr.family();
socket_t h = create_handle(domain);
if (!check_socket_bool(h))
return false;
// This will close the old connection, if any.
reset(h);
return true;
}
/////////////////////////////////////////////////////////////////////////////
bool connector::connect(const sock_address& addr)
{
sa_family_t domain = addr.family();
socket_t h = create_handle(domain);
if (!check_ret_bool(h))
if (!recreate(addr))
return false;
// This will close the old connection, if any.
reset(h);
if (!check_ret_bool(::connect(h, addr.sockaddr_ptr(), addr.size())))
if (!check_ret_bool(::connect(handle(), addr.sockaddr_ptr(), addr.size())))
return close_on_err();
return true;
}
/////////////////////////////////////////////////////////////////////////////
bool connector::connect(const sock_address& addr, std::chrono::microseconds timeout)
{
if (timeout.count() <= 0)
return connect(addr);
if (!recreate(addr))
return false;
set_non_blocking(true);
if (!check_ret_bool(::connect(handle(), addr.sockaddr_ptr(), addr.size()))) {
if (last_error() == ERR_IN_PROGRESS || last_error() == ERR_WOULD_BLOCK) {
// Non-blocking connect -- call `select` to wait until the timeout:
// Note: Windows returns errors in exceptset so check it too, the
// logic afterwords doesn't change
fd_set readset;
FD_ZERO(&readset);
FD_SET(handle(), &readset);
fd_set writeset = readset;
fd_set exceptset = readset;
timeval tv = to_timeval(timeout);
int n = check_ret(::select(handle()+1, &readset, &writeset, &exceptset, &tv));
if (n > 0) {
// Got a socket event, but it might be an error, so check:
int err;
if (get_option(SOL_SOCKET, SO_ERROR, &err))
clear(err);
} else if (n == 0) {
clear(ERR_TIMED_OUT);
}
}
if (last_error() != 0) {
close();
return false;
}
}
set_non_blocking(false);
return true;
}
/////////////////////////////////////////////////////////////////////////////
// end namespace sockpp
}

View File

@@ -65,7 +65,7 @@ std::string sys_error::error_str(int err)
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
buf, sizeof(buf), NULL);
#else
#ifdef _GNU_SOURCE
#if defined(__GLIBC__)
auto s = strerror_r(err, buf, sizeof(buf));
return s ? std::string(s) : std::string();
#else

710
src/mbedtls_context.cpp Normal file
View File

@@ -0,0 +1,710 @@
// mbedtls_context.cpp
//
// --------------------------------------------------------------------------
// This file is part of the "sockpp" C++ socket library.
//
// Copyright (c) 2014-2017 Frank Pagliughi
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// --------------------------------------------------------------------------
#include "sockpp/mbedtls_context.h"
#include "sockpp/connector.h"
#include "sockpp/exception.h"
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/debug.h>
#include <mbedtls/entropy.h>
#include <mbedtls/error.h>
#include <mbedtls/net_sockets.h>
#include <mbedtls/ssl.h>
#include <mutex>
#include <chrono>
#include <cassert>
#ifdef __APPLE__
#include <fcntl.h>
#include <TargetConditionals.h>
#ifdef TARGET_OS_OSX
// For macOS read_system_root_certs():
#include <Security/Security.h>
#endif
#elif !defined(_WIN32)
// For Unix read_system_root_certs():
#include <dirent.h>
#include <fcntl.h>
#include <fnmatch.h>
#include <fstream>
#include <iostream>
#include <sstream>
#include <sys/stat.h>
#else
#include <wincrypt.h>
#include <sstream>
#pragma comment (lib, "crypt32.lib")
#pragma comment (lib, "cryptui.lib")
#endif
// TODO: Better logging(?)
#define log(FMT,...) fprintf(stderr, "TLS: " FMT "\n", ## __VA_ARGS__)
namespace sockpp {
using namespace std;
static std::string read_system_root_certs();
static int log_mbed_ret(int ret, const char *fn) {
if (ret != 0) {
char msg[100];
mbedtls_strerror(ret, msg, sizeof(msg));
log("mbedtls error -0x%04X from %s: %s", -ret, fn, msg);
}
return ret;
}
// Simple RAII helper for mbedTLS cert struct
struct mbedtls_context::cert : public mbedtls_x509_crt
{
cert() {mbedtls_x509_crt_init(this);}
~cert() {mbedtls_x509_crt_free(this);}
};
// Simple RAII helper for mbedTLS cert struct
struct mbedtls_context::key : public mbedtls_pk_context
{
key() {mbedtls_pk_init(this);}
~key() {mbedtls_pk_free(this);}
};
#pragma mark - SOCKET:
/** Concrete implementation of tls_socket using mbedTLS. */
class mbedtls_socket : public tls_socket {
private:
mbedtls_context& context_;
mbedtls_ssl_context ssl_;
chrono::microseconds read_timeout_ {0L};
bool open_ = false;
public:
mbedtls_socket(unique_ptr<stream_socket> base,
mbedtls_context &context,
const string &hostname)
:tls_socket(move(base))
,context_(context)
{
mbedtls_ssl_init(&ssl_);
if (context.status() != 0) {
clear(context.status());
return;
}
if (check_mbed_setup(mbedtls_ssl_setup(&ssl_, context_.ssl_config_.get()),
"mbedtls_ssl_setup"))
return;
if (!hostname.empty() && check_mbed_setup(mbedtls_ssl_set_hostname(&ssl_, hostname.c_str()),
"mbedtls_ssl_set_hostname"))
return;
#if defined(_WIN32)
// Winsock does not allow us to tell if a socket is nonblocking, so assume it isn't
bool blocking = true;
#else
int flags = fcntl(stream().handle(), F_GETFL, 0);
bool blocking = (flags < 0 || (flags & O_NONBLOCK) == 0);
#endif
setup_bio(blocking);
// Run the TLS handshake:
int status;
do {
open_ = true; // temporarily, so BIO methods won't fail
status = mbedtls_ssl_handshake(&ssl_);
open_ = false;
} while (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE
|| status == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS);
if (check_mbed_setup(status, "mbedtls_ssl_handshake") != 0)
return;
uint32_t verify_flags = mbedtls_ssl_get_verify_result(&ssl_);
if (verify_flags != 0 && verify_flags != uint32_t(-1)
&& !(verify_flags & MBEDTLS_X509_BADCERT_SKIP_VERIFY)) {
char vrfy_buf[512];
mbedtls_x509_crt_verify_info(vrfy_buf, sizeof( vrfy_buf ), "", verify_flags);
log("Cert verify failed: %s", vrfy_buf );
reset();
clear(MBEDTLS_ERR_X509_CERT_VERIFY_FAILED);
return;
}
open_ = true;
}
void setup_bio(bool nonblocking) {
mbedtls_ssl_send_t *f_send = [](void *ctx, const uint8_t *buf, size_t len) {
return ((mbedtls_socket*)ctx)->bio_send(buf, len); };
mbedtls_ssl_recv_t *f_recv = nullptr;
mbedtls_ssl_recv_timeout_t *f_recv_timeout = nullptr;
if (nonblocking)
f_recv = [](void *ctx, uint8_t *buf, size_t len) {
return ((mbedtls_socket*)ctx)->bio_recv(buf, len); };
else
f_recv_timeout = [](void *ctx, uint8_t *buf, size_t len, uint32_t timeout) {
return ((mbedtls_socket*)ctx)->bio_recv_timeout(buf, len, timeout); };
mbedtls_ssl_set_bio(&ssl_, this, f_send, f_recv, f_recv_timeout);
}
~mbedtls_socket() {
close();
mbedtls_ssl_free(&ssl_);
reset(); // remove bogus file descriptor so base class won't call close() on it
}
virtual bool close() override {
if (open_) {
mbedtls_ssl_close_notify(&ssl_);
open_ = false;
}
return tls_socket::close();
}
// -------- certificate / trust API
uint32_t peer_certificate_status() override {
return mbedtls_ssl_get_verify_result(&ssl_);
}
string peer_certificate_status_message() override {
uint32_t verify_flags = mbedtls_ssl_get_verify_result(&ssl_);
if (verify_flags == 0 || verify_flags == UINT32_MAX)
return "";
char message[512];
mbedtls_x509_crt_verify_info(message, sizeof( message ), "",
verify_flags & ~MBEDTLS_X509_BADCERT_OTHER);
size_t len = strlen(message);
if (len > 0 && message[len] == '\0')
--len;
string result(message, len);
if (verify_flags & MBEDTLS_X509_BADCERT_OTHER) { // flag set by verify_callback()
if (!result.empty())
result = "\n" + result;
result = "The certificate does not match the known pinned certificate" + result;
}
return result;
}
string peer_certificate() override {
auto cert = mbedtls_ssl_get_peer_cert(&ssl_);
if (!cert)
return "";
return string((const char*)cert->raw.p, cert->raw.len);
}
// -------- stream_socket I/O
ssize_t read(void *buf, size_t length) override {
return check_mbed_io( mbedtls_ssl_read(&ssl_, (uint8_t*)buf, length) );
}
ioresult read_r(void *buf, size_t length) override {
return ioresult_from_mbed( mbedtls_ssl_read(&ssl_, (uint8_t*)buf, length) );
}
bool read_timeout(const chrono::microseconds& to) override {
bool ok = stream().read_timeout(to);
if (ok)
read_timeout_ = to;
return ok;
}
ssize_t write(const void *buf, size_t length) override {
if (length == 0)
return 0;
return check_mbed_io( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) );
}
ioresult write_r(const void *buf, size_t length) override {
if (length == 0)
return {};
return ioresult_from_mbed( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) );
}
bool write_timeout(const chrono::microseconds& to) override {
return stream().write_timeout(to);
}
bool set_non_blocking(bool nonblocking) override {
bool ok = stream().set_non_blocking(nonblocking);
if (ok)
setup_bio(nonblocking);
return ok;
}
// -------- mbedTLS BIO callbacks
int bio_send(const void* buf, size_t length) {
if (!open_)
return MBEDTLS_ERR_NET_CONN_RESET;
return bio_return_value<false>(stream().write_r(buf, length));
}
int bio_recv(void* buf, size_t length) {
if (!open_)
return MBEDTLS_ERR_NET_CONN_RESET;
return bio_return_value<true>(stream().read_r(buf, length));
}
int bio_recv_timeout(void* buf, size_t length, uint32_t timeout) {
if (!open_)
return MBEDTLS_ERR_NET_CONN_RESET;
if (timeout > 0)
stream().read_timeout(chrono::milliseconds(timeout));
int n = bio_recv(buf, length);
if (timeout > 0)
stream().read_timeout(read_timeout_);
return (int)n;
}
// -------- error handling
// Translates mbedTLS error code to POSIX (errno)
static int translate_mbed_err(int mbedErr) {
switch (mbedErr) {
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
return 0;
case MBEDTLS_ERR_SSL_WANT_READ:
case MBEDTLS_ERR_SSL_WANT_WRITE:
log(">>> mbedtls_socket returning EWOULDBLOCK");
return EWOULDBLOCK;
case MBEDTLS_ERR_NET_CONN_RESET:
return ECONNRESET;
case MBEDTLS_ERR_NET_RECV_FAILED:
case MBEDTLS_ERR_NET_SEND_FAILED:
return EIO;
default:
return mbedErr;
}
}
// Handles an mbedTLS error return value during setup, closing me on error
int check_mbed_setup(int ret, const char *fn) {
if (ret != 0) {
log_mbed_ret(ret, fn);
reset(); // marks me as closed/invalid
clear(translate_mbed_err(ret)); // sets last_error
stream().close();
open_ = false;
}
return ret;
}
// Handles an mbedTLS read/write return value, storing any error in last_error
inline ssize_t check_mbed_io(int mbedResult) {
if (mbedResult < 0) {
clear(translate_mbed_err(mbedResult)); // sets last_error
return -1;
}
return mbedResult;
}
// Handles an mbedTLS read/write return value, converting it to an ioresult.
static inline ioresult ioresult_from_mbed(int mbedResult) {
if (mbedResult < 0)
return ioresult(0, translate_mbed_err(mbedResult));
else
return ioresult(mbedResult, 0);
}
// Translates ioresult to an mbedTLS error code to return from a BIO function.
template <bool reading>
static int bio_return_value(ioresult result) {
if (result.error == 0)
return (int)result.count;
switch (result.error) {
case EPIPE:
case ECONNRESET:
return MBEDTLS_ERR_NET_CONN_RESET;
case EINTR:
case EWOULDBLOCK:
#if defined(EAGAIN) && EAGAIN != EWOULDBLOCK // these are usually synonyms
case EAGAIN:
#endif
log(">>> BIO returning MBEDTLS_ERR_SSL_WANT_%s", reading ?"READ":"WRITE");
return reading ? MBEDTLS_ERR_SSL_WANT_READ
: MBEDTLS_ERR_SSL_WANT_WRITE;
default:
return reading ? MBEDTLS_ERR_NET_RECV_FAILED
: MBEDTLS_ERR_NET_SEND_FAILED;
}
}
};
#pragma mark - CONTEXT:
static tls_context *s_default_context = nullptr;
mbedtls_context::cert *mbedtls_context::s_system_root_certs;
tls_context& tls_context::default_context() {
if (!s_default_context)
s_default_context = new mbedtls_context();
return *s_default_context;
}
// Returns a shared mbedTLS random-number generator context.
static mbedtls_ctr_drbg_context* get_drbg_context() {
static const char* k_entropy_personalization = "sockpp";
static mbedtls_entropy_context s_entropy;
static mbedtls_ctr_drbg_context s_random_ctx;
static once_flag once;
call_once(once, []() {
mbedtls_entropy_init( &s_entropy );
mbedtls_ctr_drbg_init( &s_random_ctx );
int ret = mbedtls_ctr_drbg_seed(&s_random_ctx, mbedtls_entropy_func, &s_entropy,
(const uint8_t *)k_entropy_personalization,
strlen(k_entropy_personalization));
if (ret != 0) {
log_mbed_ret(ret, "mbedtls_ctr_drbg_seed");
throw sys_error(ret); //FIXME: Not an errno; use different exception?
}
});
return &s_random_ctx;
}
unique_ptr<mbedtls_context::cert> mbedtls_context::parse_cert(const std::string &cert_data, bool partialOk) {
unique_ptr<cert> c(new cert);
mbedtls_x509_crt_init(c.get());
int ret = mbedtls_x509_crt_parse(c.get(),
(const uint8_t*)cert_data.data(), cert_data.size() + 1);
if (ret != 0) {
if(ret < 0 || !partialOk) {
log_mbed_ret(ret, "mbedtls_x509_crt_parse");
if(ret > 0) {
ret = MBEDTLS_ERR_X509_CERT_VERIFY_FAILED;
}
throw sys_error(ret);
}
}
return c;
}
void mbedtls_context::set_root_certs(const std::string &cert_data) {
root_certs_ = parse_cert(cert_data, true);
mbedtls_ssl_conf_ca_chain(ssl_config_.get(), root_certs_.get(), nullptr);
}
// Returns the set of system trusted root CA certs.
mbedtls_x509_crt* mbedtls_context::get_system_root_certs() {
static once_flag once;
call_once(once, []() {
// One-time initialization:
string certsPEM = read_system_root_certs();
if (!certsPEM.empty())
s_system_root_certs = parse_cert(certsPEM, true).release();
});
return s_system_root_certs;
}
mbedtls_context::mbedtls_context(role_t r)
:ssl_config_(new mbedtls_ssl_config)
{
mbedtls_ssl_config_init(ssl_config_.get());
mbedtls_ssl_conf_rng(ssl_config_.get(), mbedtls_ctr_drbg_random, get_drbg_context());
int endpoint = (r == CLIENT) ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER;
set_status(mbedtls_ssl_config_defaults(ssl_config_.get(),
endpoint,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT));
if (status() != 0)
return;
auto roots = get_system_root_certs();
if (roots)
mbedtls_ssl_conf_ca_chain(ssl_config_.get(), roots, nullptr);
}
mbedtls_context::~mbedtls_context() {
mbedtls_ssl_config_free(ssl_config_.get());
}
void mbedtls_context::set_logger(int threshold, Logger logger) {
if (!logger_) {
mbedtls_ssl_conf_dbg(ssl_config_.get(), [](void *ctx, int level, const char *file, int line,
const char *msg) {
auto &logger = ((mbedtls_context*)ctx)->logger_;
if (logger)
logger(level, file, line, msg);
}, this);
}
logger_ = logger;
mbedtls_debug_set_threshold(threshold);
}
void mbedtls_context::require_peer_cert(role_t forRole, bool require) {
if (forRole != role())
return;
int authMode = (require ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_OPTIONAL);
mbedtls_ssl_conf_authmode(ssl_config_.get(), authMode);
}
void mbedtls_context::allow_only_certificate(const std::string &cert_data) {
pinned_cert_.reset();
if (cert_data.empty()) {
mbedtls_ssl_conf_verify(ssl_config_.get(), nullptr, nullptr);
} else {
pinned_cert_ = parse_cert(cert_data, false);
// Install a custom verification callback:
mbedtls_ssl_conf_verify(
ssl_config_.get(),
[](void *ctx, mbedtls_x509_crt *crt, int depth, uint32_t *flags) {
return ((mbedtls_context*)ctx)->verify_callback(crt,depth,flags);
},
this);
}
}
void mbedtls_context::allow_only_certificate(mbedtls_x509_crt *certificate) {
pinned_cert_.reset();
if (certificate) {
string cert_data((const char*)certificate->raw.p, certificate->raw.len);
allow_only_certificate(cert_data);
} else {
mbedtls_ssl_conf_verify(ssl_config_.get(), nullptr, nullptr);
}
}
// callback from mbedTLS cert validation (see above)
int mbedtls_context::verify_callback(mbedtls_x509_crt *crt, int depth, uint32_t *flags) {
if (depth == 0) {
if (crt->raw.len == pinned_cert_->raw.len
&& 0 == memcmp(crt->raw.p, pinned_cert_->raw.p, crt->raw.len)) {
// The cert matches our pinned cert, so mark it as trusted.
// (It might still be invalid if it's expired or revoked...)
*flags &= ~(MBEDTLS_X509_BADCERT_NOT_TRUSTED | MBEDTLS_X509_BADCERT_CN_MISMATCH);
} else {
// If cert doesn't match pinned cert, mark it as untrusted.
*flags |= MBEDTLS_X509_BADCERT_OTHER;
}
}
return 0;
}
void mbedtls_context::set_identity(const std::string &certificate_data,
const std::string &private_key_data)
{
auto ident_cert = parse_cert(certificate_data, false);
unique_ptr<key> ident_key(new key);
int err = mbedtls_pk_parse_key(ident_key.get(),
(const uint8_t*) private_key_data.data(),
private_key_data.size(), NULL, 0);
if( err != 0 ) {
log_mbed_ret(err, "mbedtls_pk_parse_key");
throw sys_error(err);
}
set_identity(ident_cert.get(), ident_key.get());
identity_cert_ = move(ident_cert);
identity_key_ = move(ident_key);
}
void mbedtls_context::set_identity(mbedtls_x509_crt *certificate,
mbedtls_pk_context *private_key)
{
mbedtls_ssl_conf_own_cert(ssl_config_.get(), certificate, private_key);
}
mbedtls_context::role_t mbedtls_context::role() {
return (ssl_config_->endpoint == MBEDTLS_SSL_IS_CLIENT) ? CLIENT : SERVER;
}
unique_ptr<tls_socket> mbedtls_context::wrap_socket(std::unique_ptr<stream_socket> socket,
role_t socketRole,
const std::string &peer_name)
{
assert(socketRole == role());
return unique_ptr<tls_socket>(new mbedtls_socket(move(socket), *this, peer_name));
}
#pragma mark - PLATFORM SPECIFIC:
// mbedTLS does not have built-in support for reading the OS's trusted root certs.
#ifdef __APPLE__
// Read system root CA certs on macOS.
// (Sadly, SecTrustCopyAnchorCertificates() is not available on iOS)
static string read_system_root_certs() {
#if TARGET_OS_OSX
CFArrayRef roots;
OSStatus err = SecTrustCopyAnchorCertificates(&roots);
if (err)
return {};
CFDataRef pemData = nullptr;
err = SecItemExport(roots, kSecFormatPEMSequence, kSecItemPemArmour, nullptr, &pemData);
CFRelease(roots);
if (err)
return {};
string pem((const char*)CFDataGetBytePtr(pemData), CFDataGetLength(pemData));
CFRelease(pemData);
return pem;
#else
// fallback -- no certs
return "";
#endif
}
#elif defined(_WIN32)
// Windows:
static string read_system_root_certs() {
PCCERT_CONTEXT pContext = nullptr;
HCERTSTORE hStore = CertOpenSystemStore(NULL, "ROOT");
if(hStore == nullptr) {
return "";
}
stringstream certs;
while ((pContext = CertEnumCertificatesInStore(hStore, pContext))) {
certs.write((const char*)pContext->pbCertEncoded, pContext->cbCertEncoded);
}
CertCloseStore(hStore, CERT_CLOSE_STORE_FORCE_FLAG);
return certs.str();
}
#else
// Read system root CA certs on Linux using OpenSSL's cert directory
static string read_system_root_certs() {
static constexpr const char* CERTS_DIR = "/etc/ssl/certs/";
static constexpr const char* CERTS_FILE = "ca-certificates.crt";
stringstream certs;
char buf[1024];
// Subroutine to append a file to the `certs` stream:
auto read_file = [&](const string &file) {
ifstream in(file);
char last_char = '\n';
while (in) {
in.read(buf, sizeof(buf));
auto n = in.gcount();
if (n > 0) {
certs.write(buf, n);
last_char = buf[n-1];
}
}
if (last_char != '\n')
certs << '\n';
};
struct stat s;
if (stat(CERTS_DIR, &s) == 0 && S_ISDIR(s.st_mode)) {
string certs_file = string(CERTS_DIR) + CERTS_FILE;
if (stat(certs_file.c_str(), &s) == 0) {
// If there is a file containing all the certs, just read it:
read_file(certs_file);
} else {
// Otherwise concatenate all the certs found in the dir:
auto dir = opendir(CERTS_DIR);
if (dir) {
struct dirent *ent;
while (nullptr != (ent = readdir(dir))) {
if (fnmatch("?*.pem", ent->d_name, FNM_PERIOD) == 0
|| fnmatch("?*.crt", ent->d_name, FNM_PERIOD) == 0)
read_file(string(CERTS_DIR) + ent->d_name);
}
closedir(dir);
}
}
}
return certs.str();
}
#endif
}

View File

@@ -294,7 +294,11 @@ std::string socket::error_str(int err)
bool socket::shutdown(int how /*=SHUT_RDWR*/)
{
return check_ret_bool(::shutdown(handle_, how));
if(handle_ != INVALID_SOCKET) {
return check_ret_bool(::shutdown(release(), how));
}
return false;
}
// --------------------------------------------------------------------------

View File

@@ -69,6 +69,16 @@ ssize_t stream_socket::read(void *buf, size_t n)
#endif
}
ioresult stream_socket::read_r(void *buf, size_t n)
{
#if defined(_WIN32)
return ioresult(::recv(handle(), reinterpret_cast<char*>(buf),
int(n), 0));
#else
return ioresult(::recv(handle(), buf, n, 0));
#endif
}
// --------------------------------------------------------------------------
// Attempts to read the requested number of bytes by repeatedly calling
// read() until it has the data or an error occurs.
@@ -94,7 +104,6 @@ ssize_t stream_socket::read_n(void *buf, size_t n)
return (nr == 0 && nx < 0) ? nx : ssize_t(nr);
}
// --------------------------------------------------------------------------
ssize_t stream_socket::read(const std::vector<iovec>& ranges)
@@ -122,6 +131,23 @@ ssize_t stream_socket::read(const std::vector<iovec>& ranges)
#endif
}
ioresult stream_socket::read_n_r(void *buf, size_t n)
{
ioresult result;
uint8_t *b = reinterpret_cast<uint8_t*>(buf);
while (result.count < n) {
ioresult r = read_r(b + result.count, n - result.count);
if (r.count == 0) {
result.error = r.error;
break;
}
result.count += r.count;
}
return result;
}
// --------------------------------------------------------------------------
bool stream_socket::read_timeout(const microseconds& to)
@@ -147,6 +173,16 @@ ssize_t stream_socket::write(const void *buf, size_t n)
#endif
}
ioresult stream_socket::write_r(const void *buf, size_t n)
{
#if defined(_WIN32)
return ioresult(::send(handle(), reinterpret_cast<const char*>(buf),
int(n) , 0));
#else
return ioresult(::send(handle(), buf, n , 0));
#endif
}
// --------------------------------------------------------------------------
// Attempts to write the entire buffer by repeatedly calling write() until
// either all of the data is sent or an error occurs.
@@ -171,32 +207,53 @@ ssize_t stream_socket::write_n(const void *buf, size_t n)
return (nw == 0 && nx < 0) ? nx : ssize_t(nw);
}
// --------------------------------------------------------------------------
ssize_t stream_socket::write(const std::vector<iovec>& ranges)
ioresult stream_socket::write_n_r(const void *buf, size_t n)
{
if (ranges.empty())
return 0;
ioresult result;
const uint8_t *b = reinterpret_cast<const uint8_t*>(buf);
#if !defined(_WIN32)
return check_ret(::writev(handle(), ranges.data(), int(ranges.size())));
#else
std::vector<WSABUF> bufs;
for (const auto& iovec : ranges) {
bufs.push_back({
static_cast<ULONG>(iovec.iov_len),
static_cast<CHAR*>(iovec.iov_base)
});
}
while (result.count < n) {
ioresult r = write_r(b + result.count, n - result.count);
if (r.count == 0) {
result.error = r.error;
break;
}
result.count += r.count;
}
DWORD nwritten = 0,
nmsg = DWORD(bufs.size());
auto ret = check_ret(::WSASend(handle(), bufs.data(), nmsg, &nwritten, 0, nullptr, nullptr));
return ssize_t(ret == SOCKET_ERROR ? ret : nwritten);
#endif
return result;
}
// --------------------------------------------------------------------------
ssize_t stream_socket::write(const std::vector<iovec> &ranges)
{
#if !defined(_WIN32)
msghdr msg = {};
msg.msg_iov = const_cast<iovec*>(ranges.data());
msg.msg_iovlen = int(ranges.size());
if (msg.msg_iovlen == 0)
return 0;
return check_ret(sendmsg(handle(), &msg, 0));
#else
if(ranges.empty()) {
return 0;
}
std::vector<WSABUF> buffers;
for(const auto& iovec : ranges) {
buffers.push_back({
static_cast<ULONG>(iovec.iov_len),
static_cast<CHAR FAR *>(iovec.iov_base)
});
}
DWORD written = 0;
ssize_t ret = check_ret(WSASend(handle(), buffers.data(), buffers.size(), &written, 0, nullptr, nullptr));
return ret == SOCKET_ERROR ? ret : written;
#endif
}
// --------------------------------------------------------------------------
@@ -212,6 +269,8 @@ bool stream_socket::write_timeout(const microseconds& to)
return set_option(SOL_SOCKET, SO_SNDTIMEO, tv);
}
/////////////////////////////////////////////////////////////////////////////
// end namespace sockpp
}