Enable setting SSL_CTX for each client task.

This commit is contained in:
Xie Han
2024-04-18 21:42:24 +08:00
committed by xiehan
parent e55ac75d6f
commit e00a31a8d8
9 changed files with 79 additions and 38 deletions

View File

@@ -103,7 +103,8 @@ bool ComplexDnsTask::init_success()
auto *ep = &WFGlobal::get_global_settings()->dns_server_params;
ret = WFGlobal::get_route_manager()->get(type, addr, info_, ep,
uri_.host, route_result_);
uri_.host, ssl_ctx_,
route_result_);
freeaddrinfo(addr);
if (ret < 0)
{

View File

@@ -484,7 +484,8 @@ private:
int ComplexHttpProxyTask::init_ssl_connection()
{
SSL *ssl = __create_ssl(WFGlobal::get_ssl_client_ctx());
static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx();
SSL *ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx);
WFConnection *conn;
if (!ssl)

View File

@@ -350,7 +350,21 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp)
if (is_ssl_)
{
if (!(resp->get_capability_flags() & 0x800))
if (resp->get_capability_flags() & 0x800)
{
static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx();
ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx);
if (!ssl)
{
state_ = WFT_STATE_SYS_ERROR;
error_ = errno;
return 0;
}
SSL_set_connect_state(ssl);
}
else
{
this->resp = std::move(*(MySQLResponse *)resp);
state_ = WFT_STATE_TASK_ERROR;
@@ -358,15 +372,6 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp)
return 0;
}
ssl = __create_ssl(WFGlobal::get_ssl_client_ctx());
if (!ssl)
{
state_ = WFT_STATE_SYS_ERROR;
error_ = errno;
return 0;
}
SSL_set_connect_state(ssl);
}
auto *conn = this->get_connection();

View File

@@ -25,6 +25,7 @@
#include <time.h>
#include <utility>
#include <functional>
#include <openssl/ssl.h>
#include "URIParser.h"
#include "RedisMessage.h"
#include "HttpMessage.h"
@@ -445,6 +446,12 @@ public:
int retry_max,
std::function<void (T *)> callback);
static T *create_client_task(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
SSL_CTX *ssl_ctx,
int retry_max,
std::function<void (T *)> callback);
public:
static T *create_server_task(CommService *service,
std::function<void (T *)>& process);

View File

@@ -29,6 +29,7 @@
#include <functional>
#include <utility>
#include <atomic>
#include <openssl/ssl.h>
#include "WFGlobal.h"
#include "Workflow.h"
#include "WFTask.h"
@@ -75,6 +76,7 @@ public:
WFClientTask<REQ, RESP>(NULL, WFGlobal::get_scheduler(), std::move(cb))
{
type_ = TT_TCP;
ssl_ctx_ = NULL;
fixed_addr_ = false;
retry_max_ = retry_max;
retry_times_ = 0;
@@ -116,6 +118,8 @@ public:
enum TransportType get_transport_type() const { return type_; }
void set_ssl_ctx(SSL_CTX *ssl_ctx) { ssl_ctx_ = ssl_ctx; }
virtual const ParsedURI *get_current_uri() const { return &uri_; }
void set_redirect(const ParsedURI& uri)
@@ -168,6 +172,7 @@ protected:
enum TransportType type_;
ParsedURI uri_;
std::string info_;
SSL_CTX *ssl_ctx_;
bool fixed_addr_;
bool redirect_;
CTX ctx_;
@@ -225,7 +230,7 @@ void WFComplexClientTask<REQ, RESP, CTX>::init(enum TransportType type,
info_.assign(info);
params.use_tls_sni = false;
if (WFGlobal::get_route_manager()->get(type, &addrinfo, info_, &params,
"", route_result_) < 0)
"", ssl_ctx_, route_result_) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
@@ -315,6 +320,7 @@ WFRouterTask *WFComplexClientTask<REQ, RESP, CTX>::route()
.type = type_,
.uri = uri_,
.info = info_.c_str(),
.ssl_ctx = ssl_ctx_,
.fixed_addr = fixed_addr_,
.retry_times = retry_times_,
.tracing = &tracing_,
@@ -483,15 +489,19 @@ WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
{
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
char buf[8];
std::string url = "scheme://";
ParsedURI uri;
char buf[32];
sprintf(buf, "%u", port);
url += host;
url += ":";
url += buf;
URIParser::parse(url, uri);
uri.scheme = strdup("scheme");
uri.host = strdup(host.c_str());
uri.port = strdup(buf);
if (!uri.scheme || !uri.host || !uri.port)
{
uri.state = URI_STATE_ERROR;
uri.error = errno;
}
task->init(std::move(uri));
task->set_transport_type(type);
return task;
@@ -541,6 +551,22 @@ WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
return task;
}
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
SSL_CTX *ssl_ctx,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
{
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
task->set_ssl_ctx(ssl_ctx);
task->init(type, addr, addrlen, "");
return task;
}
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_server_task(CommService *service,

View File

@@ -424,7 +424,8 @@ static uint64_t __generate_key(enum TransportType type,
const struct addrinfo *addrinfo,
const std::string& other_info,
const struct EndpointParams *ep_params,
const std::string& hostname)
const std::string& hostname,
SSL_CTX *ssl_ctx)
{
const int params[] = {
ep_params->address_family, (int)ep_params->max_connections,
@@ -438,6 +439,7 @@ static uint64_t __generate_key(enum TransportType type,
buf.append((const char *)params, sizeof params);
if (type == TT_TCP_SSL || type == TT_SCTP_SSL)
{
buf.append((const char *)&ssl_ctx, sizeof (void *));
buf.append((const char *)&ep_params->ssl_connect_timeout, sizeof (int));
if (ep_params->use_tls_sni)
{
@@ -491,11 +493,18 @@ int RouteManager::get(enum TransportType type,
const struct addrinfo *addrinfo,
const std::string& other_info,
const struct EndpointParams *ep_params,
const std::string& hostname,
const std::string& hostname, SSL_CTX *ssl_ctx,
RouteResult& result)
{
uint64_t key = __generate_key(type, addrinfo, other_info,
ep_params, hostname);
if (type == TT_TCP_SSL || type == TT_SCTP_SSL)
{
static SSL_CTX *global_client_ctx = WFGlobal::get_ssl_client_ctx();
if (ssl_ctx == NULL)
ssl_ctx = global_client_ctx;
}
uint64_t key = __generate_key(type, addrinfo, other_info, ep_params,
hostname, ssl_ctx);
struct rb_node **p = &cache_.rb_node;
struct rb_node *parent = NULL;
RouteResultEntry *bound = NULL;
@@ -522,17 +531,6 @@ int RouteManager::get(enum TransportType type,
}
else
{
int ssl_connect_timeout = 0;
SSL_CTX *ssl_ctx = NULL;
if (type == TT_TCP_SSL || type == TT_SCTP_SSL)
{
static SSL_CTX *client_ssl_ctx = WFGlobal::get_ssl_client_ctx();
ssl_ctx = client_ssl_ctx;
ssl_connect_timeout = ep_params->ssl_connect_timeout;
}
struct RouteParams params = {
.transport_type = type,
.addrinfo = addrinfo,
@@ -541,7 +539,7 @@ int RouteManager::get(enum TransportType type,
.max_connections = ep_params->max_connections,
.connect_timeout = ep_params->connect_timeout,
.response_timeout = ep_params->response_timeout,
.ssl_connect_timeout = ssl_connect_timeout,
.ssl_connect_timeout = ep_params->ssl_connect_timeout,
.use_tls_sni = ep_params->use_tls_sni,
.hostname = hostname,
};

View File

@@ -63,7 +63,7 @@ public:
const struct addrinfo *addrinfo,
const std::string& other_info,
const struct EndpointParams *ep_params,
const std::string& hostname,
const std::string& hostname, SSL_CTX *ssl_ctx,
RouteResult& result);
RouteManager()

View File

@@ -423,7 +423,8 @@ void WFResolverTask::dispatch()
}
if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info,
&ep_params_, hostname, this->result) < 0)
&ep_params_, hostname, ns_params_.ssl_ctx,
this->result) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
@@ -618,7 +619,8 @@ void WFResolverTask::dns_callback_internal(void *thrd_dns_output,
(unsigned int)ttl_default,
(unsigned int)ttl_min);
if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info,
&ep_params_, hostname, this->result) < 0)
&ep_params_, hostname, ns_params_.ssl_ctx,
this->result) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;

View File

@@ -81,6 +81,7 @@ struct WFNSParams
enum TransportType type;
ParsedURI& uri;
const char *info;
SSL_CTX *ssl_ctx;
bool fixed_addr;
int retry_times;
WFNSTracing *tracing;