diff --git a/src/client/WFMySQLConnection.cc b/src/client/WFMySQLConnection.cc index d0103fa1..3eb16b7e 100644 --- a/src/client/WFMySQLConnection.cc +++ b/src/client/WFMySQLConnection.cc @@ -21,10 +21,11 @@ #include #include #include +#include #include "URIParser.h" #include "WFMySQLConnection.h" -int WFMySQLConnection::init(const std::string& url) +int WFMySQLConnection::init(const std::string& url, SSL_CTX *ssl_ctx) { std::string query; ParsedURI uri; @@ -43,6 +44,7 @@ int WFMySQLConnection::init(const std::string& url) if (uri.query) { this->uri = std::move(uri); + this->ssl_ctx = ssl_ctx; return 0; } } diff --git a/src/client/WFMySQLConnection.h b/src/client/WFMySQLConnection.h index 5d92ebf7..9615ba48 100644 --- a/src/client/WFMySQLConnection.h +++ b/src/client/WFMySQLConnection.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "URIParser.h" #include "WFTaskFactory.h" @@ -31,7 +32,12 @@ public: /* example: mysql://username:passwd@127.0.0.1/dbname?character_set=utf8 * IP string is recommmended in url. When using a domain name, the first * address resovled will be used. Don't use upstream name as a host. */ - int init(const std::string& url); + int init(const std::string& url) + { + return init(url, NULL); + } + + int init(const std::string& url, SSL_CTX *ssl_ctx); void deinit() { } @@ -41,6 +47,7 @@ public: { WFMySQLTask *task = WFTaskFactory::create_mysql_task(this->uri, 0, std::move(callback)); + this->set_ssl_ctx(task); task->get_req()->set_query(query); return task; } @@ -51,12 +58,24 @@ public: WFMySQLTask *create_disconnect_task(mysql_callback_t callback) { WFMySQLTask *task = this->create_query_task("", std::move(callback)); + this->set_ssl_ctx(task); task->set_keep_alive(0); return task; } +protected: + void set_ssl_ctx(WFMySQLTask *task) + { + using MySQLRequest = protocol::MySQLRequest; + using MySQLResponse = protocol::MySQLResponse; + auto *t = (WFComplexClientTask *)task; + /* 'ssl_ctx' can be NULL and will use default. */ + t->set_ssl_ctx(this->ssl_ctx); + } + protected: ParsedURI uri; + SSL_CTX *ssl_ctx; int id; public: