mirror of
https://github.com/sogou/workflow.git
synced 2026-02-08 01:33:17 +08:00
Enable setting SSL_CTX for each client task.
This commit is contained in:
@@ -103,7 +103,8 @@ bool ComplexDnsTask::init_success()
|
||||
|
||||
auto *ep = &WFGlobal::get_global_settings()->dns_server_params;
|
||||
ret = WFGlobal::get_route_manager()->get(type, addr, info_, ep,
|
||||
uri_.host, route_result_);
|
||||
uri_.host, ssl_ctx_,
|
||||
route_result_);
|
||||
freeaddrinfo(addr);
|
||||
if (ret < 0)
|
||||
{
|
||||
|
||||
@@ -484,7 +484,8 @@ private:
|
||||
|
||||
int ComplexHttpProxyTask::init_ssl_connection()
|
||||
{
|
||||
SSL *ssl = __create_ssl(WFGlobal::get_ssl_client_ctx());
|
||||
static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx();
|
||||
SSL *ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx);
|
||||
WFConnection *conn;
|
||||
|
||||
if (!ssl)
|
||||
|
||||
@@ -350,7 +350,21 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp)
|
||||
|
||||
if (is_ssl_)
|
||||
{
|
||||
if (!(resp->get_capability_flags() & 0x800))
|
||||
if (resp->get_capability_flags() & 0x800)
|
||||
{
|
||||
static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx();
|
||||
|
||||
ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx);
|
||||
if (!ssl)
|
||||
{
|
||||
state_ = WFT_STATE_SYS_ERROR;
|
||||
error_ = errno;
|
||||
return 0;
|
||||
}
|
||||
|
||||
SSL_set_connect_state(ssl);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->resp = std::move(*(MySQLResponse *)resp);
|
||||
state_ = WFT_STATE_TASK_ERROR;
|
||||
@@ -358,15 +372,6 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp)
|
||||
return 0;
|
||||
}
|
||||
|
||||
ssl = __create_ssl(WFGlobal::get_ssl_client_ctx());
|
||||
if (!ssl)
|
||||
{
|
||||
state_ = WFT_STATE_SYS_ERROR;
|
||||
error_ = errno;
|
||||
return 0;
|
||||
}
|
||||
|
||||
SSL_set_connect_state(ssl);
|
||||
}
|
||||
|
||||
auto *conn = this->get_connection();
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include <time.h>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <openssl/ssl.h>
|
||||
#include "URIParser.h"
|
||||
#include "RedisMessage.h"
|
||||
#include "HttpMessage.h"
|
||||
@@ -445,6 +446,12 @@ public:
|
||||
int retry_max,
|
||||
std::function<void (T *)> callback);
|
||||
|
||||
static T *create_client_task(enum TransportType type,
|
||||
const struct sockaddr *addr,
|
||||
socklen_t addrlen,
|
||||
SSL_CTX *ssl_ctx,
|
||||
int retry_max,
|
||||
std::function<void (T *)> callback);
|
||||
public:
|
||||
static T *create_server_task(CommService *service,
|
||||
std::function<void (T *)>& process);
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <atomic>
|
||||
#include <openssl/ssl.h>
|
||||
#include "WFGlobal.h"
|
||||
#include "Workflow.h"
|
||||
#include "WFTask.h"
|
||||
@@ -75,6 +76,7 @@ public:
|
||||
WFClientTask<REQ, RESP>(NULL, WFGlobal::get_scheduler(), std::move(cb))
|
||||
{
|
||||
type_ = TT_TCP;
|
||||
ssl_ctx_ = NULL;
|
||||
fixed_addr_ = false;
|
||||
retry_max_ = retry_max;
|
||||
retry_times_ = 0;
|
||||
@@ -116,6 +118,8 @@ public:
|
||||
|
||||
enum TransportType get_transport_type() const { return type_; }
|
||||
|
||||
void set_ssl_ctx(SSL_CTX *ssl_ctx) { ssl_ctx_ = ssl_ctx; }
|
||||
|
||||
virtual const ParsedURI *get_current_uri() const { return &uri_; }
|
||||
|
||||
void set_redirect(const ParsedURI& uri)
|
||||
@@ -168,6 +172,7 @@ protected:
|
||||
enum TransportType type_;
|
||||
ParsedURI uri_;
|
||||
std::string info_;
|
||||
SSL_CTX *ssl_ctx_;
|
||||
bool fixed_addr_;
|
||||
bool redirect_;
|
||||
CTX ctx_;
|
||||
@@ -225,7 +230,7 @@ void WFComplexClientTask<REQ, RESP, CTX>::init(enum TransportType type,
|
||||
info_.assign(info);
|
||||
params.use_tls_sni = false;
|
||||
if (WFGlobal::get_route_manager()->get(type, &addrinfo, info_, ¶ms,
|
||||
"", route_result_) < 0)
|
||||
"", ssl_ctx_, route_result_) < 0)
|
||||
{
|
||||
this->state = WFT_STATE_SYS_ERROR;
|
||||
this->error = errno;
|
||||
@@ -315,6 +320,7 @@ WFRouterTask *WFComplexClientTask<REQ, RESP, CTX>::route()
|
||||
.type = type_,
|
||||
.uri = uri_,
|
||||
.info = info_.c_str(),
|
||||
.ssl_ctx = ssl_ctx_,
|
||||
.fixed_addr = fixed_addr_,
|
||||
.retry_times = retry_times_,
|
||||
.tracing = &tracing_,
|
||||
@@ -483,15 +489,19 @@ WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
|
||||
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
|
||||
{
|
||||
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
|
||||
char buf[8];
|
||||
std::string url = "scheme://";
|
||||
ParsedURI uri;
|
||||
char buf[32];
|
||||
|
||||
sprintf(buf, "%u", port);
|
||||
url += host;
|
||||
url += ":";
|
||||
url += buf;
|
||||
URIParser::parse(url, uri);
|
||||
uri.scheme = strdup("scheme");
|
||||
uri.host = strdup(host.c_str());
|
||||
uri.port = strdup(buf);
|
||||
if (!uri.scheme || !uri.host || !uri.port)
|
||||
{
|
||||
uri.state = URI_STATE_ERROR;
|
||||
uri.error = errno;
|
||||
}
|
||||
|
||||
task->init(std::move(uri));
|
||||
task->set_transport_type(type);
|
||||
return task;
|
||||
@@ -541,6 +551,22 @@ WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
|
||||
return task;
|
||||
}
|
||||
|
||||
template<class REQ, class RESP>
|
||||
WFNetworkTask<REQ, RESP> *
|
||||
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
|
||||
const struct sockaddr *addr,
|
||||
socklen_t addrlen,
|
||||
SSL_CTX *ssl_ctx,
|
||||
int retry_max,
|
||||
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
|
||||
{
|
||||
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
|
||||
|
||||
task->set_ssl_ctx(ssl_ctx);
|
||||
task->init(type, addr, addrlen, "");
|
||||
return task;
|
||||
}
|
||||
|
||||
template<class REQ, class RESP>
|
||||
WFNetworkTask<REQ, RESP> *
|
||||
WFNetworkTaskFactory<REQ, RESP>::create_server_task(CommService *service,
|
||||
|
||||
@@ -424,7 +424,8 @@ static uint64_t __generate_key(enum TransportType type,
|
||||
const struct addrinfo *addrinfo,
|
||||
const std::string& other_info,
|
||||
const struct EndpointParams *ep_params,
|
||||
const std::string& hostname)
|
||||
const std::string& hostname,
|
||||
SSL_CTX *ssl_ctx)
|
||||
{
|
||||
const int params[] = {
|
||||
ep_params->address_family, (int)ep_params->max_connections,
|
||||
@@ -438,6 +439,7 @@ static uint64_t __generate_key(enum TransportType type,
|
||||
buf.append((const char *)params, sizeof params);
|
||||
if (type == TT_TCP_SSL || type == TT_SCTP_SSL)
|
||||
{
|
||||
buf.append((const char *)&ssl_ctx, sizeof (void *));
|
||||
buf.append((const char *)&ep_params->ssl_connect_timeout, sizeof (int));
|
||||
if (ep_params->use_tls_sni)
|
||||
{
|
||||
@@ -491,11 +493,18 @@ int RouteManager::get(enum TransportType type,
|
||||
const struct addrinfo *addrinfo,
|
||||
const std::string& other_info,
|
||||
const struct EndpointParams *ep_params,
|
||||
const std::string& hostname,
|
||||
const std::string& hostname, SSL_CTX *ssl_ctx,
|
||||
RouteResult& result)
|
||||
{
|
||||
uint64_t key = __generate_key(type, addrinfo, other_info,
|
||||
ep_params, hostname);
|
||||
if (type == TT_TCP_SSL || type == TT_SCTP_SSL)
|
||||
{
|
||||
static SSL_CTX *global_client_ctx = WFGlobal::get_ssl_client_ctx();
|
||||
if (ssl_ctx == NULL)
|
||||
ssl_ctx = global_client_ctx;
|
||||
}
|
||||
|
||||
uint64_t key = __generate_key(type, addrinfo, other_info, ep_params,
|
||||
hostname, ssl_ctx);
|
||||
struct rb_node **p = &cache_.rb_node;
|
||||
struct rb_node *parent = NULL;
|
||||
RouteResultEntry *bound = NULL;
|
||||
@@ -522,17 +531,6 @@ int RouteManager::get(enum TransportType type,
|
||||
}
|
||||
else
|
||||
{
|
||||
int ssl_connect_timeout = 0;
|
||||
SSL_CTX *ssl_ctx = NULL;
|
||||
|
||||
if (type == TT_TCP_SSL || type == TT_SCTP_SSL)
|
||||
{
|
||||
static SSL_CTX *client_ssl_ctx = WFGlobal::get_ssl_client_ctx();
|
||||
|
||||
ssl_ctx = client_ssl_ctx;
|
||||
ssl_connect_timeout = ep_params->ssl_connect_timeout;
|
||||
}
|
||||
|
||||
struct RouteParams params = {
|
||||
.transport_type = type,
|
||||
.addrinfo = addrinfo,
|
||||
@@ -541,7 +539,7 @@ int RouteManager::get(enum TransportType type,
|
||||
.max_connections = ep_params->max_connections,
|
||||
.connect_timeout = ep_params->connect_timeout,
|
||||
.response_timeout = ep_params->response_timeout,
|
||||
.ssl_connect_timeout = ssl_connect_timeout,
|
||||
.ssl_connect_timeout = ep_params->ssl_connect_timeout,
|
||||
.use_tls_sni = ep_params->use_tls_sni,
|
||||
.hostname = hostname,
|
||||
};
|
||||
|
||||
@@ -63,7 +63,7 @@ public:
|
||||
const struct addrinfo *addrinfo,
|
||||
const std::string& other_info,
|
||||
const struct EndpointParams *ep_params,
|
||||
const std::string& hostname,
|
||||
const std::string& hostname, SSL_CTX *ssl_ctx,
|
||||
RouteResult& result);
|
||||
|
||||
RouteManager()
|
||||
|
||||
@@ -423,7 +423,8 @@ void WFResolverTask::dispatch()
|
||||
}
|
||||
|
||||
if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info,
|
||||
&ep_params_, hostname, this->result) < 0)
|
||||
&ep_params_, hostname, ns_params_.ssl_ctx,
|
||||
this->result) < 0)
|
||||
{
|
||||
this->state = WFT_STATE_SYS_ERROR;
|
||||
this->error = errno;
|
||||
@@ -618,7 +619,8 @@ void WFResolverTask::dns_callback_internal(void *thrd_dns_output,
|
||||
(unsigned int)ttl_default,
|
||||
(unsigned int)ttl_min);
|
||||
if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info,
|
||||
&ep_params_, hostname, this->result) < 0)
|
||||
&ep_params_, hostname, ns_params_.ssl_ctx,
|
||||
this->result) < 0)
|
||||
{
|
||||
this->state = WFT_STATE_SYS_ERROR;
|
||||
this->error = errno;
|
||||
|
||||
@@ -81,6 +81,7 @@ struct WFNSParams
|
||||
enum TransportType type;
|
||||
ParsedURI& uri;
|
||||
const char *info;
|
||||
SSL_CTX *ssl_ctx;
|
||||
bool fixed_addr;
|
||||
int retry_times;
|
||||
WFNSTracing *tracing;
|
||||
|
||||
Reference in New Issue
Block a user