From 65de9c4e6571d0d8e138ba0f3f771b8141753245 Mon Sep 17 00:00:00 2001 From: xiehan <52160700+Barenboim@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:15:53 +0800 Subject: [PATCH] Add WFRedisSubscriber to support redis subscribe/psubscribe. (#1608) --- BUILD | 7 +- CMakeLists_Headers.txt | 2 + src/client/CMakeLists.txt | 7 + src/client/WFRedisSubscriber.cc | 129 +++++++++++++ src/client/WFRedisSubscriber.h | 229 +++++++++++++++++++++++ src/client/xmake.lua | 3 + src/factory/RedisTaskImpl.cc | 139 +++++++++++++- src/factory/RedisTaskImpl.inl | 37 ++++ src/include/workflow/RedisTaskImpl.inl | 1 + src/include/workflow/WFRedisSubscriber.h | 1 + 10 files changed, 552 insertions(+), 3 deletions(-) create mode 100644 src/client/WFRedisSubscriber.cc create mode 100644 src/client/WFRedisSubscriber.h create mode 100644 src/factory/RedisTaskImpl.inl create mode 120000 src/include/workflow/RedisTaskImpl.inl create mode 120000 src/include/workflow/WFRedisSubscriber.h diff --git a/BUILD b/BUILD index b66139d0..a9fd41b9 100644 --- a/BUILD +++ b/BUILD @@ -114,9 +114,11 @@ cc_library( cc_library( name = 'redis', hdrs = [ + 'src/factory/RedisTaskImpl.inl', 'src/protocol/RedisMessage.h', 'src/protocol/redis_parser.h', 'src/server/WFRedisServer.h', + 'src/client/WFRedisSubscriber.h', ], includes = [ 'src/protocol', @@ -126,6 +128,7 @@ cc_library( 'src/factory/RedisTaskImpl.cc', 'src/protocol/RedisMessage.cc', 'src/protocol/redis_parser.c', + 'src/client/WFRedisSubscriber.cc', ], deps = [ ':common', @@ -135,7 +138,6 @@ cc_library( cc_library( name = 'mysql', hdrs = [ - 'src/client/WFMySQLConnection.h', 'src/protocol/MySQLMessage.h', 'src/protocol/MySQLMessage.inl', 'src/protocol/MySQLResult.h', @@ -146,6 +148,7 @@ cc_library( 'src/protocol/mysql_stream.h', 'src/protocol/mysql_types.h', 'src/server/WFMySQLServer.h', + 'src/client/WFMySQLConnection.h', ], includes = [ 'src/protocol', @@ -153,7 +156,6 @@ cc_library( 'src/server', ], srcs = [ - 'src/client/WFMySQLConnection.cc', 'src/factory/MySQLTaskImpl.cc', 'src/protocol/MySQLMessage.cc', 'src/protocol/MySQLResult.cc', @@ -161,6 +163,7 @@ cc_library( 'src/protocol/mysql_byteorder.c', 'src/protocol/mysql_parser.c', 'src/protocol/mysql_stream.c', + 'src/client/WFMySQLConnection.cc', ], deps = [ ':common', diff --git a/CMakeLists_Headers.txt b/CMakeLists_Headers.txt index a3924d03..07a3dcfb 100644 --- a/CMakeLists_Headers.txt +++ b/CMakeLists_Headers.txt @@ -60,6 +60,7 @@ set(INCLUDE_HEADERS src/server/WFRedisServer.h src/server/WFMySQLServer.h src/client/WFMySQLConnection.h + src/client/WFRedisSubscriber.h src/client/WFConsulClient.h src/client/WFDnsClient.h src/manager/DnsCache.h @@ -89,6 +90,7 @@ set(INCLUDE_HEADERS src/factory/WFResourcePool.h src/factory/WFMessageQueue.h src/factory/WFHttpServerTask.h + src/factory/RedisTaskImpl.inl src/nameservice/WFNameService.h src/nameservice/WFDnsResolver.h src/nameservice/WFServiceGovernance.h diff --git a/src/client/CMakeLists.txt b/src/client/CMakeLists.txt index 0c5fdd08..5f074f1c 100644 --- a/src/client/CMakeLists.txt +++ b/src/client/CMakeLists.txt @@ -5,6 +5,13 @@ set(SRC WFDnsClient.cc ) +if (NOT REDIS STREQUAL "n") + set(SRC + ${SRC} + WFRedisSubscriber.cc + ) +endif () + if (NOT MYSQL STREQUAL "n") set(SRC ${SRC} diff --git a/src/client/WFRedisSubscriber.cc b/src/client/WFRedisSubscriber.cc new file mode 100644 index 00000000..33828ecb --- /dev/null +++ b/src/client/WFRedisSubscriber.cc @@ -0,0 +1,129 @@ +/* + Copyright (c) 2024 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#include +#include +#include +#include +#include "URIParser.h" +#include "RedisTaskImpl.inl" +#include "WFRedisSubscriber.h" + +int WFRedisSubscribeTask::sync_send(const std::string& command, + const std::vector& params) +{ + std::string str("*" + std::to_string(1 + params.size()) + "\r\n"); + int ret; + + str += "$" + std::to_string(command.size()) + "\r\n" + command + "\r\n"; + for (const std::string& p : params) + str += "$" + std::to_string(p.size()) + "\r\n" + p + "\r\n"; + + this->mutex.lock(); + if (this->task) + { + ret = this->task->push(str.c_str(), str.size()); + if (ret == (int)str.size()) + ret = 0; + else + { + if (ret >= 0) + errno = ENOBUFS; + ret = -1; + } + } + else + { + errno = ENOENT; + ret = -1; + } + + this->mutex.unlock(); + return ret; +} + +void WFRedisSubscribeTask::task_extract(WFRedisTask *task) +{ + auto *t = (WFRedisSubscribeTask *)task->user_data; + + if (t->extract) + t->extract(t); +} + +void WFRedisSubscribeTask::task_callback(WFRedisTask *task) +{ + auto *t = (WFRedisSubscribeTask *)task->user_data; + + t->mutex.lock(); + t->task = NULL; + t->mutex.unlock(); + + t->state = task->get_state(); + t->error = task->get_error(); + if (t->callback) + t->callback(t); + + t->release(); +} + +int WFRedisSubscriber::init(const std::string& url, SSL_CTX *ssl_ctx) +{ + if (URIParser::parse(url, this->uri) >= 0) + { + this->ssl_ctx = ssl_ctx; + return 0; + } + + if (this->uri.state == URI_STATE_INVALID) + errno = EINVAL; + + return -1; +} + +WFRedisTask * +WFRedisSubscriber::create_redis_task(const std::string& command, + const std::vector& params) +{ + WFRedisTask *task = __WFRedisTaskFactory::create_subscribe_task(this->uri, + WFRedisSubscribeTask::task_extract, + WFRedisSubscribeTask::task_callback); + this->set_ssl_ctx(task); + task->get_req()->set_request(command, params); + return task; +} + +WFRedisSubscribeTask * +WFRedisSubscriber::create_subscribe_task( + const std::vector& channels, + extract_t extract, callback_t callback) +{ + WFRedisTask *task = this->create_redis_task("SUBSCRIBE", channels); + return new WFRedisSubscribeTask(task, std::move(extract), + std::move(callback)); +} + +WFRedisSubscribeTask * +WFRedisSubscriber::create_psubscribe_task( + const std::vector& patterns, + extract_t extract, callback_t callback) +{ + WFRedisTask *task = this->create_redis_task("PSUBSCRIBE", patterns); + return new WFRedisSubscribeTask(task, std::move(extract), + std::move(callback)); +} + diff --git a/src/client/WFRedisSubscriber.h b/src/client/WFRedisSubscriber.h new file mode 100644 index 00000000..d66fd9e1 --- /dev/null +++ b/src/client/WFRedisSubscriber.h @@ -0,0 +1,229 @@ +/* + Copyright (c) 2024 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Author: Xie Han (xiehan@sogou-inc.com) +*/ + +#ifndef _WFREDISSUBSCRIBER_H_ +#define _WFREDISSUBSCRIBER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "WFTask.h" +#include "WFTaskFactory.h" + +class WFRedisSubscribeTask : public WFGenericTask +{ +public: + /* Note: Call 'get_resp()' only in the 'extract' function or + before the task is started to set response size limit. */ + protocol::RedisResponse *get_resp() + { + return this->task->get_resp(); + } + +public: + /* User needs to call 'release()' exactly once, anywhere. */ + void release() + { + if (this->flag.exchange(true)) + delete this; + } + +public: + /* Note: After 'release()' is called, all the requesting functions + should not be called except in 'extract', because the task + point may have been deleted because 'callback' finished. */ + + int subscribe(const std::vector& channels) + { + return this->sync_send("SUBSCRIBE", channels); + } + + int unsubscribe(const std::vector& channels) + { + return this->sync_send("UNSUBSCRIBE", channels); + } + + int unsubscribe_all() + { + return this->unsubscribe(std::vector()); + } + + int psubscribe(const std::vector& patterns) + { + return this->sync_send("PSUBSCRIBE", patterns); + } + + int punsubscribe(const std::vector& patterns) + { + return this->sync_send("PUNSUBSCRIBE", patterns); + } + + int punsubscribe_all() + { + return this->punsubscribe(std::vector()); + } + + int ping() + { + return this->sync_send("PING", std::vector()); + } + +public: + /* All 'timeout' proxy functions can only be called only before + the task is started or in 'extract'. */ + + /* Timeout of waiting for each message. Very useful. If not set, + the max waiting time will be the global 'response_timeout'*/ + void set_watch_timeout(int timeout) + { + this->task->set_watch_timeout(timeout); + } + + /* Timeout of receiving a complete message. */ + void set_recv_timeout(int timeout) + { + this->task->set_receive_timeout(timeout); + } + + /* Timeout of sending the first subscribe request. */ + void set_send_timeout(int timeout) + { + this->task->set_send_timeout(timeout); + } + + /* The default keep alive timeout is 0. If you want to keep + the connection alive, make sure not to send any request + after all channels/patterns were unsubscribed. */ + void set_keep_alive(int timeout) + { + this->task->set_keep_alive(timeout); + } + +public: + /* Call 'set_extract' or 'set_callback' only before the task + is started, or in 'extract'. */ + + void set_extract(std::function ex) + { + this->extract = std::move(ex); + } + + void set_callback(std::function cb) + { + this->callback = std::move(cb); + } + +protected: + virtual void dispatch() + { + series_of(this)->push_front(this->task); + this->subtask_done(); + } + + virtual SubTask *done() + { + return series_of(this)->pop(); + } + +protected: + int sync_send(const std::string& command, + const std::vector& params); + static void task_extract(WFRedisTask *task); + static void task_callback(WFRedisTask *task); + +protected: + WFRedisTask *task; + std::mutex mutex; + std::atomic flag; + std::function extract; + std::function callback; + +protected: + WFRedisSubscribeTask(WFRedisTask *task, + std::function&& ex, + std::function&& cb) : + flag(false), + extract(std::move(ex)), + callback(std::move(cb)) + { + task->user_data = this; + this->task = task; + } + + virtual ~WFRedisSubscribeTask() + { + if (this->task) + this->task->dismiss(); + } + + friend class WFRedisSubscriber; +}; + +class WFRedisSubscriber +{ +public: + int init(const std::string& url) + { + return this->init(url, NULL); + } + + int init(const std::string& url, SSL_CTX *ssl_ctx); + + void deinit() { } + +public: + using extract_t = std::function; + using callback_t = std::function; + +public: + WFRedisSubscribeTask * + create_subscribe_task(const std::vector& channels, + extract_t extract, callback_t callback); + + WFRedisSubscribeTask * + create_psubscribe_task(const std::vector& patterns, + extract_t extract, callback_t callback); + +protected: + void set_ssl_ctx(WFRedisTask *task) const + { + using RedisRequest = protocol::RedisRequest; + using RedisResponse = protocol::RedisResponse; + auto *t = (WFComplexClientTask *)task; + /* 'ssl_ctx' can be NULL and will use default. */ + t->set_ssl_ctx(this->ssl_ctx); + } + +protected: + WFRedisTask *create_redis_task(const std::string& command, + const std::vector& params); + +protected: + ParsedURI uri; + SSL_CTX *ssl_ctx; + +public: + virtual ~WFRedisSubscriber() { } +}; + +#endif + diff --git a/src/client/xmake.lua b/src/client/xmake.lua index 83d72c62..8be59cff 100644 --- a/src/client/xmake.lua +++ b/src/client/xmake.lua @@ -2,6 +2,9 @@ target("client") set_kind("object") add_files("*.cc") remove_files("WFKafkaClient.cc") + if not has_config("redis") then + remove_files("WFRedisSubscriber.cc") + end if not has_config("mysql") then remove_files("WFMySQLConnection.cc") end diff --git a/src/factory/RedisTaskImpl.cc b/src/factory/RedisTaskImpl.cc index d34defa1..87063a71 100644 --- a/src/factory/RedisTaskImpl.cc +++ b/src/factory/RedisTaskImpl.cc @@ -16,14 +16,17 @@ Authors: Wu Jiaxu (wujiaxu@sogou-inc.com) Li Yingxin (liyingxin@sogou-inc.com) Liu Kai (liukaidx@sogou-inc.com) + Xie Han (xiehan@sogou-inc.com) */ #include #include #include +#include "PackageWrapper.h" #include "WFTaskError.h" #include "WFTaskFactory.h" #include "StringUtil.h" +#include "RedisTaskImpl.inl" using namespace protocol; @@ -51,7 +54,7 @@ protected: virtual bool init_success(); virtual bool finish_once(); -private: +protected: bool need_redirect(); std::string username_; @@ -282,6 +285,114 @@ bool ComplexRedisTask::finish_once() return true; } +/****** Redis Subscribe ******/ + +class ComplexRedisSubscribeTask : public ComplexRedisTask +{ +public: + virtual int push(const void *buf, size_t size) + { + if (finished_) + { + errno = ENOENT; + return -1; + } + + if (!watching_) + { + errno = EAGAIN; + return -1; + } + + return this->scheduler->push(buf, size, this); + } + +protected: + virtual CommMessageIn *message_in() + { + if (!is_user_request_) + return this->ComplexRedisTask::message_in(); + + return &wrapper_; + } + + virtual int keep_alive_timeout() + { + if (!is_user_request_) + return this->ComplexRedisTask::keep_alive_timeout(); + + return this->keep_alive_timeo; + } + + virtual int first_timeout() + { + return watching_ ? this->watch_timeo : 0; + } + +protected: + class SubscribeWrapper : public PackageWrapper + { + protected: + virtual ProtocolMessage *next_in(ProtocolMessage *message); + + protected: + ComplexRedisSubscribeTask *task_; + + public: + SubscribeWrapper(ComplexRedisSubscribeTask *task) : + PackageWrapper(task->get_resp()) + { + task_ = task; + } + }; + +protected: + SubscribeWrapper wrapper_; + bool watching_; + bool finished_; + std::function extract_; + +public: + ComplexRedisSubscribeTask(std::function&& extract, + redis_callback_t&& callback) : + ComplexRedisTask(0, std::move(callback)), + wrapper_(this), + extract_(std::move(extract)) + { + watching_ = false; + finished_ = false; + } +}; + +ProtocolMessage * +ComplexRedisSubscribeTask::SubscribeWrapper::next_in(ProtocolMessage *message) +{ + redis_reply_t *reply = ((RedisResponse *)message)->result_ptr(); + + if (reply->type == REDIS_REPLY_TYPE_ARRAY && reply->elements == 3 && + reply->element[0]->type == REDIS_REPLY_TYPE_STRING) + { + const char *str = reply->element[0]->str; + size_t len = reply->element[0]->len; + + if ((len == 11 && strncasecmp(str, "unsubscribe", 11)) == 0 || + (len == 12 && strncasecmp(str, "punsubscribe", 12) == 0)) + { + if (reply->element[2]->type == REDIS_REPLY_TYPE_INTEGER && + reply->element[2]->integer == 0) + { + task_->finished_ = true; + } + } + } + + task_->watching_ = true; + task_->extract_(task_); + + task_->clear_resp(); + return task_->finished_ ? NULL : &task_->resp; +} + /**********Factory**********/ // redis://:password@host:port/db_num @@ -311,3 +422,29 @@ WFRedisTask *WFTaskFactory::create_redis_task(const ParsedURI& uri, return task; } +WFRedisTask * +__WFRedisTaskFactory::create_subscribe_task(const std::string& url, + extract_t extract, + redis_callback_t callback) +{ + auto *task = new ComplexRedisSubscribeTask(std::move(extract), + std::move(callback)); + ParsedURI uri; + + URIParser::parse(url, uri); + task->init(std::move(uri)); + return task; +} + +WFRedisTask * +__WFRedisTaskFactory::create_subscribe_task(const ParsedURI& uri, + extract_t extract, + redis_callback_t callback) +{ + auto *task = new ComplexRedisSubscribeTask(std::move(extract), + std::move(callback)); + + task->init(uri); + return task; +} + diff --git a/src/factory/RedisTaskImpl.inl b/src/factory/RedisTaskImpl.inl new file mode 100644 index 00000000..2329ba8b --- /dev/null +++ b/src/factory/RedisTaskImpl.inl @@ -0,0 +1,37 @@ +/* + Copyright (c) 2024 Sogou, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Authors: Xie Han (xiehan@sogou-inc.com) +*/ + +#include "WFTaskFactory.h" + +// Internal, for WFRedisSubscribeTask only. + +class __WFRedisTaskFactory +{ +private: + using extract_t = std::function; + +public: + static WFRedisTask *create_subscribe_task(const std::string& url, + extract_t extract, + redis_callback_t callback); + + static WFRedisTask *create_subscribe_task(const ParsedURI& uri, + extract_t extract, + redis_callback_t callback); +}; + diff --git a/src/include/workflow/RedisTaskImpl.inl b/src/include/workflow/RedisTaskImpl.inl new file mode 120000 index 00000000..04a7d9ab --- /dev/null +++ b/src/include/workflow/RedisTaskImpl.inl @@ -0,0 +1 @@ +../../factory/RedisTaskImpl.inl \ No newline at end of file diff --git a/src/include/workflow/WFRedisSubscriber.h b/src/include/workflow/WFRedisSubscriber.h new file mode 120000 index 00000000..4687bd3d --- /dev/null +++ b/src/include/workflow/WFRedisSubscriber.h @@ -0,0 +1 @@ +../../client/WFRedisSubscriber.h \ No newline at end of file