Merge pull request #232 from sogou/dev

Refactor name service. Refactor Upstream module based on WFNameService.
This commit is contained in:
liyingxin
2021-01-31 18:45:12 +08:00
committed by GitHub
22 changed files with 2250 additions and 1759 deletions

View File

@@ -82,6 +82,9 @@ set(INCLUDE_HEADERS
src/factory/WFAlgoTaskFactory.inl
src/factory/Workflow.h
src/factory/WFOperator.h
src/factory/WFNameService.h
src/factory/WFDNSResolver.h
src/manager/UpstreamPolicies.h
)
if(KAFKA STREQUAL "y")

View File

@@ -8,6 +8,8 @@ set(SRC
MySQLTaskImpl.cc
WFTaskFactory.cc
Workflow.cc
WFNameService.cc
WFDNSResolver.cc
)
add_library(${PROJECT_NAME} OBJECT ${SRC})

View File

@@ -471,13 +471,13 @@ bool ComplexMySQLTask::init_success()
transaction_state_ = TRANSACTION_OUT;
this->WFComplexClientTask::set_info(std::string("?maxconn=1&") +
info + "|txn:" + transaction);
this->first_addr_only_ = true;
this->fixed_addr_ = true;
}
else
{
transaction_state_ = NO_TRANSACTION;
this->WFComplexClientTask::set_info(info);
this->first_addr_only_ = false;
this->fixed_addr_ = false;
}
delete []info;

View File

@@ -0,0 +1,282 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
Xie Han (xiehan@sogou-inc.com)
*/
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <arpa/inet.h>
#include <errno.h>
#include <netdb.h>
#include <stdlib.h>
#include <ctype.h>
#include <utility>
#include <string>
#include "DNSRoutine.h"
#include "EndpointParams.h"
#include "RouteManager.h"
#include "WFGlobal.h"
#include "WFTaskError.h"
#include "WFTaskFactory.h"
#include "WFNameService.h"
#include "WFDNSResolver.h"
/*
DNS_CACHE_LEVEL_0 -> NO cache
DNS_CACHE_LEVEL_1 -> TTL MIN
DNS_CACHE_LEVEL_2 -> TTL [DEFAULT]
DNS_CACHE_LEVEL_3 -> Forever
*/
#define DNS_CACHE_LEVEL_0 0
#define DNS_CACHE_LEVEL_1 1
#define DNS_CACHE_LEVEL_2 2
#define DNS_CACHE_LEVEL_3 3
class WFResolverTask : public WFRouterTask
{
public:
WFResolverTask(const struct WFNSParams *params, int dns_cache_level,
unsigned int dns_ttl_default, unsigned int dns_ttl_min,
const struct EndpointParams *endpoint_params,
router_callback_t&& cb) :
WFRouterTask(std::move(cb)),
type_(params->type),
host_(params->uri.host ? params->uri.host : ""),
port_(params->uri.port ? atoi(params->uri.port) : 0),
info_(params->info),
dns_cache_level_(dns_cache_level),
dns_ttl_default_(dns_ttl_default),
dns_ttl_min_(dns_ttl_min),
endpoint_params_(*endpoint_params),
first_addr_only_(params->fixed_addr)
{
}
private:
virtual void dispatch();
virtual SubTask *done();
void dns_callback(WFDNSTask *dns_task);
void dns_callback_internal(DNSOutput *dns_task,
unsigned int ttl_default,
unsigned int ttl_min);
private:
TransportType type_;
std::string host_;
unsigned short port_;
std::string info_;
int dns_cache_level_;
unsigned int dns_ttl_default_;
unsigned int dns_ttl_min_;
struct EndpointParams endpoint_params_;
bool first_addr_only_;
bool insert_dns_;
};
void WFResolverTask::dispatch()
{
insert_dns_ = true;
if (dns_cache_level_ != DNS_CACHE_LEVEL_0)
{
auto *dns_cache = WFGlobal::get_dns_cache();
const DNSCache::DNSHandle *addr_handle = NULL;
switch (dns_cache_level_)
{
case DNS_CACHE_LEVEL_1:
addr_handle = dns_cache->get_confident(host_, port_);
break;
case DNS_CACHE_LEVEL_2:
addr_handle = dns_cache->get_ttl(host_, port_);
break;
case DNS_CACHE_LEVEL_3:
addr_handle = dns_cache->get(host_, port_);
break;
default:
break;
}
if (addr_handle)
{
auto *route_manager = WFGlobal::get_route_manager();
struct addrinfo *addrinfo = addr_handle->value.addrinfo;
struct addrinfo first;
if (first_addr_only_ && addrinfo->ai_next)
{
first = *addrinfo;
first.ai_next = NULL;
addrinfo = &first;
}
if (route_manager->get(type_, addrinfo, info_, &endpoint_params_,
this->result) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
}
else
this->state = WFT_STATE_SUCCESS;
insert_dns_ = false;
dns_cache->release(addr_handle);
}
}
if (insert_dns_ && !host_.empty())
{
char front = host_.front();
char back = host_.back();
struct in6_addr addr;
int ret;
if (host_.find(':') != std::string::npos)
ret = inet_pton(AF_INET6, host_.c_str(), &addr);
else if (isdigit(back) && isdigit(front))
ret = inet_pton(AF_INET, host_.c_str(), &addr);
else if (front == '/')
ret = 1;
else
ret = 0;
if (ret == 1)
{
DNSInput dns_in;
DNSOutput dns_out;
dns_in.reset(host_, port_);
DNSRoutine::run(&dns_in, &dns_out);
dns_callback_internal(&dns_out, (unsigned int)-1, (unsigned int)-1);
insert_dns_ = false;
}
}
if (insert_dns_)
{
auto&& cb = std::bind(&WFResolverTask::dns_callback,
this,
std::placeholders::_1);
WFDNSTask *dns_task = WFTaskFactory::create_dns_task(host_, port_,
std::move(cb));
series_of(this)->push_front(dns_task);
}
this->subtask_done();
}
SubTask *WFResolverTask::done()
{
SeriesWork *series = series_of(this);
if (!insert_dns_)
{
if (this->callback)
this->callback(this);
delete this;
}
return series->pop();
}
void WFResolverTask::dns_callback_internal(DNSOutput *dns_out,
unsigned int ttl_default,
unsigned int ttl_min)
{
int dns_error = dns_out->get_error();
if (dns_error)
{
if (dns_error == EAI_SYSTEM)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
}
else
{
this->state = WFT_STATE_DNS_ERROR;
this->error = dns_error;
}
}
else
{
auto *route_manager = WFGlobal::get_route_manager();
auto *dns_cache = WFGlobal::get_dns_cache();
struct addrinfo *addrinfo = dns_out->move_addrinfo();
const DNSCache::DNSHandle *addr_handle;
addr_handle = dns_cache->put(host_, port_, addrinfo,
(unsigned int)ttl_default,
(unsigned int)ttl_min);
if (route_manager->get(type_, addrinfo, info_, &endpoint_params_,
this->result) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
}
else
this->state = WFT_STATE_SUCCESS;
dns_cache->release(addr_handle);
}
}
void WFResolverTask::dns_callback(WFDNSTask *dns_task)
{
if (dns_task->get_state() == WFT_STATE_SUCCESS)
dns_callback_internal(dns_task->get_output(), dns_ttl_default_, dns_ttl_min_);
else
{
this->state = dns_task->get_state();
this->error = dns_task->get_error();
}
if (this->callback)
this->callback(this);
delete this;
}
WFRouterTask *
WFDNSResolver::create(const struct WFNSParams *params, int dns_cache_level,
unsigned int dns_ttl_default, unsigned int dns_ttl_min,
const struct EndpointParams *endpoint_params,
router_callback_t&& callback)
{
return new WFResolverTask(params, dns_cache_level,
dns_ttl_default, dns_ttl_min,
endpoint_params, std::move(callback));
}
WFRouterTask *WFDNSResolver::create_router_task(const struct WFNSParams *params,
router_callback_t callback)
{
const auto *settings = WFGlobal::get_global_settings();
unsigned int dns_ttl_default = settings->dns_ttl_default;
unsigned int dns_ttl_min = settings->dns_ttl_min;
const struct EndpointParams *endpoint_params = &settings->endpoint_params;
int dns_cache_level = params->retry_times == 0 ? DNS_CACHE_LEVEL_2 :
DNS_CACHE_LEVEL_1;
return create(params, dns_cache_level, dns_ttl_default, dns_ttl_min,
endpoint_params, std::move(callback));
}

View File

@@ -0,0 +1,39 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFDNSRESOLVER_H_
#define _WFDNSRESOLVER_H_
#include "EndpointParams.h"
#include "WFNameService.h"
class WFDNSResolver : public WFNSPolicy
{
public:
virtual WFRouterTask *create_router_task(const struct WFNSParams *params,
router_callback_t callback);
protected:
WFRouterTask *create(const struct WFNSParams *params, int dns_cache_level,
unsigned int dns_ttl_default, unsigned int dns_ttl_min,
const struct EndpointParams *endpoint_params,
router_callback_t&& callback);
};
#endif

View File

@@ -0,0 +1,128 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#include <errno.h>
#include <stddef.h>
#include <string.h>
#include <pthread.h>
#include "rbtree.h"
#include "WFNameService.h"
struct WFNSPolicyEntry
{
struct rb_node rb;
WFNSPolicy *policy;
char name[1];
};
int WFNameService::add_policy(const char *name, WFNSPolicy *policy)
{
struct rb_node **p = &this->root.rb_node;
struct rb_node *parent = NULL;
struct WFNSPolicyEntry *entry;
int n, ret = -1;
pthread_rwlock_wrlock(&this->rwlock);
while (*p)
{
parent = *p;
entry = rb_entry(*p, struct WFNSPolicyEntry, rb);
n = strcasecmp(name, entry->name);
if (n < 0)
p = &(*p)->rb_left;
else if (n > 0)
p = &(*p)->rb_right;
else
break;
}
if (!*p)
{
size_t len = strlen(name);
size_t size = offsetof(struct WFNSPolicyEntry, name) + len + 1;
entry = (struct WFNSPolicyEntry *)malloc(size);
if (entry)
{
memcpy(entry->name, name, len + 1);
entry->policy = policy;
rb_link_node(&entry->rb, parent, p);
rb_insert_color(&entry->rb, &this->root);
ret = 0;
}
}
else
errno = EEXIST;
pthread_rwlock_unlock(&this->rwlock);
return ret;
}
inline struct WFNSPolicyEntry *WFNameService::get_policy_entry(const char *name)
{
struct rb_node *p = this->root.rb_node;
struct WFNSPolicyEntry *entry;
int n;
while (p)
{
entry = rb_entry(p, struct WFNSPolicyEntry, rb);
n = strcasecmp(name, entry->name);
if (n < 0)
p = p->rb_left;
else if (n > 0)
p = p->rb_right;
else
return entry;
}
return NULL;
}
WFNSPolicy *WFNameService::get_policy(const char *name)
{
WFNSPolicy *policy = this->default_policy;
struct WFNSPolicyEntry *entry;
pthread_rwlock_rdlock(&this->rwlock);
entry = this->get_policy_entry(name);
if (entry)
policy = entry->policy;
pthread_rwlock_unlock(&this->rwlock);
return policy;
}
WFNSPolicy *WFNameService::del_policy(const char *name)
{
WFNSPolicy *policy = NULL;
struct WFNSPolicyEntry *entry;
pthread_rwlock_wrlock(&this->rwlock);
entry = this->get_policy_entry(name);
if (entry)
{
policy = entry->policy;
rb_erase(&entry->rb, &this->root);
}
pthread_rwlock_unlock(&this->rwlock);
free(entry);
return policy;
}

126
src/factory/WFNameService.h Normal file
View File

@@ -0,0 +1,126 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Xie Han (xiehan@sogou-inc.com)
*/
#ifndef _WFNAMESERVICE_H_
#define _WFNAMESERVICE_H_
#include <pthread.h>
#include <functional>
#include <utility>
#include "rbtree.h"
#include "Communicator.h"
#include "Workflow.h"
#include "WFTask.h"
#include "RouteManager.h"
#include "URIParser.h"
#include "EndpointParams.h"
class WFRouterTask : public WFGenericTask
{
public:
RouteManager::RouteResult *get_result() { return &this->result; }
void *get_cookie() const { return this->cookie; }
protected:
RouteManager::RouteResult result;
void *cookie;
std::function<void (WFRouterTask *)> callback;
public:
void set_cookie(void *cookie) { this->cookie = cookie; }
protected:
virtual SubTask *done()
{
SeriesWork *series = series_of(this);
if (this->callback)
this->callback(this);
delete this;
return series->pop();
}
public:
WFRouterTask(std::function<void (WFRouterTask *)>&& cb) :
callback(std::move(cb))
{
this->cookie = NULL;
}
};
struct WFNSParams
{
TransportType type;
ParsedURI& uri;
const char *info;
bool fixed_addr;
int retry_times;
};
using router_callback_t = std::function<void (WFRouterTask *)>;
class WFNSPolicy
{
public:
virtual WFRouterTask *create_router_task(const struct WFNSParams *params,
router_callback_t callback) = 0;
virtual void success(RouteManager::RouteResult *result, void *cookie,
CommTarget *target)
{
RouteManager::notify_available(result->cookie, target);
}
virtual void failed(RouteManager::RouteResult *result, void *cookie,
CommTarget *target)
{
if (target)
RouteManager::notify_unavailable(result->cookie, target);
}
public:
virtual ~WFNSPolicy() { }
};
class WFNameService
{
public:
int add_policy(const char *name, WFNSPolicy *policy);
WFNSPolicy *get_policy(const char *name);
WFNSPolicy *del_policy(const char *name);
private:
WFNSPolicy *default_policy;
struct rb_root root;
pthread_rwlock_t rwlock;
private:
struct WFNSPolicyEntry *get_policy_entry(const char *name);
public:
WFNameService(WFNSPolicy *default_policy) :
rwlock(PTHREAD_RWLOCK_INITIALIZER)
{
this->root.rb_node = NULL;
this->default_policy = default_policy;
}
};
#endif

View File

@@ -18,17 +18,11 @@
*/
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <arpa/inet.h>
#include <ctype.h>
#include <string>
#include <mutex>
#include "list.h"
#include "rbtree.h"
#include "DNSRoutine.h"
#include "WFGlobal.h"
#include "WFTaskError.h"
#include "WFTaskFactory.h"
class __WFCounterTask;
@@ -470,200 +464,3 @@ WFFileSyncTask *WFTaskFactory::create_fdsync_task(int fd,
std::move(callback));
}
/********RouterTask*************/
void WFRouterTask::dispatch()
{
insert_dns_ = true;
if (dns_cache_level_ != DNS_CACHE_LEVEL_0)
{
auto *dns_cache = WFGlobal::get_dns_cache();
const DNSHandle *addr_handle = NULL;
switch (dns_cache_level_)
{
case DNS_CACHE_LEVEL_1:
addr_handle = dns_cache->get_confident(host_, port_);
break;
case DNS_CACHE_LEVEL_2:
addr_handle = dns_cache->get_ttl(host_, port_);
break;
case DNS_CACHE_LEVEL_3:
addr_handle = dns_cache->get(host_, port_);
break;
default:
break;
}
if (addr_handle)
{
if (addr_handle->value.addrinfo)
{
auto *route_manager = WFGlobal::get_route_manager();
struct addrinfo *addrinfo = addr_handle->value.addrinfo;
struct addrinfo first;
if (first_addr_only_ && addrinfo->ai_next)
{
first = *addrinfo;
first.ai_next = NULL;
addrinfo = &first;
}
if (route_manager->get(type_, addrinfo,
info_, &endpoint_params_,
route_result_) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
}
else if (!route_result_.request_object)
{
//should not happen
this->state = WFT_STATE_SYS_ERROR;
this->error = EAGAIN;
}
else
this->state = WFT_STATE_SUCCESS;
insert_dns_ = false;
}
dns_cache->release(addr_handle);
}
}
if (insert_dns_ && !host_.empty())
{
char front = host_.front();
char back = host_.back();
struct in6_addr addr;
int ret;
if (host_.find(':') != std::string::npos)
ret = inet_pton(AF_INET6, host_.c_str(), &addr);
else if (isdigit(back) && isdigit(front))
ret = inet_pton(AF_INET, host_.c_str(), &addr);
else if (front == '/')
ret = 1;
else
ret = 0;
if (ret == 1)
{
DNSInput dns_in;
DNSOutput dns_out;
dns_in.reset(host_, port_);
DNSRoutine::run(&dns_in, &dns_out);
dns_callback_internal(&dns_out, (unsigned int)-1, (unsigned int)-1);
insert_dns_ = false;
}
}
if (insert_dns_)
{
auto&& cb = std::bind(&WFRouterTask::dns_callback,
this,
std::placeholders::_1);
WFDNSTask *dns_task = WFTaskFactory::create_dns_task(host_, port_,
std::move(cb));
series_of(this)->push_front(dns_task);
}
this->subtask_done();
}
SubTask* WFRouterTask::done()
{
SeriesWork *series = series_of(this);
if (!insert_dns_)
{
if (callback_)
callback_(this);
delete this;
}
return series->pop();
}
void WFRouterTask::dns_callback_internal(DNSOutput *dns_out,
unsigned int ttl_default,
unsigned int ttl_min)
{
int dns_error = dns_out->get_error();
if (dns_error)
{
if (dns_error == EAI_SYSTEM)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
}
else
{
this->state = WFT_STATE_DNS_ERROR;
this->error = dns_error;
}
}
else
{
struct addrinfo *addrinfo = dns_out->move_addrinfo();
const DNSHandle *addr_handle;
if (addrinfo)
{
auto *route_manager = WFGlobal::get_route_manager();
auto *dns_cache = WFGlobal::get_dns_cache();
addr_handle = dns_cache->put(host_, port_, addrinfo,
(unsigned int)ttl_default,
(unsigned int)ttl_min);
if (route_manager->get(type_, addrinfo, info_, &endpoint_params_,
route_result_) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
}
else if (!route_result_.request_object)
{
//should not happen
this->state = WFT_STATE_SYS_ERROR;
this->error = EAGAIN;
}
else
this->state = WFT_STATE_SUCCESS;
dns_cache->release(addr_handle);
}
else
{
//system promise addrinfo not null, here should not happen
this->state = WFT_STATE_SYS_ERROR;
this->error = EINVAL;
}
}
}
void WFRouterTask::dns_callback(WFDNSTask *dns_task)
{
if (dns_task->get_state() == WFT_STATE_SUCCESS)
dns_callback_internal(dns_task->get_output(), dns_ttl_default_, dns_ttl_min_);
else
{
this->state = dns_task->get_state();
this->error = dns_task->get_error();
}
if (callback_)
callback_(this);
delete this;
}

View File

@@ -88,15 +88,15 @@ using counter_callback_t = std::function<void (WFCounterTask *)>;
// Graph (DAG) task.
using graph_callback_t = std::function<void (WFGraphTask *)>;
// DNS task. For internal usage only.
using WFDNSTask = WFThreadTask<DNSInput, DNSOutput>;
using dns_callback_t = std::function<void (WFDNSTask *)>;
using WFEmptyTask = WFGenericTask;
using WFDynamicTask = WFGenericTask;
using dynamic_create_t = std::function<SubTask *(WFDynamicTask *)>;
// DNS task. For internal usage only.
using WFDNSTask = WFThreadTask<DNSInput, DNSOutput>;
using dns_callback_t = std::function<void (WFDNSTask *)>;
class WFTaskFactory
{
public:

View File

@@ -13,8 +13,8 @@
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
Xie Han (xiehan@sogou-inc.com)
Authors: Xie Han (xiehan@sogou-inc.com)
Wu Jiaxu (wujiaxu@sogou-inc.com)
Li Yingxin (liyingxin@sogou-inc.com)
*/
@@ -31,11 +31,11 @@
#include "WFGlobal.h"
#include "Workflow.h"
#include "WFTask.h"
#include "UpstreamManager.h"
#include "RouteManager.h"
#include "URIParser.h"
#include "WFTaskError.h"
#include "EndpointParams.h"
#include "WFNameService.h"
class __WFTimerTask : public WFTimerTask
{
@@ -131,102 +131,6 @@ WFTaskFactory::create_dynamic_task(dynamic_create_t create)
return new __WFDynamicTask(std::move(create));
}
/**********WFComplexClientTask**********/
// If you design Derived WFComplexClientTask, You have two choices:
// 1) First choice will upstream by uri, then dns/dns-cache
// 2) Second choice will directly communicate without upstream/dns/dns-cache
// 1) First choice:
// step 1. Child-Constructor call Father-Constructor to new WFComplexClientTask
// step 2. call init(uri)
// step 3. call set_type(type)
// step 4. call set_info(info) or do nothing with info
// 2) Second choice:
// step 1. Child-Constructor call Father-Constructor to new WFComplexClientTask
// step 2. call init(type, addr, addrlen, info)
// Some optional APIs for you to implement:
// [WFComplexTask]
// [ChildrenComplexTask]
// 1. init()
// init_succ() or init_failed(); // default: return true;
// 2. dispatch();
// check_request(); // default: return true;
// route(); // default:DNS; goto 1;
// 3. message_out();
// 4. message_in();
// 5. keep_alive_timeout();
// 6. done();
// finish_once(); // default: return true; means this is user request.
// // If redirect or retry: goto 1;
/*
DNS_CACHE_LEVEL_0 -> NO cache
DNS_CACHE_LEVEL_1 -> TTL MIN
DNS_CACHE_LEVEL_2 -> TTL [DEFAULT]
DNS_CACHE_LEVEL_3 -> Forever
*/
#define DNS_CACHE_LEVEL_0 0
#define DNS_CACHE_LEVEL_1 1
#define DNS_CACHE_LEVEL_2 2
#define DNS_CACHE_LEVEL_3 3
class WFRouterTask : public WFGenericTask
{
private:
using router_callback_t = std::function<void (WFRouterTask *)>;
using WFDNSTask = WFThreadTask<DNSInput, DNSOutput>;
public:
RouteManager::RouteResult route_result_;
WFRouterTask(TransportType type,
const std::string& host,
unsigned short port,
const std::string& info,
int dns_cache_level,
unsigned int dns_ttl_default,
unsigned int dns_ttl_min,
const struct EndpointParams *endpoint_params,
bool first_addr_only,
router_callback_t&& callback) :
type_(type),
host_(host),
port_(port),
info_(info),
dns_cache_level_(dns_cache_level),
dns_ttl_default_(dns_ttl_default),
dns_ttl_min_(dns_ttl_min),
endpoint_params_(*endpoint_params),
first_addr_only_(first_addr_only),
callback_(std::move(callback))
{}
private:
virtual void dispatch();
virtual SubTask *done();
void dns_callback(WFDNSTask *dns_task);
void dns_callback_internal(DNSOutput *dns_task,
unsigned int ttl_default,
unsigned int ttl_min);
private:
TransportType type_;
std::string host_;
unsigned short port_;
std::string info_;
int dns_cache_level_;
unsigned int dns_ttl_default_;
unsigned int dns_ttl_min_;
struct EndpointParams endpoint_params_;
bool first_addr_only_;
bool insert_dns_;
router_callback_t callback_;
};
template<class REQ, class RESP, typename CTX = bool>
class WFComplexClientTask : public WFClientTask<REQ, RESP>
{
@@ -234,24 +138,24 @@ protected:
using task_callback_t = std::function<void (WFNetworkTask<REQ, RESP> *)>;
public:
WFComplexClientTask(int retry_max, task_callback_t&& callback):
WFClientTask<REQ, RESP>(NULL, WFGlobal::get_scheduler(),
std::move(callback)),
retry_max_(retry_max),
first_addr_only_(false),
router_task_(NULL),
type_(TT_TCP),
retry_times_(0),
is_retry_(false),
redirect_(false)
{}
WFComplexClientTask(int retry_max, task_callback_t&& cb):
WFClientTask<REQ, RESP>(NULL, WFGlobal::get_scheduler(), std::move(cb))
{
type_ = TT_TCP;
fixed_addr_ = false;
retry_max_ = retry_max;
retry_times_ = 0;
redirect_ = false;
ns_policy_ = NULL;
router_task_ = NULL;
}
protected:
// new api for children
virtual bool init_success() { return true; }
virtual void init_failed() {}
virtual bool check_request() { return true; }
virtual SubTask *route();
virtual WFRouterTask *route();
virtual bool finish_once() { return true; }
public:
@@ -272,16 +176,12 @@ public:
void init(const ParsedURI& uri)
{
is_sockaddr_ = false;
init_state_ = 0;
uri_ = uri;
init_with_uri();
}
void init(ParsedURI&& uri)
{
is_sockaddr_ = false;
init_state_ = 0;
uri_ = std::move(uri);
init_with_uri();
}
@@ -309,25 +209,6 @@ public:
}
protected:
void set_redirect()
{
redirect_ = true;
retry_times_ = 0;
}
void set_retry(const ParsedURI& uri)
{
redirect_ = true;
init(uri);
retry_times_++;
}
void set_retry()
{
redirect_ = true;
retry_times_++;
}
virtual void dispatch();
virtual SubTask *done();
@@ -347,13 +228,19 @@ protected:
TransportType get_transport_type() const { return type_; }
protected:
TransportType type_;
ParsedURI uri_;
int retry_max_;
bool is_sockaddr_;
bool first_addr_only_;
std::string info_;
bool fixed_addr_;
bool redirect_;
CTX ctx_;
SubTask *router_task_;
int retry_max_;
int retry_times_;
WFNSPolicy *ns_policy_;
WFRouterTask *router_task_;
RouteManager::RouteResult route_result_;
void *cookie_;
public:
CTX *get_mutable_ctx() { return &ctx_; }
@@ -361,21 +248,8 @@ public:
private:
void init_with_uri();
bool set_port();
void router_callback(SubTask *task); // default: DNS
void router_callback(WFRouterTask *task);
void switch_callback(WFTimerTask *task);
RouteManager::RouteResult route_result_;
UpstreamManager::UpstreamResult upstream_result_;
TransportType type_;
std::string info_;
int retry_times_;
/* state 0: uninited or failed; 1: inited but not checked; 2: checked. */
char init_state_;
bool is_retry_;
bool redirect_;
};
template<class REQ, class RESP, typename CTX>
@@ -384,42 +258,34 @@ void WFComplexClientTask<REQ, RESP, CTX>::init(TransportType type,
socklen_t addrlen,
const std::string& info)
{
is_sockaddr_ = true;
init_state_ = 0;
type_ = type;
info_.assign(info);
struct addrinfo addrinfo;
const auto *params = &WFGlobal::get_global_settings()->endpoint_params;
if (redirect_)
{
ns_policy_ = NULL;
route_result_.clear();
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
this->timeout_reason = TOR_NOT_TIMEOUT;
}
addrinfo.ai_addrlen = addrlen;
addrinfo.ai_addr = (struct sockaddr *)addr;
addrinfo.ai_canonname = NULL;
addrinfo.ai_next = NULL;
addrinfo.ai_flags = 0;
const auto *params = &WFGlobal::get_global_settings()->endpoint_params;
struct addrinfo addrinfo = { };
addrinfo.ai_family = addr->sa_family;
addrinfo.ai_socktype = SOCK_STREAM;
addrinfo.ai_protocol = 0;
addrinfo.ai_addr = (struct sockaddr *)addr;
addrinfo.ai_addrlen = addrlen;
type_ = type;
info_.assign(info);
if (WFGlobal::get_route_manager()->get(type, &addrinfo, info_, params,
route_result_) < 0)
route_result_) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
}
else if (!route_result_.request_object)
{
//should not happen
this->state = WFT_STATE_SYS_ERROR;
this->error = EAGAIN;
}
else
{
init_state_ = this->init_success() ? 1 : 0;
else if (this->init_success())
return;
}
this->init_failed();
return;
}
template<class REQ, class RESP, typename CTX>
@@ -469,164 +335,97 @@ bool WFComplexClientTask<REQ, RESP, CTX>::set_port()
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::init_with_uri()
{
route_result_.clear();
if (uri_.state == URI_STATE_SUCCESS && this->set_port())
if (redirect_)
{
int ret = UpstreamManager::choose(uri_, upstream_result_);
ns_policy_ = NULL;
route_result_.clear();
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
this->timeout_reason = TOR_NOT_TIMEOUT;
}
if (ret < 0)
if (uri_.state == URI_STATE_SUCCESS)
{
if (this->set_port())
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
}
else if (upstream_result_.state == UPSTREAM_ALL_DOWN)
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_UPSTREAM_UNAVAILABLE;
}
else
{
init_state_ = this->init_success() ? 1 : 0;
return;
if (this->init_success())
return;
}
}
else if (uri_.state == URI_STATE_ERROR)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = uri_.error;
}
else
{
if (uri_.state == URI_STATE_ERROR)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = uri_.error;
}
else
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_PARSE_FAILED;
}
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_PARSE_FAILED;
}
this->init_failed();
return;
}
template<class REQ, class RESP, typename CTX>
SubTask *WFComplexClientTask<REQ, RESP, CTX>::route()
WFRouterTask *WFComplexClientTask<REQ, RESP, CTX>::route()
{
unsigned int dns_ttl_default;
unsigned int dns_ttl_min;
const struct EndpointParams *endpoint_params;
int dns_cache_level = (is_retry_ ? DNS_CACHE_LEVEL_1
: DNS_CACHE_LEVEL_2);
WFNameService *ns = WFGlobal::get_name_service();
auto&& cb = std::bind(&WFComplexClientTask::router_callback,
this,
std::placeholders::_1);
is_retry_ = false;//route means refresh DNS cache level
if (upstream_result_.state == UPSTREAM_SUCCESS)
{
const auto *params = upstream_result_.address_params;
dns_ttl_default = params->dns_ttl_default;
dns_ttl_min = params->dns_ttl_min;
endpoint_params = &params->endpoint_params;
}
else
{
const auto *params = WFGlobal::get_global_settings();
dns_ttl_default = params->dns_ttl_default;
dns_ttl_min = params->dns_ttl_min;
endpoint_params = &params->endpoint_params;
}
return new WFRouterTask(type_, uri_.host ? uri_.host : "",
uri_.port ? atoi(uri_.port) : 0, info_,
dns_cache_level, dns_ttl_default, dns_ttl_min,
endpoint_params, first_addr_only_, std::move(cb));
struct WFNSParams params = {
.type = type_,
.uri = uri_,
.info = info_.c_str(),
.fixed_addr = fixed_addr_,
.retry_times = retry_times_,
};
ns_policy_ = ns->get_policy(uri_.host ? uri_.host : "");
return ns_policy_->create_router_task(&params, cb);
}
/*
* router callback`s obligation:
* if success:
* 1. set route_result_ or call this->init()
* 2. series->push_front(ORIGIN_TASK)
* if failed:
* 1. this->finish_once() is optional;
* 2. this->callback() is necessary;
*/
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::router_callback(SubTask *task)
void WFComplexClientTask<REQ, RESP, CTX>::router_callback(WFRouterTask *task)
{
WFRouterTask *router_task = static_cast<WFRouterTask *>(task);
int state = router_task->get_state();
if (state == WFT_STATE_SUCCESS)
route_result_ = router_task->route_result_;
else
this->state = task->get_state();
if (this->state == WFT_STATE_SUCCESS)
{
this->state = state;
this->error = router_task->get_error();
route_result_ = std::move(*task->get_result());
cookie_ = task->get_cookie();
}
if (route_result_.request_object)
series_of(this)->push_front(this);
else
else if (this->state == WFT_STATE_UNDEFINED)
{
UpstreamManager::notify_unavailable(upstream_result_.cookie);
if (this->callback)
this->callback(this);
if (redirect_)
{
init_state_ = this->init_success() ? 1 : 0;
redirect_ = false;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
this->timeout_reason = TOR_NOT_TIMEOUT;
series_of(this)->push_front(this);
}
else
delete this;
/* should not happend */
this->state = WFT_STATE_SYS_ERROR;
this->error = ENOSYS;
}
else
this->error = task->get_error();
}
template<class REQ, class RESP, typename CTX>
void WFComplexClientTask<REQ, RESP, CTX>::dispatch()
{
// 1. children check_request()
if (init_state_ == 1)
init_state_ = this->check_request() ? 2 : 0;
if (init_state_)
switch (this->state)
{
if (route_result_.request_object)
case WFT_STATE_UNDEFINED:
if (this->check_request())
{
// 2. origin task dispatch()
this->set_request_object(route_result_.request_object);
this->WFClientTask<REQ, RESP>::dispatch();
return;
}
if (this->route_result_.request_object)
{
case WFT_STATE_SUCCESS:
this->set_request_object(route_result_.request_object);
this->WFClientTask<REQ, RESP>::dispatch();
return;
}
if (is_sockaddr_ || uri_.state == URI_STATE_SUCCESS)
{
// 3. DNS route() or children route()
router_task_ = this->route();
series_of(this)->push_front(this);
series_of(this)->push_front(router_task_);
}
else
{
if (uri_.state == URI_STATE_ERROR)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = uri_.error;
}
else
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_PARSE_FAILED;
}
}
default:
break;
}
this->subtask_done();
@@ -643,21 +442,15 @@ void WFComplexClientTask<REQ, RESP, CTX>::switch_callback(WFTimerTask *)
this->error = -this->error;
}
// 4. children finish before user callback
if (this->callback)
this->callback(this);
}
if (redirect_)
{
init_state_ = this->init_success() ? 1 : 0;
redirect_ = false;
clear_resp();
this->target = NULL;
this->timeout_reason = TOR_NOT_TIMEOUT;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
series_of(this)->push_front(this);
}
else
@@ -669,53 +462,43 @@ SubTask *WFComplexClientTask<REQ, RESP, CTX>::done()
{
SeriesWork *series = series_of(this);
// 1. routing
if (router_task_)
{
router_task_ = NULL;
return series->pop();
}
if (init_state_)
bool is_user_request = this->finish_once();
if (ns_policy_ && route_result_.request_object)
{
// 2. children can set_redirect() here
bool is_user_request = this->finish_once();
// 3. complex task success
if (this->state == WFT_STATE_SUCCESS)
{
RouteManager::notify_available(route_result_.cookie, this->target);
UpstreamManager::notify_available(upstream_result_.cookie);
upstream_result_.clear();
// 4. children message out sth. else
if (!is_user_request)
return this;
}
else if (this->state == WFT_STATE_SYS_ERROR)
{
if (this->target)
{
RouteManager::notify_unavailable(route_result_.cookie,
this->target);
}
if (this->state == WFT_STATE_SYS_ERROR)
ns_policy_->failed(&route_result_, cookie_, this->target);
else
ns_policy_->success(&route_result_, cookie_, this->target);
}
UpstreamManager::notify_unavailable(upstream_result_.cookie);
// 5. complex task failed: retry
if (retry_times_ < retry_max_)
{
if (is_sockaddr_)
set_retry();
else
set_retry(uri_);
is_retry_ = true; // will influence next round dns cache time
}
if (this->state == WFT_STATE_SUCCESS)
{
if (!is_user_request)
return this;
}
else if (this->state == WFT_STATE_SYS_ERROR)
{
if (retry_times_ < retry_max_)
{
redirect_ = true;
this->state = WFT_STATE_UNDEFINED;
this->error = 0;
this->timeout_reason = 0;
retry_times_++;
}
}
/*
* When target is NULL, it's very likely that we are still in the
* 'dispatch' thread. Running a timer will switch callback function
* to a handler thread, and this can prevent stack overflow.
* When target is NULL, it's very likely that we are in the caller's
* thread or DNS thread (dns failed). Running a timer will switch callback
* function to a handler thread, and this can prevent stack overflow.
*/
if (!this->target)
{
@@ -723,7 +506,6 @@ SubTask *WFComplexClientTask<REQ, RESP, CTX>::done()
this,
std::placeholders::_1);
WFTimerTask *timer = WFTaskFactory::create_timer_task(0, std::move(cb));
series->push_front(timer);
}
else

View File

@@ -6,6 +6,7 @@ set(SRC
UpstreamManager.cc
RouteManager.cc
WFGlobal.cc
UpstreamPolicies.cc
)
add_library(${PROJECT_NAME} OBJECT ${SRC})

View File

@@ -25,7 +25,7 @@
#define CONFIDENT_INC 10
#define TTL_INC 10
const DNSHandle *DNSCache::get_inner(const HostPort& host_port, int type)
const DNSCache::DNSHandle *DNSCache::get_inner(const HostPort& host_port, int type)
{
const DNSHandle *handle = cache_pool_.get(host_port);
@@ -73,10 +73,10 @@ const DNSHandle *DNSCache::get_inner(const HostPort& host_port, int type)
return handle;
}
const DNSHandle *DNSCache::put(const HostPort& host_port,
struct addrinfo *addrinfo,
unsigned int dns_ttl_default,
unsigned int dns_ttl_min)
const DNSCache::DNSHandle *DNSCache::put(const HostPort& host_port,
struct addrinfo *addrinfo,
unsigned int dns_ttl_default,
unsigned int dns_ttl_min)
{
int64_t expire_time;
int64_t confident_time;

View File

@@ -47,36 +47,75 @@ public:
}
};
typedef std::pair<std::string, unsigned short> HostPort;
typedef LRUHandle<HostPort, DNSCacheValue> DNSHandle;
// RAII: NO. Release handle by user
// Thread safety: YES
// MUST call release when handle no longer used
class DNSCache
{
public:
using HostPort = std::pair<std::string, unsigned short>;
using DNSHandle = LRUHandle<HostPort, DNSCacheValue>;
public:
// release handle by get/put
void release(DNSHandle *handle);
void release(const DNSHandle *handle);
void release(DNSHandle *handle)
{
cache_pool_.release(handle);
}
void release(const DNSHandle *handle)
{
cache_pool_.release(handle);
}
// get handler
// Need call release when handle no longer needed
//Handle *get(const KEY &key);
const DNSHandle *get(const HostPort& host_port);
const DNSHandle *get(const std::string& host, unsigned short port);
const DNSHandle *get(const char *host, unsigned short port);
const DNSHandle *get(const HostPort& host_port)
{
return cache_pool_.get(host_port);
}
const DNSHandle *get_ttl(const HostPort& host_port);
const DNSHandle *get_ttl(const std::string& host, unsigned short port);
const DNSHandle *get_ttl(const char *host, unsigned short port);
const DNSHandle *get(const std::string& host, unsigned short port)
{
return get(HostPort(host, port));
}
const DNSHandle *get_confident(const HostPort& host_port);
const DNSHandle *get_confident(const std::string& host, unsigned short port);
const DNSHandle *get_confident(const char *host, unsigned short port);
const DNSHandle *get(const char *host, unsigned short port)
{
return get(std::string(host), port);
}
const DNSHandle *get_ttl(const HostPort& host_port)
{
return get_inner(host_port, GET_TYPE_TTL);
}
const DNSHandle *get_ttl(const std::string& host, unsigned short port)
{
return get_ttl(HostPort(host, port));
}
const DNSHandle *get_ttl(const char *host, unsigned short port)
{
return get_ttl(std::string(host), port);
}
const DNSHandle *get_confident(const HostPort& host_port)
{
return get_inner(host_port, GET_TYPE_CONFIDENT);
}
const DNSHandle *get_confident(const std::string& host, unsigned short port)
{
return get_confident(HostPort(host, port));
}
const DNSHandle *get_confident(const char *host, unsigned short port)
{
return get_confident(std::string(host), port);
}
// put copy
// Need call release when handle no longer needed
const DNSHandle *put(const HostPort& host_port,
struct addrinfo *addrinfo,
unsigned int dns_ttl_default,
@@ -86,18 +125,35 @@ public:
unsigned short port,
struct addrinfo *addrinfo,
unsigned int dns_ttl_default,
unsigned int dns_ttl_min);
unsigned int dns_ttl_min)
{
return put(HostPort(host, port), addrinfo, dns_ttl_default, dns_ttl_min);
}
const DNSHandle *put(const char *host,
unsigned short port,
struct addrinfo *addrinfo,
unsigned int dns_ttl_default,
unsigned int dns_ttl_min);
unsigned int dns_ttl_min)
{
return put(std::string(host), port, addrinfo, dns_ttl_default, dns_ttl_min);
}
// delete from cache, deleter delay called when all inuse-handle release.
void del(const HostPort& key);
void del(const std::string& host, unsigned short port);
void del(const char *host, unsigned short port);
void del(const HostPort& key)
{
cache_pool_.del(key);
}
void del(const std::string& host, unsigned short port)
{
del(HostPort(host, port));
}
void del(const char *host, unsigned short port)
{
del(std::string(host), port);
}
private:
const DNSHandle *get_inner(const HostPort& host_port, int type);
@@ -106,95 +162,5 @@ private:
LRUCache<HostPort, DNSCacheValue, ValueDeleter> cache_pool_;
};
////////////////////
inline void DNSCache::release(DNSHandle *handle)
{
cache_pool_.release(handle);
}
inline void DNSCache::release(const DNSHandle *handle)
{
cache_pool_.release(handle);
}
inline const DNSHandle *DNSCache::get(const HostPort& host_port)
{
return cache_pool_.get(host_port);
}
inline const DNSHandle *DNSCache::get(const std::string& host, unsigned short port)
{
return get(HostPort(host, port));
}
inline const DNSHandle *DNSCache::get(const char *host, unsigned short port)
{
return get(std::string(host), port);
}
inline const DNSHandle *DNSCache::get_ttl(const HostPort& host_port)
{
return get_inner(host_port, GET_TYPE_TTL);
}
inline const DNSHandle *DNSCache::get_ttl(const std::string& host, unsigned short port)
{
return get_ttl(HostPort(host, port));
}
inline const DNSHandle *DNSCache::get_ttl(const char *host, unsigned short port)
{
return get_ttl(std::string(host), port);
}
inline const DNSHandle *DNSCache::get_confident(const HostPort& host_port)
{
return get_inner(host_port, GET_TYPE_CONFIDENT);
}
inline const DNSHandle *DNSCache::get_confident(const std::string& host, unsigned short port)
{
return get_confident(HostPort(host, port));
}
inline const DNSHandle *DNSCache::get_confident(const char *host, unsigned short port)
{
return get_confident(std::string(host), port);
}
inline const DNSHandle *DNSCache::put(const std::string& host,
unsigned short port,
struct addrinfo *addrinfo,
unsigned int dns_ttl_default,
unsigned int dns_ttl_min)
{
return put(HostPort(host, port), addrinfo, dns_ttl_default, dns_ttl_min);
}
inline const DNSHandle *DNSCache::put(const char *host,
unsigned short port,
struct addrinfo *addrinfo,
unsigned int dns_ttl_default,
unsigned int dns_ttl_min)
{
return put(std::string(host), port, addrinfo, dns_ttl_default, dns_ttl_min);
}
inline void DNSCache::del(const HostPort& key)
{
cache_pool_.del(key);
}
inline void DNSCache::del(const std::string& host, unsigned short port)
{
del(HostPort(host, port));
}
inline void DNSCache::del(const char *host, unsigned short port)
{
del(std::string(host), port);
}
#endif

View File

@@ -39,7 +39,7 @@ public:
CommSchedObject *request_object;
public:
// RouteResult(): cookie(NULL), request_object(NULL) { }
RouteResult(): cookie(NULL), request_object(NULL) { }
void clear() { cookie = NULL; request_object = NULL; }
};

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,7 @@
#include <functional>
#include "URIParser.h"
#include "EndpointParams.h"
#include "WFGlobal.h"
/**
* @file UpstreamManager.h
@@ -234,39 +235,6 @@ public:
const std::string& address,
const struct AddressParams *address_params);
public:
/// @brief Internal use only
class UpstreamResult
{
public:
void *cookie;
const struct AddressParams *address_params;
#define UPSTREAM_SUCCESS 0
#define UPSTREAM_NOTFOUND 1
#define UPSTREAM_ALL_DOWN 2
int state;
public:
UpstreamResult():
cookie(NULL),
address_params(NULL),
state(UPSTREAM_NOTFOUND)
{}
void clear()
{
cookie = NULL;
address_params = NULL;
state = UPSTREAM_NOTFOUND;
}
};
/// @brief Internal use only
static int choose(ParsedURI& uri, UpstreamResult& result);
/// @brief Internal use only
static void notify_unavailable(void *cookie);
/// @brief Internal use only
static void notify_available(void *cookie);
};
#endif

View File

@@ -0,0 +1,796 @@
/*
Copyright (c) 2021 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
*/
#include <algorithm>
#include "StringUtil.h"
#include "UpstreamPolicies.h"
#include "WFDNSResolver.h"
#define DNS_CACHE_LEVEL_1 1
#define DNS_CACHE_LEVEL_2 2
class WFSelectorFailTask : public WFRouterTask
{
public:
WFSelectorFailTask(router_callback_t&& cb)
: WFRouterTask(std::move(cb))
{
}
virtual void dispatch()
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_UPSTREAM_UNAVAILABLE;
return this->subtask_done();
}
};
static bool copy_host_port(ParsedURI& uri, const EndpointAddress *addr)
{
char *host = NULL;
char *port = NULL;
if (!addr->host.empty())
{
host = strdup(addr->host.c_str());
if (!host)
return false;
}
if (addr->port_value > 0)
{
port = strdup(addr->port.c_str());
if (!port)
{
free(host);
return false;
}
free(uri.port);
uri.port = port;
}
free(uri.host);
uri.host = host;
return true;
}
EndpointAddress::EndpointAddress(const std::string& address,
const struct AddressParams *address_params)
{
std::vector<std::string> arr = StringUtil::split(address, ':');
this->params = *address_params;
this->address = address;
this->list.next = NULL;
this->fail_count = 0;
static std::hash<std::string> std_hash;
for (int i = 0; i < VIRTUAL_GROUP_SIZE; i++)
this->consistent_hash[i] = std_hash(address + "|v" + std::to_string(i));
if (this->params.weight == 0)
this->params.weight = 1;
if (this->params.max_fails == 0)
this->params.max_fails = 1;
if (this->params.group_id < 0)
this->params.group_id = -1;
if (arr.size() == 0)
this->host = "";
else
this->host = arr[0];
if (arr.size() <= 1)
{
this->port = "";
this->port_value = 0;
}
else
{
this->port = arr[1];
this->port_value = atoi(arr[1].c_str());
}
}
WFRouterTask *UPSPolicy::create_router_task(const struct WFNSParams *params,
router_callback_t callback)
{
EndpointAddress *addr;
WFRouterTask *task;
if (this->select(params->uri, &addr) && copy_host_port(params->uri, addr))
{
unsigned int dns_ttl_default = addr->params.dns_ttl_default;
unsigned int dns_ttl_min = addr->params.dns_ttl_min;
const struct EndpointParams *endpoint_params = &addr->params.endpoint_params;
int dns_cache_level = params->retry_times == 0 ? DNS_CACHE_LEVEL_2 :
DNS_CACHE_LEVEL_1;
task = this->create(params, dns_cache_level, dns_ttl_default, dns_ttl_min,
endpoint_params, std::move(callback));
task->set_cookie(addr);
}
else
task = new WFSelectorFailTask(std::move(callback));
return task;
}
inline void UPSPolicy::recover_server_from_breaker(EndpointAddress *addr)
{
addr->fail_count = 0;
pthread_mutex_lock(&this->breaker_lock);
if (addr->list.next)
{
list_del(&addr->list);
addr->list.next = NULL;
this->recover_one_server(addr);
//this->server_list_change();
}
pthread_mutex_unlock(&this->breaker_lock);
}
inline void UPSPolicy::fuse_server_to_breaker(EndpointAddress *addr)
{
pthread_mutex_lock(&this->breaker_lock);
if (!addr->list.next)
{
addr->broken_timeout = GET_CURRENT_SECOND + MTTR_SECOND;
list_add_tail(&addr->list, &this->breaker_list);
this->fuse_one_server(addr);
//this->server_list_change();
}
pthread_mutex_unlock(&this->breaker_lock);
}
void UPSPolicy::success(RouteManager::RouteResult *result, void *cookie,
CommTarget *target)
{
pthread_rwlock_rdlock(&this->rwlock);
this->recover_server_from_breaker((EndpointAddress *)cookie);
pthread_rwlock_unlock(&this->rwlock);
WFDNSResolver::success(result, NULL, target);
}
void UPSPolicy::failed(RouteManager::RouteResult *result, void *cookie,
CommTarget *target)
{
EndpointAddress *server = (EndpointAddress *)cookie;
pthread_rwlock_rdlock(&this->rwlock);
size_t fail_count = ++server->fail_count;
if (fail_count == server->params.max_fails)
this->fuse_server_to_breaker(server);
pthread_rwlock_unlock(&this->rwlock);
WFDNSResolver::failed(result, NULL, target);
}
void UPSPolicy::check_breaker()
{
pthread_mutex_lock(&this->breaker_lock);
if (!list_empty(&this->breaker_list))
{
int64_t cur_time = GET_CURRENT_SECOND;
struct list_head *pos, *tmp;
EndpointAddress *addr;
list_for_each_safe(pos, tmp, &this->breaker_list)
{
addr = list_entry(pos, EndpointAddress, list);
if (cur_time >= addr->broken_timeout)
{
if (addr->fail_count >= addr->params.max_fails)
{
addr->fail_count = addr->params.max_fails - 1;
this->recover_one_server(addr);
}
list_del(pos);
addr->list.next = NULL;
}
}
}
pthread_mutex_unlock(&this->breaker_lock);
//this->server_list_change();
}
const EndpointAddress *UPSPolicy::first_stradegy(const ParsedURI& uri)
{
unsigned int idx = rand() % this->servers.size();
return this->servers[idx];
}
const EndpointAddress *UPSPolicy::another_stradegy(const ParsedURI& uri)
{
return this->first_stradegy(uri);
}
bool UPSPolicy::select(const ParsedURI& uri, EndpointAddress **addr)
{
pthread_rwlock_rdlock(&this->rwlock);
unsigned int n = (unsigned int)this->servers.size();
if (n == 0)
{
pthread_rwlock_unlock(&this->rwlock);
return false;
}
this->check_breaker();
if (this->nalives == 0)
{
pthread_rwlock_unlock(&this->rwlock);
return false;
}
// select_addr == NULL will only happened in consistent_hash
const EndpointAddress *select_addr = this->first_stradegy(uri);
if (!select_addr || select_addr->fail_count >= select_addr->params.max_fails)
{
if (this->try_another)
select_addr = this->another_stradegy(uri);
}
pthread_rwlock_unlock(&this->rwlock);
if (select_addr)
{
*addr = (EndpointAddress *)select_addr;
return true;
}
return false;
}
void UPSPolicy::add_server_locked(EndpointAddress *addr)
{
this->addresses.push_back(addr);
this->server_map[addr->address].push_back(addr);
this->servers.push_back(addr);
this->recover_one_server(addr);
}
int UPSPolicy::remove_server_locked(const std::string& address)
{
const auto map_it = this->server_map.find(address);
if (map_it != this->server_map.cend())
{
for (EndpointAddress *addr : map_it->second)
{
// or not: it has already been -- in nalives
if (addr->fail_count < addr->params.max_fails)
this->fuse_one_server(addr);
}
this->server_map.erase(map_it);
}
size_t n = this->servers.size();
size_t new_n = 0;
for (size_t i = 0; i < n; i++)
{
if (this->servers[i]->address != address)
{
if (new_n != i)
this->servers[new_n++] = this->servers[i];
else
new_n++;
}
}
int ret = 0;
if (new_n < n)
{
this->servers.resize(new_n);
ret = n - new_n;
}
return ret;
}
void UPSPolicy::add_server(const std::string& address,
const AddressParams *address_params)
{
EndpointAddress *addr = new EndpointAddress(address, address_params);
pthread_rwlock_wrlock(&this->rwlock);
this->add_server_locked(addr);
pthread_rwlock_unlock(&this->rwlock);
}
int UPSPolicy::remove_server(const std::string& address)
{
int ret;
pthread_rwlock_wrlock(&this->rwlock);
ret = this->remove_server_locked(address);
pthread_rwlock_unlock(&this->rwlock);
return ret;
}
int UPSPolicy::replace_server(const std::string& address,
const AddressParams *address_params)
{
int ret;
EndpointAddress *addr = new EndpointAddress(address, address_params);
pthread_rwlock_wrlock(&this->rwlock);
this->add_server_locked(addr);
ret = this->remove_server_locked(address);
pthread_rwlock_unlock(&this->rwlock);
return ret;
}
void UPSPolicy::enable_server(const std::string& address)
{
pthread_rwlock_rdlock(&this->rwlock);
const auto map_it = this->server_map.find(address);
if (map_it != this->server_map.cend())
{
for (EndpointAddress *addr : map_it->second)
this->recover_server_from_breaker(addr);
}
pthread_rwlock_unlock(&this->rwlock);
}
void UPSPolicy::disable_server(const std::string& address)
{
pthread_rwlock_rdlock(&this->rwlock);
const auto map_it = this->server_map.find(address);
if (map_it != this->server_map.cend())
{
for (EndpointAddress *addr : map_it->second)
{
addr->fail_count = addr->params.max_fails;
this->fuse_server_to_breaker(addr);
}
}
pthread_rwlock_unlock(&this->rwlock);
}
void UPSPolicy::get_main_address(std::vector<std::string>& addr_list)
{
pthread_rwlock_rdlock(&this->rwlock);
for (const EndpointAddress *server : this->servers)
addr_list.push_back(server->address);
pthread_rwlock_unlock(&this->rwlock);
}
UPSGroupPolicy::UPSGroupPolicy()
{
this->group_map.rb_node = NULL;
this->default_group = new EndpointGroup(-1, this);
rb_link_node(&this->default_group->rb, NULL, &this->group_map.rb_node);
rb_insert_color(&this->default_group->rb, &this->group_map);
}
UPSGroupPolicy::~UPSGroupPolicy()
{
EndpointGroup *group;
while (this->group_map.rb_node)
{
group = rb_entry(this->group_map.rb_node, EndpointGroup, rb);
rb_erase(this->group_map.rb_node, &this->group_map);
delete group;
}
}
bool UPSGroupPolicy::select(const ParsedURI& uri, EndpointAddress **addr)
{
pthread_rwlock_rdlock(&this->rwlock);
unsigned int n = (unsigned int)this->servers.size();
if (n == 0)
{
pthread_rwlock_unlock(&this->rwlock);
return false;
}
this->check_breaker();
if (this->nalives == 0)
{
pthread_rwlock_unlock(&this->rwlock);
return false;
}
// select_addr == NULL will only happened in consistent_hash
const EndpointAddress *select_addr = this->first_stradegy(uri);
if (!select_addr || select_addr->fail_count >= select_addr->params.max_fails)
{
if (select_addr)
select_addr = this->check_and_get(select_addr, true);
if (!select_addr && this->try_another)
{
select_addr = this->another_stradegy(uri);
select_addr = this->check_and_get(select_addr, false);
}
}
if (!select_addr)
this->default_group->get_one_backup();
pthread_rwlock_unlock(&this->rwlock);
if (select_addr)
{
*addr = (EndpointAddress *)select_addr;
return true;
}
return false;
}
// flag true : guarantee addr != NULL, and please return an available one
// flag false : means addr maybe useful but want one any way. addr may be NULL
inline const EndpointAddress *UPSGroupPolicy::check_and_get(const EndpointAddress *addr,
bool flag)
{
if (flag == true) // && addr->fail_count >= addr->params.max_fails
{
if (addr->params.group_id == -1)
return NULL;
return addr->group->get_one();
}
if (addr && addr->fail_count >= addr->params.max_fails &&
addr->params.group_id >= 0)
{
const EndpointAddress *tmp = addr->group->get_one();
if (tmp)
addr = tmp;
}
return addr;
}
const EndpointAddress *EndpointGroup::get_one()
{
if (this->nalives == 0)
return NULL;
const EndpointAddress *addr = NULL;
pthread_mutex_lock(&this->mutex);
std::random_shuffle(this->mains.begin(), this->mains.end());
for (size_t i = 0; i < this->mains.size(); i++)
{
if (this->mains[i]->fail_count < this->mains[i]->params.max_fails)
{
addr = this->mains[i];
break;
}
}
if (!addr)
{
std::random_shuffle(this->backups.begin(), this->backups.end());
for (size_t i = 0; i < this->backups.size(); i++)
{
if (this->backups[i]->fail_count < this->backups[i]->params.max_fails)
{
addr = this->backups[i];
break;
}
}
}
pthread_mutex_unlock(&this->mutex);
return addr;
}
const EndpointAddress *EndpointGroup::get_one_backup()
{
if (this->nalives == 0)
return NULL;
const EndpointAddress *addr = NULL;
pthread_mutex_lock(&this->mutex);
std::random_shuffle(this->backups.begin(), this->backups.end());
for (size_t i = 0; i < this->backups.size(); i++)
{
if (this->backups[i]->fail_count < this->backups[i]->params.max_fails)
{
addr = this->backups[i];
break;
}
}
pthread_mutex_unlock(&this->mutex);
return addr;
}
void UPSGroupPolicy::add_server_locked(EndpointAddress *addr)
{
int group_id = addr->params.group_id;
rb_node **p = &this->group_map.rb_node;
rb_node *parent = NULL;
EndpointGroup *group;
this->addresses.push_back(addr);
this->server_map[addr->address].push_back(addr);
if (addr->params.server_type == 0)
this->servers.push_back(addr);
while (*p)
{
parent = *p;
group = rb_entry(*p, EndpointGroup, rb);
if (group_id < group->id)
p = &(*p)->rb_left;
else if (group_id > group->id)
p = &(*p)->rb_right;
else
break;
}
if (*p == NULL)
{
group = new EndpointGroup(group_id, this);
rb_link_node(&group->rb, parent, p);
rb_insert_color(&group->rb, &this->group_map);
}
pthread_mutex_lock(&group->mutex);
addr->group = group;
this->recover_one_server(addr);
if (addr->params.server_type == 0)
{
group->mains.push_back(addr);
group->weight += addr->params.weight;
}
else
group->backups.push_back(addr);
pthread_mutex_unlock(&group->mutex);
return;
}
int UPSGroupPolicy::remove_server_locked(const std::string& address)
{
const auto map_it = this->server_map.find(address);
if (map_it != this->server_map.cend())
{
for (EndpointAddress *addr : map_it->second)
{
EndpointGroup *group = addr->group;
std::vector<EndpointAddress *> *vec;
if (addr->params.server_type == 0)
vec = &group->mains;
else
vec = &group->backups;
//std::lock_guard<std::mutex> lock(group->mutex);
pthread_mutex_lock(&group->mutex);
if (addr->fail_count < addr->params.max_fails)
this->fuse_one_server(addr);
if (addr->params.server_type == 0)
group->weight -= addr->params.weight;
for (auto it = vec->begin(); it != vec->end(); ++it)
{
if (*it == addr)
{
vec->erase(it);
break;
}
}
pthread_mutex_unlock(&group->mutex);
}
this->server_map.erase(map_it);
}
size_t n = this->servers.size();
size_t new_n = 0;
for (size_t i = 0; i < n; i++)
{
if (this->servers[i]->address != address)
{
if (new_n != i)
this->servers[new_n++] = this->servers[i];
else
new_n++;
}
}
int ret = 0;
if (new_n < n)
{
this->servers.resize(new_n);
ret = n - new_n;
}
return ret;
}
const EndpointAddress *UPSGroupPolicy::consistent_hash_with_group(unsigned int hash)
{
const EndpointAddress *addr = NULL;
unsigned int min_dis = (unsigned int)-1;
for (const EndpointAddress *server : this->servers)
{
if (this->is_alive_or_group_alive(server))
{
for (int i = 0; i < VIRTUAL_GROUP_SIZE; i++)
{
unsigned int dis = std::min<unsigned int>
(hash - server->consistent_hash[i],
server->consistent_hash[i] - hash);
if (dis < min_dis)
{
min_dis = dis;
addr = server;
}
}
}
}
return this->check_and_get(addr, false);
}
void UPSWeightedRandomPolicy::add_server_locked(EndpointAddress *addr)
{
UPSGroupPolicy::add_server_locked(addr);
if (addr->params.server_type == 0)
this->total_weight += addr->params.weight;
return;
}
int UPSWeightedRandomPolicy::remove_server_locked(const std::string& address)
{
const auto map_it = this->server_map.find(address);
if (map_it != this->server_map.cend())
{
for (EndpointAddress *addr : map_it->second)
{
if (addr->params.server_type == 0)
this->total_weight -= addr->params.weight;
}
}
return UPSGroupPolicy::remove_server_locked(address);
}
const EndpointAddress *UPSWeightedRandomPolicy::first_stradegy(const ParsedURI& uri)
{
int x = 0;
int s = 0;
size_t idx;
int temp_weight = this->total_weight;
if (temp_weight > 0)
x = rand() % temp_weight;
for (idx = 0; idx < this->servers.size(); idx++)
{
s += this->servers[idx]->params.weight;
if (s > x)
break;
}
if (idx == this->servers.size())
idx--;
return this->servers[idx];
}
const EndpointAddress *UPSWeightedRandomPolicy::another_stradegy(const ParsedURI& uri)
{
int temp_weight = this->available_weight;
if (temp_weight == 0)
return NULL;
const EndpointAddress *addr = NULL;
int x = rand() % temp_weight;
int s = 0;
for (const EndpointAddress *server : this->servers)
{
if (this->is_alive_or_group_alive(server))
{
addr = server;
s += server->params.weight;
if (s > x)
break;
}
}
return this->check_and_get(addr, false);
}
void UPSWeightedRandomPolicy::recover_one_server(const EndpointAddress *addr)
{
this->nalives++;
if (addr->group->nalives++ == 0 && addr->group->id > 0)
this->available_weight += addr->group->weight;
if (addr->params.group_id < 0 && addr->params.server_type == 0)
this->available_weight += addr->params.weight;
}
void UPSWeightedRandomPolicy::fuse_one_server(const EndpointAddress *addr)
{
this->nalives--;
if (--addr->group->nalives == 0 && addr->group->id > 0)
this->available_weight -= addr->group->weight;
if (addr->params.group_id < 0 && addr->params.server_type == 0)
this->available_weight -= addr->params.weight;
}
const EndpointAddress *UPSConsistentHashPolicy::first_stradegy(const ParsedURI& uri)
{
unsigned int hash_value;
if (this->consistent_hash)
hash_value = this->consistent_hash(uri.path ? uri.path : "",
uri.query ? uri.query : "",
uri.fragment ? uri.fragment : "");
else
hash_value = this->default_consistent_hash(uri.path ? uri.path : "",
uri.query ? uri.query : "",
uri.fragment ? uri.fragment : "");
return this->consistent_hash_with_group(hash_value);
}
const EndpointAddress *UPSManualPolicy::first_stradegy(const ParsedURI& uri)
{
unsigned int idx = this->manual_select(uri.path ? uri.path : "",
uri.query ? uri.query : "",
uri.fragment ? uri.fragment : "");
if (idx >= this->servers.size())
idx %= this->servers.size();
return this->servers[idx];
}
const EndpointAddress *UPSManualPolicy::another_stradegy(const ParsedURI& uri)
{
unsigned int hash_value;
if (this->try_another_select)
hash_value = this->try_another_select(uri.path ? uri.path : "",
uri.query ? uri.query : "",
uri.fragment ? uri.fragment : "");
else
hash_value = UPSConsistentHashPolicy::default_consistent_hash(uri.path ? uri.path : "",
uri.query ? uri.query : "",
uri.fragment ? uri.fragment : "");
return this->consistent_hash_with_group(hash_value);
}

View File

@@ -0,0 +1,274 @@
/*
Copyright (c) 2021 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
*/
#ifndef _UPSTREAM_POLICIES_H_
#define _UPSTREAM_POLICIES_H_
#include "EndpointParams.h"
#include "WFNameService.h"
#include "WFDNSResolver.h"
#include "WFGlobal.h"
#include "WFTaskError.h"
#include "UpstreamManager.h"
#include <unordered_map>
#include <vector>
#define MTTR_SECOND 30
#define VIRTUAL_GROUP_SIZE 16
#define GET_CURRENT_SECOND std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now().time_since_epoch()).count()
class EndpointGroup;
class UPSPolicy;
class UPSGroupPolicy;
class EndpointAddress
{
public:
EndpointGroup *group;
AddressParams params;
std::string address;
std::string host;
std::string port;
short port_value;
struct list_head list;
std::atomic<unsigned int> fail_count;
int64_t broken_timeout;
unsigned int consistent_hash[VIRTUAL_GROUP_SIZE];
public:
EndpointAddress(const std::string& address,
const struct AddressParams *address_params);
};
class EndpointGroup
{
public:
int id;
UPSGroupPolicy *policy;
struct rb_node rb;
pthread_mutex_t mutex;
std::vector<EndpointAddress *> mains;
std::vector<EndpointAddress *> backups;
std::atomic<int> nalives;
int weight;
EndpointGroup(int group_id, UPSGroupPolicy *policy) :
mutex(PTHREAD_MUTEX_INITIALIZER)
{
this->id = group_id;
this->policy = policy;
this->nalives = 0;
this->weight = 0;
}
public:
const EndpointAddress *get_one();
const EndpointAddress *get_one_backup();
};
class UPSPolicy : public WFDNSResolver
{
public:
virtual WFRouterTask *create_router_task(const struct WFNSParams *params,
router_callback_t callback);
virtual void success(RouteManager::RouteResult *result, void *cookie,
CommTarget *target);
virtual void failed(RouteManager::RouteResult *result, void *cookie,
CommTarget *target);
void add_server(const std::string& address, const AddressParams *address_params);
int remove_server(const std::string& address);
int replace_server(const std::string& address, const AddressParams *address_params);
virtual void enable_server(const std::string& address);
virtual void disable_server(const std::string& address);
virtual void get_main_address(std::vector<std::string>& addr_list);
// virtual void server_list_change(/* std::vector<server> status */) {}
public:
UPSPolicy() :
breaker_lock(PTHREAD_MUTEX_INITIALIZER),
rwlock(PTHREAD_RWLOCK_INITIALIZER)
{
this->nalives = 0;
this->try_another = false;
INIT_LIST_HEAD(&this->breaker_list);
}
virtual ~UPSPolicy()
{
for (EndpointAddress *addr : this->addresses)
delete addr;
}
private:
virtual bool select(const ParsedURI& uri, EndpointAddress **addr);
virtual void recover_one_server(const EndpointAddress *addr)
{
this->nalives++;
}
virtual void fuse_one_server(const EndpointAddress *addr)
{
this->nalives--;
}
virtual void add_server_locked(EndpointAddress *addr);
virtual int remove_server_locked(const std::string& address);
void recover_server_from_breaker(EndpointAddress *addr);
void fuse_server_to_breaker(EndpointAddress *addr);
struct list_head breaker_list;
pthread_mutex_t breaker_lock;
protected:
virtual const EndpointAddress *first_stradegy(const ParsedURI& uri);
virtual const EndpointAddress *another_stradegy(const ParsedURI& uri);
void check_breaker();
std::vector<EndpointAddress *> servers; // current servers
std::vector<EndpointAddress *> addresses; // memory management
std::unordered_map<std::string,
std::vector<EndpointAddress *>> server_map;
pthread_rwlock_t rwlock;
std::atomic<int> nalives;
bool try_another;
};
class UPSGroupPolicy : public UPSPolicy
{
public:
UPSGroupPolicy();
~UPSGroupPolicy();
protected:
struct rb_root group_map;
EndpointGroup *default_group;
private:
virtual void recover_one_server(const EndpointAddress *addr)
{
this->nalives++;
addr->group->nalives++;
}
virtual void fuse_one_server(const EndpointAddress *addr)
{
this->nalives--;
addr->group->nalives--;
}
virtual bool select(const ParsedURI& uri, EndpointAddress **addr);
protected:
virtual void add_server_locked(EndpointAddress *addr);
virtual int remove_server_locked(const std::string& address);
const EndpointAddress *consistent_hash_with_group(unsigned int hash);
const EndpointAddress *check_and_get(const EndpointAddress *addr, bool flag);
inline bool is_alive_or_group_alive(const EndpointAddress *addr) const
{
return ((addr->params.group_id < 0 &&
addr->fail_count < addr->params.max_fails) ||
(addr->params.group_id >= 0 &&
addr->group->nalives > 0));
}
};
class UPSWeightedRandomPolicy : public UPSGroupPolicy
{
public:
UPSWeightedRandomPolicy(bool try_another)
{
this->total_weight = 0;
this->available_weight = 0;
this->try_another = try_another;
}
const EndpointAddress *first_stradegy(const ParsedURI& uri);
const EndpointAddress *another_stradegy(const ParsedURI& uri);
protected:
int total_weight;
int available_weight;
private:
virtual void recover_one_server(const EndpointAddress *addr);
virtual void fuse_one_server(const EndpointAddress *addr);
virtual void add_server_locked(EndpointAddress *addr);
virtual int remove_server_locked(const std::string& address);
};
class UPSConsistentHashPolicy : public UPSGroupPolicy
{
public:
UPSConsistentHashPolicy()
{
this->consistent_hash = this->default_consistent_hash;
}
UPSConsistentHashPolicy(upstream_route_t consistent_hash)
{
this->consistent_hash = std::move(consistent_hash);
}
protected:
const EndpointAddress *first_stradegy(const ParsedURI& uri);
private:
upstream_route_t consistent_hash;
public:
static unsigned int default_consistent_hash(const char *path,
const char *query,
const char *fragment)
{
static std::hash<std::string> std_hash;
std::string str(path);
str += query;
str += fragment;
return std_hash(str);
}
};
class UPSManualPolicy : public UPSGroupPolicy
{
public:
UPSManualPolicy(bool try_another, upstream_route_t select,
upstream_route_t try_another_select)
{
this->try_another = try_another;
this->manual_select = select;
this->try_another_select = try_another_select;
}
const EndpointAddress *first_stradegy(const ParsedURI& uri);
const EndpointAddress *another_stradegy(const ParsedURI& uri);
private:
upstream_route_t manual_select;
upstream_route_t try_another_select;
};
#endif

View File

@@ -39,6 +39,8 @@
#include "Executor.h"
#include "WFTask.h"
#include "WFTaskError.h"
#include "WFNameService.h"
#include "WFDNSResolver.h"
class __WFGlobal
{
@@ -515,6 +517,28 @@ private:
Executor compute_executor_;
};
class __NameServiceManager
{
public:
static __NameServiceManager *get_instance()
{
static __NameServiceManager kInstance;
return &kInstance;
}
public:
WFNameService *get_name_service() { return &service_; }
private:
static WFDNSResolver resolver_;
WFNameService service_;
public:
__NameServiceManager() : service_(&__NameServiceManager::resolver_) { }
};
WFDNSResolver __NameServiceManager::resolver_;
CommScheduler *WFGlobal::get_scheduler()
{
return __CommManager::get_instance()->get_scheduler();
@@ -565,6 +589,11 @@ Executor *WFGlobal::get_dns_executor()
return __CommManager::get_instance()->get_dns_executor();
}
WFNameService *WFGlobal::get_name_service()
{
return __NameServiceManager::get_instance()->get_name_service();
}
const char *WFGlobal::get_default_port(const std::string& scheme)
{
return __WFGlobal::get_instance()->get_default_port(scheme);

View File

@@ -31,6 +31,7 @@
#include "RouteManager.h"
#include "Executor.h"
#include "EndpointParams.h"
#include "WFNameService.h"
/**
* @file WFGlobal.h
@@ -105,27 +106,18 @@ public:
static const char *get_error_string(int state, int error);
public:
/// @brief Internal use only
// Internal usage only
static CommScheduler *get_scheduler();
/// @brief Internal use only
static DNSCache *get_dns_cache();
/// @brief Internal use only
static RouteManager *get_route_manager();
/// @brief Internal use only
static SSL_CTX *get_ssl_client_ctx();
/// @brief Internal use only
static SSL_CTX *get_ssl_server_ctx();
/// @brief Internal use only
static ExecQueue *get_exec_queue(const std::string& queue_name);
/// @brief Internal use only
static Executor *get_compute_executor();
/// @brief Internal use only
static IOService *get_io_service();
/// @brief Internal use only
static ExecQueue *get_dns_queue();
/// @brief Internal use only
static Executor *get_dns_executor();
/// @brief Internal use only
static WFNameService *get_name_service();
static void sync_operation_begin();
static void sync_operation_end();
};

View File

@@ -37,6 +37,7 @@ set(TEST_LIST
facilities_unittest
graph_unittest
memory_unittest
upstream_unittest
)
if (APPLE)

237
test/upstream_unittest.cc Normal file
View File

@@ -0,0 +1,237 @@
/*
Copyright (c) 2020 Sogou, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Li Yingxin (liyingxin@sogou-inc.com)
*/
#include <gtest/gtest.h>
#include "workflow/UpstreamManager.h"
#include "workflow/WFHttpServer.h"
#include "workflow/WFTaskFactory.h"
#include "workflow/WFFacilities.h"
#define REDIRECT_MAX 3
#define RETRY_MAX 3
#define MTTR 30
#define MAX_FAILS 200
static void __http_process1(WFHttpTask *task)
{
auto *resp = task->get_resp();
resp->add_header_pair("Content-Type", "text/plain");
resp->append_output_body_nocopy("server1", 7);
}
static void __http_process2(WFHttpTask *task)
{
auto *resp = task->get_resp();
resp->add_header_pair("Content-Type", "text/plain");
resp->append_output_body_nocopy("server2", 7);
}
WFHttpServer http_server1(__http_process1);
WFHttpServer http_server2(__http_process2);
void register_upstream_hosts()
{
UpstreamManager::upstream_create_weighted_random("weighted.random", false);
AddressParams address_params = ADDRESS_PARAMS_DEFAULT;
address_params.weight = 1000;
UpstreamManager::upstream_add_server("weighted.random",
"127.0.0.1:8001",
&address_params);
address_params.weight = 1;
UpstreamManager::upstream_add_server("weighted.random",
"127.0.0.1:8002",
&address_params);
UpstreamManager::upstream_create_consistent_hash(
"hash",
[](const char *path, const char *query, const char *fragment) -> unsigned int {
return 1;
});
UpstreamManager::upstream_add_server("hash", "127.0.0.1:8001");
UpstreamManager::upstream_add_server("hash", "127.0.0.1:8002");
UpstreamManager::upstream_create_manual(
"manual",
[](const char *path, const char *query, const char *fragment) -> unsigned int {
return 0;
},
false, nullptr);
UpstreamManager::upstream_add_server("manual", "127.0.0.1:8001");
UpstreamManager::upstream_add_server("manual", "127.0.0.1:8002");
UpstreamManager::upstream_create_weighted_random("try_another", true);
address_params.weight = 1000;
UpstreamManager::upstream_add_server("try_another",
"127.0.0.1:8001",
&address_params);
address_params.weight = 1;
UpstreamManager::upstream_add_server("try_another",
"127.0.0.1:8002",
&address_params);
}
void basic_callback(WFHttpTask *task, std::string& message)
{
auto state = task->get_state();
EXPECT_EQ(state, WFT_STATE_SUCCESS);
if (state == WFT_STATE_SUCCESS && message.compare(""))
{
const void *body;
size_t body_len;
task->get_resp()->get_parsed_body(&body, &body_len);
std::string buffer((char *)body, body_len);
EXPECT_EQ(buffer, message);
}
WFFacilities::WaitGroup *wait_group = (WFFacilities::WaitGroup *)task->user_data;
wait_group->done();
}
TEST(upstream_unittest, BasicPolicy)
{
WFFacilities::WaitGroup wait_group(3);
register_upstream_hosts();
char url[3][30] = {"http://weighted.random", "http://hash", "http://manual"};
http_callback_t cb = std::bind(basic_callback, std::placeholders::_1,
std::string("server1"));
for (int i = 0; i < 3; i++)
{
WFHttpTask *task = WFTaskFactory::create_http_task(url[i],
REDIRECT_MAX, RETRY_MAX, cb);
task->user_data = &wait_group;
task->start();
}
wait_group.wait();
}
TEST(upstream_unittest, EnableAndDisable)
{
WFFacilities::WaitGroup wait_group(1);
UpstreamManager::upstream_disable_server("weighted.random", "127.0.0.1:8001");
//fprintf(stderr, "disable server and try......................\n");
std::string url = "http://weighted.random";
WFHttpTask *task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX,
[&wait_group, &url](WFHttpTask *task){
auto state = task->get_state();
EXPECT_EQ(state, WFT_STATE_TASK_ERROR);
EXPECT_EQ(task->get_error(), WFT_ERR_UPSTREAM_UNAVAILABLE);
UpstreamManager::upstream_enable_server("weighted.random", "127.0.0.1:8001");
//fprintf(stderr, "ensable server and try......................\n");
auto *task2 = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX,
std::bind(basic_callback,
std::placeholders::_1,
std::string("server1")));
task2->user_data = &wait_group;
series_of(task)->push_back(task2);
});
task->user_data = &wait_group;
task->start();
wait_group.wait();
}
TEST(upstream_unittest, FuseAndRecover)
{
WFFacilities::WaitGroup wait_group(1);
WFHttpTask *task;
SeriesWork *series;
protocol::HttpRequest *req;
std::string url = "http://weighted.random";
int batch = MAX_FAILS + 50;
int timeout = (MTTR + 3) * 1000000;
http_server1.stop();
fprintf(stderr, "server 1 stopped start %d tasks to fuse it\n", batch);
ParallelWork *pwork = Workflow::create_parallel_work(
[](const ParallelWork *pwork) {
fprintf(stderr, "parallel finished\n");
});
for (int i = 0; i < batch; i++)
{
task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX,
nullptr);
req = task->get_req();
req->add_header_pair("Connection", "keep-alive");
series = Workflow::create_series_work(task, nullptr);
pwork->add_series(series);
}
series = Workflow::create_series_work(pwork, nullptr);
WFTimerTask *timer = WFTaskFactory::create_timer_task(timeout,
[](WFTimerTask *task) {
fprintf(stderr, "timer_finished and start server1\n");
EXPECT_TRUE(http_server1.start("127.0.0.1", 8001) == 0)
<< "http server start failed";
});
series->push_back(timer);
task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX,
std::bind(basic_callback,
std::placeholders::_1,
std::string("server1")));
task->user_data = &wait_group;
series->push_back(task);
series->start();
wait_group.wait();
}
TEST(upstream_unittest, TryAnother)
{
WFFacilities::WaitGroup wait_group(1);
UpstreamManager::upstream_disable_server("try_another", "127.0.0.1:8001");
std::string url = "http://try_another";
WFHttpTask *task = WFTaskFactory::create_http_task(url, REDIRECT_MAX, RETRY_MAX,
std::bind(basic_callback,
std::placeholders::_1,
std::string("server2")));
task->user_data = &wait_group;
task->start();
wait_group.wait();
UpstreamManager::upstream_enable_server("try_another", "127.0.0.1:8001");
}
int main(int argc, char* argv[])
{
::testing::InitGoogleTest(&argc, argv);
EXPECT_TRUE(http_server1.start("127.0.0.1", 8001) == 0)
<< "http server start failed";
EXPECT_TRUE(http_server2.start("127.0.0.1", 8002) == 0)
<< "http server start failed";
EXPECT_EQ(RUN_ALL_TESTS(), 0);
http_server1.stop();
http_server2.stop();
return 0;
}