diff --git a/include/sockpp/connector.h b/include/sockpp/connector.h index 683ca31..f7776ce 100644 --- a/include/sockpp/connector.h +++ b/include/sockpp/connector.h @@ -69,6 +69,8 @@ class connector : public stream_socket connector(const connector&) =delete; connector& operator=(const connector&) =delete; + bool recreate(const sock_address& addr); + public: /** * Creates an unconnected connector. @@ -80,6 +82,14 @@ public: * @param addr The remote server address. */ connector(const sock_address& addr) { connect(addr); } + /** + * Creates the connector and attempts to connect to the specified + * address, with a timeout. + * If the operation times out, the \ref last_error will be set to ETIMEOUT. + * @param addr The remote server address. + * @param t The duration after which to give up. Zero means never. + */ + connector(const sock_address& addr, std::chrono::milliseconds t) { connect(addr, t); } /** * Move constructor. * Creates a connector by moving the other connector to this one. @@ -112,6 +122,16 @@ public: * @return @em true on success, @em false on error */ bool connect(const sock_address& addr); + /** + * Attempts to connect to the specified server, with a timeout. + * If the socket is currently connected, this will close the current + * connection and open the new one. + * If the operation times out, the \ref last_error will be set to ETIMEOUT. + * @param addr The remote server address. + * @param timeout The duration after which to give up. Zero means never. + * @return @em true on success, @em false on error + */ + bool connect(const sock_address& addr, std::chrono::microseconds timeout); }; ///////////////////////////////////////////////////////////////////////////// @@ -182,6 +202,18 @@ public: * @return @em true on success, @em false on error */ bool connect(const addr_t& addr) { return base::connect(addr); } + /** + * Attempts to connect to the specified server, with a timeout. + * If the socket is currently connected, this will close the current + * connection and open the new one. + * If the operation times out, the \ref last_error will be set to ETIMEOUT. + * @param addr The remote server address. + * @param timeout The duration after which to give up. Zero means never. + * @return @em true on success, @em false on error + */ + bool connect(const addr_t& addr, std::chrono::microseconds timeout) { + return base::connect(addr, timeout); + } }; ///////////////////////////////////////////////////////////////////////////// diff --git a/include/sockpp/mbedtls_context.h b/include/sockpp/mbedtls_context.h new file mode 100644 index 0000000..850c4ca --- /dev/null +++ b/include/sockpp/mbedtls_context.h @@ -0,0 +1,120 @@ +/** + * @file mbedtls_socket.h + * + * TLS (SSL) socket implementation using mbedTLS. + * + * @author Jens Alfke + * @author Couchbase, Inc. + * @author www.couchbase.com + * + * @date August 2019 + */ + +// -------------------------------------------------------------------------- +// This file is part of the "sockpp" C++ socket library. +// +// Copyright (c) 2014-2019 Frank Pagliughi +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// -------------------------------------------------------------------------- + +#ifndef __sockpp_mbedtls_socket_h +#define __sockpp_mbedtls_socket_h + +#include "sockpp/tls_context.h" +#include "sockpp/tls_socket.h" +#include +#include +#include + +struct mbedtls_pk_context; +struct mbedtls_ssl_config; +struct mbedtls_x509_crt; + +namespace sockpp { + + /** + * A concrete implementation of \ref tls_context, using the mbedTLS library. + * You probably don't want to use this class directly, unless you want to instantiate a + * custom context so you can have different contexts for different sockets. + */ + class mbedtls_context : public tls_context { + public: + mbedtls_context(role_t = CLIENT); + ~mbedtls_context() override; + + void set_root_certs(const std::string &certData) override; + void require_peer_cert(role_t, bool) override; + void allow_only_certificate(const std::string &certData) override; + + void allow_only_certificate(mbedtls_x509_crt *certificate); + + /** + * Sets the identity certificate and private key using mbedTLS objects. + */ + void set_identity(mbedtls_x509_crt *certificate, + mbedtls_pk_context *private_key); + + void set_identity(const std::string &certificate_data, + const std::string &private_key_data) override; + + std::unique_ptr wrap_socket(std::unique_ptr socket, + role_t, + const std::string &peer_name) override; + + role_t role(); + + static mbedtls_x509_crt* get_system_root_certs(); + + using Logger = std::function; + void set_logger(int threshold, Logger); + + private: + struct cert; + struct key; + + int verify_callback(mbedtls_x509_crt *crt, int depth, uint32_t *flags); + static std::unique_ptr parse_cert(const std::string &cert_data, bool partialOk); + + std::unique_ptr ssl_config_; + std::unique_ptr root_certs_; + std::unique_ptr pinned_cert_; + + std::unique_ptr identity_cert_; + std::unique_ptr identity_key_; + Logger logger_; + + static cert *s_system_root_certs; + + friend class mbedtls_socket; + }; + +} + +#endif diff --git a/include/sockpp/platform.h b/include/sockpp/platform.h index 0a8797a..c4e1586 100644 --- a/include/sockpp/platform.h +++ b/include/sockpp/platform.h @@ -78,11 +78,7 @@ #ifndef _SSIZE_T_DEFINED #define _SSIZE_T_DEFINED #undef ssize_t - #ifdef _WIN64 - using ssize_t = int64_t; - #else - using ssize_t = int; - #endif // _WIN64 + using ssize_t = SSIZE_T; #endif // _SSIZE_T_DEFINED #ifndef _SUSECONDS_T diff --git a/include/sockpp/socket.h b/include/sockpp/socket.h index c5b3f44..4540907 100644 --- a/include/sockpp/socket.h +++ b/include/sockpp/socket.h @@ -415,7 +415,7 @@ public: * @param on Whether to turn non-blocking mode on or off. * @return @em true on success, @em false on failure. */ - bool set_non_blocking(bool on=true); + virtual bool set_non_blocking(bool on=true); /** * Gets a string describing the specified error. * This is typically the returned message from the system strerror(). @@ -445,14 +445,16 @@ public: * @li SHUT_RDWR (2) Further reads and writes disallowed. * @return @em true on success, @em false on error. */ - bool shutdown(int how=SHUT_RDWR); + virtual bool shutdown(int how=SHUT_RDWR); /** * Closes the socket. * After closing the socket, the handle is @em invalid, and can not be * used again until reassigned. * @return @em true if the sock is closed, @em false on error. */ - bool close(); + virtual bool close(); + + friend struct ioresult; }; ///////////////////////////////////////////////////////////////////////////// diff --git a/include/sockpp/stream_socket.h b/include/sockpp/stream_socket.h index dce71a8..e95f0de 100644 --- a/include/sockpp/stream_socket.h +++ b/include/sockpp/stream_socket.h @@ -54,6 +54,26 @@ namespace sockpp { ///////////////////////////////////////////////////////////////////////////// +/** + * Result of a thread-safe read or write + * (\ref read_r, \ref read_n_r, \ref write_r, \ref write_n_r) + */ +struct ioresult { + size_t count = 0; ///< Byte count, or 0 on error or EOF + int error = 0; ///< errno value (0 if no error or EOF) + + ioresult() = default; + + ioresult(size_t c, int e) :count(c), error(e) { } + + explicit inline ioresult(ssize_t n) { + if (n >= 0) + count = size_t(n); + else + error = socket::get_last_error(); + } +}; + /** * Base class for streaming sockets, such as TCP and Unix Domain. * This is the streaming connection between two peers. It looks like a @@ -230,6 +250,12 @@ public: bool write_timeout(const std::chrono::duration& to) { return write_timeout(std::chrono::duration_cast(to)); } + + virtual ioresult read_r(void *buf, size_t n); + virtual ioresult read_n_r(void *buf, size_t n); + virtual ioresult write_r(const void *buf, size_t n); + virtual ioresult write_n_r(const void *buf, size_t n); + }; ///////////////////////////////////////////////////////////////////////////// diff --git a/include/sockpp/tls_context.h b/include/sockpp/tls_context.h new file mode 100644 index 0000000..fee505f --- /dev/null +++ b/include/sockpp/tls_context.h @@ -0,0 +1,158 @@ +/** + * @file tls_context.h + * + * Context object for TLS (SSL) sockets. + * + * @author Jens Alfke + * @author Couchbase, Inc. + * @author www.couchbase.com + * + * @date August 2019 + */ + +// -------------------------------------------------------------------------- +// This file is part of the "sockpp" C++ socket library. +// +// Copyright (c) 2014-2019 Frank Pagliughi +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// -------------------------------------------------------------------------- + +#ifndef __sockpp_tls_context_h +#define __sockpp_tls_context_h + +#include "sockpp/platform.h" +#include +#include + +namespace sockpp { + class connector; + class stream_socket; + class tls_socket; + + /** + * Context / configuration for TLS (SSL) connections; also acts as a factory for + * \ref tls_socket objects. + * + * A single context can be shared by any number of \ref tls_socket instances. + * A context must remain in scope as long as any socket using it remains in scope. + */ + class tls_context + { + public: + enum role_t { + CLIENT = 0, + SERVER = 1 + }; + /** + * A singleton context that can be used if you don't need any per-connection + * configuration. + */ + static tls_context& default_context(); + + virtual ~tls_context() =default; + + /** + * Tells whether the context is initialized and valid. Check this after constructing + * an instance and do not use if not valid. + * @return Zero if valid, a nonzero error code if initialization failed. + * The code may be a POSIX code, or one specific to the TLS library. + */ + int status() const { + return status_; + } + + operator bool() const { + return status_ == 0; + } + + /** + * Overrides the set of trusted root certificates used for validation. + */ + virtual void set_root_certs(const std::string &certData) =0; + + /** + * Configures whether the peer is required to present a valid certificate, for a connection + * using the given role. + * * For the CLIENT role the default is true; if you change to false, you take + * responsibility for validating the server certificate yourself! + * * For the SERVER role the default is false; you can change it to true to require + * client certificate authentication. + * @param role The role you are configuring this setting for + * @param require Pass true to require a valid peer certificate, false to not require. + */ + virtual void require_peer_cert(role_t role, bool require) =0; + + /** + * Requires that the peer have the exact certificate given. + * This is known as "cert-pinning". It's more secure, but requires that the client + * update its copy of the certificate whenever the server updates it. + * @param certData The X.509 certificate in DER or PEM form; or an empty string for + * no pinning (the default). + */ + virtual void allow_only_certificate(const std::string &certData) =0; + + virtual void set_identity(const std::string &certificate_data, + const std::string &private_key_data) =0; + + /** + * Creates a new \ref tls_socket instance that wraps the given connector socket. + * The \ref tls_socket takes ownership of the base socket and will close it when + * it's closed. + * When this method returns, the TLS handshake will already have completed; + * be sure to check the stream's status, since the handshake may have failed. + * @param socket The underlying connector socket that TLS will use for I/O. + * @param role CLIENT or SERVER mode. + * @param peer_name The peer's canonical hostname, or other distinguished name, + * to be used for certificate validation. + * @return A new \ref tls_socket to use for secure I/O. + */ + virtual std::unique_ptr wrap_socket(std::unique_ptr socket, + role_t role, + const std::string &peer_name) =0; + + protected: + tls_context() =default; + + /** + * Sets the error status of the context. Call this if initialization fails. + */ + void set_status(int s) const { + status_ = s; + } + + private: + tls_context(const tls_context&) =delete; + + mutable int status_ =0; + }; + +} + +#endif diff --git a/include/sockpp/tls_socket.h b/include/sockpp/tls_socket.h new file mode 100644 index 0000000..f1cdfa5 --- /dev/null +++ b/include/sockpp/tls_socket.h @@ -0,0 +1,158 @@ +/** + * @file tls_socket.h + * + * TLS (SSL) sockets. + * + * @author Jens Alfke + * @author Couchbase, Inc. + * @author www.couchbase.com + * + * @date August 2019 + */ + +// -------------------------------------------------------------------------- +// This file is part of the "sockpp" C++ socket library. +// +// Copyright (c) 2014-2019 Frank Pagliughi +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// -------------------------------------------------------------------------- + +#ifndef __sockpp_tls_socket_h +#define __sockpp_tls_socket_h + +#include "sockpp/stream_socket.h" +#include "sockpp/tls_context.h" +#include +#include + +namespace sockpp { + + /** + * Abstract base class of TLS (SSL) sockets. + * Instances are created by the \ref wrap_socket factory method of \ref tls_context. + * First you create/open a regular \ref stream_socket, then immediately wrap it with TLS. + */ + class tls_socket : public stream_socket + { + /** The base class */ + using base = stream_socket; + + public: + /** + * Returns the verification status of the peer's certificate. + * If the certificate is valid, returns zero. + * Otherwise, returns a nonzero value whose interpretation depends on the actual + * TLS library in use. + * (For example, if using mbedTLS the return value is a bitwise combination of + * \c MBEDTLS_X509_BADCERT_XXX and \c MBEDTLS_X509_BADCRL_XXX flags + * defined in , or -1 if the TLS handshake failed earlier.) + */ + virtual uint32_t peer_certificate_status() =0; + + /** + * Returns an error message describing any problem with the peer's certificate. + */ + virtual std::string peer_certificate_status_message() =0; + + /** + * Returns the peer's X.509 certificate data, in binary DER format. + */ + virtual std::string peer_certificate() =0; + + /** + * Move assignment. + * @param rhs The other socket to move into this one. + * @return A reference to this object. + */ + tls_socket& operator=(tls_socket&& rhs) { + base::operator=(std::move(rhs)); + stream_ = std::move(rhs.stream_); + return *this; + } + + // I/O primitives must be reimplemented in subclasses: + + virtual ssize_t read(void *buf, size_t n) override = 0; + virtual ioresult read_r(void *buf, size_t n) override = 0; + virtual bool read_timeout(const std::chrono::microseconds& to) override = 0; + virtual ssize_t write(const void *buf, size_t n) override = 0; + virtual ioresult write_r(const void *buf, size_t n) override = 0; + virtual bool write_timeout(const std::chrono::microseconds& to) override = 0; + + virtual ssize_t write(const std::vector &ranges) override { + return ranges.empty() ? 0 : write(ranges[0].iov_base, ranges[0].iov_len); + } + + virtual bool set_non_blocking(bool on) override = 0; + + virtual bool close() override { + bool ok = true; + if (stream_) { + ok = stream_->close(); + if (!ok && !last_error()) + clear(stream_->last_error()); + stream_.reset(); + } + release(); + return ok; + } + + virtual ~tls_socket() { + close(); + } + + + protected: + static constexpr socket_t PLACEHOLDER_SOCKET = -999; + + tls_socket(std::unique_ptr stream) + :base(PLACEHOLDER_SOCKET) + ,stream_(move(stream)) + { } + + /** + * Creates a TLS socket by copying the socket handle from the + * specified socket object and transfers ownership of the socket. + */ + tls_socket(tls_socket&& sock) : tls_socket(std::move(sock.stream_)) {} + + /** + * The underlying socket stream that this socket wraps. + * The TLS code reads and writes this stream. + */ + stream_socket &stream() {return *stream_;} + + private: + std::unique_ptr stream_; + }; + +} + +#endif diff --git a/src/acceptor.cpp b/src/acceptor.cpp index af63669..718c04e 100644 --- a/src/acceptor.cpp +++ b/src/acceptor.cpp @@ -73,14 +73,18 @@ bool acceptor::open(const sock_address& addr, int queSize /*=DFLT_QUE_SIZE*/) reset(h); - #if !defined(_WIN32) - // TODO: This should be an option - if (domain == AF_INET || domain == AF_INET6) { - int reuse = 1; - if (!set_option(SOL_SOCKET, SO_REUSEADDR, reuse)) - return close_on_err(); - } + #ifdef WIN32 + const int reuseSocket = SO_REUSEADDR; + #else + const int reuseSocket = SO_REUSEPORT; #endif + + // TODO: This should be an option + if (domain == AF_INET || domain == AF_INET6) { + int reuse = 1; + if (!set_option(SOL_SOCKET, reuseSocket, reuse)) + return close_on_err(); + } if (!bind(addr) || !listen(queSize)) return close_on_err(); diff --git a/src/connector.cpp b/src/connector.cpp index 0ebcfe4..bb0aefa 100644 --- a/src/connector.cpp +++ b/src/connector.cpp @@ -35,28 +35,94 @@ // -------------------------------------------------------------------------- #include "sockpp/connector.h" +#include namespace sockpp { +#ifdef _WIN32 + // Winsock calls return non-POSIX error codes + #define ERR_IN_PROGRESS WSAEINPROGRESS + #define ERR_TIMED_OUT WSAETIMEDOUT + #define ERR_WOULD_BLOCK WSAEWOULDBLOCK +#else + #define ERR_IN_PROGRESS EINPROGRESS + #define ERR_TIMED_OUT ETIMEDOUT + #define ERR_WOULD_BLOCK EWOULDBLOCK +#endif + +///////////////////////////////////////////////////////////////////////////// + +bool connector::recreate(const sock_address& addr) +{ + sa_family_t domain = addr.family(); + socket_t h = create_handle(domain); + + if (!check_socket_bool(h)) + return false; + + // This will close the old connection, if any. + reset(h); + return true; +} + + ///////////////////////////////////////////////////////////////////////////// bool connector::connect(const sock_address& addr) { - sa_family_t domain = addr.family(); - socket_t h = create_handle(domain); - - if (!check_ret_bool(h)) + if (!recreate(addr)) return false; - // This will close the old connection, if any. - reset(h); - - if (!check_ret_bool(::connect(h, addr.sockaddr_ptr(), addr.size()))) + if (!check_ret_bool(::connect(handle(), addr.sockaddr_ptr(), addr.size()))) return close_on_err(); return true; } +///////////////////////////////////////////////////////////////////////////// + +bool connector::connect(const sock_address& addr, std::chrono::microseconds timeout) +{ + if (timeout.count() <= 0) + return connect(addr); + + if (!recreate(addr)) + return false; + + set_non_blocking(true); + if (!check_ret_bool(::connect(handle(), addr.sockaddr_ptr(), addr.size()))) { + if (last_error() == ERR_IN_PROGRESS || last_error() == ERR_WOULD_BLOCK) { + // Non-blocking connect -- call `select` to wait until the timeout: + // Note: Windows returns errors in exceptset so check it too, the + // logic afterwords doesn't change + fd_set readset; + FD_ZERO(&readset); + FD_SET(handle(), &readset); + fd_set writeset = readset; + fd_set exceptset = readset; + timeval tv = to_timeval(timeout); + int n = check_ret(::select(handle()+1, &readset, &writeset, &exceptset, &tv)); + + if (n > 0) { + // Got a socket event, but it might be an error, so check: + int err; + if (get_option(SOL_SOCKET, SO_ERROR, &err)) + clear(err); + } else if (n == 0) { + clear(ERR_TIMED_OUT); + } + } + + if (last_error() != 0) { + close(); + return false; + } + } + + set_non_blocking(false); + return true; +} + ///////////////////////////////////////////////////////////////////////////// // end namespace sockpp } diff --git a/src/exception.cpp b/src/exception.cpp index 5be94d0..948d247 100644 --- a/src/exception.cpp +++ b/src/exception.cpp @@ -65,7 +65,7 @@ std::string sys_error::error_str(int err) NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), buf, sizeof(buf), NULL); #else - #ifdef _GNU_SOURCE + #if defined(__GLIBC__) auto s = strerror_r(err, buf, sizeof(buf)); return s ? std::string(s) : std::string(); #else diff --git a/src/mbedtls_context.cpp b/src/mbedtls_context.cpp new file mode 100644 index 0000000..6e45142 --- /dev/null +++ b/src/mbedtls_context.cpp @@ -0,0 +1,710 @@ +// mbedtls_context.cpp +// +// -------------------------------------------------------------------------- +// This file is part of the "sockpp" C++ socket library. +// +// Copyright (c) 2014-2017 Frank Pagliughi +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// -------------------------------------------------------------------------- + +#include "sockpp/mbedtls_context.h" +#include "sockpp/connector.h" +#include "sockpp/exception.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __APPLE__ + #include + #include + #ifdef TARGET_OS_OSX + // For macOS read_system_root_certs(): + #include + #endif +#elif !defined(_WIN32) + // For Unix read_system_root_certs(): + #include + #include + #include + #include + #include + #include + #include +#else + #include + #include + + #pragma comment (lib, "crypt32.lib") + #pragma comment (lib, "cryptui.lib") +#endif + + +// TODO: Better logging(?) +#define log(FMT,...) fprintf(stderr, "TLS: " FMT "\n", ## __VA_ARGS__) + + +namespace sockpp { + using namespace std; + + + static std::string read_system_root_certs(); + + + static int log_mbed_ret(int ret, const char *fn) { + if (ret != 0) { + char msg[100]; + mbedtls_strerror(ret, msg, sizeof(msg)); + log("mbedtls error -0x%04X from %s: %s", -ret, fn, msg); + } + return ret; + } + + + // Simple RAII helper for mbedTLS cert struct + struct mbedtls_context::cert : public mbedtls_x509_crt + { + cert() {mbedtls_x509_crt_init(this);} + ~cert() {mbedtls_x509_crt_free(this);} + }; + + + // Simple RAII helper for mbedTLS cert struct + struct mbedtls_context::key : public mbedtls_pk_context + { + key() {mbedtls_pk_init(this);} + ~key() {mbedtls_pk_free(this);} + }; + + +#pragma mark - SOCKET: + + + /** Concrete implementation of tls_socket using mbedTLS. */ + class mbedtls_socket : public tls_socket { + private: + mbedtls_context& context_; + mbedtls_ssl_context ssl_; + chrono::microseconds read_timeout_ {0L}; + bool open_ = false; + + public: + + mbedtls_socket(unique_ptr base, + mbedtls_context &context, + const string &hostname) + :tls_socket(move(base)) + ,context_(context) + { + mbedtls_ssl_init(&ssl_); + if (context.status() != 0) { + clear(context.status()); + return; + } + + if (check_mbed_setup(mbedtls_ssl_setup(&ssl_, context_.ssl_config_.get()), + "mbedtls_ssl_setup")) + return; + if (!hostname.empty() && check_mbed_setup(mbedtls_ssl_set_hostname(&ssl_, hostname.c_str()), + "mbedtls_ssl_set_hostname")) + return; + +#if defined(_WIN32) + // Winsock does not allow us to tell if a socket is nonblocking, so assume it isn't + bool blocking = true; +#else + int flags = fcntl(stream().handle(), F_GETFL, 0); + bool blocking = (flags < 0 || (flags & O_NONBLOCK) == 0); +#endif + setup_bio(blocking); + + // Run the TLS handshake: + int status; + do { + open_ = true; // temporarily, so BIO methods won't fail + status = mbedtls_ssl_handshake(&ssl_); + open_ = false; + } while (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE + || status == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS); + if (check_mbed_setup(status, "mbedtls_ssl_handshake") != 0) + return; + + uint32_t verify_flags = mbedtls_ssl_get_verify_result(&ssl_); + if (verify_flags != 0 && verify_flags != uint32_t(-1) + && !(verify_flags & MBEDTLS_X509_BADCERT_SKIP_VERIFY)) { + char vrfy_buf[512]; + mbedtls_x509_crt_verify_info(vrfy_buf, sizeof( vrfy_buf ), "", verify_flags); + log("Cert verify failed: %s", vrfy_buf ); + reset(); + clear(MBEDTLS_ERR_X509_CERT_VERIFY_FAILED); + return; + } + open_ = true; + } + + + void setup_bio(bool nonblocking) { + mbedtls_ssl_send_t *f_send = [](void *ctx, const uint8_t *buf, size_t len) { + return ((mbedtls_socket*)ctx)->bio_send(buf, len); }; + mbedtls_ssl_recv_t *f_recv = nullptr; + mbedtls_ssl_recv_timeout_t *f_recv_timeout = nullptr; + if (nonblocking) + f_recv = [](void *ctx, uint8_t *buf, size_t len) { + return ((mbedtls_socket*)ctx)->bio_recv(buf, len); }; + else + f_recv_timeout = [](void *ctx, uint8_t *buf, size_t len, uint32_t timeout) { + return ((mbedtls_socket*)ctx)->bio_recv_timeout(buf, len, timeout); }; + mbedtls_ssl_set_bio(&ssl_, this, f_send, f_recv, f_recv_timeout); + } + + + ~mbedtls_socket() { + close(); + mbedtls_ssl_free(&ssl_); + reset(); // remove bogus file descriptor so base class won't call close() on it + } + + + virtual bool close() override { + if (open_) { + mbedtls_ssl_close_notify(&ssl_); + open_ = false; + } + return tls_socket::close(); + } + + + // -------- certificate / trust API + + + uint32_t peer_certificate_status() override { + return mbedtls_ssl_get_verify_result(&ssl_); + } + + + string peer_certificate_status_message() override { + uint32_t verify_flags = mbedtls_ssl_get_verify_result(&ssl_); + if (verify_flags == 0 || verify_flags == UINT32_MAX) + return ""; + char message[512]; + mbedtls_x509_crt_verify_info(message, sizeof( message ), "", + verify_flags & ~MBEDTLS_X509_BADCERT_OTHER); + size_t len = strlen(message); + if (len > 0 && message[len] == '\0') + --len; + + string result(message, len); + if (verify_flags & MBEDTLS_X509_BADCERT_OTHER) { // flag set by verify_callback() + if (!result.empty()) + result = "\n" + result; + result = "The certificate does not match the known pinned certificate" + result; + } + return result; + } + + + string peer_certificate() override { + auto cert = mbedtls_ssl_get_peer_cert(&ssl_); + if (!cert) + return ""; + return string((const char*)cert->raw.p, cert->raw.len); + } + + + // -------- stream_socket I/O + + + ssize_t read(void *buf, size_t length) override { + return check_mbed_io( mbedtls_ssl_read(&ssl_, (uint8_t*)buf, length) ); + } + + + ioresult read_r(void *buf, size_t length) override { + return ioresult_from_mbed( mbedtls_ssl_read(&ssl_, (uint8_t*)buf, length) ); + } + + + bool read_timeout(const chrono::microseconds& to) override { + bool ok = stream().read_timeout(to); + if (ok) + read_timeout_ = to; + return ok; + } + + + ssize_t write(const void *buf, size_t length) override { + if (length == 0) + return 0; + return check_mbed_io( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) ); + } + + + ioresult write_r(const void *buf, size_t length) override { + if (length == 0) + return {}; + return ioresult_from_mbed( mbedtls_ssl_write(&ssl_, (const uint8_t*)buf, length) ); + } + + + bool write_timeout(const chrono::microseconds& to) override { + return stream().write_timeout(to); + } + + + bool set_non_blocking(bool nonblocking) override { + bool ok = stream().set_non_blocking(nonblocking); + if (ok) + setup_bio(nonblocking); + return ok; + } + + + // -------- mbedTLS BIO callbacks + + + int bio_send(const void* buf, size_t length) { + if (!open_) + return MBEDTLS_ERR_NET_CONN_RESET; + return bio_return_value(stream().write_r(buf, length)); + } + + + int bio_recv(void* buf, size_t length) { + if (!open_) + return MBEDTLS_ERR_NET_CONN_RESET; + return bio_return_value(stream().read_r(buf, length)); + } + + + int bio_recv_timeout(void* buf, size_t length, uint32_t timeout) { + if (!open_) + return MBEDTLS_ERR_NET_CONN_RESET; + if (timeout > 0) + stream().read_timeout(chrono::milliseconds(timeout)); + + int n = bio_recv(buf, length); + + if (timeout > 0) + stream().read_timeout(read_timeout_); + return (int)n; + } + + + // -------- error handling + + + // Translates mbedTLS error code to POSIX (errno) + static int translate_mbed_err(int mbedErr) { + switch (mbedErr) { + case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: + return 0; + case MBEDTLS_ERR_SSL_WANT_READ: + case MBEDTLS_ERR_SSL_WANT_WRITE: + log(">>> mbedtls_socket returning EWOULDBLOCK"); + return EWOULDBLOCK; + case MBEDTLS_ERR_NET_CONN_RESET: + return ECONNRESET; + case MBEDTLS_ERR_NET_RECV_FAILED: + case MBEDTLS_ERR_NET_SEND_FAILED: + return EIO; + default: + return mbedErr; + } + } + + + // Handles an mbedTLS error return value during setup, closing me on error + int check_mbed_setup(int ret, const char *fn) { + if (ret != 0) { + log_mbed_ret(ret, fn); + reset(); // marks me as closed/invalid + clear(translate_mbed_err(ret)); // sets last_error + stream().close(); + open_ = false; + } + return ret; + } + + + // Handles an mbedTLS read/write return value, storing any error in last_error + inline ssize_t check_mbed_io(int mbedResult) { + if (mbedResult < 0) { + clear(translate_mbed_err(mbedResult)); // sets last_error + return -1; + } + return mbedResult; + } + + + // Handles an mbedTLS read/write return value, converting it to an ioresult. + static inline ioresult ioresult_from_mbed(int mbedResult) { + if (mbedResult < 0) + return ioresult(0, translate_mbed_err(mbedResult)); + else + return ioresult(mbedResult, 0); + } + + + // Translates ioresult to an mbedTLS error code to return from a BIO function. + template + static int bio_return_value(ioresult result) { + if (result.error == 0) + return (int)result.count; + switch (result.error) { + case EPIPE: + case ECONNRESET: + return MBEDTLS_ERR_NET_CONN_RESET; + case EINTR: + case EWOULDBLOCK: +#if defined(EAGAIN) && EAGAIN != EWOULDBLOCK // these are usually synonyms + case EAGAIN: +#endif + log(">>> BIO returning MBEDTLS_ERR_SSL_WANT_%s", reading ?"READ":"WRITE"); + return reading ? MBEDTLS_ERR_SSL_WANT_READ + : MBEDTLS_ERR_SSL_WANT_WRITE; + default: + return reading ? MBEDTLS_ERR_NET_RECV_FAILED + : MBEDTLS_ERR_NET_SEND_FAILED; + } + } + + + }; + + +#pragma mark - CONTEXT: + + + static tls_context *s_default_context = nullptr; + + mbedtls_context::cert *mbedtls_context::s_system_root_certs; + + + tls_context& tls_context::default_context() { + if (!s_default_context) + s_default_context = new mbedtls_context(); + return *s_default_context; + } + + + // Returns a shared mbedTLS random-number generator context. + static mbedtls_ctr_drbg_context* get_drbg_context() { + static const char* k_entropy_personalization = "sockpp"; + static mbedtls_entropy_context s_entropy; + static mbedtls_ctr_drbg_context s_random_ctx; + + static once_flag once; + call_once(once, []() { + mbedtls_entropy_init( &s_entropy ); + mbedtls_ctr_drbg_init( &s_random_ctx ); + int ret = mbedtls_ctr_drbg_seed(&s_random_ctx, mbedtls_entropy_func, &s_entropy, + (const uint8_t *)k_entropy_personalization, + strlen(k_entropy_personalization)); + if (ret != 0) { + log_mbed_ret(ret, "mbedtls_ctr_drbg_seed"); + throw sys_error(ret); //FIXME: Not an errno; use different exception? + } + }); + return &s_random_ctx; + } + + + unique_ptr mbedtls_context::parse_cert(const std::string &cert_data, bool partialOk) { + unique_ptr c(new cert); + mbedtls_x509_crt_init(c.get()); + int ret = mbedtls_x509_crt_parse(c.get(), + (const uint8_t*)cert_data.data(), cert_data.size() + 1); + if (ret != 0) { + if(ret < 0 || !partialOk) { + log_mbed_ret(ret, "mbedtls_x509_crt_parse"); + if(ret > 0) { + ret = MBEDTLS_ERR_X509_CERT_VERIFY_FAILED; + } + + throw sys_error(ret); + } + } + return c; + } + + + void mbedtls_context::set_root_certs(const std::string &cert_data) { + root_certs_ = parse_cert(cert_data, true); + mbedtls_ssl_conf_ca_chain(ssl_config_.get(), root_certs_.get(), nullptr); + } + + + // Returns the set of system trusted root CA certs. + mbedtls_x509_crt* mbedtls_context::get_system_root_certs() { + static once_flag once; + call_once(once, []() { + // One-time initialization: + string certsPEM = read_system_root_certs(); + if (!certsPEM.empty()) + s_system_root_certs = parse_cert(certsPEM, true).release(); + }); + return s_system_root_certs; + } + + + mbedtls_context::mbedtls_context(role_t r) + :ssl_config_(new mbedtls_ssl_config) + { + mbedtls_ssl_config_init(ssl_config_.get()); + mbedtls_ssl_conf_rng(ssl_config_.get(), mbedtls_ctr_drbg_random, get_drbg_context()); + int endpoint = (r == CLIENT) ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER; + set_status(mbedtls_ssl_config_defaults(ssl_config_.get(), + endpoint, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT)); + if (status() != 0) + return; + + auto roots = get_system_root_certs(); + if (roots) + mbedtls_ssl_conf_ca_chain(ssl_config_.get(), roots, nullptr); + } + + + mbedtls_context::~mbedtls_context() { + mbedtls_ssl_config_free(ssl_config_.get()); + } + + + void mbedtls_context::set_logger(int threshold, Logger logger) { + if (!logger_) { + mbedtls_ssl_conf_dbg(ssl_config_.get(), [](void *ctx, int level, const char *file, int line, + const char *msg) { + auto &logger = ((mbedtls_context*)ctx)->logger_; + if (logger) + logger(level, file, line, msg); + }, this); + } + logger_ = logger; + mbedtls_debug_set_threshold(threshold); + } + + + void mbedtls_context::require_peer_cert(role_t forRole, bool require) { + if (forRole != role()) + return; + int authMode = (require ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_OPTIONAL); + mbedtls_ssl_conf_authmode(ssl_config_.get(), authMode); + } + + + void mbedtls_context::allow_only_certificate(const std::string &cert_data) { + pinned_cert_.reset(); + if (cert_data.empty()) { + mbedtls_ssl_conf_verify(ssl_config_.get(), nullptr, nullptr); + } else { + pinned_cert_ = parse_cert(cert_data, false); + // Install a custom verification callback: + mbedtls_ssl_conf_verify( + ssl_config_.get(), + [](void *ctx, mbedtls_x509_crt *crt, int depth, uint32_t *flags) { + return ((mbedtls_context*)ctx)->verify_callback(crt,depth,flags); + }, + this); + } + } + + + void mbedtls_context::allow_only_certificate(mbedtls_x509_crt *certificate) { + pinned_cert_.reset(); + if (certificate) { + string cert_data((const char*)certificate->raw.p, certificate->raw.len); + allow_only_certificate(cert_data); + } else { + mbedtls_ssl_conf_verify(ssl_config_.get(), nullptr, nullptr); + } + } + + + // callback from mbedTLS cert validation (see above) + int mbedtls_context::verify_callback(mbedtls_x509_crt *crt, int depth, uint32_t *flags) { + if (depth == 0) { + if (crt->raw.len == pinned_cert_->raw.len + && 0 == memcmp(crt->raw.p, pinned_cert_->raw.p, crt->raw.len)) { + // The cert matches our pinned cert, so mark it as trusted. + // (It might still be invalid if it's expired or revoked...) + *flags &= ~(MBEDTLS_X509_BADCERT_NOT_TRUSTED | MBEDTLS_X509_BADCERT_CN_MISMATCH); + } else { + // If cert doesn't match pinned cert, mark it as untrusted. + *flags |= MBEDTLS_X509_BADCERT_OTHER; + } + } + return 0; + } + + + void mbedtls_context::set_identity(const std::string &certificate_data, + const std::string &private_key_data) + { + auto ident_cert = parse_cert(certificate_data, false); + + unique_ptr ident_key(new key); + int err = mbedtls_pk_parse_key(ident_key.get(), + (const uint8_t*) private_key_data.data(), + private_key_data.size(), NULL, 0); + if( err != 0 ) { + log_mbed_ret(err, "mbedtls_pk_parse_key"); + throw sys_error(err); + } + + set_identity(ident_cert.get(), ident_key.get()); + identity_cert_ = move(ident_cert); + identity_key_ = move(ident_key); + } + + + void mbedtls_context::set_identity(mbedtls_x509_crt *certificate, + mbedtls_pk_context *private_key) + { + mbedtls_ssl_conf_own_cert(ssl_config_.get(), certificate, private_key); + } + + + mbedtls_context::role_t mbedtls_context::role() { + return (ssl_config_->endpoint == MBEDTLS_SSL_IS_CLIENT) ? CLIENT : SERVER; + } + + + unique_ptr mbedtls_context::wrap_socket(std::unique_ptr socket, + role_t socketRole, + const std::string &peer_name) + { + assert(socketRole == role()); + return unique_ptr(new mbedtls_socket(move(socket), *this, peer_name)); + } + + +#pragma mark - PLATFORM SPECIFIC: + + + // mbedTLS does not have built-in support for reading the OS's trusted root certs. + +#ifdef __APPLE__ + // Read system root CA certs on macOS. + // (Sadly, SecTrustCopyAnchorCertificates() is not available on iOS) + static string read_system_root_certs() { + #if TARGET_OS_OSX + CFArrayRef roots; + OSStatus err = SecTrustCopyAnchorCertificates(&roots); + if (err) + return {}; + CFDataRef pemData = nullptr; + err = SecItemExport(roots, kSecFormatPEMSequence, kSecItemPemArmour, nullptr, &pemData); + CFRelease(roots); + if (err) + return {}; + string pem((const char*)CFDataGetBytePtr(pemData), CFDataGetLength(pemData)); + CFRelease(pemData); + return pem; + #else + // fallback -- no certs + return ""; + #endif + } + +#elif defined(_WIN32) + // Windows: + static string read_system_root_certs() { + PCCERT_CONTEXT pContext = nullptr; + HCERTSTORE hStore = CertOpenSystemStore(NULL, "ROOT"); + if(hStore == nullptr) { + return ""; + } + + stringstream certs; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext))) { + certs.write((const char*)pContext->pbCertEncoded, pContext->cbCertEncoded); + } + + CertCloseStore(hStore, CERT_CLOSE_STORE_FORCE_FLAG); + return certs.str(); + } + +#else + // Read system root CA certs on Linux using OpenSSL's cert directory + static string read_system_root_certs() { + static constexpr const char* CERTS_DIR = "/etc/ssl/certs/"; + static constexpr const char* CERTS_FILE = "ca-certificates.crt"; + + stringstream certs; + char buf[1024]; + // Subroutine to append a file to the `certs` stream: + auto read_file = [&](const string &file) { + ifstream in(file); + char last_char = '\n'; + while (in) { + in.read(buf, sizeof(buf)); + auto n = in.gcount(); + if (n > 0) { + certs.write(buf, n); + last_char = buf[n-1]; + } + } + if (last_char != '\n') + certs << '\n'; + }; + + struct stat s; + if (stat(CERTS_DIR, &s) == 0 && S_ISDIR(s.st_mode)) { + string certs_file = string(CERTS_DIR) + CERTS_FILE; + if (stat(certs_file.c_str(), &s) == 0) { + // If there is a file containing all the certs, just read it: + read_file(certs_file); + } else { + // Otherwise concatenate all the certs found in the dir: + auto dir = opendir(CERTS_DIR); + if (dir) { + struct dirent *ent; + while (nullptr != (ent = readdir(dir))) { + if (fnmatch("?*.pem", ent->d_name, FNM_PERIOD) == 0 + || fnmatch("?*.crt", ent->d_name, FNM_PERIOD) == 0) + read_file(string(CERTS_DIR) + ent->d_name); + } + closedir(dir); + } + } + } + return certs.str(); + } + +#endif + + +} diff --git a/src/socket.cpp b/src/socket.cpp index 1f3d850..3255771 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -294,7 +294,11 @@ std::string socket::error_str(int err) bool socket::shutdown(int how /*=SHUT_RDWR*/) { - return check_ret_bool(::shutdown(handle_, how)); + if(handle_ != INVALID_SOCKET) { + return check_ret_bool(::shutdown(release(), how)); + } + + return false; } // -------------------------------------------------------------------------- diff --git a/src/stream_socket.cpp b/src/stream_socket.cpp index 852edfa..b2b62e5 100644 --- a/src/stream_socket.cpp +++ b/src/stream_socket.cpp @@ -69,6 +69,16 @@ ssize_t stream_socket::read(void *buf, size_t n) #endif } +ioresult stream_socket::read_r(void *buf, size_t n) +{ + #if defined(_WIN32) + return ioresult(::recv(handle(), reinterpret_cast(buf), + int(n), 0)); + #else + return ioresult(::recv(handle(), buf, n, 0)); + #endif +} + // -------------------------------------------------------------------------- // Attempts to read the requested number of bytes by repeatedly calling // read() until it has the data or an error occurs. @@ -94,7 +104,6 @@ ssize_t stream_socket::read_n(void *buf, size_t n) return (nr == 0 && nx < 0) ? nx : ssize_t(nr); } - // -------------------------------------------------------------------------- ssize_t stream_socket::read(const std::vector& ranges) @@ -122,6 +131,23 @@ ssize_t stream_socket::read(const std::vector& ranges) #endif } +ioresult stream_socket::read_n_r(void *buf, size_t n) +{ + ioresult result; + uint8_t *b = reinterpret_cast(buf); + + while (result.count < n) { + ioresult r = read_r(b + result.count, n - result.count); + if (r.count == 0) { + result.error = r.error; + break; + } + result.count += r.count; + } + + return result; +} + // -------------------------------------------------------------------------- bool stream_socket::read_timeout(const microseconds& to) @@ -147,6 +173,16 @@ ssize_t stream_socket::write(const void *buf, size_t n) #endif } +ioresult stream_socket::write_r(const void *buf, size_t n) +{ + #if defined(_WIN32) + return ioresult(::send(handle(), reinterpret_cast(buf), + int(n) , 0)); + #else + return ioresult(::send(handle(), buf, n , 0)); + #endif +} + // -------------------------------------------------------------------------- // Attempts to write the entire buffer by repeatedly calling write() until // either all of the data is sent or an error occurs. @@ -171,32 +207,53 @@ ssize_t stream_socket::write_n(const void *buf, size_t n) return (nw == 0 && nx < 0) ? nx : ssize_t(nw); } -// -------------------------------------------------------------------------- - -ssize_t stream_socket::write(const std::vector& ranges) +ioresult stream_socket::write_n_r(const void *buf, size_t n) { - if (ranges.empty()) - return 0; + ioresult result; + const uint8_t *b = reinterpret_cast(buf); - #if !defined(_WIN32) - return check_ret(::writev(handle(), ranges.data(), int(ranges.size()))); - #else - std::vector bufs; - for (const auto& iovec : ranges) { - bufs.push_back({ - static_cast(iovec.iov_len), - static_cast(iovec.iov_base) - }); - } + while (result.count < n) { + ioresult r = write_r(b + result.count, n - result.count); + if (r.count == 0) { + result.error = r.error; + break; + } + result.count += r.count; + } - DWORD nwritten = 0, - nmsg = DWORD(bufs.size()); - - auto ret = check_ret(::WSASend(handle(), bufs.data(), nmsg, &nwritten, 0, nullptr, nullptr)); - return ssize_t(ret == SOCKET_ERROR ? ret : nwritten); - #endif + return result; } +// -------------------------------------------------------------------------- + + +ssize_t stream_socket::write(const std::vector &ranges) +{ + #if !defined(_WIN32) + msghdr msg = {}; + msg.msg_iov = const_cast(ranges.data()); + msg.msg_iovlen = int(ranges.size()); + if (msg.msg_iovlen == 0) + return 0; + return check_ret(sendmsg(handle(), &msg, 0)); + #else + if(ranges.empty()) { + return 0; + } + + std::vector buffers; + for(const auto& iovec : ranges) { + buffers.push_back({ + static_cast(iovec.iov_len), + static_cast(iovec.iov_base) + }); + } + + DWORD written = 0; + ssize_t ret = check_ret(WSASend(handle(), buffers.data(), buffers.size(), &written, 0, nullptr, nullptr)); + return ret == SOCKET_ERROR ? ret : written; + #endif +} // -------------------------------------------------------------------------- @@ -212,6 +269,8 @@ bool stream_socket::write_timeout(const microseconds& to) return set_option(SOL_SOCKET, SO_SNDTIMEO, tv); } + + ///////////////////////////////////////////////////////////////////////////// // end namespace sockpp }