Update windows branch codes.

This commit is contained in:
Xie Han
2024-05-09 02:01:09 +08:00
parent 95e45eb549
commit 2ab8e95bb2
25 changed files with 1302 additions and 565 deletions

View File

@@ -77,6 +77,7 @@ public:
enum TransportType transport_type;
std::string scheme;
std::vector<std::string> broker_hosts;
SSL_CTX *ssl_ctx;
KafkaCgroup cgroup;
KafkaMetaList meta_list;
KafkaBrokerMap broker_map;
@@ -192,7 +193,7 @@ private:
int dispatch_locked();
inline KafkaBroker *get_broker(int node_id)
KafkaBroker *get_broker(int node_id)
{
return this->member->broker_map.find_item(node_id);
}
@@ -294,7 +295,7 @@ void KafkaClientTask::kafka_rebalance_callback(__WFKafkaTask *task)
kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type,
coordinator->get_host(),
coordinator->get_port(),
"", 0,
member->ssl_ctx, "", 0,
kafka_heartbeat_callback);
kafka_task->user_data = member;
kafka_task->get_req()->set_api_type(Kafka_Heartbeat);
@@ -327,7 +328,7 @@ void KafkaClientTask::kafka_rebalance_proc(KafkaMember *member, SeriesWork *seri
task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type,
coordinator->get_host(),
coordinator->get_port(),
"", 0,
member->ssl_ctx, "", 0,
kafka_rebalance_callback);
task->user_data = member;
task->get_req()->set_config(member->config);
@@ -392,7 +393,7 @@ void KafkaClientTask::kafka_timer_callback(WFTimerTask *task)
kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type,
coordinator->get_host(),
coordinator->get_port(),
"", 0,
member->ssl_ctx, "", 0,
kafka_heartbeat_callback);
kafka_task->user_data = member;
@@ -496,69 +497,70 @@ void KafkaClientTask::kafka_meta_callback(__WFKafkaTask *task)
void KafkaClientTask::kafka_cgroup_callback(__WFKafkaTask *task)
{
KafkaClientTask *t = (KafkaClientTask *)task->user_data;
KafkaMember *member = t->member;
SeriesWork *heartbeat_series = NULL;
void *msg = NULL;
size_t max;
t->member->mutex.lock();
member->mutex.lock();
t->state = task->get_state();
t->error = task->get_error();
t->kafka_error = *static_cast<ComplexKafkaTask *>(task)->get_mutable_ctx();
if (t->state == WFT_STATE_SUCCESS)
{
t->member->cgroup = std::move(*(task->get_resp()->get_cgroup()));
member->cgroup = std::move(*(task->get_resp()->get_cgroup()));
kafka_merge_meta_list(&t->member->meta_list,
kafka_merge_meta_list(&member->meta_list,
task->get_resp()->get_meta_list());
t->meta_list.rewind();
KafkaMeta *meta;
while ((meta = t->meta_list.get_next()) != NULL)
(t->member->meta_status)[meta->get_topic()] = true;
(member->meta_status)[meta->get_topic()] = true;
kafka_merge_broker_list(t->member->scheme,
&t->member->broker_hosts,
&t->member->broker_map,
kafka_merge_broker_list(member->scheme,
&member->broker_hosts,
&member->broker_map,
task->get_resp()->get_broker_list());
t->member->cgroup_status = KAFKA_CGROUP_DONE;
member->cgroup_status = KAFKA_CGROUP_DONE;
if (t->member->heartbeat_status == KAFKA_HEARTBEAT_UNINIT)
if (member->heartbeat_status == KAFKA_HEARTBEAT_UNINIT)
{
__WFKafkaTask *kafka_task;
KafkaBroker *coordinator = t->member->cgroup.get_coordinator();
kafka_task = __WFKafkaTaskFactory::create_kafka_task(t->member->transport_type,
KafkaBroker *coordinator = member->cgroup.get_coordinator();
kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type,
coordinator->get_host(),
coordinator->get_port(),
"", 0,
member->ssl_ctx, "", 0,
kafka_heartbeat_callback);
kafka_task->user_data = t->member;
t->member->incref();
kafka_task->user_data = member;
member->incref();
kafka_task->get_req()->set_config(t->member->config);
kafka_task->get_req()->set_config(member->config);
kafka_task->get_req()->set_api_type(Kafka_Heartbeat);
kafka_task->get_req()->set_cgroup(t->member->cgroup);
kafka_task->get_req()->set_cgroup(member->cgroup);
kafka_task->get_req()->set_broker(*coordinator);
heartbeat_series = Workflow::create_series_work(kafka_task, nullptr);
t->member->heartbeat_status = KAFKA_HEARTBEAT_DOING;
t->member->heartbeat_series = heartbeat_series;
member->heartbeat_status = KAFKA_HEARTBEAT_DOING;
member->heartbeat_series = heartbeat_series;
}
}
else
{
t->member->cgroup_status = KAFKA_CGROUP_UNINIT;
t->member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT;
t->member->heartbeat_series = NULL;
member->cgroup_status = KAFKA_CGROUP_UNINIT;
member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT;
member->heartbeat_series = NULL;
t->finish = true;
msg = t;
}
max = t->member->cgroup_wait_cnt;
max = member->cgroup_wait_cnt;
char name[64];
snprintf(name, 64, "%p.cgroup", t->member);
t->member->mutex.unlock();
snprintf(name, 64, "%p.cgroup", member);
member->mutex.unlock();
WFTaskFactory::signal_by_name(name, msg, max);
@@ -789,16 +791,17 @@ bool KafkaClientTask::compare_topics(KafkaClientTask *task)
bool KafkaClientTask::check_cgroup()
{
if (this->member->cgroup_outdated &&
this->member->cgroup_status != KAFKA_CGROUP_DOING)
KafkaMember *member = this->member;
if (member->cgroup_outdated && member->cgroup_status != KAFKA_CGROUP_DOING)
{
this->member->cgroup_outdated = false;
this->member->cgroup_status = KAFKA_CGROUP_UNINIT;
this->member->heartbeat_series = NULL;
this->member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT;
member->cgroup_outdated = false;
member->cgroup_status = KAFKA_CGROUP_UNINIT;
member->heartbeat_series = NULL;
member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT;
}
if (this->member->cgroup_status == KAFKA_CGROUP_DOING)
if (member->cgroup_status == KAFKA_CGROUP_DOING)
{
WFConditional *cond;
char name[64];
@@ -806,27 +809,27 @@ bool KafkaClientTask::check_cgroup()
this->wait_cgroup = true;
cond = WFTaskFactory::create_conditional(name, this, &this->msg);
series_of(this)->push_front(cond);
this->member->cgroup_wait_cnt++;
member->cgroup_wait_cnt++;
return false;
}
if ((this->api_type == Kafka_Fetch || this->api_type == Kafka_OffsetCommit) &&
(this->member->cgroup_status == KAFKA_CGROUP_UNINIT))
(member->cgroup_status == KAFKA_CGROUP_UNINIT))
{
__WFKafkaTask *task;
task = __WFKafkaTaskFactory::create_kafka_task(this->url,
task = __WFKafkaTaskFactory::create_kafka_task(this->url, member->ssl_ctx,
this->retry_max,
kafka_cgroup_callback);
task->user_data = this;
task->get_req()->set_config(this->config);
task->get_req()->set_api_type(Kafka_FindCoordinator);
task->get_req()->set_cgroup(this->member->cgroup);
task->get_req()->set_meta_list(this->member->meta_list);
task->get_req()->set_cgroup(member->cgroup);
task->get_req()->set_meta_list(member->meta_list);
series_of(this)->push_front(this);
series_of(this)->push_front(task);
this->member->cgroup_status = KAFKA_CGROUP_DOING;
this->member->cgroup_wait_cnt = 0;
member->cgroup_status = KAFKA_CGROUP_DOING;
member->cgroup_wait_cnt = 0;
return false;
}
@@ -835,12 +838,13 @@ bool KafkaClientTask::check_cgroup()
bool KafkaClientTask::check_meta()
{
KafkaMember *member = this->member;
KafkaMetaList *uninit_meta_list;
if (this->get_meta_status(&uninit_meta_list))
return true;
if (this->member->meta_doing)
if (member->meta_doing)
{
WFConditional *cond;
char name[64];
@@ -848,13 +852,13 @@ bool KafkaClientTask::check_meta()
this->wait_cgroup = false;
cond = WFTaskFactory::create_conditional(name, this, &this->msg);
series_of(this)->push_front(cond);
this->member->meta_wait_cnt++;
member->meta_wait_cnt++;
}
else
{
__WFKafkaTask *task;
task = __WFKafkaTaskFactory::create_kafka_task(this->url,
task = __WFKafkaTaskFactory::create_kafka_task(this->url, member->ssl_ctx,
this->retry_max,
kafka_meta_callback);
task->user_data = this;
@@ -863,8 +867,8 @@ bool KafkaClientTask::check_meta()
task->get_req()->set_meta_list(*uninit_meta_list);
series_of(this)->push_front(this);
series_of(this)->push_front(task);
this->member->meta_wait_cnt = 0;
this->member->meta_doing = true;
member->meta_wait_cnt = 0;
member->meta_doing = true;
}
delete uninit_meta_list;
@@ -921,6 +925,7 @@ int KafkaClientTask::dispatch_locked()
task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type,
broker->get_host(),
broker->get_port(),
member->ssl_ctx,
this->get_userinfo(),
this->retry_max,
std::move(cb));
@@ -956,6 +961,7 @@ int KafkaClientTask::dispatch_locked()
task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type,
broker->get_host(),
broker->get_port(),
member->ssl_ctx,
this->get_userinfo(),
this->retry_max,
std::move(cb));
@@ -991,6 +997,7 @@ int KafkaClientTask::dispatch_locked()
task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type,
coordinator->get_host(),
coordinator->get_port(),
member->ssl_ctx,
this->get_userinfo(),
this->retry_max,
kafka_offsetcommit_callback);
@@ -1020,6 +1027,7 @@ int KafkaClientTask::dispatch_locked()
task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type,
coordinator->get_host(),
coordinator->get_port(),
member->ssl_ctx,
this->get_userinfo(), 0,
kafka_leavegroup_callback);
task->user_data = this;
@@ -1051,6 +1059,7 @@ int KafkaClientTask::dispatch_locked()
task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type,
broker->get_host(),
broker->get_port(),
member->ssl_ctx,
this->get_userinfo(),
this->retry_max,
std::move(cb));
@@ -1580,7 +1589,7 @@ SubTask *WFKafkaTask::done()
return series->pop();
}
int WFKafkaClient::init(const std::string& broker)
int WFKafkaClient::init(const std::string& broker, SSL_CTX *ssl_ctx)
{
std::vector<std::string> broker_hosts;
std::string::size_type ppos = 0;
@@ -1620,6 +1629,7 @@ int WFKafkaClient::init(const std::string& broker)
this->member = new KafkaMember;
this->member->broker_hosts = std::move(broker_hosts);
this->member->ssl_ctx = ssl_ctx;
if (use_ssl)
{
this->member->transport_type = TT_TCP_SSL;
@@ -1629,9 +1639,10 @@ int WFKafkaClient::init(const std::string& broker)
return 0;
}
int WFKafkaClient::init(const std::string& broker, const std::string& group)
int WFKafkaClient::init(const std::string& broker, const std::string& group,
SSL_CTX *ssl_ctx)
{
if (this->init(broker) < 0)
if (this->init(broker, ssl_ctx) < 0)
return -1;
this->member->cgroup.set_group(group);
@@ -1652,8 +1663,7 @@ WFKafkaTask *WFKafkaClient::create_kafka_task(const std::string& query,
int retry_max,
kafka_callback_t cb)
{
WFKafkaTask *task = new KafkaClientTask(query, retry_max, std::move(cb),
this);
WFKafkaTask *task = new KafkaClientTask(query, retry_max, std::move(cb), this);
return task;
}

View File

@@ -22,6 +22,7 @@
#include <string>
#include <vector>
#include <functional>
#include <openssl/ssl.h>
#include "WFTask.h"
#include "KafkaMessage.h"
#include "KafkaResult.h"
@@ -145,9 +146,22 @@ public:
// example: kafka://kafka.sogou
// example: kafka.sogou:9090
// example: kafka://10.160.23.23:9000,10.123.23.23,kafka://kafka.sogou
int init(const std::string& broker_url);
// example: kafkas://kafka.sogou -> kafka over TLS
int init(const std::string& broker_url)
{
return this->init(broker_url, NULL);
}
int init(const std::string& broker_url, const std::string& group);
int init(const std::string& broker_url, const std::string& group)
{
return this->init(broker_url, group, NULL);
}
// With a specific SSL_CTX. Effective only on brokers over TLS.
int init(const std::string& broker_url, SSL_CTX *ssl_ctx);
int init(const std::string& broker_url, const std::string& group,
SSL_CTX *ssl_ctx);
int deinit();

View File

@@ -16,6 +16,7 @@
Author: Xie Han (xiehan@sogou-inc.com)
*/
#include <errno.h>
#include <stdlib.h>
#include <string.h>
#include <string>
@@ -23,7 +24,7 @@
#include "URIParser.h"
#include "WFMySQLConnection.h"
int WFMySQLConnection::init(const std::string& url)
int WFMySQLConnection::init(const std::string& url, SSL_CTX *ssl_ctx)
{
std::string query;
ParsedURI uri;
@@ -42,9 +43,12 @@ int WFMySQLConnection::init(const std::string& url)
if (uri.query)
{
this->uri = std::move(uri);
this->ssl_ctx = ssl_ctx;
return 0;
}
}
else if (uri.state == URI_STATE_INVALID)
errno = EINVAL;
return -1;
}

View File

@@ -22,6 +22,7 @@
#include <string>
#include <utility>
#include <functional>
#include <openssl/ssl.h>
#include "URIParser.h"
#include "WFTaskFactory.h"
@@ -31,20 +32,50 @@ public:
/* example: mysql://username:passwd@127.0.0.1/dbname?character_set=utf8
* IP string is recommmended in url. When using a domain name, the first
* address resovled will be used. Don't use upstream name as a host. */
int init(const std::string& url);
int init(const std::string& url)
{
return init(url, NULL);
}
int init(const std::string& url, SSL_CTX *ssl_ctx);
void deinit() { }
public:
WFMySQLTask *create_query_task(const std::string& query,
mysql_callback_t callback);
mysql_callback_t callback)
{
WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0,
std::move(callback));
this->set_ssl_ctx(task);
task->get_req()->set_query(query);
return task;
}
public:
/* If you don't disconnect manually, the TCP connection will be
* kept alive after this object is deleted, and maybe reused by
* another WFMySQLConnection object with same id and url. */
WFMySQLTask *create_disconnect_task(mysql_callback_t callback);
WFMySQLTask *create_disconnect_task(mysql_callback_t callback)
{
WFMySQLTask *task = this->create_query_task("", std::move(callback));
this->set_ssl_ctx(task);
task->set_keep_alive(0);
return task;
}
protected:
void set_ssl_ctx(WFMySQLTask *task) const
{
using MySQLRequest = protocol::MySQLRequest;
using MySQLResponse = protocol::MySQLResponse;
auto *t = (WFComplexClientTask<MySQLRequest, MySQLResponse> *)task;
/* 'ssl_ctx' can be NULL and will use default. */
t->set_ssl_ctx(this->ssl_ctx);
}
protected:
ParsedURI uri;
SSL_CTX *ssl_ctx;
int id;
public:
@@ -54,25 +85,5 @@ public:
virtual ~WFMySQLConnection() { }
};
inline WFMySQLTask *
WFMySQLConnection::create_query_task(const std::string& query,
mysql_callback_t callback)
{
WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0,
std::move(callback));
task->get_req()->set_query(query);
return task;
}
inline WFMySQLTask *
WFMySQLConnection::create_disconnect_task(mysql_callback_t callback)
{
WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0,
std::move(callback));
task->get_req()->set_query("");
task->set_keep_alive(0);
return task;
}
#endif

View File

@@ -17,9 +17,11 @@
*/
#include <string>
#include <atomic>
#include "DnsMessage.h"
#include "WFTaskError.h"
#include "WFTaskFactory.h"
#include "DnsMessage.h"
#include "WFServer.h"
using namespace protocol;
@@ -31,6 +33,7 @@ class ComplexDnsTask : public WFComplexClientTask<DnsRequest, DnsResponse,
std::function<void (WFDnsTask *)>>
{
static struct addrinfo hints;
static std::atomic<size_t> seq;
public:
ComplexDnsTask(int retry_max, dns_callback_t&& cb):
@@ -54,24 +57,21 @@ private:
struct addrinfo ComplexDnsTask::hints =
{
/*.ai_flags =*/ AI_NUMERICSERV | AI_NUMERICHOST,
/*.ai_family =*/ AF_UNSPEC,
/*.ai_socktype =*/ SOCK_STREAM,
/*.ai_protocol =*/ 0,
/*.ai_addrlen =*/ 0,
/*.ai_addr =*/ NULL,
/*.ai_canonname =*/ NULL,
/*.ai_next =*/ NULL
/*.ai_flags =*/ AI_NUMERICSERV | AI_NUMERICHOST,
/*.ai_family =*/ AF_UNSPEC,
/*.ai_socktype =*/ SOCK_STREAM
};
std::atomic<size_t> ComplexDnsTask::seq(0);
CommMessageOut *ComplexDnsTask::message_out()
{
DnsRequest *req = this->get_req();
DnsResponse *resp = this->get_resp();
TransportType type = this->get_transport_type();
enum TransportType type = this->get_transport_type();
if (req->get_id() == 0)
req->set_id((this->get_seq() + 1) * 99991 % 65535 + 1);
req->set_id(++ComplexDnsTask::seq * 99991 % 65535 + 1);
resp->set_request_id(req->get_id());
resp->set_request_name(req->get_question_name());
req->set_single_packet(type == TT_UDP);
@@ -93,21 +93,22 @@ bool ComplexDnsTask::init_success()
if (!this->route_result_.request_object)
{
TransportType type = this->get_transport_type();
enum TransportType type = this->get_transport_type();
struct addrinfo *addr;
int ret;
ret = getaddrinfo(uri_.host, uri_.port, &hints, &addr);
if (ret != 0)
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_PARSE_FAILED;
this->state = WFT_STATE_DNS_ERROR;
this->error = ret;
return false;
}
auto *ep = &WFGlobal::get_global_settings()->dns_server_params;
ret = WFGlobal::get_route_manager()->get(type, addr, info_, ep,
uri_.host, route_result_);
uri_.host, ssl_ctx_,
route_result_);
freeaddrinfo(addr);
if (ret < 0)
{
@@ -146,7 +147,7 @@ bool ComplexDnsTask::finish_once()
bool ComplexDnsTask::need_redirect()
{
DnsResponse *client_resp = this->get_resp();
TransportType type = this->get_transport_type();
enum TransportType type = this->get_transport_type();
if (type == TT_UDP && client_resp->get_tc() == 1)
{
@@ -189,3 +190,42 @@ WFDnsTask *WFTaskFactory::create_dns_task(const ParsedURI& uri,
return task;
}
/**********Server**********/
class WFDnsServerTask : public WFServerTask<DnsRequest, DnsResponse>
{
public:
WFDnsServerTask(CommService *service,
std::function<void (WFDnsTask *)>& proc) :
WFServerTask(service, WFGlobal::get_scheduler(), proc)
{
// this->type = ((WFServerBase *)service)->get_params()->transport_type;
this->type = TT_TCP;
}
protected:
virtual CommMessageIn *message_in()
{
this->get_req()->set_single_packet(this->type == TT_UDP);
return this->WFServerTask::message_in();
}
virtual CommMessageOut *message_out()
{
this->get_resp()->set_single_packet(this->type == TT_UDP);
return this->WFServerTask::message_out();
}
protected:
enum TransportType type;
};
/**********Server Factory**********/
WFDnsTask *WFServerTaskFactory::create_dns_task(CommService *service,
std::function<void (WFDnsTask *)>& proc)
{
return new WFDnsServerTask(service, proc);
}

View File

@@ -484,7 +484,8 @@ private:
int ComplexHttpProxyTask::init_ssl_connection()
{
SSL *ssl = __create_ssl(WFGlobal::get_ssl_client_ctx());
static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx();
SSL *ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx);
WFConnection *conn;
if (!ssl)

View File

@@ -19,9 +19,9 @@
#include <assert.h>
#include <stdio.h>
#include <string.h>
#include <string>
#include <set>
#include <openssl/ssl.h>
#include <openssl/sha.h>
#include <openssl/evp.h>
#include "StringUtil.h"
@@ -715,12 +715,14 @@ bool __ComplexKafkaTask::finish_once()
/**********Factory**********/
// kafka://user:password:sasl@host:port/api=type&topic=name
__WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const std::string& url,
SSL_CTX *ssl_ctx,
int retry_max,
__kafka_callback_t callback)
{
auto *task = new __ComplexKafkaTask(retry_max, std::move(callback));
ParsedURI uri;
task->set_ssl_ctx(ssl_ctx);
ParsedURI uri;
URIParser::parse(url, uri);
task->init(std::move(uri));
task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT);
@@ -728,10 +730,12 @@ __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const std::string& url,
}
__WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const ParsedURI& uri,
SSL_CTX *ssl_ctx,
int retry_max,
__kafka_callback_t callback)
{
auto *task = new __ComplexKafkaTask(retry_max, std::move(callback));
task->set_ssl_ctx(ssl_ctx);
task->init(uri);
task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT);
@@ -741,22 +745,36 @@ __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const ParsedURI& uri,
__WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(enum TransportType type,
const char *host,
unsigned short port,
SSL_CTX *ssl_ctx,
const std::string& info,
int retry_max,
__kafka_callback_t callback)
{
auto *task = new __ComplexKafkaTask(retry_max, std::move(callback));
std::string url = (type == TT_TCP_SSL ? "kafkas://" : "kafka://");
if (!info.empty())
url += info + "@";
url += host;
url += ":" + std::to_string(port);
task->set_ssl_ctx(ssl_ctx);
ParsedURI uri;
URIParser::parse(url, uri);
char buf[32];
if (type == TT_TCP_SSL)
uri.scheme = strdup("kafkas");
else
uri.scheme = strdup("kafka");
if (!info.empty())
uri.userinfo = strdup(info.c_str());
uri.host = strdup(host);
sprintf(buf, "%u", port);
uri.port = strdup(buf);
if (!uri.scheme || !uri.host || !uri.port ||
(!info.empty() && !uri.userinfo))
{
uri.state = URI_STATE_ERROR;
uri.error = errno;
}
task->init(std::move(uri));
task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT);
return task;

View File

@@ -16,6 +16,7 @@
Authors: Wang Zhulei (wangzhulei@sogou-inc.com)
*/
#include <openssl/ssl.h>
#include "WFTaskFactory.h"
#include "KafkaMessage.h"
@@ -32,16 +33,19 @@ public:
* user task.
*/
static __WFKafkaTask *create_kafka_task(const ParsedURI& uri,
SSL_CTX *ssl_ctx,
int retry_max,
__kafka_callback_t callback);
static __WFKafkaTask *create_kafka_task(const std::string& url,
SSL_CTX *ssl_ctx,
int retry_max,
__kafka_callback_t callback);
static __WFKafkaTask *create_kafka_task(enum TransportType type,
const char *host,
unsigned short port,
SSL_CTX *ssl_ctx,
const std::string& info,
int retry_max,
__kafka_callback_t callback);

View File

@@ -241,7 +241,7 @@ CommMessageOut *ComplexMySQLTask::message_out()
break;
case ST_FIRST_USER_REQUEST:
if (this->is_fixed_addr())
if (this->is_fixed_conn())
{
auto *target = (RouteManager::RouteTarget *)this->target;
@@ -350,7 +350,21 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp)
if (is_ssl_)
{
if (!(resp->get_capability_flags() & 0x800))
if (resp->get_capability_flags() & 0x800)
{
static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx();
ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx);
if (!ssl)
{
state_ = WFT_STATE_SYS_ERROR;
error_ = errno;
return 0;
}
SSL_set_connect_state(ssl);
}
else
{
this->resp = std::move(*(MySQLResponse *)resp);
state_ = WFT_STATE_TASK_ERROR;
@@ -358,15 +372,6 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp)
return 0;
}
ssl = __create_ssl(WFGlobal::get_ssl_client_ctx());
if (!ssl)
{
state_ = WFT_STATE_SYS_ERROR;
error_ = errno;
return 0;
}
SSL_set_connect_state(ssl);
}
auto *conn = this->get_connection();
@@ -712,9 +717,9 @@ bool ComplexMySQLTask::init_success()
if (!transaction.empty())
{
this->WFComplexClientTask::set_info(std::string("?maxconn=1&") +
info + "|txn:" + transaction);
this->set_fixed_addr(true);
this->set_fixed_conn(true);
this->WFComplexClientTask::set_info(info + ("|txn:" + transaction));
}
else
this->WFComplexClientTask::set_info(info);
@@ -741,7 +746,7 @@ bool ComplexMySQLTask::finish_once()
return false;
}
if (this->is_fixed_addr())
if (this->is_fixed_conn())
{
if (this->state != WFT_STATE_SUCCESS || this->keep_alive_timeo == 0)
{
@@ -767,7 +772,7 @@ WFMySQLTask *WFTaskFactory::create_mysql_task(const std::string& url,
URIParser::parse(url, uri);
task->init(std::move(uri));
if (task->is_fixed_addr())
if (task->is_fixed_conn())
task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION);
else
task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT);
@@ -782,7 +787,7 @@ WFMySQLTask *WFTaskFactory::create_mysql_task(const ParsedURI& uri,
auto *task = new ComplexMySQLTask(retry_max, std::move(callback));
task->init(uri);
if (task->is_fixed_addr())
if (task->is_fixed_conn())
task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION);
else
task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT);

View File

@@ -407,6 +407,13 @@ public:
int retry_max,
std::function<void (T *)> callback);
static T *create_client_task(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
SSL_CTX *ssl_ctx,
int retry_max,
std::function<void (T *)> callback);
public:
static T *create_server_task(CommService *service,
std::function<void (T *)>& process);

View File

@@ -26,6 +26,7 @@
#include <functional>
#include <utility>
#include <atomic>
#include <openssl/ssl.h>
#include "PlatformSocket.h"
#include "WFGlobal.h"
#include "Workflow.h"
@@ -73,7 +74,9 @@ public:
WFClientTask<REQ, RESP>(NULL, WFGlobal::get_scheduler(), std::move(cb))
{
type_ = TT_TCP;
ssl_ctx_ = NULL;
fixed_addr_ = false;
fixed_conn_ = false;
retry_max_ = retry_max;
retry_times_ = 0;
redirect_ = false;
@@ -102,17 +105,19 @@ public:
init_with_uri();
}
void init(TransportType type,
void init(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
const std::string& info);
void set_transport_type(TransportType type)
void set_transport_type(enum TransportType type)
{
type_ = type;
}
TransportType get_transport_type() const { return type_; }
enum TransportType get_transport_type() const { return type_; }
void set_ssl_ctx(SSL_CTX *ssl_ctx) { ssl_ctx_ = ssl_ctx; }
virtual const ParsedURI *get_current_uri() const { return &uri_; }
@@ -122,7 +127,7 @@ public:
init(uri);
}
void set_redirect(TransportType type, const struct sockaddr *addr,
void set_redirect(enum TransportType type, const struct sockaddr *addr,
socklen_t addrlen, const std::string& info)
{
redirect_ = true;
@@ -131,9 +136,13 @@ public:
bool is_fixed_addr() const { return this->fixed_addr_; }
bool is_fixed_conn() const { return this->fixed_conn_; }
protected:
void set_fixed_addr(int fixed) { this->fixed_addr_ = fixed; }
void set_fixed_conn(int fixed) { this->fixed_conn_ = fixed; }
void set_info(const std::string& info)
{
info_.assign(info);
@@ -163,10 +172,12 @@ protected:
}
protected:
TransportType type_;
enum TransportType type_;
ParsedURI uri_;
std::string info_;
SSL_CTX *ssl_ctx_;
bool fixed_addr_;
bool fixed_conn_;
bool redirect_;
CTX ctx_;
int retry_max_;
@@ -205,7 +216,7 @@ void WFComplexClientTask<REQ, RESP, CTX>::clear_prev_state()
}
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::init(TransportType type,
void WFComplexClientTask<REQ, RESP, CTX>::init(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
const std::string& info)
@@ -216,7 +227,6 @@ void WFComplexClientTask<REQ, RESP, CTX>::init(TransportType type,
auto params = WFGlobal::get_global_settings()->endpoint_params;
struct addrinfo addrinfo = { };
addrinfo.ai_family = addr->sa_family;
addrinfo.ai_socktype = SOCK_STREAM;
addrinfo.ai_addr = (struct sockaddr *)addr;
addrinfo.ai_addrlen = addrlen;
@@ -224,7 +234,7 @@ void WFComplexClientTask<REQ, RESP, CTX>::init(TransportType type,
info_.assign(info);
params.use_tls_sni = false;
if (WFGlobal::get_route_manager()->get(type, &addrinfo, info_, &params,
"", route_result_) < 0)
"", ssl_ctx_, route_result_) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
@@ -277,10 +287,10 @@ template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::init_with_uri()
{
if (redirect_)
{
clear_prev_state();
ns_policy_ = WFGlobal::get_dns_resolver();
}
{
clear_prev_state();
ns_policy_ = WFGlobal::get_dns_resolver();
}
if (uri_.state == URI_STATE_SUCCESS)
{
@@ -311,12 +321,14 @@ WFRouterTask *WFComplexClientTask<REQ, RESP, CTX>::route()
this,
std::placeholders::_1);
struct WFNSParams params = {
/*.type =*/ type_,
/*.uri =*/ uri_,
/*.info =*/ info_.c_str(),
/*.fixed_addr =*/ fixed_addr_,
/*.retry_times =*/ retry_times_,
/*.tracing =*/ &tracing_,
/*.type =*/ type_,
/*.uri =*/ uri_,
/*.info =*/ info_.c_str(),
/*.ssl_ctx =*/ ssl_ctx_,
/*.fixed_addr =*/ fixed_addr_,
/*.fixed_conn =*/ fixed_conn_,
/*.retry_times =*/ retry_times_,
/*.tracing =*/ &tracing_,
};
if (!ns_policy_)
@@ -475,22 +487,26 @@ SubTask *WFComplexClientTask<REQ, RESP, CTX>::done()
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(TransportType type,
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const std::string& host,
unsigned short port,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
{
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
char buf[8];
std::string url = "scheme://";
ParsedURI uri;
char buf[32];
sprintf(buf, "%u", port);
url += host;
url += ":";
url += buf;
URIParser::parse(url, uri);
uri.scheme = strdup("scheme");
uri.host = strdup(host.c_str());
uri.port = strdup(buf);
if (!uri.scheme || !uri.host || !uri.port)
{
uri.state = URI_STATE_ERROR;
uri.error = errno;
}
task->init(std::move(uri));
task->set_transport_type(type);
return task;
@@ -498,7 +514,7 @@ WFNetworkTaskFactory<REQ, RESP>::create_client_task(TransportType type,
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(TransportType type,
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const std::string& url,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
@@ -514,7 +530,7 @@ WFNetworkTaskFactory<REQ, RESP>::create_client_task(TransportType type,
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(TransportType type,
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const ParsedURI& uri,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
@@ -528,7 +544,7 @@ WFNetworkTaskFactory<REQ, RESP>::create_client_task(TransportType type,
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(TransportType type,
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
int retry_max,
@@ -540,6 +556,22 @@ WFNetworkTaskFactory<REQ, RESP>::create_client_task(TransportType type,
return task;
}
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_client_task(enum TransportType type,
const struct sockaddr *addr,
socklen_t addrlen,
SSL_CTX *ssl_ctx,
int retry_max,
std::function<void (WFNetworkTask<REQ, RESP> *)> callback)
{
auto *task = new WFComplexClientTask<REQ, RESP>(retry_max, std::move(callback));
task->set_ssl_ctx(ssl_ctx);
task->init(type, addr, addrlen, "");
return task;
}
template<class REQ, class RESP>
WFNetworkTask<REQ, RESP> *
WFNetworkTaskFactory<REQ, RESP>::create_server_task(CommService *service,
@@ -553,6 +585,9 @@ WFNetworkTaskFactory<REQ, RESP>::create_server_task(CommService *service,
class WFServerTaskFactory
{
public:
static WFDnsTask *create_dns_task(CommService *service,
std::function<void (WFDnsTask *)>& proc);
static WFHttpTask *create_http_task(CommService *service,
std::function<void (WFHttpTask *)>& proc)
{
@@ -670,26 +705,24 @@ void WFTaskFactory::reset_go_task(WFGoTask *task, FUNC&& func, ARGS&&... args)
{
auto&& tmp = std::bind(std::forward<FUNC>(func),
std::forward<ARGS>(args)...);
static_cast<__WFGoTask *>(task)->set_go_func(std::move(tmp));
((__WFGoTask *)task)->set_go_func(std::move(tmp));
}
/**********Create go task with nullptr func**********/
template<>
inline WFGoTask *WFTaskFactory::create_go_task<std::nullptr_t>
(const std::string& queue_name,
std::nullptr_t&& func)
template<> inline
WFGoTask *WFTaskFactory::create_go_task(const std::string& queue_name,
std::nullptr_t&& func)
{
return new __WFGoTask(WFGlobal::get_exec_queue(queue_name),
WFGlobal::get_compute_executor(),
nullptr);
}
template<>
inline WFGoTask *WFTaskFactory::create_timedgo_task<std::nullptr_t>
(time_t seconds, long nanoseconds,
const std::string& queue_name,
std::nullptr_t&& func)
template<> inline
WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds,
const std::string& queue_name,
std::nullptr_t&& func)
{
return new __WFTimedGoTask(seconds, nanoseconds,
WFGlobal::get_exec_queue(queue_name),
@@ -697,28 +730,25 @@ inline WFGoTask *WFTaskFactory::create_timedgo_task<std::nullptr_t>
nullptr);
}
template<>
inline WFGoTask *WFTaskFactory::create_go_task<std::nullptr_t>
(ExecQueue *queue, Executor *executor,
std::nullptr_t&& func)
template<> inline
WFGoTask *WFTaskFactory::create_go_task(ExecQueue *queue, Executor *executor,
std::nullptr_t&& func)
{
return new __WFGoTask(queue, executor, nullptr);
}
template<>
inline WFGoTask *WFTaskFactory::create_timedgo_task<std::nullptr_t>
(time_t seconds, long nanoseconds,
ExecQueue *queue, Executor *executor,
std::nullptr_t&& func)
template<> inline
WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds,
ExecQueue *queue, Executor *executor,
std::nullptr_t&& func)
{
return new __WFTimedGoTask(seconds, nanoseconds, queue, executor, nullptr);
}
template<>
inline void WFTaskFactory::reset_go_task<std::nullptr_t>
(WFGoTask *task, std::nullptr_t&& func)
template<> inline
void WFTaskFactory::reset_go_task(WFGoTask *task, std::nullptr_t&& func)
{
static_cast<__WFGoTask *>(task)->set_go_func(nullptr);
((__WFGoTask *)task)->set_go_func(nullptr);
}
/**********Template Thread Task Factory**********/

View File

@@ -74,8 +74,8 @@ static inline int __set_fd_nonblock(int fd)
return flags;
}
static int __bind_and_listen(int sockfd, const struct sockaddr *addr,
socklen_t addrlen)
static int __bind_sockaddr(int sockfd, const struct sockaddr *addr,
socklen_t addrlen)
{
struct sockaddr_storage ss;
socklen_t len;
@@ -97,7 +97,7 @@ static int __bind_and_listen(int sockfd, const struct sockaddr *addr,
return -1;
}
return listen(sockfd, SOMAXCONN < 4096 ? 4096 : SOMAXCONN);
return 0;
}
static int __create_ssl(SSL_CTX *ssl_ctx, struct CommConnEntry *entry)
@@ -119,6 +119,48 @@ static int __create_ssl(SSL_CTX *ssl_ctx, struct CommConnEntry *entry)
return -1;
}
static int __send_to_conn(const void *buf, size_t size,
struct CommConnEntry *entry)
{
const struct sockaddr *addr;
socklen_t addrlen;
int ret;
if (!entry->ssl)
{
entry->target->get_addr(&addr, &addrlen);
return sendto(entry->sockfd, buf, size, 0, addr, addrlen);
}
if (size == 0)
return 0;
ret = SSL_write(entry->ssl, buf, size);
if (ret <= 0)
{
ret = SSL_get_error(entry->ssl, ret);
if (ret != SSL_ERROR_SYSCALL)
errno = -ret;
ret = -1;
}
return ret;
}
static void __release_conn(struct CommConnEntry *entry)
{
delete entry->conn;
if (!entry->service)
pthread_mutex_destroy(&entry->mutex);
if (entry->ssl)
SSL_free(entry->ssl);
close(entry->sockfd);
free(entry);
}
#define SSL_WRITE_BUFSIZE 8192
static int __ssl_writev(SSL *ssl, const struct iovec vectors[], int cnt)
@@ -186,26 +228,7 @@ void CommTarget::deinit()
int CommMessageIn::feedback(const void *buf, size_t size)
{
struct CommConnEntry *entry = this->entry;
int ret;
if (!entry->ssl)
return write(entry->sockfd, buf, size);
if (size == 0)
return 0;
ret = SSL_write(entry->ssl, buf, size);
if (ret <= 0)
{
ret = SSL_get_error(entry->ssl, ret);
if (ret != SSL_ERROR_SYSCALL)
errno = -ret;
ret = -1;
}
return ret;
return __send_to_conn(buf, size, this->entry);
}
void CommMessageIn::renew()
@@ -305,6 +328,9 @@ public:
}
}
public:
int shutdown();
private:
int sockfd;
int ref;
@@ -322,36 +348,50 @@ private:
friend class Communicator;
};
CommSession::~CommSession()
int CommServiceTarget::shutdown()
{
struct CommConnEntry *entry;
struct list_head *pos;
CommTarget *target;
int errno_bak;
int ret = 0;
if (!this->passive)
return;
target = this->target;
if (this->passive == 1)
pthread_mutex_lock(&this->mutex);
if (!list_empty(&this->idle_list))
{
pthread_mutex_lock(&target->mutex);
if (!list_empty(&target->idle_list))
{
pos = target->idle_list.next;
entry = list_entry(pos, struct CommConnEntry, list);
list_del(pos);
entry = list_entry(this->idle_list.next, struct CommConnEntry, list);
list_del(&entry->list);
if (this->service->reliable)
{
errno_bak = errno;
mpoller_del(entry->sockfd, entry->mpoller);
entry->state = CONN_STATE_CLOSING;
errno = errno_bak;
}
else
{
__release_conn(entry);
this->decref();
}
pthread_mutex_unlock(&target->mutex);
ret = 1;
}
((CommServiceTarget *)target)->decref();
pthread_mutex_unlock(&this->mutex);
return ret;
}
CommSession::~CommSession()
{
CommServiceTarget *target;
if (!this->passive)
return;
target = (CommServiceTarget *)this->target;
if (this->passive == 1)
target->shutdown();
target->decref();
}
inline int Communicator::first_timeout(CommSession *session)
@@ -404,19 +444,6 @@ int Communicator::first_timeout_recv(CommSession *session)
return Communicator::first_timeout(session);
}
void Communicator::release_conn(struct CommConnEntry *entry)
{
delete entry->conn;
if (!entry->service)
pthread_mutex_destroy(&entry->mutex);
if (entry->ssl)
SSL_free(entry->ssl);
close(entry->sockfd);
free(entry);
}
void Communicator::shutdown_service(CommService *service)
{
close(service->listen_fd);
@@ -670,7 +697,7 @@ void Communicator::handle_incoming_request(struct poller_result *res)
if (__sync_sub_and_fetch(&entry->ref, 1) == 0)
{
this->release_conn(entry);
__release_conn(entry);
((CommServiceTarget *)target)->decref();
}
}
@@ -752,7 +779,7 @@ void Communicator::handle_incoming_reply(struct poller_result *res)
}
if (__sync_sub_and_fetch(&entry->ref, 1) == 0)
this->release_conn(entry);
__release_conn(entry);
}
}
@@ -824,7 +851,7 @@ void Communicator::handle_reply_result(struct poller_result *res)
session->handle(state, res->error);
if (__sync_sub_and_fetch(&entry->ref, 1) == 0)
{
this->release_conn(entry);
__release_conn(entry);
((CommServiceTarget *)target)->decref();
}
@@ -877,7 +904,7 @@ void Communicator::handle_request_result(struct poller_result *res)
/* do nothing */
pthread_mutex_unlock(&entry->mutex);
if (__sync_sub_and_fetch(&entry->ref, 1) == 0)
this->release_conn(entry);
__release_conn(entry);
break;
}
@@ -910,7 +937,7 @@ struct CommConnEntry *Communicator::accept_conn(CommServiceTarget *target,
if (entry->conn)
{
entry->seq = 0;
entry->mpoller = this->mpoller;
entry->mpoller = NULL;
entry->service = service;
entry->target = target;
entry->ssl = NULL;
@@ -996,7 +1023,7 @@ void Communicator::handle_connect_result(struct poller_result *res)
target->release();
session->handle(state, res->error);
this->release_conn(entry);
__release_conn(entry);
break;
}
}
@@ -1012,9 +1039,10 @@ void Communicator::handle_listen_result(struct poller_result *res)
{
case PR_ST_SUCCESS:
target = (CommServiceTarget *)res->data.result;
entry = this->accept_conn(target, service);
entry = Communicator::accept_conn(target, service);
if (entry)
{
entry->mpoller = this->mpoller;
if (service->ssl_ctx)
{
if (__create_ssl(service->ssl_ctx, entry) >= 0 &&
@@ -1045,7 +1073,7 @@ void Communicator::handle_listen_result(struct poller_result *res)
}
}
this->release_conn(entry);
__release_conn(entry);
}
else
close(target->sockfd);
@@ -1064,6 +1092,54 @@ void Communicator::handle_listen_result(struct poller_result *res)
}
}
void Communicator::handle_recvfrom_result(struct poller_result *res)
{
CommService *service = (CommService *)res->data.context;
struct CommConnEntry *entry;
CommTarget *target;
int state, error;
switch (res->state)
{
case PR_ST_SUCCESS:
entry = (struct CommConnEntry *)res->data.result;
target = entry->target;
if (entry->state == CONN_STATE_SUCCESS)
{
state = CS_STATE_TOREPLY;
error = 0;
entry->state = CONN_STATE_IDLE;
list_add(&entry->list, &target->idle_list);
}
else
{
state = CS_STATE_ERROR;
if (entry->state == CONN_STATE_ERROR)
error = entry->error;
else
error = EBADMSG;
}
entry->session->handle(state, error);
if (state == CS_STATE_ERROR)
{
__release_conn(entry);
((CommServiceTarget *)target)->decref();
}
break;
case PR_ST_DELETED:
this->shutdown_service(service);
break;
case PR_ST_ERROR:
case PR_ST_STOPPED:
service->handle_stop(res->error);
break;
}
}
void Communicator::handle_ssl_accept_result(struct poller_result *res)
{
struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context;
@@ -1087,7 +1163,7 @@ void Communicator::handle_ssl_accept_result(struct poller_result *res)
case PR_ST_DELETED:
case PR_ST_ERROR:
case PR_ST_STOPPED:
this->release_conn(entry);
__release_conn(entry);
((CommServiceTarget *)target)->decref();
break;
}
@@ -1184,6 +1260,9 @@ void Communicator::handler_thread_routine(void *context)
case PD_OP_LISTEN:
comm->handle_listen_result(res);
break;
case PD_OP_RECVFROM:
comm->handle_recvfrom_result(res);
break;
case PD_OP_SSL_ACCEPT:
comm->handle_ssl_accept_result(res);
break;
@@ -1343,6 +1422,58 @@ poller_message_t *Communicator::create_reply(void *context)
return session->in;
}
int Communicator::recv_request(const void *buf, size_t size,
struct CommConnEntry *entry)
{
CommService *service = entry->service;
CommTarget *target = entry->target;
CommSession *session;
size_t n;
int ret;
session = service->new_session(entry->seq, entry->conn);
if (!session)
return -1;
session->passive = 1;
entry->session = session;
session->target = target;
session->conn = entry->conn;
session->seq = entry->seq++;
session->out = NULL;
session->in = NULL;
entry->state = CONN_STATE_RECEIVING;
((CommServiceTarget *)target)->incref();
session->in = session->message_in();
if (session->in)
{
session->in->entry = entry;
do
{
n = size;
ret = session->in->append(buf, &n);
if (ret == 0)
{
size -= n;
buf = (const char *)buf + n;
}
else if (ret < 0)
{
entry->error = errno;
entry->state = CONN_STATE_ERROR;
}
else
entry->state = CONN_STATE_SUCCESS;
} while (ret == 0 && size > 0);
}
return 0;
}
int Communicator::partial_written(size_t n, void *context)
{
struct CommConnEntry *entry = (struct CommConnEntry *)context;
@@ -1378,6 +1509,40 @@ void *Communicator::accept(const struct sockaddr *addr, socklen_t addrlen,
return NULL;
}
void *Communicator::recvfrom(const struct sockaddr *addr, socklen_t addrlen,
const void *buf, size_t size, void *context)
{
CommService *service = (CommService *)context;
struct CommConnEntry *entry;
CommServiceTarget *target;
void *result;
int sockfd;
sockfd = dup(service->listen_fd);
if (sockfd >= 0)
{
result = Communicator::accept(addr, addrlen, sockfd, context);
if (result)
{
target = (CommServiceTarget *)result;
entry = Communicator::accept_conn(target, service);
if (entry)
{
if (Communicator::recv_request(buf, size, entry) >= 0)
return entry;
__release_conn(entry);
}
else
close(sockfd);
target->decref();
}
}
return NULL;
}
void Communicator::callback(struct poller_result *res, void *context)
{
msgqueue_t *msgqueue = (msgqueue_t *)context;
@@ -1502,7 +1667,7 @@ struct CommConnEntry *Communicator::launch_conn(CommSession *session,
int sockfd;
int ret;
sockfd = this->nonblock_connect(target);
sockfd = Communicator::nonblock_connect(target);
if (sockfd >= 0)
{
entry = (struct CommConnEntry *)malloc(sizeof (struct CommConnEntry));
@@ -1515,7 +1680,7 @@ struct CommConnEntry *Communicator::launch_conn(CommSession *session,
if (entry->conn)
{
entry->seq = 0;
entry->mpoller = this->mpoller;
entry->mpoller = NULL;
entry->service = NULL;
entry->target = target;
entry->session = session;
@@ -1598,9 +1763,10 @@ int Communicator::request_new_conn(CommSession *session, CommTarget *target)
struct poller_data data;
int timeout;
entry = this->launch_conn(session, target);
entry = Communicator::launch_conn(session, target);
if (entry)
{
entry->mpoller = this->mpoller;
session->conn = entry->conn;
session->seq = entry->seq++;
data.operation = PD_OP_CONNECT;
@@ -1611,7 +1777,7 @@ int Communicator::request_new_conn(CommSession *session, CommTarget *target)
if (mpoller_add(&data, timeout, this->mpoller) >= 0)
return 0;
this->release_conn(entry);
__release_conn(entry);
}
return -1;
@@ -1648,15 +1814,21 @@ int Communicator::request(CommSession *session, CommTarget *target)
int Communicator::nonblock_listen(CommService *service)
{
int sockfd = service->create_listen_fd();
int ret;
if (sockfd >= 0)
{
if (__set_fd_nonblock(sockfd) >= 0)
{
if (__bind_and_listen(sockfd, service->bind_addr,
service->addrlen) >= 0)
if (__bind_sockaddr(sockfd, service->bind_addr,
service->addrlen) >= 0)
{
return sockfd;
ret = listen(sockfd, SOMAXCONN);
if (ret >= 0 || errno == EOPNOTSUPP)
{
service->reliable = (ret >= 0);
return sockfd;
}
}
}
@@ -1669,6 +1841,7 @@ int Communicator::nonblock_listen(CommService *service)
int Communicator::bind(CommService *service)
{
struct poller_data data;
int errno_bak = errno;
int sockfd;
sockfd = this->nonblock_listen(service);
@@ -1676,13 +1849,25 @@ int Communicator::bind(CommService *service)
{
service->listen_fd = sockfd;
service->ref = 1;
data.operation = PD_OP_LISTEN;
data.fd = sockfd;
data.accept = Communicator::accept;
data.context = service;
data.result = NULL;
if (service->reliable)
{
data.operation = PD_OP_LISTEN;
data.accept = Communicator::accept;
}
else
{
data.operation = PD_OP_RECVFROM;
data.recvfrom = Communicator::recvfrom;
}
if (mpoller_add(&data, service->listen_timeout, this->mpoller) >= 0)
{
errno = errno_bak;
return 0;
}
close(sockfd);
}
@@ -1702,7 +1887,7 @@ void Communicator::unbind(CommService *service)
}
}
int Communicator::reply_idle_conn(CommSession *session, CommTarget *target)
int Communicator::reply_reliable(CommSession *session, CommTarget *target)
{
struct CommConnEntry *entry;
struct list_head *pos;
@@ -1734,25 +1919,85 @@ int Communicator::reply_idle_conn(CommSession *session, CommTarget *target)
return ret;
}
int Communicator::reply_message_unreliable(struct CommConnEntry *entry)
{
struct iovec vectors[ENCODE_IOV_MAX];
int cnt;
cnt = entry->session->out->encode(vectors, ENCODE_IOV_MAX);
if ((unsigned int)cnt > ENCODE_IOV_MAX)
{
if (cnt > ENCODE_IOV_MAX)
errno = EOVERFLOW;
return -1;
}
if (cnt > 0)
{
struct msghdr message = {
.msg_name = entry->target->addr,
.msg_namelen = entry->target->addrlen,
.msg_iov = vectors,
#ifdef __linux__
.msg_iovlen = (size_t)cnt,
#else
.msg_iovlen = cnt,
#endif
};
if (sendmsg(entry->sockfd, &message, 0) < 0)
return -1;
}
return 0;
}
int Communicator::reply_unreliable(CommSession *session, CommTarget *target)
{
struct CommConnEntry *entry;
struct list_head *pos;
if (!list_empty(&target->idle_list))
{
pos = target->idle_list.next;
entry = list_entry(pos, struct CommConnEntry, list);
list_del(pos);
session->out = session->message_out();
if (session->out)
{
if (this->reply_message_unreliable(entry) >= 0)
return 0;
}
__release_conn(entry);
((CommServiceTarget *)target)->decref();
}
else
errno = ENOENT;
return -1;
}
int Communicator::reply(CommSession *session)
{
struct CommConnEntry *entry;
CommTarget *target;
CommServiceTarget *target;
int errno_bak;
int ret;
if (session->passive != 1)
{
errno = session->passive ? ENOENT : EPERM;
errno = session->passive ? ENOENT : EINVAL;
return -1;
}
errno_bak = errno;
session->passive = 2;
target = session->target;
ret = this->reply_idle_conn(session, target);
if (ret < 0)
return -1;
target = (CommServiceTarget *)session->target;
if (target->service->reliable)
ret = this->reply_reliable(session, target);
else
ret = this->reply_unreliable(session, target);
if (ret == 0)
{
@@ -1760,10 +2005,12 @@ int Communicator::reply(CommSession *session)
session->handle(CS_STATE_SUCCESS, 0);
if (__sync_sub_and_fetch(&entry->ref, 1) == 0)
{
this->release_conn(entry);
((CommServiceTarget *)target)->decref();
__release_conn(entry);
target->decref();
}
}
else if (ret < 0)
return -1;
errno = errno_bak;
return 0;
@@ -1777,7 +2024,7 @@ int Communicator::push(const void *buf, size_t size, CommSession *session)
if (session->passive != 1)
{
errno = session->passive ? ENOENT : EPERM;
errno = session->passive ? ENOENT : EINVAL;
return -1;
}
@@ -1785,22 +2032,7 @@ int Communicator::push(const void *buf, size_t size, CommSession *session)
if (!list_empty(&target->idle_list))
{
entry = list_entry(target->idle_list.next, struct CommConnEntry, list);
if (!entry->ssl)
ret = write(entry->sockfd, buf, size);
else if (size == 0)
ret = 0;
else
{
ret = SSL_write(entry->ssl, buf, size);
if (ret <= 0)
{
ret = SSL_get_error(entry->ssl, ret);
if (ret != SSL_ERROR_SYSCALL)
errno = -ret;
ret = -1;
}
}
ret = __send_to_conn(buf, size, entry);
}
else
{
@@ -1814,33 +2046,23 @@ int Communicator::push(const void *buf, size_t size, CommSession *session)
int Communicator::shutdown(CommSession *session)
{
CommTarget *target = session->target;
struct CommConnEntry *entry;
int ret;
CommServiceTarget *target;
if (session->passive != 1)
{
errno = session->passive ? ENOENT : EPERM;
errno = session->passive ? ENOENT : EINVAL;
return -1;
}
session->passive = 2;
pthread_mutex_lock(&target->mutex);
if (!list_empty(&target->idle_list))
{
entry = list_entry(target->idle_list.next, struct CommConnEntry, list);
list_del(&entry->list);
ret = mpoller_del(entry->sockfd, entry->mpoller);
entry->state = CONN_STATE_CLOSING;
}
else
target = (CommServiceTarget *)session->target;
if (!target->shutdown())
{
errno = ENOENT;
ret = -1;
return -1;
}
pthread_mutex_unlock(&target->mutex);
return ret;
return 0;
}
int Communicator::sleep(SleepSession *session)

View File

@@ -31,9 +31,8 @@
class CommConnection
{
protected:
public:
virtual ~CommConnection() { }
friend class Communicator;
};
class CommTarget
@@ -91,7 +90,7 @@ private:
public:
virtual ~CommTarget() { }
friend class CommSession;
friend class CommServiceTarget;
friend class Communicator;
};
@@ -223,6 +222,7 @@ private:
void decref();
private:
int reliable;
int listen_fd;
int ref;
@@ -298,16 +298,6 @@ private:
int create_handler_threads(size_t handler_threads);
int nonblock_connect(CommTarget *target);
int nonblock_listen(CommService *service);
struct CommConnEntry *launch_conn(CommSession *session,
CommTarget *target);
struct CommConnEntry *accept_conn(class CommServiceTarget *target,
CommService *service);
void release_conn(struct CommConnEntry *entry);
void shutdown_service(CommService *service);
void shutdown_io_service(IOService *service);
@@ -319,10 +309,13 @@ private:
int send_message(struct CommConnEntry *entry);
int request_idle_conn(CommSession *session, CommTarget *target);
int reply_idle_conn(CommSession *session, CommTarget *target);
int request_new_conn(CommSession *session, CommTarget *target);
int request_idle_conn(CommSession *session, CommTarget *target);
int reply_message_unreliable(struct CommConnEntry *entry);
int reply_reliable(CommSession *session, CommTarget *target);
int reply_unreliable(CommSession *session, CommTarget *target);
void handle_incoming_request(struct poller_result *res);
void handle_incoming_reply(struct poller_result *res);
@@ -336,6 +329,8 @@ private:
void handle_connect_result(struct poller_result *res);
void handle_listen_result(struct poller_result *res);
void handle_recvfrom_result(struct poller_result *res);
void handle_ssl_accept_result(struct poller_result *res);
void handle_sleep_result(struct poller_result *res);
@@ -344,6 +339,14 @@ private:
static void handler_thread_routine(void *context);
static int nonblock_connect(CommTarget *target);
static int nonblock_listen(CommService *service);
static struct CommConnEntry *launch_conn(CommSession *session,
CommTarget *target);
static struct CommConnEntry *accept_conn(class CommServiceTarget *target,
CommService *service);
static int first_timeout(CommSession *session);
static int next_timeout(CommSession *session);
@@ -358,11 +361,17 @@ private:
static poller_message_t *create_request(void *context);
static poller_message_t *create_reply(void *context);
static int recv_request(const void *buf, size_t size,
struct CommConnEntry *entry);
static int partial_written(size_t n, void *context);
static void *accept(const struct sockaddr *addr, socklen_t addrlen,
int sockfd, void *context);
static void *recvfrom(const struct sockaddr *addr, socklen_t addrlen,
const void *buf, size_t size, void *context);
static void callback(struct poller_result *res, void *context);
public:

View File

@@ -651,6 +651,55 @@ static void __poller_handle_connect(struct __poller_node *node,
poller->callback((struct poller_result *)node, poller->context);
}
static void __poller_handle_recvfrom(struct __poller_node *node,
poller_t *poller)
{
struct __poller_node *res = node->res;
struct sockaddr_storage ss;
struct sockaddr *addr = (struct sockaddr *)&ss;
socklen_t addrlen;
void *result;
ssize_t n;
while (1)
{
addrlen = sizeof (struct sockaddr_storage);
n = recvfrom(node->data.fd, poller->buf, POLLER_BUFSIZE, 0,
addr, &addrlen);
if (n < 0)
{
if (errno == EAGAIN)
return;
else
break;
}
result = node->data.recvfrom(addr, addrlen, poller->buf, n,
node->data.context);
if (!result)
break;
res->data = node->data;
res->data.result = result;
res->error = 0;
res->state = PR_ST_SUCCESS;
poller->callback((struct poller_result *)res, poller->context);
res = (struct __poller_node *)malloc(sizeof (struct __poller_node));
node->res = res;
if (!res)
break;
}
if (__poller_remove_node(node, poller))
return;
node->error = errno;
node->state = PR_ST_ERROR;
free(node->res);
poller->callback((struct poller_result *)node, poller->context);
}
static void __poller_handle_ssl_accept(struct __poller_node *node,
poller_t *poller)
{
@@ -849,55 +898,6 @@ static void __poller_handle_notify(struct __poller_node *node,
poller->callback((struct poller_result *)node, poller->context);
}
static void __poller_handle_recvfrom(struct __poller_node *node,
poller_t *poller)
{
struct __poller_node *res = node->res;
struct sockaddr_storage ss;
struct sockaddr *addr = (struct sockaddr *)&ss;
socklen_t addrlen;
void *result;
ssize_t n;
while (1)
{
addrlen = sizeof (struct sockaddr_storage);
n = recvfrom(node->data.fd, poller->buf, POLLER_BUFSIZE, 0,
addr, &addrlen);
if (n < 0)
{
if (errno == EAGAIN)
return;
else
break;
}
result = node->data.recvfrom(addr, addrlen, poller->buf, n,
node->data.context);
if (!result)
break;
res->data = node->data;
res->data.result = result;
res->error = 0;
res->state = PR_ST_SUCCESS;
poller->callback((struct poller_result *)res, poller->context);
res = (struct __poller_node *)malloc(sizeof (struct __poller_node));
node->res = res;
if (!res)
break;
}
if (__poller_remove_node(node, poller))
return;
node->error = errno;
node->state = PR_ST_ERROR;
free(node->res);
poller->callback((struct poller_result *)node, poller->context);
}
static int __poller_handle_pipe(poller_t *poller)
{
struct __poller_node **node = (struct __poller_node **)poller->buf;
@@ -1055,6 +1055,9 @@ static void *__poller_thread_routine(void *arg)
case PD_OP_CONNECT:
__poller_handle_connect(node, poller);
break;
case PD_OP_RECVFROM:
__poller_handle_recvfrom(node, poller);
break;
case PD_OP_SSL_ACCEPT:
__poller_handle_ssl_accept(node, poller);
break;
@@ -1070,9 +1073,6 @@ static void *__poller_thread_routine(void *arg)
case PD_OP_NOTIFY:
__poller_handle_notify(node, poller);
break;
case PD_OP_RECVFROM:
__poller_handle_recvfrom(node, poller);
break;
}
}
@@ -1282,6 +1282,9 @@ static int __poller_data_get_event(int *event, const struct poller_data *data)
case PD_OP_CONNECT:
*event = EPOLLOUT | EPOLLET;
return 0;
case PD_OP_RECVFROM:
*event = EPOLLIN | EPOLLET;
return 1;
case PD_OP_SSL_ACCEPT:
*event = EPOLLIN | EPOLLET;
return 0;
@@ -1297,9 +1300,6 @@ static int __poller_data_get_event(int *event, const struct poller_data *data)
case PD_OP_NOTIFY:
*event = EPOLLIN | EPOLLET;
return 1;
case PD_OP_RECVFROM:
*event = EPOLLIN | EPOLLET;
return 1;
default:
errno = EINVAL;
return -1;

View File

@@ -40,14 +40,14 @@ struct poller_data
#define PD_OP_WRITE 2
#define PD_OP_LISTEN 3
#define PD_OP_CONNECT 4
#define PD_OP_RECVFROM 5
#define PD_OP_SSL_READ PD_OP_READ
#define PD_OP_SSL_WRITE PD_OP_WRITE
#define PD_OP_SSL_ACCEPT 5
#define PD_OP_SSL_CONNECT 6
#define PD_OP_SSL_SHUTDOWN 7
#define PD_OP_EVENT 8
#define PD_OP_NOTIFY 9
#define PD_OP_RECVFROM 10
#define PD_OP_SSL_ACCEPT 6
#define PD_OP_SSL_CONNECT 7
#define PD_OP_SSL_SHUTDOWN 8
#define PD_OP_EVENT 9
#define PD_OP_NOTIFY 10
short operation;
unsigned short iovcnt;
int fd;
@@ -57,10 +57,10 @@ struct poller_data
poller_message_t *(*create_message)(void *);
int (*partial_written)(size_t, void *);
void *(*accept)(const struct sockaddr *, socklen_t, int, void *);
void *(*event)(void *);
void *(*notify)(void *, void *);
void *(*recvfrom)(const struct sockaddr *, socklen_t,
const void *, size_t, void *);
void *(*event)(void *);
void *(*notify)(void *, void *);
};
void *context;
union

View File

@@ -23,42 +23,31 @@
#define GET_CURRENT_SECOND std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now().time_since_epoch()).count()
#define CONFIDENT_INC 10
#define TTL_INC 10
#define TTL_INC 5
const DnsCache::DnsHandle *DnsCache::get_inner(const HostPort& host_port, int type)
const DnsCache::DnsHandle *DnsCache::get_inner(const HostPort& host_port,
int type)
{
int64_t cur_time = GET_CURRENT_SECOND;
int64_t cur = GET_CURRENT_SECOND;
std::lock_guard<std::mutex> lock(mutex_);
const DnsHandle *handle = cache_pool_.get(host_port);
if (handle)
if (handle && ((type == GET_TYPE_TTL && cur > handle->value.expire_time) ||
(type == GET_TYPE_CONFIDENT && cur > handle->value.confident_time)))
{
switch (type)
if (!handle->value.delayed())
{
case GET_TYPE_TTL:
if (cur_time > handle->value.expire_time)
{
const_cast<DnsHandle *>(handle)->value.expire_time += TTL_INC;
cache_pool_.release(handle);
return NULL;
}
DnsHandle *h = const_cast<DnsHandle *>(handle);
if (type == GET_TYPE_TTL)
h->value.expire_time += TTL_INC;
else
h->value.confident_time += TTL_INC;
break;
case GET_TYPE_CONFIDENT:
if (cur_time > handle->value.confident_time)
{
const_cast<DnsHandle *>(handle)->value.confident_time += CONFIDENT_INC;
cache_pool_.release(handle);
return NULL;
}
break;
default:
break;
h->value.addrinfo->ai_flags |= 2;
}
cache_pool_.release(handle);
return NULL;
}
return handle;
@@ -90,3 +79,29 @@ const DnsCache::DnsHandle *DnsCache::put(const HostPort& host_port,
return cache_pool_.put(host_port, {addrinfo, confident_time, expire_time});
}
const DnsCache::DnsHandle *DnsCache::get(const DnsCache::HostPort& host_port)
{
std::lock_guard<std::mutex> lock(mutex_);
return cache_pool_.get(host_port);
}
void DnsCache::release(const DnsCache::DnsHandle *handle)
{
std::lock_guard<std::mutex> lock(mutex_);
cache_pool_.release(handle);
}
void DnsCache::del(const DnsCache::HostPort& key)
{
std::lock_guard<std::mutex> lock(mutex_);
cache_pool_.del(key);
}
DnsCache::DnsCache()
{
}
DnsCache::~DnsCache()
{
}

View File

@@ -35,6 +35,11 @@ struct DnsCacheValue
struct addrinfo *addrinfo;
int64_t confident_time;
int64_t expire_time;
bool delayed() const
{
return addrinfo->ai_flags & 2;
}
};
// RAII: NO. Release handle by user
@@ -47,27 +52,10 @@ public:
using DnsHandle = LRUHandle<HostPort, DnsCacheValue>;
public:
// release handle by get/put
void release(DnsHandle *handle)
{
std::lock_guard<std::mutex> lock(mutex_);
cache_pool_.release(handle);
}
void release(const DnsHandle *handle)
{
std::lock_guard<std::mutex> lock(mutex_);
cache_pool_.release(handle);
}
// get handler
// Need call release when handle no longer needed
//Handle *get(const KEY &key);
const DnsHandle *get(const HostPort& host_port)
{
std::lock_guard<std::mutex> lock(mutex_);
return cache_pool_.get(host_port);
}
const DnsHandle *get(const HostPort& host_port);
const DnsHandle *get(const std::string& host, unsigned short port)
{
@@ -132,12 +120,11 @@ public:
return put(std::string(host), port, addrinfo, dns_ttl_default, dns_ttl_min);
}
// release handle by get/put
void release(const DnsHandle *handle);
// delete from cache, deleter delay called when all inuse-handle release.
void del(const HostPort& key)
{
std::lock_guard<std::mutex> lock(mutex_);
cache_pool_.del(key);
}
void del(const HostPort& key);
void del(const std::string& host, unsigned short port)
{
@@ -161,14 +148,22 @@ private:
{
struct addrinfo *ai = value.addrinfo;
if (ai && (ai->ai_flags & AI_PASSIVE))
freeaddrinfo(ai);
else
protocol::DnsUtil::freeaddrinfo(ai);
if (ai)
{
if (ai->ai_flags)
freeaddrinfo(ai);
else
protocol::DnsUtil::freeaddrinfo(ai);
}
}
};
LRUCache<HostPort, DnsCacheValue, ValueDeleter> cache_pool_;
public:
// To prevent inline calling LRUCache's constructor and deconstructor.
DnsCache();
~DnsCache();
};
#endif

View File

@@ -20,6 +20,7 @@
#define _ENDPOINTPARAMS_H_
#include <stddef.h>
#include "PlatformSocket.h"
/**
* @file EndpointParams.h
@@ -37,6 +38,7 @@ enum TransportType
struct EndpointParams
{
int address_family;
size_t max_connections;
int connect_timeout;
int response_timeout;
@@ -46,6 +48,7 @@ struct EndpointParams
static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT =
{
/* address_family = */ AF_INET,
/* .max_connections = */ 200,
/* .connect_timeout = */ 10 * 1000,
/* .response_timeout = */ 10 * 1000,

View File

@@ -76,7 +76,7 @@ private:
};
/* To support TLS SNI. */
class RouteTargetSNI : public RouteManager::RouteTarget
class RouteTargetTCPSNI : public RouteTargetTCP
{
private:
virtual int init_ssl(SSL *ssl)
@@ -91,7 +91,27 @@ private:
std::string hostname;
public:
RouteTargetSNI(const std::string& name) : hostname(name)
RouteTargetTCPSNI(const std::string& name) : hostname(name)
{
}
};
class RouteTargetSCTPSNI : public RouteTargetSCTP
{
private:
virtual int init_ssl(SSL *ssl)
{
if (SSL_set_tlsext_host_name(ssl, this->hostname.c_str()) > 0)
return 0;
else
return -1;
}
private:
std::string hostname;
public:
RouteTargetSCTPSNI(const std::string& name) : hostname(name)
{
}
};
@@ -101,11 +121,11 @@ public:
struct RouteParams
{
TransportType transport_type;
enum TransportType transport_type;
const struct addrinfo *addrinfo;
uint64_t key;
SSL_CTX *ssl_ctx;
unsigned int max_connections;
size_t max_connections;
int connect_timeout;
int response_timeout;
int ssl_connect_timeout;
@@ -120,7 +140,7 @@ public:
CommSchedObject *request_object;
CommSchedGroup *group;
std::mutex mutex;
std::vector<CommSchedTarget *> targets;
std::vector<RouteManager::RouteTarget *> targets;
struct list_head breaker_list;
uint64_t key;
int nleft;
@@ -139,34 +159,35 @@ public:
int init(const struct RouteParams *params);
void deinit();
void notify_unavailable(CommSchedTarget *target);
void notify_available(CommSchedTarget *target);
void notify_unavailable(RouteManager::RouteTarget *target);
void notify_available(RouteManager::RouteTarget *target);
void check_breaker();
private:
void free_list();
CommSchedTarget *create_target(const struct RouteParams *params,
const struct addrinfo *addrinfo);
RouteManager::RouteTarget *create_target(const struct RouteParams *params,
const struct addrinfo *addrinfo);
int add_group_targets(const struct RouteParams *params);
};
struct __breaker_node
{
CommSchedTarget *target;
RouteManager::RouteTarget *target;
int64_t timeout;
struct list_head breaker_list;
};
CommSchedTarget *RouteResultEntry::create_target(const struct RouteParams *params,
const struct addrinfo *addr)
RouteManager::RouteTarget *
RouteResultEntry::create_target(const struct RouteParams *params,
const struct addrinfo *addr)
{
CommSchedTarget *target;
RouteManager::RouteTarget *target;
switch (params->transport_type)
{
case TT_TCP_SSL:
if (params->use_tls_sni)
target = new RouteTargetSNI(params->hostname);
target = new RouteTargetTCPSNI(params->hostname);
else
case TT_TCP:
target = new RouteTargetTCP();
@@ -174,16 +195,19 @@ CommSchedTarget *RouteResultEntry::create_target(const struct RouteParams *param
case TT_UDP:
target = new RouteTargetUDP();
break;
case TT_SCTP:
case TT_SCTP_SSL:
target = new RouteTargetSCTP();
if (params->use_tls_sni)
target = new RouteTargetSCTPSNI(params->hostname);
else
case TT_SCTP:
target = new RouteTargetSCTP();
break;
default:
errno = EINVAL;
return NULL;
}
if (target->init(addr->ai_addr, (socklen_t)addr->ai_addrlen, params->ssl_ctx,
if (target->init(addr->ai_addr, addr->ai_addrlen, params->ssl_ctx,
params->connect_timeout, params->ssl_connect_timeout,
params->response_timeout, params->max_connections) < 0)
{
@@ -197,7 +221,7 @@ CommSchedTarget *RouteResultEntry::create_target(const struct RouteParams *param
int RouteResultEntry::init(const struct RouteParams *params)
{
const struct addrinfo *addr = params->addrinfo;
CommSchedTarget *target;
RouteManager::RouteTarget *target;
if (addr == NULL)//0
{
@@ -238,8 +262,8 @@ int RouteResultEntry::init(const struct RouteParams *params)
int RouteResultEntry::add_group_targets(const struct RouteParams *params)
{
RouteManager::RouteTarget *target;
const struct addrinfo *addr;
CommSchedTarget *target;
for (addr = params->addrinfo; addr; addr = addr->ai_next)
{
@@ -298,7 +322,7 @@ void RouteResultEntry::deinit()
}
}
void RouteResultEntry::notify_unavailable(CommSchedTarget *target)
void RouteResultEntry::notify_unavailable(RouteManager::RouteTarget *target)
{
if (this->targets.size() <= 1)
return;
@@ -324,7 +348,7 @@ void RouteResultEntry::notify_unavailable(CommSchedTarget *target)
this->nleft--;
}
void RouteResultEntry::notify_available(CommSchedTarget *target)
void RouteResultEntry::notify_available(RouteManager::RouteTarget *target)
{
if (this->targets.size() <= 1 || this->nbreak == 0)
return;
@@ -396,23 +420,26 @@ static uint64_t __fnv_hash(const unsigned char *data, size_t size)
return hash;
}
static uint64_t __generate_key(TransportType type,
static uint64_t __generate_key(enum TransportType type,
const struct addrinfo *addrinfo,
const std::string& other_info,
const struct EndpointParams *ep_params,
const std::string& hostname)
const std::string& hostname,
SSL_CTX *ssl_ctx)
{
std::string buf((const char *)&type, sizeof (TransportType));
unsigned int max_conn = ep_params->max_connections;
const int params[] = {
ep_params->address_family, (int)ep_params->max_connections,
ep_params->connect_timeout, ep_params->response_timeout
};
std::string buf((const char *)&type, sizeof (enum TransportType));
if (!other_info.empty())
buf += other_info;
buf.append((const char *)&max_conn, sizeof (unsigned int));
buf.append((const char *)&ep_params->connect_timeout, sizeof (int));
buf.append((const char *)&ep_params->response_timeout, sizeof (int));
if (type == TT_TCP_SSL)
buf.append((const char *)params, sizeof params);
if (type == TT_TCP_SSL || type == TT_SCTP_SSL)
{
buf.append((const char *)&ssl_ctx, sizeof (void *));
buf.append((const char *)&ep_params->ssl_connect_timeout, sizeof (int));
if (ep_params->use_tls_sni)
{
@@ -462,15 +489,24 @@ RouteManager::~RouteManager()
}
}
int RouteManager::get(TransportType type,
int RouteManager::get(enum TransportType type,
const struct addrinfo *addrinfo,
const std::string& other_info,
const struct EndpointParams *ep_params,
const std::string& hostname,
const std::string& hostname, SSL_CTX *ssl_ctx,
RouteResult& result)
{
uint64_t key = __generate_key(type, addrinfo, other_info,
ep_params, hostname);
if (type == TT_TCP_SSL || type == TT_SCTP_SSL)
{
static SSL_CTX *global_client_ctx = WFGlobal::get_ssl_client_ctx();
if (ssl_ctx == NULL)
ssl_ctx = global_client_ctx;
}
else
ssl_ctx = NULL;
uint64_t key = __generate_key(type, addrinfo, other_info, ep_params,
hostname, ssl_ctx);
struct rb_node **p = &cache_.rb_node;
struct rb_node *parent = NULL;
RouteResultEntry *bound = NULL;
@@ -497,37 +533,19 @@ int RouteManager::get(TransportType type,
}
else
{
int ssl_connect_timeout = 0;
SSL_CTX *ssl_ctx = NULL;
if (type == TT_TCP_SSL || type == TT_SCTP_SSL)
{
static SSL_CTX *client_ssl_ctx = WFGlobal::get_ssl_client_ctx();
ssl_ctx = client_ssl_ctx;
ssl_connect_timeout = ep_params->ssl_connect_timeout;
}
struct RouteParams params = {
/* .transport_type = */ type,
/* .addrinfo = */ addrinfo,
/* .key = */ key,
/* .ssl_ctx = */ ssl_ctx,
/* .max_connections = */ (unsigned int)ep_params->max_connections,
/* .connect_timeout = */ ep_params->connect_timeout,
/* .response_timeout = */ ep_params->response_timeout,
/* .ssl_connect_timeout = */ ssl_connect_timeout,
/* .use_tls_sni = */ ep_params->use_tls_sni,
/* .hostname = */ hostname,
/*.transport_type =*/ type,
/*.addrinfo =*/ addrinfo,
/*.key =*/ key,
/*.ssl_ctx =*/ ssl_ctx,
/*.max_connections =*/ ep_params->max_connections,
/*.connect_timeout =*/ ep_params->connect_timeout,
/*.response_timeout =*/ ep_params->response_timeout,
/*.ssl_connect_timeout =*/ ep_params->ssl_connect_timeout,
/*.use_tls_sni =*/ ep_params->use_tls_sni,
/*.hostname =*/ hostname,
};
if (StringUtil::start_with(other_info, "?maxconn="))
{
int maxconn = atoi(other_info.c_str() + 9);
if (maxconn > 0)
params.max_connections = maxconn;
}
entry = new RouteResultEntry;
if (entry->init(&params) >= 0)
{
@@ -549,12 +567,12 @@ int RouteManager::get(TransportType type,
void RouteManager::notify_unavailable(void *cookie, CommTarget *target)
{
if (cookie && target)
((RouteResultEntry *)cookie)->notify_unavailable((CommSchedTarget *)target);
((RouteResultEntry *)cookie)->notify_unavailable((RouteTarget *)target);
}
void RouteManager::notify_available(void *cookie, CommTarget *target)
{
if (cookie && target)
((RouteResultEntry *)cookie)->notify_available((CommSchedTarget *)target);
((RouteResultEntry *)cookie)->notify_available((RouteTarget *)target);
}

View File

@@ -43,6 +43,32 @@ public:
class RouteTarget : public CommSchedTarget
{
#if OPENSSL_VERSION_NUMBER >= 0x10100000L
public:
int init(const struct sockaddr *addr, socklen_t addrlen, SSL_CTX *ssl_ctx,
int connect_timeout, int ssl_connect_timeout, int response_timeout,
size_t max_connections)
{
int ret = this->CommSchedTarget::init(addr, addrlen, ssl_ctx,
connect_timeout, ssl_connect_timeout,
response_timeout, max_connections);
if (ret >= 0 && ssl_ctx)
SSL_CTX_up_ref(ssl_ctx);
return ret;
}
void deinit()
{
SSL_CTX *ssl_ctx = this->get_ssl_ctx();
this->CommSchedTarget::deinit();
if (ssl_ctx)
SSL_CTX_free(ssl_ctx);
}
#endif
public:
int state;
@@ -57,11 +83,11 @@ public:
};
public:
int get(TransportType type,
int get(enum TransportType type,
const struct addrinfo *addrinfo,
const std::string& other_info,
const struct EndpointParams *ep_params,
const std::string& hostname,
const std::string& hostname, SSL_CTX *ssl_ctx,
RouteResult& result);
RouteManager()

View File

@@ -40,35 +40,21 @@
#define HOSTS_LINEBUF_INIT_SIZE 128
#define PORT_STR_MAX 5
static constexpr struct addrinfo __ai_hints =
{
#ifdef AI_ADDRCONFIG
/*.ai_flags = */ AI_ADDRCONFIG,
#else
/*.ai_flags = */ 0,
#endif
/*.ai_family = */ AF_UNSPEC,
/*.ai_socktype = */ SOCK_STREAM,
/*.ai_protocol = */ 0,
/*.ai_addrlen = */ 0,
/*.ai_addr = */ NULL,
/*.ai_canonname = */ NULL,
/*.ai_next = */ NULL
};
class DnsInput
{
public:
DnsInput() :
port_(0),
numeric_host_(false)
numeric_host_(false),
family_(AF_UNSPEC)
{}
DnsInput(const std::string& host, unsigned short port,
bool numeric_host) :
bool numeric_host, int family) :
host_(host),
port_(port),
numeric_host_(numeric_host)
numeric_host_(numeric_host),
family_(family)
{}
void reset(const std::string& host, unsigned short port)
@@ -76,14 +62,16 @@ public:
host_.assign(host);
port_ = port;
numeric_host_ = false;
family_ = AF_UNSPEC;
}
void reset(const std::string& host, unsigned short port,
bool numeric_host)
bool numeric_host, int family)
{
host_.assign(host);
port_ = port;
numeric_host_ = numeric_host;
family_ = family;
}
const std::string& get_host() const { return host_; }
@@ -94,6 +82,7 @@ protected:
std::string host_;
unsigned short port_;
bool numeric_host_;
int family_;
friend class DnsRoutine;
};
@@ -109,7 +98,12 @@ public:
~DnsOutput()
{
if (addrinfo_)
freeaddrinfo(addrinfo_);
{
if (addrinfo_->ai_flags)
freeaddrinfo(addrinfo_);
else
free(addrinfo_);
}
}
int get_error() const { return error_; }
@@ -137,7 +131,12 @@ public:
static void create(DnsOutput *out, int error, struct addrinfo *ai)
{
if (out->addrinfo_)
freeaddrinfo(out->addrinfo_);
{
if (out->addrinfo_->ai_flags)
freeaddrinfo(out->addrinfo_);
else
free(out->addrinfo_);
}
out->error_ = error;
out->addrinfo_ = ai;
@@ -146,21 +145,21 @@ public:
void DnsRoutine::run(const DnsInput *in, DnsOutput *out)
{
if (!in->host_.empty() && in->host_[0] == '/')
return;
struct addrinfo hints = __ai_hints;
struct addrinfo hints = {
/*.ai_flags =*/ AI_ADDRCONFIG | AI_NUMERICSERV,
/*.ai_family =*/ in->family_,
/*.ai_socktype =*/ SOCK_STREAM,
};
char port_str[PORT_STR_MAX + 1];
hints.ai_flags |= AI_NUMERICSERV;
if (in->is_numeric_host())
hints.ai_flags |= AI_NUMERICHOST;
snprintf(port_str, PORT_STR_MAX + 1, "%u", in->port_);
out->error_ = getaddrinfo(in->host_.c_str(),
port_str,
&hints,
&out->addrinfo_);
out->error_ = getaddrinfo(in->host_.c_str(), port_str,
&hints, &out->addrinfo_);
if (out->error_ == 0)
out->addrinfo_->ai_flags = 1;
}
// Dns Thread task. For internal usage only.
@@ -178,13 +177,18 @@ struct DnsContext
static int __default_family()
{
struct addrinfo hints = {
/*.ai_flags =*/ AI_ADDRCONFIG,
/*.ai_family =*/ AF_UNSPEC,
/*.ai_socktype =*/ SOCK_STREAM,
};
struct addrinfo *res;
struct addrinfo *cur;
int family = AF_UNSPEC;
bool v4 = false;
bool v6 = false;
if (getaddrinfo(NULL, "1", &__ai_hints, &res) == 0)
if (getaddrinfo(NULL, "1", &hints, &res) == 0)
{
for (cur = res; cur; cur = cur->ai_next)
{
@@ -294,18 +298,9 @@ static int __readaddrinfo(const char *path,
return ret;
}
// Add AI_PASSIVE to point that this addrinfo is alloced by getaddrinfo
static void __add_passive_flags(struct addrinfo *ai)
{
while (ai)
{
ai->ai_flags |= AI_PASSIVE;
ai = ai->ai_next;
}
}
static ThreadDnsTask *__create_thread_dns_task(const std::string& host,
unsigned short port,
int family,
thread_dns_callback_t callback)
{
auto *task = WFThreadTaskFactory<DnsInput, DnsOutput>::
@@ -314,12 +309,46 @@ static ThreadDnsTask *__create_thread_dns_task(const std::string& host,
DnsRoutine::run,
std::move(callback));
task->get_input()->reset(host, port);
task->get_input()->reset(host, port, false, family);
return task;
}
static std::string __get_cache_host(const std::string& hostname,
int family)
{
char c;
if (family == AF_UNSPEC)
c = '*';
else if (family == AF_INET)
c = '4';
else if (family == AF_INET6)
c = '6';
else
c = '?';
return hostname + c;
}
static std::string __get_guard_name(const std::string& cache_host,
unsigned short port)
{
std::string guard_name("INTERNAL-dns:");
guard_name.append(cache_host).append(":");
guard_name.append(std::to_string(port));
return guard_name;
}
void WFResolverTask::dispatch()
{
if (this->msg_)
{
this->state = WFT_STATE_DNS_ERROR;
this->error = (intptr_t)msg_;
this->subtask_done();
return;
}
const ParsedURI& uri = ns_params_.uri;
host_ = uri.host ? uri.host : "";
port_ = uri.port ? atoi(uri.port) : 0;
@@ -327,11 +356,22 @@ void WFResolverTask::dispatch()
DnsCache *dns_cache = WFGlobal::get_dns_cache();
const DnsCache::DnsHandle *addr_handle;
std::string hostname = host_;
int family = ep_params_.address_family;
std::string cache_host = __get_cache_host(hostname, family);
if (ns_params_.retry_times == 0)
addr_handle = dns_cache->get_ttl(hostname, port_);
addr_handle = dns_cache->get_ttl(cache_host, port_);
else
addr_handle = dns_cache->get_confident(hostname, port_);
addr_handle = dns_cache->get_confident(cache_host, port_);
if (in_guard_ && (addr_handle == NULL || addr_handle->value.delayed()))
{
if (addr_handle)
dns_cache->release(addr_handle);
this->request_dns();
return;
}
if (addr_handle)
{
@@ -347,7 +387,8 @@ void WFResolverTask::dispatch()
}
if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info,
&ep_params_, hostname, this->result) < 0)
&ep_params_, hostname, ns_params_.ssl_ctx,
this->result) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
@@ -378,11 +419,11 @@ void WFResolverTask::dispatch()
if (ret == 1)
{
DnsInput dns_in(hostname, port_, true); // 'true' means numeric host
// 'true' means numeric host
DnsInput dns_in(hostname, port_, true, AF_UNSPEC);
DnsOutput dns_out;
DnsRoutine::run(&dns_in, &dns_out);
__add_passive_flags((struct addrinfo *)dns_out.get_addrinfo());
dns_callback_internal(&dns_out, (unsigned int)-1, (unsigned int)-1);
this->subtask_done();
return;
@@ -392,32 +433,53 @@ void WFResolverTask::dispatch()
const char *hosts = WFGlobal::get_global_settings()->hosts_path;
if (hosts)
{
struct addrinfo hints = {
/*.ai_flags =*/ AI_ADDRCONFIG | AI_NUMERICSERV | AI_NUMERICHOST,
/*.ai_family =*/ ep_params_.address_family,
/*.ai_socktype =*/ SOCK_STREAM,
};
struct addrinfo *ai;
int ret = __readaddrinfo(hosts, host_, port_, &__ai_hints, &ai);
int ret;
ret = __readaddrinfo(hosts, host_, port_, &hints, &ai);
if (ret == 0)
{
DnsOutput out;
DnsRoutine::create(&out, ret, ai);
__add_passive_flags((struct addrinfo *)out.get_addrinfo());
dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_);
this->subtask_done();
return;
}
}
std::string guard_name = __get_guard_name(cache_host, port_);
WFConditional *guard = WFTaskFactory::create_guard(guard_name, this, &msg_);
in_guard_ = true;
has_next_ = true;
series_of(this)->push_front(guard);
this->subtask_done();
}
void WFResolverTask::request_dns()
{
WFDnsClient *client = WFGlobal::get_dns_client();
if (client)
{
static int family = __default_family();
static int default_family = __default_family();
WFResourcePool *respool = WFGlobal::get_dns_respool();
int family = ep_params_.address_family;
if (family == AF_UNSPEC)
family = default_family;
if (family == AF_INET || family == AF_INET6)
{
auto&& cb = std::bind(&WFResolverTask::dns_single_callback,
this,
std::placeholders::_1);
WFDnsTask *dns_task = client->create_dns_task(hostname, std::move(cb));
WFDnsTask *dns_task = client->create_dns_task(host_, std::move(cb));
if (family == AF_INET6)
dns_task->get_req()->set_question_type(DNS_TYPE_AAAA);
@@ -437,10 +499,10 @@ void WFResolverTask::dispatch()
dctx[0].port = port_;
dctx[1].port = port_;
task_v4 = client->create_dns_task(hostname, dns_partial_callback);
task_v4 = client->create_dns_task(host_, dns_partial_callback);
task_v4->user_data = dctx;
task_v6 = client->create_dns_task(hostname, dns_partial_callback);
task_v6 = client->create_dns_task(host_, dns_partial_callback);
task_v6->get_req()->set_question_type(DNS_TYPE_AAAA);
task_v6->user_data = dctx + 1;
@@ -461,11 +523,13 @@ void WFResolverTask::dispatch()
}
else
{
ThreadDnsTask *dns_task;
auto&& cb = std::bind(&WFResolverTask::thread_dns_callback,
this,
std::placeholders::_1);
ThreadDnsTask *dns_task = __create_thread_dns_task(hostname, port_,
std::move(cb));
dns_task = __create_thread_dns_task(host_, port_,
ep_params_.address_family,
std::move(cb));
series_of(this)->push_front(dns_task);
}
@@ -478,12 +542,7 @@ SubTask *WFResolverTask::done()
SeriesWork *series = series_of(this);
if (!has_next_)
{
if (this->callback)
this->callback(this);
delete this;
}
task_callback();
else
has_next_ = false;
@@ -499,7 +558,7 @@ void WFResolverTask::dns_callback_internal(void *thrd_dns_output,
if (dns_error)
{
if (dns_error == /*EAI_SYSTEM*/ EAI_FAIL)
if (dns_error == EAI_FAIL)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
@@ -517,12 +576,15 @@ void WFResolverTask::dns_callback_internal(void *thrd_dns_output,
struct addrinfo *addrinfo = dns_out->move_addrinfo();
const DnsCache::DnsHandle *addr_handle;
std::string hostname = host_;
int family = ep_params_.address_family;
std::string cache_host = __get_cache_host(hostname, family);
addr_handle = dns_cache->put(hostname, port_, addrinfo,
addr_handle = dns_cache->put(cache_host, port_, addrinfo,
(unsigned int)ttl_default,
(unsigned int)ttl_min);
if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info,
&ep_params_, hostname, this->result) < 0)
&ep_params_, hostname, ns_params_.ssl_ctx,
this->result) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
@@ -555,10 +617,7 @@ void WFResolverTask::dns_single_callback(void *net_dns_task)
this->error = dns_task->get_error();
}
if (this->callback)
this->callback(this);
delete this;
task_callback();
}
void WFResolverTask::dns_partial_callback(void *net_dns_task)
@@ -618,10 +677,7 @@ void WFResolverTask::dns_parallel_callback(const void *parallel)
delete[] c4;
if (this->callback)
this->callback(this);
delete this;
task_callback();
}
void WFResolverTask::thread_dns_callback(void *thrd_dns_task)
@@ -631,7 +687,6 @@ void WFResolverTask::thread_dns_callback(void *thrd_dns_task)
if (dns_task->get_state() == WFT_STATE_SUCCESS)
{
DnsOutput *out = dns_task->get_output();
__add_passive_flags((struct addrinfo *)out->get_addrinfo());
dns_callback_internal(out, dns_ttl_default_, dns_ttl_min_);
}
else
@@ -640,6 +695,23 @@ void WFResolverTask::thread_dns_callback(void *thrd_dns_task)
this->error = dns_task->get_error();
}
task_callback();
}
void WFResolverTask::task_callback()
{
if (in_guard_)
{
int family = ep_params_.address_family;
std::string cache_host = __get_cache_host(host_, family);
std::string guard_name = __get_guard_name(cache_host, port_);
if (this->state == WFT_STATE_DNS_ERROR)
msg_ = (void *)(intptr_t)this->error;
WFTaskFactory::release_guard_safe(guard_name, msg_);
}
if (this->callback)
this->callback(this);

View File

@@ -35,9 +35,14 @@ public:
ns_params_(*ns_params),
ep_params_(*ep_params)
{
if (ns_params_.fixed_conn)
ep_params_.max_connections = 1;
dns_ttl_default_ = dns_ttl_default;
dns_ttl_min_ = dns_ttl_min;
has_next_ = false;
in_guard_ = false;
msg_ = NULL;
}
WFResolverTask(const struct WFNSParams *ns_params,
@@ -45,7 +50,12 @@ public:
WFRouterTask(std::move(cb)),
ns_params_(*ns_params)
{
if (ns_params_.fixed_conn)
ep_params_.max_connections = 1;
has_next_ = false;
in_guard_ = false;
msg_ = NULL;
}
protected:
@@ -62,6 +72,9 @@ private:
unsigned int ttl_default,
unsigned int ttl_min);
void request_dns();
void task_callback();
protected:
struct WFNSParams ns_params_;
unsigned int dns_ttl_default_;
@@ -72,6 +85,8 @@ private:
const char *host_;
unsigned short port_;
bool has_next_;
bool in_guard_;
void *msg_;
};
class WFDnsResolver : public WFNSPolicy

View File

@@ -74,10 +74,12 @@ public:
struct WFNSParams
{
TransportType type;
enum TransportType type;
ParsedURI& uri;
const char *info;
SSL_CTX *ssl_ctx;
bool fixed_addr;
bool fixed_conn;
int retry_times;
WFNSTracing *tracing;
};

View File

@@ -25,7 +25,6 @@
#define DNS_LABELS_MAX 63
#define DNS_NAMES_MAX 256
#define DNS_MSGBASE_INIT_SIZE 514 // 512 + 2(leading length)
#define DNS_HEADER_SIZE sizeof (struct dns_header)
#define MAX(x, y) ((x) <= (y) ? (y) : (x))
struct __dns_record_entry
@@ -102,11 +101,16 @@ static int __dns_parser_parse_host(char *phost, dns_parser_t *parser)
else if ((len & 0xC0) == 0xC0)
{
pointer = __dns_parser_uint16(*cur) & 0x3FFF;
*cur += 2;
if (pointer >= parser->msgsize)
return -2;
// pointer must point to a prior position
if ((const char *)parser->msgbase + pointer >= *cur)
return -2;
*cur += 2;
// backup cur only when the first pointer occurs
if (curbackup == NULL)
curbackup = *cur;
@@ -707,7 +711,7 @@ void dns_parser_init(dns_parser_t *parser)
parser->bufsize = 0;
parser->complete = 0;
parser->single_packet = 0;
memset(&parser->header, 0, DNS_HEADER_SIZE);
memset(&parser->header, 0, sizeof (struct dns_header));
memset(&parser->question, 0, sizeof (struct dns_question));
INIT_LIST_HEAD(&parser->answer_list);
INIT_LIST_HEAD(&parser->authority_list);
@@ -770,16 +774,16 @@ int dns_parser_parse_all(dns_parser_t *parser)
parser->cur = (const char *)parser->msgbase;
h = &parser->header;
if (parser->msgsize < DNS_HEADER_SIZE)
if (parser->msgsize < sizeof (struct dns_header))
return -2;
memcpy(h, parser->msgbase, DNS_HEADER_SIZE);
memcpy(h, parser->msgbase, sizeof (struct dns_header));
h->id = ntohs(h->id);
h->qdcount = ntohs(h->qdcount);
h->ancount = ntohs(h->ancount);
h->nscount = ntohs(h->nscount);
h->arcount = ntohs(h->arcount);
parser->cur += DNS_HEADER_SIZE;
parser->cur += sizeof (struct dns_header);
ret = __dns_parser_parse_question(parser);
if (ret < 0)
@@ -911,6 +915,186 @@ int dns_record_cursor_find_cname(const char *name,
return 1;
}
int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass,
uint32_t ttl, uint16_t rlen, const void *rdata,
struct list_head *list)
{
struct __dns_record_entry *entry;
size_t entry_size = sizeof (struct __dns_record_entry) + rlen;
entry = (struct __dns_record_entry *)malloc(entry_size);
if (!entry)
return -1;
entry->record.name = strdup(name);
if (!entry->record.name)
{
free(entry);
return -1;
}
entry->record.type = type;
entry->record.rclass = rclass;
entry->record.ttl = ttl;
entry->record.rdlength = rlen;
entry->record.rdata = (void *)(entry + 1);
memcpy(entry->record.rdata, rdata, rlen);
list_add_tail(&entry->entry_list, list);
return 0;
}
int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass,
uint32_t ttl, const char *rdata,
struct list_head *list)
{
size_t rlen = strlen(rdata);
// record.rdlength has no meaning for parsed record types, ignore its
// correctness, same for soa/srv/mx record
return dns_add_raw_record(name, type, rclass, ttl, rlen+1, rdata, list);
}
int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl,
const char *mname, const char *rname,
uint32_t serial, int32_t refresh,
int32_t retry, int32_t expire, uint32_t minimum,
struct list_head *list)
{
struct __dns_record_entry *entry;
struct dns_record_soa *soa;
size_t entry_size;
char *pname, *pmname, *prname;
entry_size = sizeof (struct __dns_record_entry) +
sizeof (struct dns_record_soa);
entry = (struct __dns_record_entry *)malloc(entry_size);
if (!entry)
return -1;
entry->record.rdata = (void *)(entry + 1);
entry->record.rdlength = 0;
soa = (struct dns_record_soa *)(entry->record.rdata);
pname = strdup(name);
pmname = strdup(mname);
prname = strdup(rname);
if (!pname || !pmname || !prname)
{
free(pname);
free(pmname);
free(prname);
free(entry);
return -1;
}
soa->mname = pmname;
soa->rname = prname;
soa->serial = serial;
soa->refresh = refresh;
soa->retry = retry;
soa->expire = expire;
soa->minimum = minimum;
entry->record.name = pname;
entry->record.type = DNS_TYPE_SOA;
entry->record.rclass = rclass;
entry->record.ttl = ttl;
list_add_tail(&entry->entry_list, list);
return 0;
}
int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl,
uint16_t priority, uint16_t weight,
uint16_t port, const char *target,
struct list_head *list)
{
struct __dns_record_entry *entry;
struct dns_record_srv *srv;
size_t entry_size;
char *pname, *ptarget;
entry_size = sizeof (struct __dns_record_entry) +
sizeof (struct dns_record_srv);
entry = (struct __dns_record_entry *)malloc(entry_size);
if (!entry)
return -1;
entry->record.rdata = (void *)(entry + 1);
entry->record.rdlength = 0;
srv = (struct dns_record_srv *)(entry->record.rdata);
pname = strdup(name);
ptarget = strdup(target);
if (!pname || !ptarget)
{
free(pname);
free(ptarget);
free(entry);
return -1;
}
srv->priority = priority;
srv->weight = weight;
srv->port = port;
srv->target = ptarget;
entry->record.name = pname;
entry->record.type = DNS_TYPE_SRV;
entry->record.rclass = rclass;
entry->record.ttl = ttl;
list_add_tail(&entry->entry_list, list);
return 0;
}
int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl,
int16_t preference, const char *exchange,
struct list_head *list)
{
struct __dns_record_entry *entry;
struct dns_record_mx *mx;
size_t entry_size;
char *pname, *pexchange;
entry_size = sizeof (struct __dns_record_entry) +
sizeof (struct dns_record_mx);
entry = (struct __dns_record_entry *)malloc(entry_size);
if (!entry)
return -1;
entry->record.rdata = (void *)(entry + 1);
entry->record.rdlength = 0;
mx = (struct dns_record_mx *)(entry->record.rdata);
pname = strdup(name);
pexchange = strdup(exchange);
if (!pname || !pexchange)
{
free(pname);
free(pexchange);
free(entry);
return -1;
}
mx->preference = preference;
mx->exchange = pexchange;
entry->record.name = pname;
entry->record.type = DNS_TYPE_MX;
entry->record.rclass = rclass;
entry->record.ttl = ttl;
list_add_tail(&entry->entry_list, list);
return 0;
}
const char *dns_type2str(int type)
{
switch (type)

View File

@@ -78,11 +78,19 @@ enum
DNS_RCODE_REFUSED
};
enum
{
DNS_ANSWER_SECTION = 1,
DNS_AUTHORITY_SECTION = 2,
DNS_ADDITIONAL_SECTION = 3,
};
/**
* dns_header_t is a struct to describe the header of a dns
* request or response packet, but the byte order is not
* transformed.
*/
#pragma pack(1)
struct dns_header
{
uint16_t id;
@@ -112,6 +120,7 @@ struct dns_header
uint16_t nscount;
uint16_t arcount;
};
#pragma pack()
struct dns_question
{
@@ -205,6 +214,29 @@ int dns_record_cursor_find_cname(const char *name,
const char **cname,
dns_record_cursor_t *cursor);
int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass,
uint32_t ttl, uint16_t rlen, const void *rdata,
struct list_head *list);
int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass,
uint32_t ttl, const char *rdata,
struct list_head *list);
int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl,
const char *mname, const char *rname,
uint32_t serial, int32_t refresh,
int32_t retry, int32_t expire, uint32_t minimum,
struct list_head *list);
int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl,
uint16_t priority, uint16_t weight,
uint16_t port, const char *target,
struct list_head *list);
int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl,
int16_t preference, const char *exchange,
struct list_head *list);
const char *dns_type2str(int type);
const char *dns_class2str(int dnsclass);
const char *dns_opcode2str(int opcode);