diff --git a/docs/en/tutorial-05-http_proxy.md b/docs/en/tutorial-05-http_proxy.md index 719052fe..198f4092 100644 --- a/docs/en/tutorial-05-http_proxy.md +++ b/docs/en/tutorial-05-http_proxy.md @@ -41,6 +41,7 @@ In [WFHttpServer.h](/src/server/WFHttpServer.h), the default parameters for an H ~~~cpp static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = { + .transport_type = TT_TCP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, @@ -49,7 +50,7 @@ static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = .ssl_accept_timeout = 10 * 1000, }; ~~~ - +**transport\_type**: the transport layer protocol. Besides the default type TT_TCP, you may specify TT_UDP, or TT_SCTP on Linux platform. **max\_connections**: the maximum number of connections is 2000. When it is exceeded, the least recently used keep-alive connection will be closed. If there is no keep-alive connection, the server will refuse new connections. **peer\_response\_timeout**: set the maximum duration for reading or sending out a block of data. The default setting is 10 seconds. **receive\_timeout**: set the maximum duration for receiving a complete request; -1 means unlimited time. diff --git a/docs/en/tutorial-10-user_defined_protocol.md b/docs/en/tutorial-10-user_defined_protocol.md index ce849825..b2972157 100644 --- a/docs/en/tutorial-10-user_defined_protocol.md +++ b/docs/en/tutorial-10-user_defined_protocol.md @@ -48,7 +48,6 @@ private: * For the definition of **struct iovec**, please see the system calls **readv** or **writev**. * Normally the return value of the encode function is between 0 and max, indicating how many vector are used in the message. * In case of UDP protocol, please note that the total length must not be more than 64k, and no more than 1024 vectors are used (in Linux, writev writes only 1024 vectors at one time). - * UDP protocol can only be used for a client, and UDP server cannot be realized. * The encode -1 indicates errors. To return -1, you need to set errno. If the return value is > max, you will get an EOVERFLOW error. All errors are obtained in the callback. * For performance reasons, the content pointed to by the iov\_base pointer in the vector will not be copied. So it generally points to the member of the message class. diff --git a/docs/tutorial-05-http_proxy.md b/docs/tutorial-05-http_proxy.md index 438776f6..862a1a02 100644 --- a/docs/tutorial-05-http_proxy.md +++ b/docs/tutorial-05-http_proxy.md @@ -39,6 +39,7 @@ int main(int argc, char *argv[]) ~~~cpp static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = { + .transport_type = TT_TCP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, @@ -47,6 +48,7 @@ static constexpr struct WFServerParams HTTP_SERVER_PARAMS_DEFAULT = .ssl_accept_timeout = 10 * 1000, }; ~~~ +transport_type:传输层协议,默认为TCP。除了TT_TCP外,可选择的还有TT_UDP和Linux下支持的TT_SCTP。 max_connections:最大连接数2000,达到上限之后会关闭最久未使用的keep-alive连接。没找到keep-alive连接,则拒绝新连接。 peer_response_timeout:每读取到一块数据或发送出一块数据的超时时间为10秒。 receive_timeout:接收一条完整的请求超时时间为-1,无限。 diff --git a/docs/tutorial-10-user_defined_protocol.md b/docs/tutorial-10-user_defined_protocol.md index 821b562d..768124fb 100644 --- a/docs/tutorial-10-user_defined_protocol.md +++ b/docs/tutorial-10-user_defined_protocol.md @@ -45,7 +45,6 @@ private: * 结构体struct iovec定义在请参考系统调用readv和writev。 * encode函数正确情况下的返回值在0到max之间,表示消息使用了多少个vector。 * 如果是UDP协议,请注意总长度不超过64k,并且使用不超过1024个vector(Linux一次writev只能1024个vector)。 - * UDP协议只能用于client,无法实现UDP server。 * encode返回-1表示错误。返回-1时,需要置errno。如果返回值>max,将得到一个EOVERFLOW错误。错误都在callback里得到。 * 为了性能考虑vector里的iov_base指针指向的内容不会被复制。所以一般指向消息类的成员。 diff --git a/src/factory/DnsTaskImpl.cc b/src/factory/DnsTaskImpl.cc index 9495e543..0dc54af7 100644 --- a/src/factory/DnsTaskImpl.cc +++ b/src/factory/DnsTaskImpl.cc @@ -18,9 +18,10 @@ #include #include +#include "DnsMessage.h" #include "WFTaskError.h" #include "WFTaskFactory.h" -#include "DnsMessage.h" +#include "WFServer.h" using namespace protocol; @@ -182,3 +183,42 @@ WFDnsTask *WFTaskFactory::create_dns_task(const ParsedURI& uri, return task; } + +/**********Server**********/ + +class WFDnsServerTask : public WFServerTask +{ +public: + WFDnsServerTask(CommService *service, + std::function& proc) : + WFServerTask(service, WFGlobal::get_scheduler(), proc) + { + auto *server = (WFServer *)service; + this->type = server->get_params()->transport_type; + } + +protected: + virtual CommMessageIn *message_in() + { + this->get_req()->set_single_packet(this->type == TT_UDP); + return this->WFServerTask::message_in(); + } + + virtual CommMessageOut *message_out() + { + this->get_resp()->set_single_packet(this->type == TT_UDP); + return this->WFServerTask::message_out(); + } + +protected: + enum TransportType type; +}; + +/**********Server Factory**********/ + +WFDnsTask *WFServerTaskFactory::create_dns_task(CommService *service, + std::function& proc) +{ + return new WFDnsServerTask(service, proc); +} + diff --git a/src/factory/WFTaskFactory.inl b/src/factory/WFTaskFactory.inl index acb1e162..e6a9a8cf 100644 --- a/src/factory/WFTaskFactory.inl +++ b/src/factory/WFTaskFactory.inl @@ -547,6 +547,9 @@ WFNetworkTaskFactory::create_server_task(CommService *service, class WFServerTaskFactory { public: + static WFDnsTask *create_dns_task(CommService *service, + std::function& proc); + static WFHttpTask *create_http_task(CommService *service, std::function& proc) { diff --git a/src/kernel/Communicator.cc b/src/kernel/Communicator.cc index 7123c7bd..5735f10e 100644 --- a/src/kernel/Communicator.cc +++ b/src/kernel/Communicator.cc @@ -71,8 +71,8 @@ static inline int __set_fd_nonblock(int fd) return flags; } -static int __bind_and_listen(int sockfd, const struct sockaddr *addr, - socklen_t addrlen) +static int __bind_sockaddr(int sockfd, const struct sockaddr *addr, + socklen_t addrlen) { struct sockaddr_storage ss; socklen_t len; @@ -94,7 +94,17 @@ static int __bind_and_listen(int sockfd, const struct sockaddr *addr, return -1; } - return listen(sockfd, SOMAXCONN < 4096 ? 4096 : SOMAXCONN); + return 0; +} + +static void __release_conn(struct CommConnEntry *entry) +{ + delete entry->conn; + if (!entry->service) + pthread_mutex_destroy(&entry->mutex); + + close(entry->sockfd); + free(entry); } int CommTarget::init(const struct sockaddr *addr, socklen_t addrlen, @@ -132,6 +142,17 @@ void CommTarget::deinit() int CommMessageIn::feedback(const void *buf, size_t size) { struct CommConnEntry *entry = this->entry; + CommSession *session = entry->session; + const struct sockaddr *addr; + socklen_t addrlen; + int ret; + + if (session->passive && !session->reliable) + { + entry->target->get_addr(&addr, &addrlen); + return sendto(entry->sockfd, buf, size, 0, addr, addrlen); + } + return write(entry->sockfd, buf, size); } @@ -266,10 +287,18 @@ CommSession::~CommSession() entry = list_entry(pos, struct CommConnEntry, list); list_del(pos); - errno_bak = errno; - mpoller_del(entry->sockfd, entry->mpoller); - entry->state = CONN_STATE_CLOSING; - errno = errno_bak; + if (this->reliable) + { + errno_bak = errno; + mpoller_del(entry->sockfd, entry->mpoller); + entry->state = CONN_STATE_CLOSING; + errno = errno_bak; + } + else + { + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } } pthread_mutex_unlock(&target->mutex); @@ -328,16 +357,6 @@ int Communicator::first_timeout_recv(CommSession *session) return Communicator::first_timeout(session); } -void Communicator::release_conn(struct CommConnEntry *entry) -{ - delete entry->conn; - if (!entry->service) - pthread_mutex_destroy(&entry->mutex); - - close(entry->sockfd); - free(entry); -} - void Communicator::shutdown_service(CommService *service) { close(service->listen_fd); @@ -579,7 +598,7 @@ void Communicator::handle_incoming_request(struct poller_result *res) if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { - this->release_conn(entry); + __release_conn(entry); ((CommServiceTarget *)target)->decref(); } } @@ -661,7 +680,7 @@ void Communicator::handle_incoming_reply(struct poller_result *res) } if (__sync_sub_and_fetch(&entry->ref, 1) == 0) - this->release_conn(entry); + __release_conn(entry); } } @@ -733,7 +752,7 @@ void Communicator::handle_reply_result(struct poller_result *res) session->handle(state, res->error); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { - this->release_conn(entry); + __release_conn(entry); ((CommServiceTarget *)target)->decref(); } @@ -786,7 +805,7 @@ void Communicator::handle_request_result(struct poller_result *res) /* do nothing */ pthread_mutex_unlock(&entry->mutex); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) - this->release_conn(entry); + __release_conn(entry); break; } @@ -819,7 +838,7 @@ struct CommConnEntry *Communicator::accept_conn(CommServiceTarget *target, if (entry->conn) { entry->seq = 0; - entry->mpoller = this->mpoller; + entry->mpoller = NULL; entry->service = service; entry->target = target; entry->sockfd = target->sockfd; @@ -887,7 +906,7 @@ void Communicator::handle_connect_result(struct poller_result *res) target->release(); session->handle(state, res->error); - this->release_conn(entry); + __release_conn(entry); break; } } @@ -903,9 +922,10 @@ void Communicator::handle_listen_result(struct poller_result *res) { case PR_ST_SUCCESS: target = (CommServiceTarget *)res->data.result; - entry = this->accept_conn(target, service); + entry = Communicator::accept_conn(target, service); if (entry) { + entry->mpoller = this->mpoller; res->data.operation = PD_OP_READ; res->data.fd = entry->sockfd; res->data.create_message = Communicator::create_request; @@ -919,7 +939,7 @@ void Communicator::handle_listen_result(struct poller_result *res) break; } - this->release_conn(entry); + __release_conn(entry); } else close(target->sockfd); @@ -938,6 +958,54 @@ void Communicator::handle_listen_result(struct poller_result *res) } } +void Communicator::handle_recvfrom_result(struct poller_result *res) +{ + CommService *service = (CommService *)res->data.context; + struct CommConnEntry *entry; + CommTarget *target; + int state, error; + + switch (res->state) + { + case PR_ST_SUCCESS: + entry = (struct CommConnEntry *)res->data.result; + target = entry->target; + if (entry->state == CONN_STATE_SUCCESS) + { + state = CS_STATE_TOREPLY; + error = 0; + entry->state = CONN_STATE_IDLE; + list_add(&entry->list, &target->idle_list); + } + else + { + state = CS_STATE_ERROR; + if (entry->state == CONN_STATE_ERROR) + error = entry->error; + else + error = EBADMSG; + } + + entry->session->handle(state, error); + if (state == CS_STATE_ERROR) + { + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + + break; + + case PR_ST_DELETED: + this->shutdown_service(service); + break; + + case PR_ST_ERROR: + case PR_ST_STOPPED: + service->handle_stop(res->error); + break; + } +} + void Communicator::handle_sleep_result(struct poller_result *res) { SleepSession *session = (SleepSession *)res->data.context; @@ -1028,6 +1096,9 @@ void Communicator::handler_thread_routine(void *context) case PD_OP_LISTEN: comm->handle_listen_result(res); break; + case PD_OP_RECVFROM: + comm->handle_recvfrom_result(res); + break; case PD_OP_EVENT: case PD_OP_NOTIFY: comm->handle_aio_result(res); @@ -1132,6 +1203,7 @@ poller_message_t *Communicator::create_request(void *context) return NULL; session->passive = 1; + session->reliable = 1; entry->session = session; session->target = target; session->conn = entry->conn; @@ -1184,6 +1256,59 @@ poller_message_t *Communicator::create_reply(void *context) return session->in; } +int Communicator::recv_request(const void *buf, size_t size, + struct CommConnEntry *entry) +{ + CommService *service = entry->service; + CommTarget *target = entry->target; + CommSession *session; + size_t n; + int ret; + + session = service->new_session(entry->seq, entry->conn); + if (!session) + return -1; + + session->passive = 1; + session->reliable = 0; + entry->session = session; + session->target = target; + session->conn = entry->conn; + session->seq = entry->seq++; + session->out = NULL; + session->in = NULL; + + entry->state = CONN_STATE_RECEIVING; + + ((CommServiceTarget *)target)->incref(); + + session->in = session->message_in(); + if (session->in) + { + session->in->entry = entry; + do + { + n = size; + ret = session->in->append(buf, &n); + if (ret == 0) + { + size -= n; + buf = (const char *)buf + n; + } + else if (ret < 0) + { + entry->error = errno; + entry->state = CONN_STATE_ERROR; + } + else + entry->state = CONN_STATE_SUCCESS; + + } while (ret == 0 && size > 0); + } + + return 0; +} + int Communicator::partial_written(size_t n, void *context) { struct CommConnEntry *entry = (struct CommConnEntry *)context; @@ -1215,6 +1340,41 @@ void *Communicator::accept(const struct sockaddr *addr, socklen_t addrlen, delete target; } + close(sockfd); + return NULL; +} + +void *Communicator::recvfrom(const struct sockaddr *addr, socklen_t addrlen, + const void *buf, size_t size, void *context) +{ + CommService *service = (CommService *)context; + struct CommConnEntry *entry; + CommServiceTarget *target; + void *result; + int sockfd; + + sockfd = dup(service->listen_fd); + if (sockfd >= 0) + { + result = Communicator::accept(addr, addrlen, sockfd, context); + if (result) + { + target = (CommServiceTarget *)result; + entry = Communicator::accept_conn(target, service); + if (entry) + { + if (Communicator::recv_request(buf, size, entry) >= 0) + return entry; + + __release_conn(entry); + } + else + close(sockfd); + + target->decref(); + } + } + return NULL; } @@ -1342,7 +1502,7 @@ struct CommConnEntry *Communicator::launch_conn(CommSession *session, int sockfd; int ret; - sockfd = this->nonblock_connect(target); + sockfd = Communicator::nonblock_connect(target); if (sockfd >= 0) { entry = (struct CommConnEntry *)malloc(sizeof (struct CommConnEntry)); @@ -1355,7 +1515,7 @@ struct CommConnEntry *Communicator::launch_conn(CommSession *session, if (entry->conn) { entry->seq = 0; - entry->mpoller = this->mpoller; + entry->mpoller = NULL; entry->service = NULL; entry->target = target; entry->session = session; @@ -1437,9 +1597,10 @@ int Communicator::request_new_conn(CommSession *session, CommTarget *target) struct poller_data data; int timeout; - entry = this->launch_conn(session, target); + entry = Communicator::launch_conn(session, target); if (entry) { + entry->mpoller = this->mpoller; session->conn = entry->conn; session->seq = entry->seq++; data.operation = PD_OP_CONNECT; @@ -1449,7 +1610,7 @@ int Communicator::request_new_conn(CommSession *session, CommTarget *target) if (mpoller_add(&data, timeout, this->mpoller) >= 0) return 0; - this->release_conn(entry); + __release_conn(entry); } return -1; @@ -1483,18 +1644,24 @@ int Communicator::request(CommSession *session, CommTarget *target) return 0; } -int Communicator::nonblock_listen(CommService *service) +int Communicator::nonblock_listen(CommService *service, int *reliable) { int sockfd = service->create_listen_fd(); + int ret; if (sockfd >= 0) { if (__set_fd_nonblock(sockfd) >= 0) { - if (__bind_and_listen(sockfd, service->bind_addr, - service->addrlen) >= 0) + if (__bind_sockaddr(sockfd, service->bind_addr, + service->addrlen) >= 0) { - return sockfd; + ret = listen(sockfd, SOMAXCONN); + if (ret >= 0 || errno == EOPNOTSUPP) + { + *reliable = (ret >= 0); + return sockfd; + } } } @@ -1507,20 +1674,34 @@ int Communicator::nonblock_listen(CommService *service) int Communicator::bind(CommService *service) { struct poller_data data; + int errno_bak = errno; + int reliable; int sockfd; - sockfd = this->nonblock_listen(service); + sockfd = this->nonblock_listen(service, &reliable); if (sockfd >= 0) { service->listen_fd = sockfd; service->ref = 1; - data.operation = PD_OP_LISTEN; data.fd = sockfd; - data.accept = Communicator::accept; data.context = service; data.result = NULL; + if (reliable) + { + data.operation = PD_OP_LISTEN; + data.accept = Communicator::accept; + } + else + { + data.operation = PD_OP_RECVFROM; + data.recvfrom = Communicator::recvfrom; + } + if (mpoller_add(&data, service->listen_timeout, this->mpoller) >= 0) + { + errno = errno_bak; return 0; + } close(sockfd); } @@ -1572,6 +1753,62 @@ int Communicator::reply_idle_conn(CommSession *session, CommTarget *target) return ret; } +int Communicator::reply_message_unreliable(struct CommConnEntry *entry) +{ + struct iovec vectors[ENCODE_IOV_MAX]; + int cnt; + + cnt = entry->session->out->encode(vectors, ENCODE_IOV_MAX); + if ((unsigned int)cnt > ENCODE_IOV_MAX) + { + if (cnt > ENCODE_IOV_MAX) + errno = EOVERFLOW; + return -1; + } + + if (cnt > 0) + { + struct msghdr message = { + .msg_name = entry->target->addr, + .msg_namelen = entry->target->addrlen, + .msg_iov = vectors, + .msg_iovlen = (size_t)cnt, + }; + if (sendmsg(entry->sockfd, &message, 0) < 0) + return -1; + } + + return 0; +} + +int Communicator::reply_unreliable(CommSession *session, CommTarget *target) +{ + struct CommConnEntry *entry; + struct list_head *pos; + + if (!list_empty(&target->idle_list)) + { + pos = target->idle_list.next; + entry = list_entry(pos, struct CommConnEntry, list); + list_del(pos); + + target = entry->target; + session->out = session->message_out(); + if (session->out) + { + if (this->reply_message_unreliable(entry) >= 0) + return 0; + } + + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + else + errno = ENOENT; + + return -1; +} + int Communicator::reply(CommSession *session) { struct CommConnEntry *entry; @@ -1588,9 +1825,10 @@ int Communicator::reply(CommSession *session) errno_bak = errno; session->passive = 2; target = session->target; - ret = this->reply_idle_conn(session, target); - if (ret < 0) - return -1; + if (session->reliable) + ret = this->reply_idle_conn(session, target); + else + ret = this->reply_unreliable(session, target); if (ret == 0) { @@ -1598,10 +1836,12 @@ int Communicator::reply(CommSession *session) session->handle(CS_STATE_SUCCESS, 0); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { - this->release_conn(entry); + __release_conn(entry); ((CommServiceTarget *)target)->decref(); } } + else if (ret < 0) + return -1; errno = errno_bak; return 0; @@ -1623,7 +1863,13 @@ int Communicator::push(const void *buf, size_t size, CommSession *session) if (!list_empty(&target->idle_list)) { entry = list_entry(target->idle_list.next, struct CommConnEntry, list); - ret = write(entry->sockfd, buf, size); + if (session->reliable) + ret = write(entry->sockfd, buf, size); + else + { + ret = sendto(entry->sockfd, buf, size, 0, + target->addr, target->addrlen); + } } else { @@ -1639,6 +1885,7 @@ int Communicator::shutdown(CommSession *session) { CommTarget *target = session->target; struct CommConnEntry *entry; + struct list_head *pos; int ret; if (session->passive != 1) @@ -1651,10 +1898,21 @@ int Communicator::shutdown(CommSession *session) pthread_mutex_lock(&target->mutex); if (!list_empty(&target->idle_list)) { - entry = list_entry(target->idle_list.next, struct CommConnEntry, list); - list_del(&entry->list); - ret = mpoller_del(entry->sockfd, entry->mpoller); - entry->state = CONN_STATE_CLOSING; + pos = target->idle_list.next; + entry = list_entry(pos, struct CommConnEntry, list); + list_del(pos); + + if (session->reliable) + { + ret = mpoller_del(entry->sockfd, entry->mpoller); + entry->state = CONN_STATE_CLOSING; + } + else + { + ret = 0; + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } } else { diff --git a/src/kernel/Communicator.h b/src/kernel/Communicator.h index 46611b4f..05cb485b 100644 --- a/src/kernel/Communicator.h +++ b/src/kernel/Communicator.h @@ -30,9 +30,8 @@ class CommConnection { -protected: +public: virtual ~CommConnection() { } - friend class Communicator; }; class CommTarget @@ -144,7 +143,8 @@ private: private: struct timespec begin_time; int timeout; - int passive; + short passive; + short reliable; public: CommSession() { this->passive = 0; } @@ -271,16 +271,6 @@ private: int create_handler_threads(size_t handler_threads); - int nonblock_connect(CommTarget *target); - int nonblock_listen(CommService *service); - - struct CommConnEntry *launch_conn(CommSession *session, - CommTarget *target); - struct CommConnEntry *accept_conn(class CommServiceTarget *target, - CommService *service); - - void release_conn(struct CommConnEntry *entry); - void shutdown_service(CommService *service); void shutdown_io_service(IOService *service); @@ -292,10 +282,14 @@ private: int send_message(struct CommConnEntry *entry); + int request_new_conn(CommSession *session, CommTarget *target); + int request_idle_conn(CommSession *session, CommTarget *target); int reply_idle_conn(CommSession *session, CommTarget *target); - int request_new_conn(CommSession *session, CommTarget *target); + int reply_message_unreliable(struct CommConnEntry *entry); + + int reply_unreliable(CommSession *session, CommTarget *target); void handle_incoming_request(struct poller_result *res); void handle_incoming_reply(struct poller_result *res); @@ -309,12 +303,22 @@ private: void handle_connect_result(struct poller_result *res); void handle_listen_result(struct poller_result *res); + void handle_recvfrom_result(struct poller_result *res); + void handle_sleep_result(struct poller_result *res); void handle_aio_result(struct poller_result *res); static void handler_thread_routine(void *context); + static int nonblock_connect(CommTarget *target); + static int nonblock_listen(CommService *service, int *reliable); + + static struct CommConnEntry *launch_conn(CommSession *session, + CommTarget *target); + static struct CommConnEntry *accept_conn(class CommServiceTarget *target, + CommService *service); + static int first_timeout(CommSession *session); static int next_timeout(CommSession *session); @@ -329,11 +333,17 @@ private: static poller_message_t *create_request(void *context); static poller_message_t *create_reply(void *context); + static int recv_request(const void *buf, size_t size, + struct CommConnEntry *entry); + static int partial_written(size_t n, void *context); static void *accept(const struct sockaddr *addr, socklen_t addrlen, int sockfd, void *context); + static void *recvfrom(const struct sockaddr *addr, socklen_t addrlen, + const void *buf, size_t size, void *context); + static void callback(struct poller_result *res, void *context); public: diff --git a/src/kernel/poller.c b/src/kernel/poller.c index baede535..65a82d0f 100644 --- a/src/kernel/poller.c +++ b/src/kernel/poller.c @@ -535,10 +535,7 @@ static void __poller_handle_listen(struct __poller_node *node, result = node->data.accept(addr, addrlen, sockfd, node->data.context); if (!result) - { - close(sockfd); break; - } res->data = node->data; res->data.result = result; diff --git a/src/protocol/DnsMessage.cc b/src/protocol/DnsMessage.cc index fd925ed3..59372176 100644 --- a/src/protocol/DnsMessage.cc +++ b/src/protocol/DnsMessage.cc @@ -37,6 +37,157 @@ static inline void __append_uint16(std::string& s, uint16_t tmp) s.append((const char *)&tmp, sizeof (uint16_t)); } +static inline void __append_uint32(std::string& s, uint32_t tmp) +{ + tmp = htonl(tmp); + s.append((const char *)&tmp, sizeof (uint32_t)); +} + +static inline int __append_name(std::string& s, const char *p) +{ + const char *name; + size_t len; + + while (*p) + { + name = p; + while (*p && *p != '.') + p++; + + len = p - name; + if (len > DNS_LABELS_MAX || (len == 0 && *p && *(p + 1))) + { + errno = EINVAL; + return -1; + } + + if (len > 0) + { + __append_uint8(s, len); + s.append(name, len); + } + + if (*p == '.') + p++; + } + + len = 0; + __append_uint8(s, len); + + return 0; +} + +static inline int __append_record_list(std::string& s, int *count, + dns_record_cursor_t *cursor) +{ + int cnt = 0; + struct dns_record *record; + std::string record_buf; + std::string rdata_buf; + int ret; + + while (dns_record_cursor_next(&record, cursor) == 0) + { + record_buf.clear(); + ret = __append_name(record_buf, record->name); + if (ret < 0) + return ret; + + __append_uint16(record_buf, record->type); + __append_uint16(record_buf, record->rclass); + __append_uint32(record_buf, record->ttl); + + switch (record->type) + { + case DNS_TYPE_A: + case DNS_TYPE_AAAA: + __append_uint16(record_buf, record->rdlength); + record_buf.append((const char *)record->rdata, record->rdlength); + break; + + case DNS_TYPE_NS: + case DNS_TYPE_CNAME: + case DNS_TYPE_PTR: + rdata_buf.clear(); + ret = __append_name(rdata_buf, (const char *)record->rdata); + if (ret < 0) + return ret; + + __append_uint16(record_buf, rdata_buf.size()); + record_buf.append(rdata_buf); + + break; + + case DNS_TYPE_SOA: + { + auto *soa = (struct dns_record_soa *)record->rdata; + + rdata_buf.clear(); + ret = __append_name(rdata_buf, soa->mname); + if (ret < 0) + return ret; + ret = __append_name(rdata_buf, soa->rname); + if (ret < 0) + return ret; + + __append_uint32(rdata_buf, soa->serial); + __append_uint32(rdata_buf, soa->refresh); + __append_uint32(rdata_buf, soa->retry); + __append_uint32(rdata_buf, soa->expire); + __append_uint32(rdata_buf, soa->minimum); + + __append_uint16(record_buf, rdata_buf.size()); + record_buf.append(rdata_buf); + + break; + } + + case DNS_TYPE_SRV: + { + auto *srv = (struct dns_record_srv *)record->rdata; + + rdata_buf.clear(); + __append_uint16(rdata_buf, srv->priority); + __append_uint16(rdata_buf, srv->weight); + __append_uint16(rdata_buf, srv->port); + ret = __append_name(rdata_buf, srv->target); + if (ret < 0) + return ret; + + __append_uint16(record_buf, rdata_buf.size()); + record_buf.append(rdata_buf); + + break; + } + case DNS_TYPE_MX: + { + auto *mx = (struct dns_record_mx *)record->rdata; + rdata_buf.clear(); + __append_uint16(rdata_buf, mx->preference); + ret = __append_name(rdata_buf, mx->exchange); + if (ret < 0) + return ret; + + __append_uint16(record_buf, rdata_buf.size()); + record_buf.append(rdata_buf); + + break; + } + default: + // TODO not implement + continue; + } + + cnt++; + s.append(record_buf); + } + + if (count) + *count = cnt; + + return 0; +} + DnsMessage::DnsMessage(DnsMessage&& msg) : ProtocolMessage(std::move(msg)) { @@ -70,54 +221,58 @@ DnsMessage& DnsMessage::operator = (DnsMessage&& msg) int DnsMessage::encode_reply() { + dns_record_cursor_t cursor; struct dns_header h; - const char *name; + std::string tmpbuf; const char *p; - size_t len; + int ancount; + int nscount; + int arcount; + int ret; msgbuf.clear(); msgsize = 0; - // TODO encode other field + // TODO + // this is an incomplete and inefficient way, compress not used, // pointers can only be used for occurances of a domain name where // the format is not class specific + dns_answer_cursor_init(&cursor, this->parser); + ret = __append_record_list(tmpbuf, &ancount, &cursor); + dns_record_cursor_deinit(&cursor); + if (ret < 0) + return ret; + + dns_authority_cursor_init(&cursor, this->parser); + ret = __append_record_list(tmpbuf, &nscount, &cursor); + dns_record_cursor_deinit(&cursor); + if (ret < 0) + return ret; + + dns_additional_cursor_init(&cursor, this->parser); + ret = __append_record_list(tmpbuf, &arcount, &cursor); + dns_record_cursor_deinit(&cursor); + if (ret < 0) + return ret; + h = this->parser->header; h.id = htons(h.id); h.qdcount = htons(1); - h.ancount = htons(0); - h.nscount = htons(0); - h.arcount = htons(0); + h.ancount = htons(ancount); + h.nscount = htons(nscount); + h.arcount = htons(arcount); msgbuf.append((const char *)&h, sizeof (struct dns_header)); p = parser->question.qname ? parser->question.qname : "."; - while (*p) - { - name = p; - while (*p && *p != '.') - p++; + ret = __append_name(msgbuf, p); + if (ret < 0) + return ret; - len = p - name; - if (len > DNS_LABELS_MAX || (len == 0 && *p && *(p + 1))) - { - errno = EINVAL; - return -1; - } - - if (len > 0) - { - __append_uint8(msgbuf, len); - msgbuf.append(name, len); - } - - if (*p == '.') - p++; - } - - len = 0; - __append_uint8(msgbuf, len); __append_uint16(msgbuf, parser->question.qtype); __append_uint16(msgbuf, parser->question.qclass); + msgbuf.append(tmpbuf); + if (msgbuf.size() >= (1 << 16)) { errno = EOVERFLOW; diff --git a/src/server/WFDnsServer.h b/src/server/WFDnsServer.h index 15f0daf5..be90810a 100644 --- a/src/server/WFDnsServer.h +++ b/src/server/WFDnsServer.h @@ -29,6 +29,7 @@ using WFDnsServer = WFServer inline +CommSession *WFDnsServer::new_session(long long seq, CommConnection *conn) +{ + WFDnsTask *task; + + task = WFServerTaskFactory::create_dns_task(this, this->process); + task->set_keep_alive(this->params.keep_alive_timeout); + task->set_receive_timeout(this->params.receive_timeout); + task->get_req()->set_size_limit(this->params.request_size_limit); + + return task; +} + #endif diff --git a/src/server/WFHttpServer.h b/src/server/WFHttpServer.h index e4af07d9..c696c96f 100644 --- a/src/server/WFHttpServer.h +++ b/src/server/WFHttpServer.h @@ -30,6 +30,7 @@ using WFHttpServer = WFServer #include #include "CommScheduler.h" +#include "EndpointParams.h" #include "WFConnection.h" #include "WFGlobal.h" #include "WFServer.h" @@ -72,10 +73,32 @@ int WFServerBase::create_listen_fd() { const struct sockaddr *bind_addr; socklen_t addrlen; + int type, protocol; int reuse = 1; + switch (this->params.transport_type) + { + case TT_TCP: + type = SOCK_STREAM; + protocol = 0; + break; + case TT_UDP: + type = SOCK_DGRAM; + protocol = 0; + break; +#ifdef IPPROTO_SCTP + case TT_SCTP: + type = SOCK_STREAM; + protocol = IPPROTO_SCTP; + break; +#endif + default: + errno = EPROTONOSUPPORT; + return -1; + } + this->get_addr(&bind_addr, &addrlen); - this->listen_fd = socket(bind_addr->sa_family, SOCK_STREAM, 0); + this->listen_fd = socket(bind_addr->sa_family, type, protocol); if (this->listen_fd >= 0) { setsockopt(this->listen_fd, SOL_SOCKET, SO_REUSEADDR, diff --git a/src/server/WFServer.h b/src/server/WFServer.h index 2dcd7ca2..ccbde3c1 100644 --- a/src/server/WFServer.h +++ b/src/server/WFServer.h @@ -27,10 +27,12 @@ #include #include #include +#include "EndpointParams.h" #include "WFTaskFactory.h" struct WFServerParams { + enum TransportType transport_type; size_t max_connections; int peer_response_timeout; /* timeout of each read or write operation */ int receive_timeout; /* timeout of receiving the whole message */ @@ -40,6 +42,7 @@ struct WFServerParams static constexpr struct WFServerParams SERVER_PARAMS_DEFAULT = { + .transport_type = TT_TCP, .max_connections = 2000, .peer_response_timeout = 10 * 1000, .receive_timeout = -1, @@ -116,6 +119,8 @@ public: return -1; } + const struct WFServerParams *get_params() const { return &this->params; } + protected: WFServerParams params; diff --git a/tutorial/CMakeLists.txt b/tutorial/CMakeLists.txt index 5e3b2af2..e12051a3 100644 --- a/tutorial/CMakeLists.txt +++ b/tutorial/CMakeLists.txt @@ -32,6 +32,7 @@ else () endif () set(TUTORIAL_LIST + dns_proxy tutorial-00-helloworld tutorial-01-wget tutorial-04-http_echo_server diff --git a/tutorial/tutorial-05-http_proxy.cc b/tutorial/tutorial-05-http_proxy.cc index 22861308..a4ed9479 100644 --- a/tutorial/tutorial-05-http_proxy.cc +++ b/tutorial/tutorial-05-http_proxy.cc @@ -149,6 +149,12 @@ int main(int argc, char *argv[]) port = atoi(argv[1]); signal(SIGINT, sig_handler); + struct WFGlobalSettings settings = GLOBAL_SETTINGS_DEFAULT; + settings.resolv_conf_path = "./resolv.conf"; + settings.dns_ttl_default = 5; + settings.dns_ttl_min = 1; + WORKFLOW_library_init(&settings); + struct WFServerParams params = HTTP_SERVER_PARAMS_DEFAULT; /* for safety, limit request size to 8MB. */ params.request_size_limit = 8 * 1024 * 1024;