diff --git a/src/client/WFKafkaClient.cc b/src/client/WFKafkaClient.cc index 217f0cd2..8339aed9 100644 --- a/src/client/WFKafkaClient.cc +++ b/src/client/WFKafkaClient.cc @@ -77,6 +77,7 @@ public: enum TransportType transport_type; std::string scheme; std::vector broker_hosts; + SSL_CTX *ssl_ctx; KafkaCgroup cgroup; KafkaMetaList meta_list; KafkaBrokerMap broker_map; @@ -192,7 +193,7 @@ private: int dispatch_locked(); - inline KafkaBroker *get_broker(int node_id) + KafkaBroker *get_broker(int node_id) { return this->member->broker_map.find_item(node_id); } @@ -294,7 +295,7 @@ void KafkaClientTask::kafka_rebalance_callback(__WFKafkaTask *task) kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), - "", 0, + member->ssl_ctx, "", 0, kafka_heartbeat_callback); kafka_task->user_data = member; kafka_task->get_req()->set_api_type(Kafka_Heartbeat); @@ -327,7 +328,7 @@ void KafkaClientTask::kafka_rebalance_proc(KafkaMember *member, SeriesWork *seri task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), - "", 0, + member->ssl_ctx, "", 0, kafka_rebalance_callback); task->user_data = member; task->get_req()->set_config(member->config); @@ -392,7 +393,7 @@ void KafkaClientTask::kafka_timer_callback(WFTimerTask *task) kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), - "", 0, + member->ssl_ctx, "", 0, kafka_heartbeat_callback); kafka_task->user_data = member; @@ -496,69 +497,70 @@ void KafkaClientTask::kafka_meta_callback(__WFKafkaTask *task) void KafkaClientTask::kafka_cgroup_callback(__WFKafkaTask *task) { KafkaClientTask *t = (KafkaClientTask *)task->user_data; + KafkaMember *member = t->member; SeriesWork *heartbeat_series = NULL; void *msg = NULL; size_t max; - t->member->mutex.lock(); + member->mutex.lock(); t->state = task->get_state(); t->error = task->get_error(); t->kafka_error = *static_cast(task)->get_mutable_ctx(); if (t->state == WFT_STATE_SUCCESS) { - t->member->cgroup = std::move(*(task->get_resp()->get_cgroup())); + member->cgroup = std::move(*(task->get_resp()->get_cgroup())); - kafka_merge_meta_list(&t->member->meta_list, + kafka_merge_meta_list(&member->meta_list, task->get_resp()->get_meta_list()); t->meta_list.rewind(); KafkaMeta *meta; while ((meta = t->meta_list.get_next()) != NULL) - (t->member->meta_status)[meta->get_topic()] = true; + (member->meta_status)[meta->get_topic()] = true; - kafka_merge_broker_list(t->member->scheme, - &t->member->broker_hosts, - &t->member->broker_map, + kafka_merge_broker_list(member->scheme, + &member->broker_hosts, + &member->broker_map, task->get_resp()->get_broker_list()); - t->member->cgroup_status = KAFKA_CGROUP_DONE; + member->cgroup_status = KAFKA_CGROUP_DONE; - if (t->member->heartbeat_status == KAFKA_HEARTBEAT_UNINIT) + if (member->heartbeat_status == KAFKA_HEARTBEAT_UNINIT) { __WFKafkaTask *kafka_task; - KafkaBroker *coordinator = t->member->cgroup.get_coordinator(); - kafka_task = __WFKafkaTaskFactory::create_kafka_task(t->member->transport_type, + KafkaBroker *coordinator = member->cgroup.get_coordinator(); + kafka_task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), - "", 0, + member->ssl_ctx, "", 0, kafka_heartbeat_callback); - kafka_task->user_data = t->member; - t->member->incref(); + kafka_task->user_data = member; + member->incref(); - kafka_task->get_req()->set_config(t->member->config); + kafka_task->get_req()->set_config(member->config); kafka_task->get_req()->set_api_type(Kafka_Heartbeat); - kafka_task->get_req()->set_cgroup(t->member->cgroup); + kafka_task->get_req()->set_cgroup(member->cgroup); kafka_task->get_req()->set_broker(*coordinator); heartbeat_series = Workflow::create_series_work(kafka_task, nullptr); - t->member->heartbeat_status = KAFKA_HEARTBEAT_DOING; - t->member->heartbeat_series = heartbeat_series; + member->heartbeat_status = KAFKA_HEARTBEAT_DOING; + member->heartbeat_series = heartbeat_series; } } else { - t->member->cgroup_status = KAFKA_CGROUP_UNINIT; - t->member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; - t->member->heartbeat_series = NULL; + member->cgroup_status = KAFKA_CGROUP_UNINIT; + member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; + member->heartbeat_series = NULL; t->finish = true; msg = t; } - max = t->member->cgroup_wait_cnt; + max = member->cgroup_wait_cnt; char name[64]; - snprintf(name, 64, "%p.cgroup", t->member); - t->member->mutex.unlock(); + snprintf(name, 64, "%p.cgroup", member); + member->mutex.unlock(); WFTaskFactory::signal_by_name(name, msg, max); @@ -789,16 +791,17 @@ bool KafkaClientTask::compare_topics(KafkaClientTask *task) bool KafkaClientTask::check_cgroup() { - if (this->member->cgroup_outdated && - this->member->cgroup_status != KAFKA_CGROUP_DOING) + KafkaMember *member = this->member; + + if (member->cgroup_outdated && member->cgroup_status != KAFKA_CGROUP_DOING) { - this->member->cgroup_outdated = false; - this->member->cgroup_status = KAFKA_CGROUP_UNINIT; - this->member->heartbeat_series = NULL; - this->member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; + member->cgroup_outdated = false; + member->cgroup_status = KAFKA_CGROUP_UNINIT; + member->heartbeat_series = NULL; + member->heartbeat_status = KAFKA_HEARTBEAT_UNINIT; } - if (this->member->cgroup_status == KAFKA_CGROUP_DOING) + if (member->cgroup_status == KAFKA_CGROUP_DOING) { WFConditional *cond; char name[64]; @@ -806,27 +809,27 @@ bool KafkaClientTask::check_cgroup() this->wait_cgroup = true; cond = WFTaskFactory::create_conditional(name, this, &this->msg); series_of(this)->push_front(cond); - this->member->cgroup_wait_cnt++; + member->cgroup_wait_cnt++; return false; } if ((this->api_type == Kafka_Fetch || this->api_type == Kafka_OffsetCommit) && - (this->member->cgroup_status == KAFKA_CGROUP_UNINIT)) + (member->cgroup_status == KAFKA_CGROUP_UNINIT)) { __WFKafkaTask *task; - task = __WFKafkaTaskFactory::create_kafka_task(this->url, + task = __WFKafkaTaskFactory::create_kafka_task(this->url, member->ssl_ctx, this->retry_max, kafka_cgroup_callback); task->user_data = this; task->get_req()->set_config(this->config); task->get_req()->set_api_type(Kafka_FindCoordinator); - task->get_req()->set_cgroup(this->member->cgroup); - task->get_req()->set_meta_list(this->member->meta_list); + task->get_req()->set_cgroup(member->cgroup); + task->get_req()->set_meta_list(member->meta_list); series_of(this)->push_front(this); series_of(this)->push_front(task); - this->member->cgroup_status = KAFKA_CGROUP_DOING; - this->member->cgroup_wait_cnt = 0; + member->cgroup_status = KAFKA_CGROUP_DOING; + member->cgroup_wait_cnt = 0; return false; } @@ -835,12 +838,13 @@ bool KafkaClientTask::check_cgroup() bool KafkaClientTask::check_meta() { + KafkaMember *member = this->member; KafkaMetaList *uninit_meta_list; if (this->get_meta_status(&uninit_meta_list)) return true; - if (this->member->meta_doing) + if (member->meta_doing) { WFConditional *cond; char name[64]; @@ -848,13 +852,13 @@ bool KafkaClientTask::check_meta() this->wait_cgroup = false; cond = WFTaskFactory::create_conditional(name, this, &this->msg); series_of(this)->push_front(cond); - this->member->meta_wait_cnt++; + member->meta_wait_cnt++; } else { __WFKafkaTask *task; - task = __WFKafkaTaskFactory::create_kafka_task(this->url, + task = __WFKafkaTaskFactory::create_kafka_task(this->url, member->ssl_ctx, this->retry_max, kafka_meta_callback); task->user_data = this; @@ -863,8 +867,8 @@ bool KafkaClientTask::check_meta() task->get_req()->set_meta_list(*uninit_meta_list); series_of(this)->push_front(this); series_of(this)->push_front(task); - this->member->meta_wait_cnt = 0; - this->member->meta_doing = true; + member->meta_wait_cnt = 0; + member->meta_doing = true; } delete uninit_meta_list; @@ -921,6 +925,7 @@ int KafkaClientTask::dispatch_locked() task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, broker->get_host(), broker->get_port(), + member->ssl_ctx, this->get_userinfo(), this->retry_max, std::move(cb)); @@ -956,6 +961,7 @@ int KafkaClientTask::dispatch_locked() task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, broker->get_host(), broker->get_port(), + member->ssl_ctx, this->get_userinfo(), this->retry_max, std::move(cb)); @@ -991,6 +997,7 @@ int KafkaClientTask::dispatch_locked() task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), + member->ssl_ctx, this->get_userinfo(), this->retry_max, kafka_offsetcommit_callback); @@ -1020,6 +1027,7 @@ int KafkaClientTask::dispatch_locked() task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, coordinator->get_host(), coordinator->get_port(), + member->ssl_ctx, this->get_userinfo(), 0, kafka_leavegroup_callback); task->user_data = this; @@ -1051,6 +1059,7 @@ int KafkaClientTask::dispatch_locked() task = __WFKafkaTaskFactory::create_kafka_task(member->transport_type, broker->get_host(), broker->get_port(), + member->ssl_ctx, this->get_userinfo(), this->retry_max, std::move(cb)); @@ -1580,7 +1589,7 @@ SubTask *WFKafkaTask::done() return series->pop(); } -int WFKafkaClient::init(const std::string& broker) +int WFKafkaClient::init(const std::string& broker, SSL_CTX *ssl_ctx) { std::vector broker_hosts; std::string::size_type ppos = 0; @@ -1620,6 +1629,7 @@ int WFKafkaClient::init(const std::string& broker) this->member = new KafkaMember; this->member->broker_hosts = std::move(broker_hosts); + this->member->ssl_ctx = ssl_ctx; if (use_ssl) { this->member->transport_type = TT_TCP_SSL; @@ -1629,9 +1639,10 @@ int WFKafkaClient::init(const std::string& broker) return 0; } -int WFKafkaClient::init(const std::string& broker, const std::string& group) +int WFKafkaClient::init(const std::string& broker, const std::string& group, + SSL_CTX *ssl_ctx) { - if (this->init(broker) < 0) + if (this->init(broker, ssl_ctx) < 0) return -1; this->member->cgroup.set_group(group); @@ -1652,8 +1663,7 @@ WFKafkaTask *WFKafkaClient::create_kafka_task(const std::string& query, int retry_max, kafka_callback_t cb) { - WFKafkaTask *task = new KafkaClientTask(query, retry_max, std::move(cb), - this); + WFKafkaTask *task = new KafkaClientTask(query, retry_max, std::move(cb), this); return task; } diff --git a/src/client/WFKafkaClient.h b/src/client/WFKafkaClient.h index b7777f14..61ec25e3 100644 --- a/src/client/WFKafkaClient.h +++ b/src/client/WFKafkaClient.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "WFTask.h" #include "KafkaMessage.h" #include "KafkaResult.h" @@ -145,9 +146,22 @@ public: // example: kafka://kafka.sogou // example: kafka.sogou:9090 // example: kafka://10.160.23.23:9000,10.123.23.23,kafka://kafka.sogou - int init(const std::string& broker_url); + // example: kafkas://kafka.sogou -> kafka over TLS + int init(const std::string& broker_url) + { + return this->init(broker_url, NULL); + } - int init(const std::string& broker_url, const std::string& group); + int init(const std::string& broker_url, const std::string& group) + { + return this->init(broker_url, group, NULL); + } + + // With a specific SSL_CTX. Effective only on brokers over TLS. + int init(const std::string& broker_url, SSL_CTX *ssl_ctx); + + int init(const std::string& broker_url, const std::string& group, + SSL_CTX *ssl_ctx); int deinit(); diff --git a/src/client/WFMySQLConnection.cc b/src/client/WFMySQLConnection.cc index d93694e6..ff66a57a 100644 --- a/src/client/WFMySQLConnection.cc +++ b/src/client/WFMySQLConnection.cc @@ -16,6 +16,7 @@ Author: Xie Han (xiehan@sogou-inc.com) */ +#include #include #include #include @@ -23,7 +24,7 @@ #include "URIParser.h" #include "WFMySQLConnection.h" -int WFMySQLConnection::init(const std::string& url) +int WFMySQLConnection::init(const std::string& url, SSL_CTX *ssl_ctx) { std::string query; ParsedURI uri; @@ -42,9 +43,12 @@ int WFMySQLConnection::init(const std::string& url) if (uri.query) { this->uri = std::move(uri); + this->ssl_ctx = ssl_ctx; return 0; } } + else if (uri.state == URI_STATE_INVALID) + errno = EINVAL; return -1; } diff --git a/src/client/WFMySQLConnection.h b/src/client/WFMySQLConnection.h index ef0ee033..6b7566b9 100644 --- a/src/client/WFMySQLConnection.h +++ b/src/client/WFMySQLConnection.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "URIParser.h" #include "WFTaskFactory.h" @@ -31,20 +32,50 @@ public: /* example: mysql://username:passwd@127.0.0.1/dbname?character_set=utf8 * IP string is recommmended in url. When using a domain name, the first * address resovled will be used. Don't use upstream name as a host. */ - int init(const std::string& url); + int init(const std::string& url) + { + return init(url, NULL); + } + + int init(const std::string& url, SSL_CTX *ssl_ctx); + + void deinit() { } public: WFMySQLTask *create_query_task(const std::string& query, - mysql_callback_t callback); + mysql_callback_t callback) + { + WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0, + std::move(callback)); + this->set_ssl_ctx(task); + task->get_req()->set_query(query); + return task; + } -public: /* If you don't disconnect manually, the TCP connection will be * kept alive after this object is deleted, and maybe reused by * another WFMySQLConnection object with same id and url. */ - WFMySQLTask *create_disconnect_task(mysql_callback_t callback); + WFMySQLTask *create_disconnect_task(mysql_callback_t callback) + { + WFMySQLTask *task = this->create_query_task("", std::move(callback)); + this->set_ssl_ctx(task); + task->set_keep_alive(0); + return task; + } + +protected: + void set_ssl_ctx(WFMySQLTask *task) const + { + using MySQLRequest = protocol::MySQLRequest; + using MySQLResponse = protocol::MySQLResponse; + auto *t = (WFComplexClientTask *)task; + /* 'ssl_ctx' can be NULL and will use default. */ + t->set_ssl_ctx(this->ssl_ctx); + } protected: ParsedURI uri; + SSL_CTX *ssl_ctx; int id; public: @@ -54,25 +85,5 @@ public: virtual ~WFMySQLConnection() { } }; -inline WFMySQLTask * -WFMySQLConnection::create_query_task(const std::string& query, - mysql_callback_t callback) -{ - WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0, - std::move(callback)); - task->get_req()->set_query(query); - return task; -} - -inline WFMySQLTask * -WFMySQLConnection::create_disconnect_task(mysql_callback_t callback) -{ - WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0, - std::move(callback)); - task->get_req()->set_query(""); - task->set_keep_alive(0); - return task; -} - #endif diff --git a/src/factory/DnsTaskImpl.cc b/src/factory/DnsTaskImpl.cc index 46de4c89..18b70494 100644 --- a/src/factory/DnsTaskImpl.cc +++ b/src/factory/DnsTaskImpl.cc @@ -17,9 +17,11 @@ */ #include +#include +#include "DnsMessage.h" #include "WFTaskError.h" #include "WFTaskFactory.h" -#include "DnsMessage.h" +#include "WFServer.h" using namespace protocol; @@ -31,6 +33,7 @@ class ComplexDnsTask : public WFComplexClientTask> { static struct addrinfo hints; + static std::atomic seq; public: ComplexDnsTask(int retry_max, dns_callback_t&& cb): @@ -54,24 +57,21 @@ private: struct addrinfo ComplexDnsTask::hints = { - /*.ai_flags =*/ AI_NUMERICSERV | AI_NUMERICHOST, - /*.ai_family =*/ AF_UNSPEC, - /*.ai_socktype =*/ SOCK_STREAM, - /*.ai_protocol =*/ 0, - /*.ai_addrlen =*/ 0, - /*.ai_addr =*/ NULL, - /*.ai_canonname =*/ NULL, - /*.ai_next =*/ NULL + /*.ai_flags =*/ AI_NUMERICSERV | AI_NUMERICHOST, + /*.ai_family =*/ AF_UNSPEC, + /*.ai_socktype =*/ SOCK_STREAM }; +std::atomic ComplexDnsTask::seq(0); + CommMessageOut *ComplexDnsTask::message_out() { DnsRequest *req = this->get_req(); DnsResponse *resp = this->get_resp(); - TransportType type = this->get_transport_type(); + enum TransportType type = this->get_transport_type(); if (req->get_id() == 0) - req->set_id((this->get_seq() + 1) * 99991 % 65535 + 1); + req->set_id(++ComplexDnsTask::seq * 99991 % 65535 + 1); resp->set_request_id(req->get_id()); resp->set_request_name(req->get_question_name()); req->set_single_packet(type == TT_UDP); @@ -93,21 +93,22 @@ bool ComplexDnsTask::init_success() if (!this->route_result_.request_object) { - TransportType type = this->get_transport_type(); + enum TransportType type = this->get_transport_type(); struct addrinfo *addr; int ret; ret = getaddrinfo(uri_.host, uri_.port, &hints, &addr); if (ret != 0) { - this->state = WFT_STATE_TASK_ERROR; - this->error = WFT_ERR_URI_PARSE_FAILED; + this->state = WFT_STATE_DNS_ERROR; + this->error = ret; return false; } auto *ep = &WFGlobal::get_global_settings()->dns_server_params; ret = WFGlobal::get_route_manager()->get(type, addr, info_, ep, - uri_.host, route_result_); + uri_.host, ssl_ctx_, + route_result_); freeaddrinfo(addr); if (ret < 0) { @@ -146,7 +147,7 @@ bool ComplexDnsTask::finish_once() bool ComplexDnsTask::need_redirect() { DnsResponse *client_resp = this->get_resp(); - TransportType type = this->get_transport_type(); + enum TransportType type = this->get_transport_type(); if (type == TT_UDP && client_resp->get_tc() == 1) { @@ -189,3 +190,42 @@ WFDnsTask *WFTaskFactory::create_dns_task(const ParsedURI& uri, return task; } + +/**********Server**********/ + +class WFDnsServerTask : public WFServerTask +{ +public: + WFDnsServerTask(CommService *service, + std::function& proc) : + WFServerTask(service, WFGlobal::get_scheduler(), proc) + { + // this->type = ((WFServerBase *)service)->get_params()->transport_type; + this->type = TT_TCP; + } + +protected: + virtual CommMessageIn *message_in() + { + this->get_req()->set_single_packet(this->type == TT_UDP); + return this->WFServerTask::message_in(); + } + + virtual CommMessageOut *message_out() + { + this->get_resp()->set_single_packet(this->type == TT_UDP); + return this->WFServerTask::message_out(); + } + +protected: + enum TransportType type; +}; + +/**********Server Factory**********/ + +WFDnsTask *WFServerTaskFactory::create_dns_task(CommService *service, + std::function& proc) +{ + return new WFDnsServerTask(service, proc); +} + diff --git a/src/factory/HttpTaskImpl.cc b/src/factory/HttpTaskImpl.cc index 4e612659..73c8dba7 100644 --- a/src/factory/HttpTaskImpl.cc +++ b/src/factory/HttpTaskImpl.cc @@ -484,7 +484,8 @@ private: int ComplexHttpProxyTask::init_ssl_connection() { - SSL *ssl = __create_ssl(WFGlobal::get_ssl_client_ctx()); + static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx(); + SSL *ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx); WFConnection *conn; if (!ssl) diff --git a/src/factory/KafkaTaskImpl.cc b/src/factory/KafkaTaskImpl.cc index 30b71bad..ccd45304 100644 --- a/src/factory/KafkaTaskImpl.cc +++ b/src/factory/KafkaTaskImpl.cc @@ -19,9 +19,9 @@ #include #include +#include #include #include -#include #include #include #include "StringUtil.h" @@ -715,12 +715,14 @@ bool __ComplexKafkaTask::finish_once() /**********Factory**********/ // kafka://user:password:sasl@host:port/api=type&topic=name __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const std::string& url, + SSL_CTX *ssl_ctx, int retry_max, __kafka_callback_t callback) { auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); - ParsedURI uri; + task->set_ssl_ctx(ssl_ctx); + ParsedURI uri; URIParser::parse(url, uri); task->init(std::move(uri)); task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); @@ -728,10 +730,12 @@ __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const std::string& url, } __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const ParsedURI& uri, + SSL_CTX *ssl_ctx, int retry_max, __kafka_callback_t callback) { auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); + task->set_ssl_ctx(ssl_ctx); task->init(uri); task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); @@ -741,22 +745,36 @@ __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(const ParsedURI& uri, __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(enum TransportType type, const char *host, unsigned short port, + SSL_CTX *ssl_ctx, const std::string& info, int retry_max, __kafka_callback_t callback) { auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); - - std::string url = (type == TT_TCP_SSL ? "kafkas://" : "kafka://"); - - if (!info.empty()) - url += info + "@"; - - url += host; - url += ":" + std::to_string(port); + task->set_ssl_ctx(ssl_ctx); ParsedURI uri; - URIParser::parse(url, uri); + char buf[32]; + + if (type == TT_TCP_SSL) + uri.scheme = strdup("kafkas"); + else + uri.scheme = strdup("kafka"); + + if (!info.empty()) + uri.userinfo = strdup(info.c_str()); + + uri.host = strdup(host); + sprintf(buf, "%u", port); + uri.port = strdup(buf); + + if (!uri.scheme || !uri.host || !uri.port || + (!info.empty() && !uri.userinfo)) + { + uri.state = URI_STATE_ERROR; + uri.error = errno; + } + task->init(std::move(uri)); task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); return task; diff --git a/src/factory/KafkaTaskImpl.inl b/src/factory/KafkaTaskImpl.inl index 9b600444..83ec4c78 100644 --- a/src/factory/KafkaTaskImpl.inl +++ b/src/factory/KafkaTaskImpl.inl @@ -16,6 +16,7 @@ Authors: Wang Zhulei (wangzhulei@sogou-inc.com) */ +#include #include "WFTaskFactory.h" #include "KafkaMessage.h" @@ -32,16 +33,19 @@ public: * user task. */ static __WFKafkaTask *create_kafka_task(const ParsedURI& uri, + SSL_CTX *ssl_ctx, int retry_max, __kafka_callback_t callback); static __WFKafkaTask *create_kafka_task(const std::string& url, + SSL_CTX *ssl_ctx, int retry_max, __kafka_callback_t callback); static __WFKafkaTask *create_kafka_task(enum TransportType type, const char *host, unsigned short port, + SSL_CTX *ssl_ctx, const std::string& info, int retry_max, __kafka_callback_t callback); diff --git a/src/factory/MySQLTaskImpl.cc b/src/factory/MySQLTaskImpl.cc index c7134874..2d62bd0e 100644 --- a/src/factory/MySQLTaskImpl.cc +++ b/src/factory/MySQLTaskImpl.cc @@ -241,7 +241,7 @@ CommMessageOut *ComplexMySQLTask::message_out() break; case ST_FIRST_USER_REQUEST: - if (this->is_fixed_addr()) + if (this->is_fixed_conn()) { auto *target = (RouteManager::RouteTarget *)this->target; @@ -350,7 +350,21 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp) if (is_ssl_) { - if (!(resp->get_capability_flags() & 0x800)) + if (resp->get_capability_flags() & 0x800) + { + static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx(); + + ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx); + if (!ssl) + { + state_ = WFT_STATE_SYS_ERROR; + error_ = errno; + return 0; + } + + SSL_set_connect_state(ssl); + } + else { this->resp = std::move(*(MySQLResponse *)resp); state_ = WFT_STATE_TASK_ERROR; @@ -358,15 +372,6 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp) return 0; } - ssl = __create_ssl(WFGlobal::get_ssl_client_ctx()); - if (!ssl) - { - state_ = WFT_STATE_SYS_ERROR; - error_ = errno; - return 0; - } - - SSL_set_connect_state(ssl); } auto *conn = this->get_connection(); @@ -712,9 +717,9 @@ bool ComplexMySQLTask::init_success() if (!transaction.empty()) { - this->WFComplexClientTask::set_info(std::string("?maxconn=1&") + - info + "|txn:" + transaction); this->set_fixed_addr(true); + this->set_fixed_conn(true); + this->WFComplexClientTask::set_info(info + ("|txn:" + transaction)); } else this->WFComplexClientTask::set_info(info); @@ -741,7 +746,7 @@ bool ComplexMySQLTask::finish_once() return false; } - if (this->is_fixed_addr()) + if (this->is_fixed_conn()) { if (this->state != WFT_STATE_SUCCESS || this->keep_alive_timeo == 0) { @@ -767,7 +772,7 @@ WFMySQLTask *WFTaskFactory::create_mysql_task(const std::string& url, URIParser::parse(url, uri); task->init(std::move(uri)); - if (task->is_fixed_addr()) + if (task->is_fixed_conn()) task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION); else task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT); @@ -782,7 +787,7 @@ WFMySQLTask *WFTaskFactory::create_mysql_task(const ParsedURI& uri, auto *task = new ComplexMySQLTask(retry_max, std::move(callback)); task->init(uri); - if (task->is_fixed_addr()) + if (task->is_fixed_conn()) task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION); else task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT); diff --git a/src/factory/WFTaskFactory.h b/src/factory/WFTaskFactory.h index 48573c21..ced15df7 100644 --- a/src/factory/WFTaskFactory.h +++ b/src/factory/WFTaskFactory.h @@ -407,6 +407,13 @@ public: int retry_max, std::function callback); + static T *create_client_task(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + SSL_CTX *ssl_ctx, + int retry_max, + std::function callback); + public: static T *create_server_task(CommService *service, std::function& process); diff --git a/src/factory/WFTaskFactory.inl b/src/factory/WFTaskFactory.inl index 37763b11..dd90d49c 100644 --- a/src/factory/WFTaskFactory.inl +++ b/src/factory/WFTaskFactory.inl @@ -26,6 +26,7 @@ #include #include #include +#include #include "PlatformSocket.h" #include "WFGlobal.h" #include "Workflow.h" @@ -73,7 +74,9 @@ public: WFClientTask(NULL, WFGlobal::get_scheduler(), std::move(cb)) { type_ = TT_TCP; + ssl_ctx_ = NULL; fixed_addr_ = false; + fixed_conn_ = false; retry_max_ = retry_max; retry_times_ = 0; redirect_ = false; @@ -102,17 +105,19 @@ public: init_with_uri(); } - void init(TransportType type, + void init(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, const std::string& info); - void set_transport_type(TransportType type) + void set_transport_type(enum TransportType type) { type_ = type; } - TransportType get_transport_type() const { return type_; } + enum TransportType get_transport_type() const { return type_; } + + void set_ssl_ctx(SSL_CTX *ssl_ctx) { ssl_ctx_ = ssl_ctx; } virtual const ParsedURI *get_current_uri() const { return &uri_; } @@ -122,7 +127,7 @@ public: init(uri); } - void set_redirect(TransportType type, const struct sockaddr *addr, + void set_redirect(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, const std::string& info) { redirect_ = true; @@ -131,9 +136,13 @@ public: bool is_fixed_addr() const { return this->fixed_addr_; } + bool is_fixed_conn() const { return this->fixed_conn_; } + protected: void set_fixed_addr(int fixed) { this->fixed_addr_ = fixed; } + void set_fixed_conn(int fixed) { this->fixed_conn_ = fixed; } + void set_info(const std::string& info) { info_.assign(info); @@ -163,10 +172,12 @@ protected: } protected: - TransportType type_; + enum TransportType type_; ParsedURI uri_; std::string info_; + SSL_CTX *ssl_ctx_; bool fixed_addr_; + bool fixed_conn_; bool redirect_; CTX ctx_; int retry_max_; @@ -205,7 +216,7 @@ void WFComplexClientTask::clear_prev_state() } template -void WFComplexClientTask::init(TransportType type, +void WFComplexClientTask::init(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, const std::string& info) @@ -216,7 +227,6 @@ void WFComplexClientTask::init(TransportType type, auto params = WFGlobal::get_global_settings()->endpoint_params; struct addrinfo addrinfo = { }; addrinfo.ai_family = addr->sa_family; - addrinfo.ai_socktype = SOCK_STREAM; addrinfo.ai_addr = (struct sockaddr *)addr; addrinfo.ai_addrlen = addrlen; @@ -224,7 +234,7 @@ void WFComplexClientTask::init(TransportType type, info_.assign(info); params.use_tls_sni = false; if (WFGlobal::get_route_manager()->get(type, &addrinfo, info_, ¶ms, - "", route_result_) < 0) + "", ssl_ctx_, route_result_) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; @@ -277,10 +287,10 @@ template void WFComplexClientTask::init_with_uri() { if (redirect_) - { - clear_prev_state(); - ns_policy_ = WFGlobal::get_dns_resolver(); - } + { + clear_prev_state(); + ns_policy_ = WFGlobal::get_dns_resolver(); + } if (uri_.state == URI_STATE_SUCCESS) { @@ -311,12 +321,14 @@ WFRouterTask *WFComplexClientTask::route() this, std::placeholders::_1); struct WFNSParams params = { - /*.type =*/ type_, - /*.uri =*/ uri_, - /*.info =*/ info_.c_str(), - /*.fixed_addr =*/ fixed_addr_, - /*.retry_times =*/ retry_times_, - /*.tracing =*/ &tracing_, + /*.type =*/ type_, + /*.uri =*/ uri_, + /*.info =*/ info_.c_str(), + /*.ssl_ctx =*/ ssl_ctx_, + /*.fixed_addr =*/ fixed_addr_, + /*.fixed_conn =*/ fixed_conn_, + /*.retry_times =*/ retry_times_, + /*.tracing =*/ &tracing_, }; if (!ns_policy_) @@ -475,22 +487,26 @@ SubTask *WFComplexClientTask::done() template WFNetworkTask * -WFNetworkTaskFactory::create_client_task(TransportType type, +WFNetworkTaskFactory::create_client_task(enum TransportType type, const std::string& host, unsigned short port, int retry_max, std::function *)> callback) { auto *task = new WFComplexClientTask(retry_max, std::move(callback)); - char buf[8]; - std::string url = "scheme://"; ParsedURI uri; + char buf[32]; sprintf(buf, "%u", port); - url += host; - url += ":"; - url += buf; - URIParser::parse(url, uri); + uri.scheme = strdup("scheme"); + uri.host = strdup(host.c_str()); + uri.port = strdup(buf); + if (!uri.scheme || !uri.host || !uri.port) + { + uri.state = URI_STATE_ERROR; + uri.error = errno; + } + task->init(std::move(uri)); task->set_transport_type(type); return task; @@ -498,7 +514,7 @@ WFNetworkTaskFactory::create_client_task(TransportType type, template WFNetworkTask * -WFNetworkTaskFactory::create_client_task(TransportType type, +WFNetworkTaskFactory::create_client_task(enum TransportType type, const std::string& url, int retry_max, std::function *)> callback) @@ -514,7 +530,7 @@ WFNetworkTaskFactory::create_client_task(TransportType type, template WFNetworkTask * -WFNetworkTaskFactory::create_client_task(TransportType type, +WFNetworkTaskFactory::create_client_task(enum TransportType type, const ParsedURI& uri, int retry_max, std::function *)> callback) @@ -528,7 +544,7 @@ WFNetworkTaskFactory::create_client_task(TransportType type, template WFNetworkTask * -WFNetworkTaskFactory::create_client_task(TransportType type, +WFNetworkTaskFactory::create_client_task(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, int retry_max, @@ -540,6 +556,22 @@ WFNetworkTaskFactory::create_client_task(TransportType type, return task; } +template +WFNetworkTask * +WFNetworkTaskFactory::create_client_task(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + SSL_CTX *ssl_ctx, + int retry_max, + std::function *)> callback) +{ + auto *task = new WFComplexClientTask(retry_max, std::move(callback)); + + task->set_ssl_ctx(ssl_ctx); + task->init(type, addr, addrlen, ""); + return task; +} + template WFNetworkTask * WFNetworkTaskFactory::create_server_task(CommService *service, @@ -553,6 +585,9 @@ WFNetworkTaskFactory::create_server_task(CommService *service, class WFServerTaskFactory { public: + static WFDnsTask *create_dns_task(CommService *service, + std::function& proc); + static WFHttpTask *create_http_task(CommService *service, std::function& proc) { @@ -670,26 +705,24 @@ void WFTaskFactory::reset_go_task(WFGoTask *task, FUNC&& func, ARGS&&... args) { auto&& tmp = std::bind(std::forward(func), std::forward(args)...); - static_cast<__WFGoTask *>(task)->set_go_func(std::move(tmp)); + ((__WFGoTask *)task)->set_go_func(std::move(tmp)); } /**********Create go task with nullptr func**********/ -template<> -inline WFGoTask *WFTaskFactory::create_go_task - (const std::string& queue_name, - std::nullptr_t&& func) +template<> inline +WFGoTask *WFTaskFactory::create_go_task(const std::string& queue_name, + std::nullptr_t&& func) { return new __WFGoTask(WFGlobal::get_exec_queue(queue_name), WFGlobal::get_compute_executor(), nullptr); } -template<> -inline WFGoTask *WFTaskFactory::create_timedgo_task - (time_t seconds, long nanoseconds, - const std::string& queue_name, - std::nullptr_t&& func) +template<> inline +WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, + const std::string& queue_name, + std::nullptr_t&& func) { return new __WFTimedGoTask(seconds, nanoseconds, WFGlobal::get_exec_queue(queue_name), @@ -697,28 +730,25 @@ inline WFGoTask *WFTaskFactory::create_timedgo_task nullptr); } -template<> -inline WFGoTask *WFTaskFactory::create_go_task - (ExecQueue *queue, Executor *executor, - std::nullptr_t&& func) +template<> inline +WFGoTask *WFTaskFactory::create_go_task(ExecQueue *queue, Executor *executor, + std::nullptr_t&& func) { return new __WFGoTask(queue, executor, nullptr); } -template<> -inline WFGoTask *WFTaskFactory::create_timedgo_task - (time_t seconds, long nanoseconds, - ExecQueue *queue, Executor *executor, - std::nullptr_t&& func) +template<> inline +WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, + ExecQueue *queue, Executor *executor, + std::nullptr_t&& func) { return new __WFTimedGoTask(seconds, nanoseconds, queue, executor, nullptr); } -template<> -inline void WFTaskFactory::reset_go_task - (WFGoTask *task, std::nullptr_t&& func) +template<> inline +void WFTaskFactory::reset_go_task(WFGoTask *task, std::nullptr_t&& func) { - static_cast<__WFGoTask *>(task)->set_go_func(nullptr); + ((__WFGoTask *)task)->set_go_func(nullptr); } /**********Template Thread Task Factory**********/ diff --git a/src/kernel/Communicator.cc b/src/kernel/Communicator.cc index d81fc5e5..41c72bb4 100644 --- a/src/kernel/Communicator.cc +++ b/src/kernel/Communicator.cc @@ -74,8 +74,8 @@ static inline int __set_fd_nonblock(int fd) return flags; } -static int __bind_and_listen(int sockfd, const struct sockaddr *addr, - socklen_t addrlen) +static int __bind_sockaddr(int sockfd, const struct sockaddr *addr, + socklen_t addrlen) { struct sockaddr_storage ss; socklen_t len; @@ -97,7 +97,7 @@ static int __bind_and_listen(int sockfd, const struct sockaddr *addr, return -1; } - return listen(sockfd, SOMAXCONN < 4096 ? 4096 : SOMAXCONN); + return 0; } static int __create_ssl(SSL_CTX *ssl_ctx, struct CommConnEntry *entry) @@ -119,6 +119,48 @@ static int __create_ssl(SSL_CTX *ssl_ctx, struct CommConnEntry *entry) return -1; } +static int __send_to_conn(const void *buf, size_t size, + struct CommConnEntry *entry) +{ + const struct sockaddr *addr; + socklen_t addrlen; + int ret; + + if (!entry->ssl) + { + entry->target->get_addr(&addr, &addrlen); + return sendto(entry->sockfd, buf, size, 0, addr, addrlen); + } + + if (size == 0) + return 0; + + ret = SSL_write(entry->ssl, buf, size); + if (ret <= 0) + { + ret = SSL_get_error(entry->ssl, ret); + if (ret != SSL_ERROR_SYSCALL) + errno = -ret; + + ret = -1; + } + + return ret; +} + +static void __release_conn(struct CommConnEntry *entry) +{ + delete entry->conn; + if (!entry->service) + pthread_mutex_destroy(&entry->mutex); + + if (entry->ssl) + SSL_free(entry->ssl); + + close(entry->sockfd); + free(entry); +} + #define SSL_WRITE_BUFSIZE 8192 static int __ssl_writev(SSL *ssl, const struct iovec vectors[], int cnt) @@ -186,26 +228,7 @@ void CommTarget::deinit() int CommMessageIn::feedback(const void *buf, size_t size) { - struct CommConnEntry *entry = this->entry; - int ret; - - if (!entry->ssl) - return write(entry->sockfd, buf, size); - - if (size == 0) - return 0; - - ret = SSL_write(entry->ssl, buf, size); - if (ret <= 0) - { - ret = SSL_get_error(entry->ssl, ret); - if (ret != SSL_ERROR_SYSCALL) - errno = -ret; - - ret = -1; - } - - return ret; + return __send_to_conn(buf, size, this->entry); } void CommMessageIn::renew() @@ -305,6 +328,9 @@ public: } } +public: + int shutdown(); + private: int sockfd; int ref; @@ -322,36 +348,50 @@ private: friend class Communicator; }; -CommSession::~CommSession() +int CommServiceTarget::shutdown() { struct CommConnEntry *entry; - struct list_head *pos; - CommTarget *target; int errno_bak; + int ret = 0; - if (!this->passive) - return; - - target = this->target; - if (this->passive == 1) + pthread_mutex_lock(&this->mutex); + if (!list_empty(&this->idle_list)) { - pthread_mutex_lock(&target->mutex); - if (!list_empty(&target->idle_list)) - { - pos = target->idle_list.next; - entry = list_entry(pos, struct CommConnEntry, list); - list_del(pos); + entry = list_entry(this->idle_list.next, struct CommConnEntry, list); + list_del(&entry->list); + if (this->service->reliable) + { errno_bak = errno; mpoller_del(entry->sockfd, entry->mpoller); entry->state = CONN_STATE_CLOSING; errno = errno_bak; } + else + { + __release_conn(entry); + this->decref(); + } - pthread_mutex_unlock(&target->mutex); + ret = 1; } - ((CommServiceTarget *)target)->decref(); + pthread_mutex_unlock(&this->mutex); + return ret; +} + +CommSession::~CommSession() +{ + CommServiceTarget *target; + + if (!this->passive) + return; + + target = (CommServiceTarget *)this->target; + if (this->passive == 1) + target->shutdown(); + + target->decref(); } inline int Communicator::first_timeout(CommSession *session) @@ -404,19 +444,6 @@ int Communicator::first_timeout_recv(CommSession *session) return Communicator::first_timeout(session); } -void Communicator::release_conn(struct CommConnEntry *entry) -{ - delete entry->conn; - if (!entry->service) - pthread_mutex_destroy(&entry->mutex); - - if (entry->ssl) - SSL_free(entry->ssl); - - close(entry->sockfd); - free(entry); -} - void Communicator::shutdown_service(CommService *service) { close(service->listen_fd); @@ -670,7 +697,7 @@ void Communicator::handle_incoming_request(struct poller_result *res) if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { - this->release_conn(entry); + __release_conn(entry); ((CommServiceTarget *)target)->decref(); } } @@ -752,7 +779,7 @@ void Communicator::handle_incoming_reply(struct poller_result *res) } if (__sync_sub_and_fetch(&entry->ref, 1) == 0) - this->release_conn(entry); + __release_conn(entry); } } @@ -824,7 +851,7 @@ void Communicator::handle_reply_result(struct poller_result *res) session->handle(state, res->error); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { - this->release_conn(entry); + __release_conn(entry); ((CommServiceTarget *)target)->decref(); } @@ -877,7 +904,7 @@ void Communicator::handle_request_result(struct poller_result *res) /* do nothing */ pthread_mutex_unlock(&entry->mutex); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) - this->release_conn(entry); + __release_conn(entry); break; } @@ -910,7 +937,7 @@ struct CommConnEntry *Communicator::accept_conn(CommServiceTarget *target, if (entry->conn) { entry->seq = 0; - entry->mpoller = this->mpoller; + entry->mpoller = NULL; entry->service = service; entry->target = target; entry->ssl = NULL; @@ -996,7 +1023,7 @@ void Communicator::handle_connect_result(struct poller_result *res) target->release(); session->handle(state, res->error); - this->release_conn(entry); + __release_conn(entry); break; } } @@ -1012,9 +1039,10 @@ void Communicator::handle_listen_result(struct poller_result *res) { case PR_ST_SUCCESS: target = (CommServiceTarget *)res->data.result; - entry = this->accept_conn(target, service); + entry = Communicator::accept_conn(target, service); if (entry) { + entry->mpoller = this->mpoller; if (service->ssl_ctx) { if (__create_ssl(service->ssl_ctx, entry) >= 0 && @@ -1045,7 +1073,7 @@ void Communicator::handle_listen_result(struct poller_result *res) } } - this->release_conn(entry); + __release_conn(entry); } else close(target->sockfd); @@ -1064,6 +1092,54 @@ void Communicator::handle_listen_result(struct poller_result *res) } } +void Communicator::handle_recvfrom_result(struct poller_result *res) +{ + CommService *service = (CommService *)res->data.context; + struct CommConnEntry *entry; + CommTarget *target; + int state, error; + + switch (res->state) + { + case PR_ST_SUCCESS: + entry = (struct CommConnEntry *)res->data.result; + target = entry->target; + if (entry->state == CONN_STATE_SUCCESS) + { + state = CS_STATE_TOREPLY; + error = 0; + entry->state = CONN_STATE_IDLE; + list_add(&entry->list, &target->idle_list); + } + else + { + state = CS_STATE_ERROR; + if (entry->state == CONN_STATE_ERROR) + error = entry->error; + else + error = EBADMSG; + } + + entry->session->handle(state, error); + if (state == CS_STATE_ERROR) + { + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + + break; + + case PR_ST_DELETED: + this->shutdown_service(service); + break; + + case PR_ST_ERROR: + case PR_ST_STOPPED: + service->handle_stop(res->error); + break; + } +} + void Communicator::handle_ssl_accept_result(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; @@ -1087,7 +1163,7 @@ void Communicator::handle_ssl_accept_result(struct poller_result *res) case PR_ST_DELETED: case PR_ST_ERROR: case PR_ST_STOPPED: - this->release_conn(entry); + __release_conn(entry); ((CommServiceTarget *)target)->decref(); break; } @@ -1184,6 +1260,9 @@ void Communicator::handler_thread_routine(void *context) case PD_OP_LISTEN: comm->handle_listen_result(res); break; + case PD_OP_RECVFROM: + comm->handle_recvfrom_result(res); + break; case PD_OP_SSL_ACCEPT: comm->handle_ssl_accept_result(res); break; @@ -1343,6 +1422,58 @@ poller_message_t *Communicator::create_reply(void *context) return session->in; } +int Communicator::recv_request(const void *buf, size_t size, + struct CommConnEntry *entry) +{ + CommService *service = entry->service; + CommTarget *target = entry->target; + CommSession *session; + size_t n; + int ret; + + session = service->new_session(entry->seq, entry->conn); + if (!session) + return -1; + + session->passive = 1; + entry->session = session; + session->target = target; + session->conn = entry->conn; + session->seq = entry->seq++; + session->out = NULL; + session->in = NULL; + + entry->state = CONN_STATE_RECEIVING; + + ((CommServiceTarget *)target)->incref(); + + session->in = session->message_in(); + if (session->in) + { + session->in->entry = entry; + do + { + n = size; + ret = session->in->append(buf, &n); + if (ret == 0) + { + size -= n; + buf = (const char *)buf + n; + } + else if (ret < 0) + { + entry->error = errno; + entry->state = CONN_STATE_ERROR; + } + else + entry->state = CONN_STATE_SUCCESS; + + } while (ret == 0 && size > 0); + } + + return 0; +} + int Communicator::partial_written(size_t n, void *context) { struct CommConnEntry *entry = (struct CommConnEntry *)context; @@ -1378,6 +1509,40 @@ void *Communicator::accept(const struct sockaddr *addr, socklen_t addrlen, return NULL; } +void *Communicator::recvfrom(const struct sockaddr *addr, socklen_t addrlen, + const void *buf, size_t size, void *context) +{ + CommService *service = (CommService *)context; + struct CommConnEntry *entry; + CommServiceTarget *target; + void *result; + int sockfd; + + sockfd = dup(service->listen_fd); + if (sockfd >= 0) + { + result = Communicator::accept(addr, addrlen, sockfd, context); + if (result) + { + target = (CommServiceTarget *)result; + entry = Communicator::accept_conn(target, service); + if (entry) + { + if (Communicator::recv_request(buf, size, entry) >= 0) + return entry; + + __release_conn(entry); + } + else + close(sockfd); + + target->decref(); + } + } + + return NULL; +} + void Communicator::callback(struct poller_result *res, void *context) { msgqueue_t *msgqueue = (msgqueue_t *)context; @@ -1502,7 +1667,7 @@ struct CommConnEntry *Communicator::launch_conn(CommSession *session, int sockfd; int ret; - sockfd = this->nonblock_connect(target); + sockfd = Communicator::nonblock_connect(target); if (sockfd >= 0) { entry = (struct CommConnEntry *)malloc(sizeof (struct CommConnEntry)); @@ -1515,7 +1680,7 @@ struct CommConnEntry *Communicator::launch_conn(CommSession *session, if (entry->conn) { entry->seq = 0; - entry->mpoller = this->mpoller; + entry->mpoller = NULL; entry->service = NULL; entry->target = target; entry->session = session; @@ -1598,9 +1763,10 @@ int Communicator::request_new_conn(CommSession *session, CommTarget *target) struct poller_data data; int timeout; - entry = this->launch_conn(session, target); + entry = Communicator::launch_conn(session, target); if (entry) { + entry->mpoller = this->mpoller; session->conn = entry->conn; session->seq = entry->seq++; data.operation = PD_OP_CONNECT; @@ -1611,7 +1777,7 @@ int Communicator::request_new_conn(CommSession *session, CommTarget *target) if (mpoller_add(&data, timeout, this->mpoller) >= 0) return 0; - this->release_conn(entry); + __release_conn(entry); } return -1; @@ -1648,15 +1814,21 @@ int Communicator::request(CommSession *session, CommTarget *target) int Communicator::nonblock_listen(CommService *service) { int sockfd = service->create_listen_fd(); + int ret; if (sockfd >= 0) { if (__set_fd_nonblock(sockfd) >= 0) { - if (__bind_and_listen(sockfd, service->bind_addr, - service->addrlen) >= 0) + if (__bind_sockaddr(sockfd, service->bind_addr, + service->addrlen) >= 0) { - return sockfd; + ret = listen(sockfd, SOMAXCONN); + if (ret >= 0 || errno == EOPNOTSUPP) + { + service->reliable = (ret >= 0); + return sockfd; + } } } @@ -1669,6 +1841,7 @@ int Communicator::nonblock_listen(CommService *service) int Communicator::bind(CommService *service) { struct poller_data data; + int errno_bak = errno; int sockfd; sockfd = this->nonblock_listen(service); @@ -1676,13 +1849,25 @@ int Communicator::bind(CommService *service) { service->listen_fd = sockfd; service->ref = 1; - data.operation = PD_OP_LISTEN; data.fd = sockfd; - data.accept = Communicator::accept; data.context = service; data.result = NULL; + if (service->reliable) + { + data.operation = PD_OP_LISTEN; + data.accept = Communicator::accept; + } + else + { + data.operation = PD_OP_RECVFROM; + data.recvfrom = Communicator::recvfrom; + } + if (mpoller_add(&data, service->listen_timeout, this->mpoller) >= 0) + { + errno = errno_bak; return 0; + } close(sockfd); } @@ -1702,7 +1887,7 @@ void Communicator::unbind(CommService *service) } } -int Communicator::reply_idle_conn(CommSession *session, CommTarget *target) +int Communicator::reply_reliable(CommSession *session, CommTarget *target) { struct CommConnEntry *entry; struct list_head *pos; @@ -1734,25 +1919,85 @@ int Communicator::reply_idle_conn(CommSession *session, CommTarget *target) return ret; } +int Communicator::reply_message_unreliable(struct CommConnEntry *entry) +{ + struct iovec vectors[ENCODE_IOV_MAX]; + int cnt; + + cnt = entry->session->out->encode(vectors, ENCODE_IOV_MAX); + if ((unsigned int)cnt > ENCODE_IOV_MAX) + { + if (cnt > ENCODE_IOV_MAX) + errno = EOVERFLOW; + return -1; + } + + if (cnt > 0) + { + struct msghdr message = { + .msg_name = entry->target->addr, + .msg_namelen = entry->target->addrlen, + .msg_iov = vectors, +#ifdef __linux__ + .msg_iovlen = (size_t)cnt, +#else + .msg_iovlen = cnt, +#endif + }; + if (sendmsg(entry->sockfd, &message, 0) < 0) + return -1; + } + + return 0; +} + +int Communicator::reply_unreliable(CommSession *session, CommTarget *target) +{ + struct CommConnEntry *entry; + struct list_head *pos; + + if (!list_empty(&target->idle_list)) + { + pos = target->idle_list.next; + entry = list_entry(pos, struct CommConnEntry, list); + list_del(pos); + + session->out = session->message_out(); + if (session->out) + { + if (this->reply_message_unreliable(entry) >= 0) + return 0; + } + + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + else + errno = ENOENT; + + return -1; +} + int Communicator::reply(CommSession *session) { struct CommConnEntry *entry; - CommTarget *target; + CommServiceTarget *target; int errno_bak; int ret; if (session->passive != 1) { - errno = session->passive ? ENOENT : EPERM; + errno = session->passive ? ENOENT : EINVAL; return -1; } errno_bak = errno; session->passive = 2; - target = session->target; - ret = this->reply_idle_conn(session, target); - if (ret < 0) - return -1; + target = (CommServiceTarget *)session->target; + if (target->service->reliable) + ret = this->reply_reliable(session, target); + else + ret = this->reply_unreliable(session, target); if (ret == 0) { @@ -1760,10 +2005,12 @@ int Communicator::reply(CommSession *session) session->handle(CS_STATE_SUCCESS, 0); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { - this->release_conn(entry); - ((CommServiceTarget *)target)->decref(); + __release_conn(entry); + target->decref(); } } + else if (ret < 0) + return -1; errno = errno_bak; return 0; @@ -1777,7 +2024,7 @@ int Communicator::push(const void *buf, size_t size, CommSession *session) if (session->passive != 1) { - errno = session->passive ? ENOENT : EPERM; + errno = session->passive ? ENOENT : EINVAL; return -1; } @@ -1785,22 +2032,7 @@ int Communicator::push(const void *buf, size_t size, CommSession *session) if (!list_empty(&target->idle_list)) { entry = list_entry(target->idle_list.next, struct CommConnEntry, list); - if (!entry->ssl) - ret = write(entry->sockfd, buf, size); - else if (size == 0) - ret = 0; - else - { - ret = SSL_write(entry->ssl, buf, size); - if (ret <= 0) - { - ret = SSL_get_error(entry->ssl, ret); - if (ret != SSL_ERROR_SYSCALL) - errno = -ret; - - ret = -1; - } - } + ret = __send_to_conn(buf, size, entry); } else { @@ -1814,33 +2046,23 @@ int Communicator::push(const void *buf, size_t size, CommSession *session) int Communicator::shutdown(CommSession *session) { - CommTarget *target = session->target; - struct CommConnEntry *entry; - int ret; + CommServiceTarget *target; if (session->passive != 1) { - errno = session->passive ? ENOENT : EPERM; + errno = session->passive ? ENOENT : EINVAL; return -1; } session->passive = 2; - pthread_mutex_lock(&target->mutex); - if (!list_empty(&target->idle_list)) - { - entry = list_entry(target->idle_list.next, struct CommConnEntry, list); - list_del(&entry->list); - ret = mpoller_del(entry->sockfd, entry->mpoller); - entry->state = CONN_STATE_CLOSING; - } - else + target = (CommServiceTarget *)session->target; + if (!target->shutdown()) { errno = ENOENT; - ret = -1; + return -1; } - pthread_mutex_unlock(&target->mutex); - return ret; + return 0; } int Communicator::sleep(SleepSession *session) diff --git a/src/kernel/Communicator.h b/src/kernel/Communicator.h index 7552b539..dbbc2ba7 100644 --- a/src/kernel/Communicator.h +++ b/src/kernel/Communicator.h @@ -31,9 +31,8 @@ class CommConnection { -protected: +public: virtual ~CommConnection() { } - friend class Communicator; }; class CommTarget @@ -91,7 +90,7 @@ private: public: virtual ~CommTarget() { } - friend class CommSession; + friend class CommServiceTarget; friend class Communicator; }; @@ -223,6 +222,7 @@ private: void decref(); private: + int reliable; int listen_fd; int ref; @@ -298,16 +298,6 @@ private: int create_handler_threads(size_t handler_threads); - int nonblock_connect(CommTarget *target); - int nonblock_listen(CommService *service); - - struct CommConnEntry *launch_conn(CommSession *session, - CommTarget *target); - struct CommConnEntry *accept_conn(class CommServiceTarget *target, - CommService *service); - - void release_conn(struct CommConnEntry *entry); - void shutdown_service(CommService *service); void shutdown_io_service(IOService *service); @@ -319,10 +309,13 @@ private: int send_message(struct CommConnEntry *entry); - int request_idle_conn(CommSession *session, CommTarget *target); - int reply_idle_conn(CommSession *session, CommTarget *target); - int request_new_conn(CommSession *session, CommTarget *target); + int request_idle_conn(CommSession *session, CommTarget *target); + + int reply_message_unreliable(struct CommConnEntry *entry); + + int reply_reliable(CommSession *session, CommTarget *target); + int reply_unreliable(CommSession *session, CommTarget *target); void handle_incoming_request(struct poller_result *res); void handle_incoming_reply(struct poller_result *res); @@ -336,6 +329,8 @@ private: void handle_connect_result(struct poller_result *res); void handle_listen_result(struct poller_result *res); + void handle_recvfrom_result(struct poller_result *res); + void handle_ssl_accept_result(struct poller_result *res); void handle_sleep_result(struct poller_result *res); @@ -344,6 +339,14 @@ private: static void handler_thread_routine(void *context); + static int nonblock_connect(CommTarget *target); + static int nonblock_listen(CommService *service); + + static struct CommConnEntry *launch_conn(CommSession *session, + CommTarget *target); + static struct CommConnEntry *accept_conn(class CommServiceTarget *target, + CommService *service); + static int first_timeout(CommSession *session); static int next_timeout(CommSession *session); @@ -358,11 +361,17 @@ private: static poller_message_t *create_request(void *context); static poller_message_t *create_reply(void *context); + static int recv_request(const void *buf, size_t size, + struct CommConnEntry *entry); + static int partial_written(size_t n, void *context); static void *accept(const struct sockaddr *addr, socklen_t addrlen, int sockfd, void *context); + static void *recvfrom(const struct sockaddr *addr, socklen_t addrlen, + const void *buf, size_t size, void *context); + static void callback(struct poller_result *res, void *context); public: diff --git a/src/kernel/poller.c b/src/kernel/poller.c index c2881cf1..4da9db9d 100644 --- a/src/kernel/poller.c +++ b/src/kernel/poller.c @@ -651,6 +651,55 @@ static void __poller_handle_connect(struct __poller_node *node, poller->callback((struct poller_result *)node, poller->context); } +static void __poller_handle_recvfrom(struct __poller_node *node, + poller_t *poller) +{ + struct __poller_node *res = node->res; + struct sockaddr_storage ss; + struct sockaddr *addr = (struct sockaddr *)&ss; + socklen_t addrlen; + void *result; + ssize_t n; + + while (1) + { + addrlen = sizeof (struct sockaddr_storage); + n = recvfrom(node->data.fd, poller->buf, POLLER_BUFSIZE, 0, + addr, &addrlen); + if (n < 0) + { + if (errno == EAGAIN) + return; + else + break; + } + + result = node->data.recvfrom(addr, addrlen, poller->buf, n, + node->data.context); + if (!result) + break; + + res->data = node->data; + res->data.result = result; + res->error = 0; + res->state = PR_ST_SUCCESS; + poller->callback((struct poller_result *)res, poller->context); + + res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + node->res = res; + if (!res) + break; + } + + if (__poller_remove_node(node, poller)) + return; + + node->error = errno; + node->state = PR_ST_ERROR; + free(node->res); + poller->callback((struct poller_result *)node, poller->context); +} + static void __poller_handle_ssl_accept(struct __poller_node *node, poller_t *poller) { @@ -849,55 +898,6 @@ static void __poller_handle_notify(struct __poller_node *node, poller->callback((struct poller_result *)node, poller->context); } -static void __poller_handle_recvfrom(struct __poller_node *node, - poller_t *poller) -{ - struct __poller_node *res = node->res; - struct sockaddr_storage ss; - struct sockaddr *addr = (struct sockaddr *)&ss; - socklen_t addrlen; - void *result; - ssize_t n; - - while (1) - { - addrlen = sizeof (struct sockaddr_storage); - n = recvfrom(node->data.fd, poller->buf, POLLER_BUFSIZE, 0, - addr, &addrlen); - if (n < 0) - { - if (errno == EAGAIN) - return; - else - break; - } - - result = node->data.recvfrom(addr, addrlen, poller->buf, n, - node->data.context); - if (!result) - break; - - res->data = node->data; - res->data.result = result; - res->error = 0; - res->state = PR_ST_SUCCESS; - poller->callback((struct poller_result *)res, poller->context); - - res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); - node->res = res; - if (!res) - break; - } - - if (__poller_remove_node(node, poller)) - return; - - node->error = errno; - node->state = PR_ST_ERROR; - free(node->res); - poller->callback((struct poller_result *)node, poller->context); -} - static int __poller_handle_pipe(poller_t *poller) { struct __poller_node **node = (struct __poller_node **)poller->buf; @@ -1055,6 +1055,9 @@ static void *__poller_thread_routine(void *arg) case PD_OP_CONNECT: __poller_handle_connect(node, poller); break; + case PD_OP_RECVFROM: + __poller_handle_recvfrom(node, poller); + break; case PD_OP_SSL_ACCEPT: __poller_handle_ssl_accept(node, poller); break; @@ -1070,9 +1073,6 @@ static void *__poller_thread_routine(void *arg) case PD_OP_NOTIFY: __poller_handle_notify(node, poller); break; - case PD_OP_RECVFROM: - __poller_handle_recvfrom(node, poller); - break; } } @@ -1282,6 +1282,9 @@ static int __poller_data_get_event(int *event, const struct poller_data *data) case PD_OP_CONNECT: *event = EPOLLOUT | EPOLLET; return 0; + case PD_OP_RECVFROM: + *event = EPOLLIN | EPOLLET; + return 1; case PD_OP_SSL_ACCEPT: *event = EPOLLIN | EPOLLET; return 0; @@ -1297,9 +1300,6 @@ static int __poller_data_get_event(int *event, const struct poller_data *data) case PD_OP_NOTIFY: *event = EPOLLIN | EPOLLET; return 1; - case PD_OP_RECVFROM: - *event = EPOLLIN | EPOLLET; - return 1; default: errno = EINVAL; return -1; diff --git a/src/kernel/poller.h b/src/kernel/poller.h index 89831277..71ff70cc 100644 --- a/src/kernel/poller.h +++ b/src/kernel/poller.h @@ -40,14 +40,14 @@ struct poller_data #define PD_OP_WRITE 2 #define PD_OP_LISTEN 3 #define PD_OP_CONNECT 4 +#define PD_OP_RECVFROM 5 #define PD_OP_SSL_READ PD_OP_READ #define PD_OP_SSL_WRITE PD_OP_WRITE -#define PD_OP_SSL_ACCEPT 5 -#define PD_OP_SSL_CONNECT 6 -#define PD_OP_SSL_SHUTDOWN 7 -#define PD_OP_EVENT 8 -#define PD_OP_NOTIFY 9 -#define PD_OP_RECVFROM 10 +#define PD_OP_SSL_ACCEPT 6 +#define PD_OP_SSL_CONNECT 7 +#define PD_OP_SSL_SHUTDOWN 8 +#define PD_OP_EVENT 9 +#define PD_OP_NOTIFY 10 short operation; unsigned short iovcnt; int fd; @@ -57,10 +57,10 @@ struct poller_data poller_message_t *(*create_message)(void *); int (*partial_written)(size_t, void *); void *(*accept)(const struct sockaddr *, socklen_t, int, void *); - void *(*event)(void *); - void *(*notify)(void *, void *); void *(*recvfrom)(const struct sockaddr *, socklen_t, const void *, size_t, void *); + void *(*event)(void *); + void *(*notify)(void *, void *); }; void *context; union diff --git a/src/manager/DnsCache.cc b/src/manager/DnsCache.cc index fbafd40d..3c0cee9c 100644 --- a/src/manager/DnsCache.cc +++ b/src/manager/DnsCache.cc @@ -23,42 +23,31 @@ #define GET_CURRENT_SECOND std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() -#define CONFIDENT_INC 10 -#define TTL_INC 10 +#define TTL_INC 5 -const DnsCache::DnsHandle *DnsCache::get_inner(const HostPort& host_port, int type) +const DnsCache::DnsHandle *DnsCache::get_inner(const HostPort& host_port, + int type) { - int64_t cur_time = GET_CURRENT_SECOND; + int64_t cur = GET_CURRENT_SECOND; std::lock_guard lock(mutex_); const DnsHandle *handle = cache_pool_.get(host_port); - if (handle) + if (handle && ((type == GET_TYPE_TTL && cur > handle->value.expire_time) || + (type == GET_TYPE_CONFIDENT && cur > handle->value.confident_time))) { - switch (type) + if (!handle->value.delayed()) { - case GET_TYPE_TTL: - if (cur_time > handle->value.expire_time) - { - const_cast(handle)->value.expire_time += TTL_INC; - cache_pool_.release(handle); - return NULL; - } + DnsHandle *h = const_cast(handle); + if (type == GET_TYPE_TTL) + h->value.expire_time += TTL_INC; + else + h->value.confident_time += TTL_INC; - break; - - case GET_TYPE_CONFIDENT: - if (cur_time > handle->value.confident_time) - { - const_cast(handle)->value.confident_time += CONFIDENT_INC; - cache_pool_.release(handle); - return NULL; - } - - break; - - default: - break; + h->value.addrinfo->ai_flags |= 2; } + + cache_pool_.release(handle); + return NULL; } return handle; @@ -90,3 +79,29 @@ const DnsCache::DnsHandle *DnsCache::put(const HostPort& host_port, return cache_pool_.put(host_port, {addrinfo, confident_time, expire_time}); } +const DnsCache::DnsHandle *DnsCache::get(const DnsCache::HostPort& host_port) +{ + std::lock_guard lock(mutex_); + return cache_pool_.get(host_port); +} + +void DnsCache::release(const DnsCache::DnsHandle *handle) +{ + std::lock_guard lock(mutex_); + cache_pool_.release(handle); +} + +void DnsCache::del(const DnsCache::HostPort& key) +{ + std::lock_guard lock(mutex_); + cache_pool_.del(key); +} + +DnsCache::DnsCache() +{ +} + +DnsCache::~DnsCache() +{ +} + diff --git a/src/manager/DnsCache.h b/src/manager/DnsCache.h index 7e864de2..a3720fc3 100644 --- a/src/manager/DnsCache.h +++ b/src/manager/DnsCache.h @@ -35,6 +35,11 @@ struct DnsCacheValue struct addrinfo *addrinfo; int64_t confident_time; int64_t expire_time; + + bool delayed() const + { + return addrinfo->ai_flags & 2; + } }; // RAII: NO. Release handle by user @@ -47,27 +52,10 @@ public: using DnsHandle = LRUHandle; public: - // release handle by get/put - void release(DnsHandle *handle) - { - std::lock_guard lock(mutex_); - cache_pool_.release(handle); - } - - void release(const DnsHandle *handle) - { - std::lock_guard lock(mutex_); - cache_pool_.release(handle); - } - // get handler // Need call release when handle no longer needed //Handle *get(const KEY &key); - const DnsHandle *get(const HostPort& host_port) - { - std::lock_guard lock(mutex_); - return cache_pool_.get(host_port); - } + const DnsHandle *get(const HostPort& host_port); const DnsHandle *get(const std::string& host, unsigned short port) { @@ -132,12 +120,11 @@ public: return put(std::string(host), port, addrinfo, dns_ttl_default, dns_ttl_min); } + // release handle by get/put + void release(const DnsHandle *handle); + // delete from cache, deleter delay called when all inuse-handle release. - void del(const HostPort& key) - { - std::lock_guard lock(mutex_); - cache_pool_.del(key); - } + void del(const HostPort& key); void del(const std::string& host, unsigned short port) { @@ -161,14 +148,22 @@ private: { struct addrinfo *ai = value.addrinfo; - if (ai && (ai->ai_flags & AI_PASSIVE)) - freeaddrinfo(ai); - else - protocol::DnsUtil::freeaddrinfo(ai); + if (ai) + { + if (ai->ai_flags) + freeaddrinfo(ai); + else + protocol::DnsUtil::freeaddrinfo(ai); + } } }; LRUCache cache_pool_; + +public: + // To prevent inline calling LRUCache's constructor and deconstructor. + DnsCache(); + ~DnsCache(); }; #endif diff --git a/src/manager/EndpointParams.h b/src/manager/EndpointParams.h index aa3f8928..cd264c8f 100644 --- a/src/manager/EndpointParams.h +++ b/src/manager/EndpointParams.h @@ -20,6 +20,7 @@ #define _ENDPOINTPARAMS_H_ #include +#include "PlatformSocket.h" /** * @file EndpointParams.h @@ -37,6 +38,7 @@ enum TransportType struct EndpointParams { + int address_family; size_t max_connections; int connect_timeout; int response_timeout; @@ -46,6 +48,7 @@ struct EndpointParams static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = { +/* address_family = */ AF_INET, /* .max_connections = */ 200, /* .connect_timeout = */ 10 * 1000, /* .response_timeout = */ 10 * 1000, diff --git a/src/manager/RouteManager.cc b/src/manager/RouteManager.cc index 8a99f335..526ffea0 100644 --- a/src/manager/RouteManager.cc +++ b/src/manager/RouteManager.cc @@ -76,7 +76,7 @@ private: }; /* To support TLS SNI. */ -class RouteTargetSNI : public RouteManager::RouteTarget +class RouteTargetTCPSNI : public RouteTargetTCP { private: virtual int init_ssl(SSL *ssl) @@ -91,7 +91,27 @@ private: std::string hostname; public: - RouteTargetSNI(const std::string& name) : hostname(name) + RouteTargetTCPSNI(const std::string& name) : hostname(name) + { + } +}; + +class RouteTargetSCTPSNI : public RouteTargetSCTP +{ +private: + virtual int init_ssl(SSL *ssl) + { + if (SSL_set_tlsext_host_name(ssl, this->hostname.c_str()) > 0) + return 0; + else + return -1; + } + +private: + std::string hostname; + +public: + RouteTargetSCTPSNI(const std::string& name) : hostname(name) { } }; @@ -101,11 +121,11 @@ public: struct RouteParams { - TransportType transport_type; + enum TransportType transport_type; const struct addrinfo *addrinfo; uint64_t key; SSL_CTX *ssl_ctx; - unsigned int max_connections; + size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; @@ -120,7 +140,7 @@ public: CommSchedObject *request_object; CommSchedGroup *group; std::mutex mutex; - std::vector targets; + std::vector targets; struct list_head breaker_list; uint64_t key; int nleft; @@ -139,34 +159,35 @@ public: int init(const struct RouteParams *params); void deinit(); - void notify_unavailable(CommSchedTarget *target); - void notify_available(CommSchedTarget *target); + void notify_unavailable(RouteManager::RouteTarget *target); + void notify_available(RouteManager::RouteTarget *target); void check_breaker(); private: void free_list(); - CommSchedTarget *create_target(const struct RouteParams *params, - const struct addrinfo *addrinfo); + RouteManager::RouteTarget *create_target(const struct RouteParams *params, + const struct addrinfo *addrinfo); int add_group_targets(const struct RouteParams *params); }; struct __breaker_node { - CommSchedTarget *target; + RouteManager::RouteTarget *target; int64_t timeout; struct list_head breaker_list; }; -CommSchedTarget *RouteResultEntry::create_target(const struct RouteParams *params, - const struct addrinfo *addr) +RouteManager::RouteTarget * +RouteResultEntry::create_target(const struct RouteParams *params, + const struct addrinfo *addr) { - CommSchedTarget *target; + RouteManager::RouteTarget *target; switch (params->transport_type) { case TT_TCP_SSL: if (params->use_tls_sni) - target = new RouteTargetSNI(params->hostname); + target = new RouteTargetTCPSNI(params->hostname); else case TT_TCP: target = new RouteTargetTCP(); @@ -174,16 +195,19 @@ CommSchedTarget *RouteResultEntry::create_target(const struct RouteParams *param case TT_UDP: target = new RouteTargetUDP(); break; - case TT_SCTP: case TT_SCTP_SSL: - target = new RouteTargetSCTP(); + if (params->use_tls_sni) + target = new RouteTargetSCTPSNI(params->hostname); + else + case TT_SCTP: + target = new RouteTargetSCTP(); break; default: errno = EINVAL; return NULL; } - if (target->init(addr->ai_addr, (socklen_t)addr->ai_addrlen, params->ssl_ctx, + if (target->init(addr->ai_addr, addr->ai_addrlen, params->ssl_ctx, params->connect_timeout, params->ssl_connect_timeout, params->response_timeout, params->max_connections) < 0) { @@ -197,7 +221,7 @@ CommSchedTarget *RouteResultEntry::create_target(const struct RouteParams *param int RouteResultEntry::init(const struct RouteParams *params) { const struct addrinfo *addr = params->addrinfo; - CommSchedTarget *target; + RouteManager::RouteTarget *target; if (addr == NULL)//0 { @@ -238,8 +262,8 @@ int RouteResultEntry::init(const struct RouteParams *params) int RouteResultEntry::add_group_targets(const struct RouteParams *params) { + RouteManager::RouteTarget *target; const struct addrinfo *addr; - CommSchedTarget *target; for (addr = params->addrinfo; addr; addr = addr->ai_next) { @@ -298,7 +322,7 @@ void RouteResultEntry::deinit() } } -void RouteResultEntry::notify_unavailable(CommSchedTarget *target) +void RouteResultEntry::notify_unavailable(RouteManager::RouteTarget *target) { if (this->targets.size() <= 1) return; @@ -324,7 +348,7 @@ void RouteResultEntry::notify_unavailable(CommSchedTarget *target) this->nleft--; } -void RouteResultEntry::notify_available(CommSchedTarget *target) +void RouteResultEntry::notify_available(RouteManager::RouteTarget *target) { if (this->targets.size() <= 1 || this->nbreak == 0) return; @@ -396,23 +420,26 @@ static uint64_t __fnv_hash(const unsigned char *data, size_t size) return hash; } -static uint64_t __generate_key(TransportType type, +static uint64_t __generate_key(enum TransportType type, const struct addrinfo *addrinfo, const std::string& other_info, const struct EndpointParams *ep_params, - const std::string& hostname) + const std::string& hostname, + SSL_CTX *ssl_ctx) { - std::string buf((const char *)&type, sizeof (TransportType)); - unsigned int max_conn = ep_params->max_connections; + const int params[] = { + ep_params->address_family, (int)ep_params->max_connections, + ep_params->connect_timeout, ep_params->response_timeout + }; + std::string buf((const char *)&type, sizeof (enum TransportType)); if (!other_info.empty()) buf += other_info; - buf.append((const char *)&max_conn, sizeof (unsigned int)); - buf.append((const char *)&ep_params->connect_timeout, sizeof (int)); - buf.append((const char *)&ep_params->response_timeout, sizeof (int)); - if (type == TT_TCP_SSL) + buf.append((const char *)params, sizeof params); + if (type == TT_TCP_SSL || type == TT_SCTP_SSL) { + buf.append((const char *)&ssl_ctx, sizeof (void *)); buf.append((const char *)&ep_params->ssl_connect_timeout, sizeof (int)); if (ep_params->use_tls_sni) { @@ -462,15 +489,24 @@ RouteManager::~RouteManager() } } -int RouteManager::get(TransportType type, +int RouteManager::get(enum TransportType type, const struct addrinfo *addrinfo, const std::string& other_info, const struct EndpointParams *ep_params, - const std::string& hostname, + const std::string& hostname, SSL_CTX *ssl_ctx, RouteResult& result) { - uint64_t key = __generate_key(type, addrinfo, other_info, - ep_params, hostname); + if (type == TT_TCP_SSL || type == TT_SCTP_SSL) + { + static SSL_CTX *global_client_ctx = WFGlobal::get_ssl_client_ctx(); + if (ssl_ctx == NULL) + ssl_ctx = global_client_ctx; + } + else + ssl_ctx = NULL; + + uint64_t key = __generate_key(type, addrinfo, other_info, ep_params, + hostname, ssl_ctx); struct rb_node **p = &cache_.rb_node; struct rb_node *parent = NULL; RouteResultEntry *bound = NULL; @@ -497,37 +533,19 @@ int RouteManager::get(TransportType type, } else { - int ssl_connect_timeout = 0; - SSL_CTX *ssl_ctx = NULL; - - if (type == TT_TCP_SSL || type == TT_SCTP_SSL) - { - static SSL_CTX *client_ssl_ctx = WFGlobal::get_ssl_client_ctx(); - - ssl_ctx = client_ssl_ctx; - ssl_connect_timeout = ep_params->ssl_connect_timeout; - } - struct RouteParams params = { - /* .transport_type = */ type, - /* .addrinfo = */ addrinfo, - /* .key = */ key, - /* .ssl_ctx = */ ssl_ctx, - /* .max_connections = */ (unsigned int)ep_params->max_connections, - /* .connect_timeout = */ ep_params->connect_timeout, - /* .response_timeout = */ ep_params->response_timeout, - /* .ssl_connect_timeout = */ ssl_connect_timeout, - /* .use_tls_sni = */ ep_params->use_tls_sni, - /* .hostname = */ hostname, + /*.transport_type =*/ type, + /*.addrinfo =*/ addrinfo, + /*.key =*/ key, + /*.ssl_ctx =*/ ssl_ctx, + /*.max_connections =*/ ep_params->max_connections, + /*.connect_timeout =*/ ep_params->connect_timeout, + /*.response_timeout =*/ ep_params->response_timeout, + /*.ssl_connect_timeout =*/ ep_params->ssl_connect_timeout, + /*.use_tls_sni =*/ ep_params->use_tls_sni, + /*.hostname =*/ hostname, }; - if (StringUtil::start_with(other_info, "?maxconn=")) - { - int maxconn = atoi(other_info.c_str() + 9); - if (maxconn > 0) - params.max_connections = maxconn; - } - entry = new RouteResultEntry; if (entry->init(¶ms) >= 0) { @@ -549,12 +567,12 @@ int RouteManager::get(TransportType type, void RouteManager::notify_unavailable(void *cookie, CommTarget *target) { if (cookie && target) - ((RouteResultEntry *)cookie)->notify_unavailable((CommSchedTarget *)target); + ((RouteResultEntry *)cookie)->notify_unavailable((RouteTarget *)target); } void RouteManager::notify_available(void *cookie, CommTarget *target) { if (cookie && target) - ((RouteResultEntry *)cookie)->notify_available((CommSchedTarget *)target); + ((RouteResultEntry *)cookie)->notify_available((RouteTarget *)target); } diff --git a/src/manager/RouteManager.h b/src/manager/RouteManager.h index 2271f787..3265196d 100644 --- a/src/manager/RouteManager.h +++ b/src/manager/RouteManager.h @@ -43,6 +43,32 @@ public: class RouteTarget : public CommSchedTarget { +#if OPENSSL_VERSION_NUMBER >= 0x10100000L + public: + int init(const struct sockaddr *addr, socklen_t addrlen, SSL_CTX *ssl_ctx, + int connect_timeout, int ssl_connect_timeout, int response_timeout, + size_t max_connections) + { + int ret = this->CommSchedTarget::init(addr, addrlen, ssl_ctx, + connect_timeout, ssl_connect_timeout, + response_timeout, max_connections); + + if (ret >= 0 && ssl_ctx) + SSL_CTX_up_ref(ssl_ctx); + + return ret; + } + + void deinit() + { + SSL_CTX *ssl_ctx = this->get_ssl_ctx(); + + this->CommSchedTarget::deinit(); + if (ssl_ctx) + SSL_CTX_free(ssl_ctx); + } +#endif + public: int state; @@ -57,11 +83,11 @@ public: }; public: - int get(TransportType type, + int get(enum TransportType type, const struct addrinfo *addrinfo, const std::string& other_info, const struct EndpointParams *ep_params, - const std::string& hostname, + const std::string& hostname, SSL_CTX *ssl_ctx, RouteResult& result); RouteManager() diff --git a/src/nameservice/WFDnsResolver.cc b/src/nameservice/WFDnsResolver.cc index ea85bdbc..79cb51d2 100644 --- a/src/nameservice/WFDnsResolver.cc +++ b/src/nameservice/WFDnsResolver.cc @@ -40,35 +40,21 @@ #define HOSTS_LINEBUF_INIT_SIZE 128 #define PORT_STR_MAX 5 -static constexpr struct addrinfo __ai_hints = -{ -#ifdef AI_ADDRCONFIG - /*.ai_flags = */ AI_ADDRCONFIG, -#else - /*.ai_flags = */ 0, -#endif - /*.ai_family = */ AF_UNSPEC, - /*.ai_socktype = */ SOCK_STREAM, - /*.ai_protocol = */ 0, - /*.ai_addrlen = */ 0, - /*.ai_addr = */ NULL, - /*.ai_canonname = */ NULL, - /*.ai_next = */ NULL -}; - class DnsInput { public: DnsInput() : port_(0), - numeric_host_(false) + numeric_host_(false), + family_(AF_UNSPEC) {} DnsInput(const std::string& host, unsigned short port, - bool numeric_host) : + bool numeric_host, int family) : host_(host), port_(port), - numeric_host_(numeric_host) + numeric_host_(numeric_host), + family_(family) {} void reset(const std::string& host, unsigned short port) @@ -76,14 +62,16 @@ public: host_.assign(host); port_ = port; numeric_host_ = false; + family_ = AF_UNSPEC; } void reset(const std::string& host, unsigned short port, - bool numeric_host) + bool numeric_host, int family) { host_.assign(host); port_ = port; numeric_host_ = numeric_host; + family_ = family; } const std::string& get_host() const { return host_; } @@ -94,6 +82,7 @@ protected: std::string host_; unsigned short port_; bool numeric_host_; + int family_; friend class DnsRoutine; }; @@ -109,7 +98,12 @@ public: ~DnsOutput() { if (addrinfo_) - freeaddrinfo(addrinfo_); + { + if (addrinfo_->ai_flags) + freeaddrinfo(addrinfo_); + else + free(addrinfo_); + } } int get_error() const { return error_; } @@ -137,7 +131,12 @@ public: static void create(DnsOutput *out, int error, struct addrinfo *ai) { if (out->addrinfo_) - freeaddrinfo(out->addrinfo_); + { + if (out->addrinfo_->ai_flags) + freeaddrinfo(out->addrinfo_); + else + free(out->addrinfo_); + } out->error_ = error; out->addrinfo_ = ai; @@ -146,21 +145,21 @@ public: void DnsRoutine::run(const DnsInput *in, DnsOutput *out) { - if (!in->host_.empty() && in->host_[0] == '/') - return; - - struct addrinfo hints = __ai_hints; + struct addrinfo hints = { + /*.ai_flags =*/ AI_ADDRCONFIG | AI_NUMERICSERV, + /*.ai_family =*/ in->family_, + /*.ai_socktype =*/ SOCK_STREAM, + }; char port_str[PORT_STR_MAX + 1]; - hints.ai_flags |= AI_NUMERICSERV; if (in->is_numeric_host()) hints.ai_flags |= AI_NUMERICHOST; snprintf(port_str, PORT_STR_MAX + 1, "%u", in->port_); - out->error_ = getaddrinfo(in->host_.c_str(), - port_str, - &hints, - &out->addrinfo_); + out->error_ = getaddrinfo(in->host_.c_str(), port_str, + &hints, &out->addrinfo_); + if (out->error_ == 0) + out->addrinfo_->ai_flags = 1; } // Dns Thread task. For internal usage only. @@ -178,13 +177,18 @@ struct DnsContext static int __default_family() { + struct addrinfo hints = { + /*.ai_flags =*/ AI_ADDRCONFIG, + /*.ai_family =*/ AF_UNSPEC, + /*.ai_socktype =*/ SOCK_STREAM, + }; struct addrinfo *res; struct addrinfo *cur; int family = AF_UNSPEC; bool v4 = false; bool v6 = false; - if (getaddrinfo(NULL, "1", &__ai_hints, &res) == 0) + if (getaddrinfo(NULL, "1", &hints, &res) == 0) { for (cur = res; cur; cur = cur->ai_next) { @@ -294,18 +298,9 @@ static int __readaddrinfo(const char *path, return ret; } -// Add AI_PASSIVE to point that this addrinfo is alloced by getaddrinfo -static void __add_passive_flags(struct addrinfo *ai) -{ - while (ai) - { - ai->ai_flags |= AI_PASSIVE; - ai = ai->ai_next; - } -} - static ThreadDnsTask *__create_thread_dns_task(const std::string& host, unsigned short port, + int family, thread_dns_callback_t callback) { auto *task = WFThreadTaskFactory:: @@ -314,12 +309,46 @@ static ThreadDnsTask *__create_thread_dns_task(const std::string& host, DnsRoutine::run, std::move(callback)); - task->get_input()->reset(host, port); + task->get_input()->reset(host, port, false, family); return task; } +static std::string __get_cache_host(const std::string& hostname, + int family) +{ + char c; + + if (family == AF_UNSPEC) + c = '*'; + else if (family == AF_INET) + c = '4'; + else if (family == AF_INET6) + c = '6'; + else + c = '?'; + + return hostname + c; +} + +static std::string __get_guard_name(const std::string& cache_host, + unsigned short port) +{ + std::string guard_name("INTERNAL-dns:"); + guard_name.append(cache_host).append(":"); + guard_name.append(std::to_string(port)); + return guard_name; +} + void WFResolverTask::dispatch() { + if (this->msg_) + { + this->state = WFT_STATE_DNS_ERROR; + this->error = (intptr_t)msg_; + this->subtask_done(); + return; + } + const ParsedURI& uri = ns_params_.uri; host_ = uri.host ? uri.host : ""; port_ = uri.port ? atoi(uri.port) : 0; @@ -327,11 +356,22 @@ void WFResolverTask::dispatch() DnsCache *dns_cache = WFGlobal::get_dns_cache(); const DnsCache::DnsHandle *addr_handle; std::string hostname = host_; + int family = ep_params_.address_family; + std::string cache_host = __get_cache_host(hostname, family); if (ns_params_.retry_times == 0) - addr_handle = dns_cache->get_ttl(hostname, port_); + addr_handle = dns_cache->get_ttl(cache_host, port_); else - addr_handle = dns_cache->get_confident(hostname, port_); + addr_handle = dns_cache->get_confident(cache_host, port_); + + if (in_guard_ && (addr_handle == NULL || addr_handle->value.delayed())) + { + if (addr_handle) + dns_cache->release(addr_handle); + + this->request_dns(); + return; + } if (addr_handle) { @@ -347,7 +387,8 @@ void WFResolverTask::dispatch() } if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info, - &ep_params_, hostname, this->result) < 0) + &ep_params_, hostname, ns_params_.ssl_ctx, + this->result) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; @@ -378,11 +419,11 @@ void WFResolverTask::dispatch() if (ret == 1) { - DnsInput dns_in(hostname, port_, true); // 'true' means numeric host + // 'true' means numeric host + DnsInput dns_in(hostname, port_, true, AF_UNSPEC); DnsOutput dns_out; DnsRoutine::run(&dns_in, &dns_out); - __add_passive_flags((struct addrinfo *)dns_out.get_addrinfo()); dns_callback_internal(&dns_out, (unsigned int)-1, (unsigned int)-1); this->subtask_done(); return; @@ -392,32 +433,53 @@ void WFResolverTask::dispatch() const char *hosts = WFGlobal::get_global_settings()->hosts_path; if (hosts) { + struct addrinfo hints = { + /*.ai_flags =*/ AI_ADDRCONFIG | AI_NUMERICSERV | AI_NUMERICHOST, + /*.ai_family =*/ ep_params_.address_family, + /*.ai_socktype =*/ SOCK_STREAM, + }; struct addrinfo *ai; - int ret = __readaddrinfo(hosts, host_, port_, &__ai_hints, &ai); + int ret; + ret = __readaddrinfo(hosts, host_, port_, &hints, &ai); if (ret == 0) { DnsOutput out; DnsRoutine::create(&out, ret, ai); - __add_passive_flags((struct addrinfo *)out.get_addrinfo()); dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_); this->subtask_done(); return; } } + std::string guard_name = __get_guard_name(cache_host, port_); + WFConditional *guard = WFTaskFactory::create_guard(guard_name, this, &msg_); + + in_guard_ = true; + has_next_ = true; + + series_of(this)->push_front(guard); + this->subtask_done(); +} + +void WFResolverTask::request_dns() +{ WFDnsClient *client = WFGlobal::get_dns_client(); if (client) { - static int family = __default_family(); + static int default_family = __default_family(); WFResourcePool *respool = WFGlobal::get_dns_respool(); + int family = ep_params_.address_family; + if (family == AF_UNSPEC) + family = default_family; + if (family == AF_INET || family == AF_INET6) { auto&& cb = std::bind(&WFResolverTask::dns_single_callback, this, std::placeholders::_1); - WFDnsTask *dns_task = client->create_dns_task(hostname, std::move(cb)); + WFDnsTask *dns_task = client->create_dns_task(host_, std::move(cb)); if (family == AF_INET6) dns_task->get_req()->set_question_type(DNS_TYPE_AAAA); @@ -437,10 +499,10 @@ void WFResolverTask::dispatch() dctx[0].port = port_; dctx[1].port = port_; - task_v4 = client->create_dns_task(hostname, dns_partial_callback); + task_v4 = client->create_dns_task(host_, dns_partial_callback); task_v4->user_data = dctx; - task_v6 = client->create_dns_task(hostname, dns_partial_callback); + task_v6 = client->create_dns_task(host_, dns_partial_callback); task_v6->get_req()->set_question_type(DNS_TYPE_AAAA); task_v6->user_data = dctx + 1; @@ -461,11 +523,13 @@ void WFResolverTask::dispatch() } else { + ThreadDnsTask *dns_task; auto&& cb = std::bind(&WFResolverTask::thread_dns_callback, this, std::placeholders::_1); - ThreadDnsTask *dns_task = __create_thread_dns_task(hostname, port_, - std::move(cb)); + dns_task = __create_thread_dns_task(host_, port_, + ep_params_.address_family, + std::move(cb)); series_of(this)->push_front(dns_task); } @@ -478,12 +542,7 @@ SubTask *WFResolverTask::done() SeriesWork *series = series_of(this); if (!has_next_) - { - if (this->callback) - this->callback(this); - - delete this; - } + task_callback(); else has_next_ = false; @@ -499,7 +558,7 @@ void WFResolverTask::dns_callback_internal(void *thrd_dns_output, if (dns_error) { - if (dns_error == /*EAI_SYSTEM*/ EAI_FAIL) + if (dns_error == EAI_FAIL) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; @@ -517,12 +576,15 @@ void WFResolverTask::dns_callback_internal(void *thrd_dns_output, struct addrinfo *addrinfo = dns_out->move_addrinfo(); const DnsCache::DnsHandle *addr_handle; std::string hostname = host_; + int family = ep_params_.address_family; + std::string cache_host = __get_cache_host(hostname, family); - addr_handle = dns_cache->put(hostname, port_, addrinfo, + addr_handle = dns_cache->put(cache_host, port_, addrinfo, (unsigned int)ttl_default, (unsigned int)ttl_min); if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info, - &ep_params_, hostname, this->result) < 0) + &ep_params_, hostname, ns_params_.ssl_ctx, + this->result) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; @@ -555,10 +617,7 @@ void WFResolverTask::dns_single_callback(void *net_dns_task) this->error = dns_task->get_error(); } - if (this->callback) - this->callback(this); - - delete this; + task_callback(); } void WFResolverTask::dns_partial_callback(void *net_dns_task) @@ -618,10 +677,7 @@ void WFResolverTask::dns_parallel_callback(const void *parallel) delete[] c4; - if (this->callback) - this->callback(this); - - delete this; + task_callback(); } void WFResolverTask::thread_dns_callback(void *thrd_dns_task) @@ -631,7 +687,6 @@ void WFResolverTask::thread_dns_callback(void *thrd_dns_task) if (dns_task->get_state() == WFT_STATE_SUCCESS) { DnsOutput *out = dns_task->get_output(); - __add_passive_flags((struct addrinfo *)out->get_addrinfo()); dns_callback_internal(out, dns_ttl_default_, dns_ttl_min_); } else @@ -640,6 +695,23 @@ void WFResolverTask::thread_dns_callback(void *thrd_dns_task) this->error = dns_task->get_error(); } + task_callback(); +} + +void WFResolverTask::task_callback() +{ + if (in_guard_) + { + int family = ep_params_.address_family; + std::string cache_host = __get_cache_host(host_, family); + std::string guard_name = __get_guard_name(cache_host, port_); + + if (this->state == WFT_STATE_DNS_ERROR) + msg_ = (void *)(intptr_t)this->error; + + WFTaskFactory::release_guard_safe(guard_name, msg_); + } + if (this->callback) this->callback(this); diff --git a/src/nameservice/WFDnsResolver.h b/src/nameservice/WFDnsResolver.h index d0272141..4584d96a 100644 --- a/src/nameservice/WFDnsResolver.h +++ b/src/nameservice/WFDnsResolver.h @@ -35,9 +35,14 @@ public: ns_params_(*ns_params), ep_params_(*ep_params) { + if (ns_params_.fixed_conn) + ep_params_.max_connections = 1; + dns_ttl_default_ = dns_ttl_default; dns_ttl_min_ = dns_ttl_min; has_next_ = false; + in_guard_ = false; + msg_ = NULL; } WFResolverTask(const struct WFNSParams *ns_params, @@ -45,7 +50,12 @@ public: WFRouterTask(std::move(cb)), ns_params_(*ns_params) { + if (ns_params_.fixed_conn) + ep_params_.max_connections = 1; + has_next_ = false; + in_guard_ = false; + msg_ = NULL; } protected: @@ -62,6 +72,9 @@ private: unsigned int ttl_default, unsigned int ttl_min); + void request_dns(); + void task_callback(); + protected: struct WFNSParams ns_params_; unsigned int dns_ttl_default_; @@ -72,6 +85,8 @@ private: const char *host_; unsigned short port_; bool has_next_; + bool in_guard_; + void *msg_; }; class WFDnsResolver : public WFNSPolicy diff --git a/src/nameservice/WFNameService.h b/src/nameservice/WFNameService.h index a6ef8e3f..268f46dc 100644 --- a/src/nameservice/WFNameService.h +++ b/src/nameservice/WFNameService.h @@ -74,10 +74,12 @@ public: struct WFNSParams { - TransportType type; + enum TransportType type; ParsedURI& uri; const char *info; + SSL_CTX *ssl_ctx; bool fixed_addr; + bool fixed_conn; int retry_times; WFNSTracing *tracing; }; diff --git a/src/protocol/dns_parser.c b/src/protocol/dns_parser.c index 901ae203..ab3bdb7c 100644 --- a/src/protocol/dns_parser.c +++ b/src/protocol/dns_parser.c @@ -25,7 +25,6 @@ #define DNS_LABELS_MAX 63 #define DNS_NAMES_MAX 256 #define DNS_MSGBASE_INIT_SIZE 514 // 512 + 2(leading length) -#define DNS_HEADER_SIZE sizeof (struct dns_header) #define MAX(x, y) ((x) <= (y) ? (y) : (x)) struct __dns_record_entry @@ -102,11 +101,16 @@ static int __dns_parser_parse_host(char *phost, dns_parser_t *parser) else if ((len & 0xC0) == 0xC0) { pointer = __dns_parser_uint16(*cur) & 0x3FFF; - *cur += 2; if (pointer >= parser->msgsize) return -2; + // pointer must point to a prior position + if ((const char *)parser->msgbase + pointer >= *cur) + return -2; + + *cur += 2; + // backup cur only when the first pointer occurs if (curbackup == NULL) curbackup = *cur; @@ -707,7 +711,7 @@ void dns_parser_init(dns_parser_t *parser) parser->bufsize = 0; parser->complete = 0; parser->single_packet = 0; - memset(&parser->header, 0, DNS_HEADER_SIZE); + memset(&parser->header, 0, sizeof (struct dns_header)); memset(&parser->question, 0, sizeof (struct dns_question)); INIT_LIST_HEAD(&parser->answer_list); INIT_LIST_HEAD(&parser->authority_list); @@ -770,16 +774,16 @@ int dns_parser_parse_all(dns_parser_t *parser) parser->cur = (const char *)parser->msgbase; h = &parser->header; - if (parser->msgsize < DNS_HEADER_SIZE) + if (parser->msgsize < sizeof (struct dns_header)) return -2; - memcpy(h, parser->msgbase, DNS_HEADER_SIZE); + memcpy(h, parser->msgbase, sizeof (struct dns_header)); h->id = ntohs(h->id); h->qdcount = ntohs(h->qdcount); h->ancount = ntohs(h->ancount); h->nscount = ntohs(h->nscount); h->arcount = ntohs(h->arcount); - parser->cur += DNS_HEADER_SIZE; + parser->cur += sizeof (struct dns_header); ret = __dns_parser_parse_question(parser); if (ret < 0) @@ -911,6 +915,186 @@ int dns_record_cursor_find_cname(const char *name, return 1; } +int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, uint16_t rlen, const void *rdata, + struct list_head *list) +{ + struct __dns_record_entry *entry; + size_t entry_size = sizeof (struct __dns_record_entry) + rlen; + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.name = strdup(name); + if (!entry->record.name) + { + free(entry); + return -1; + } + + entry->record.type = type; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + entry->record.rdlength = rlen; + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, rdata, rlen); + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, const char *rdata, + struct list_head *list) +{ + size_t rlen = strlen(rdata); + // record.rdlength has no meaning for parsed record types, ignore its + // correctness, same for soa/srv/mx record + return dns_add_raw_record(name, type, rclass, ttl, rlen+1, rdata, list); +} + +int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl, + const char *mname, const char *rname, + uint32_t serial, int32_t refresh, + int32_t retry, int32_t expire, uint32_t minimum, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_soa *soa; + size_t entry_size; + char *pname, *pmname, *prname; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_soa); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + soa = (struct dns_record_soa *)(entry->record.rdata); + + pname = strdup(name); + pmname = strdup(mname); + prname = strdup(rname); + + if (!pname || !pmname || !prname) + { + free(pname); + free(pmname); + free(prname); + free(entry); + return -1; + } + + soa->mname = pmname; + soa->rname = prname; + soa->serial = serial; + soa->refresh = refresh; + soa->retry = retry; + soa->expire = expire; + soa->minimum = minimum; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_SOA; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl, + uint16_t priority, uint16_t weight, + uint16_t port, const char *target, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_srv *srv; + size_t entry_size; + char *pname, *ptarget; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_srv); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + srv = (struct dns_record_srv *)(entry->record.rdata); + + pname = strdup(name); + ptarget = strdup(target); + + if (!pname || !ptarget) + { + free(pname); + free(ptarget); + free(entry); + return -1; + } + + srv->priority = priority; + srv->weight = weight; + srv->port = port; + srv->target = ptarget; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_SRV; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl, + int16_t preference, const char *exchange, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_mx *mx; + size_t entry_size; + char *pname, *pexchange; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_mx); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + mx = (struct dns_record_mx *)(entry->record.rdata); + + pname = strdup(name); + pexchange = strdup(exchange); + + if (!pname || !pexchange) + { + free(pname); + free(pexchange); + free(entry); + return -1; + } + + mx->preference = preference; + mx->exchange = pexchange; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_MX; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + const char *dns_type2str(int type) { switch (type) diff --git a/src/protocol/dns_parser.h b/src/protocol/dns_parser.h index 892af6c7..e950767d 100644 --- a/src/protocol/dns_parser.h +++ b/src/protocol/dns_parser.h @@ -78,11 +78,19 @@ enum DNS_RCODE_REFUSED }; +enum +{ + DNS_ANSWER_SECTION = 1, + DNS_AUTHORITY_SECTION = 2, + DNS_ADDITIONAL_SECTION = 3, +}; + /** * dns_header_t is a struct to describe the header of a dns * request or response packet, but the byte order is not * transformed. */ +#pragma pack(1) struct dns_header { uint16_t id; @@ -112,6 +120,7 @@ struct dns_header uint16_t nscount; uint16_t arcount; }; +#pragma pack() struct dns_question { @@ -205,6 +214,29 @@ int dns_record_cursor_find_cname(const char *name, const char **cname, dns_record_cursor_t *cursor); +int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, uint16_t rlen, const void *rdata, + struct list_head *list); + +int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, const char *rdata, + struct list_head *list); + +int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl, + const char *mname, const char *rname, + uint32_t serial, int32_t refresh, + int32_t retry, int32_t expire, uint32_t minimum, + struct list_head *list); + +int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl, + uint16_t priority, uint16_t weight, + uint16_t port, const char *target, + struct list_head *list); + +int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl, + int16_t preference, const char *exchange, + struct list_head *list); + const char *dns_type2str(int type); const char *dns_class2str(int dnsclass); const char *dns_opcode2str(int opcode);