From d54f06e6ae0d2763dfd0f53135d3470a9fbebdda Mon Sep 17 00:00:00 2001 From: kedixa <1204837541@qq.com> Date: Tue, 28 May 2024 20:17:25 +0800 Subject: [PATCH] HttpTask support auth when has userinfo (#1554) * HttpTask support auth when has userinfo * clear auth when redirect to other host * HttpProxyTask support auth when has userinfo * remove one StringUtil::url_decode as requested in #1554 --- src/factory/HttpTaskImpl.cc | 177 +++++++++++++++++++++++++++--------- src/util/StringUtil.cc | 31 +++---- src/util/StringUtil.h | 1 - 3 files changed, 149 insertions(+), 60 deletions(-) diff --git a/src/factory/HttpTaskImpl.cc b/src/factory/HttpTaskImpl.cc index f32fde86..ba852f50 100644 --- a/src/factory/HttpTaskImpl.cc +++ b/src/factory/HttpTaskImpl.cc @@ -39,6 +39,23 @@ using namespace protocol; /**********Client**********/ +static int __encode_auth(const char *p, std::string& auth) +{ + size_t len = strlen(p); + size_t base64_len = (len + 2) / 3 * 4; + char *base64 = (char *)malloc(base64_len + 1); + + if (!base64) + return -1; + + EVP_EncodeBlock((unsigned char *)base64, (const unsigned char *)p, len); + auth.append("Basic "); + auth.append(base64, base64_len); + + free(base64); + return 0; +} + class ComplexHttpTask : public WFComplexClientTask { public: @@ -64,8 +81,9 @@ protected: virtual bool finish_once(); protected: - bool need_redirect(ParsedURI& uri); - bool redirect_url(HttpResponse *client_resp, ParsedURI& uri); + bool need_redirect(const ParsedURI& uri, ParsedURI& new_uri); + bool redirect_url(HttpResponse *client_resp, + const ParsedURI& uri, ParsedURI& new_uri); void set_empty_request(); void check_response(); @@ -184,6 +202,10 @@ void ComplexHttpTask::set_empty_request() client_req->set_request_uri("/"); cursor.find_and_erase(&header); + + header.name = "Authorization"; + header.name_len = strlen("Authorization"); + cursor.find_and_erase(&header); } void ComplexHttpTask::init_failed() @@ -206,7 +228,6 @@ bool ComplexHttpTask::init_success() { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; - this->set_empty_request(); return false; } @@ -253,10 +274,29 @@ bool ComplexHttpTask::init_success() this->WFComplexClientTask::set_transport_type(is_ssl ? TT_TCP_SSL : TT_TCP); client_req->set_request_uri(request_uri.c_str()); client_req->set_header_pair("Host", header_host.c_str()); + + if (uri_.userinfo && uri_.userinfo[0]) + { + std::string userinfo(uri_.userinfo); + std::string http_auth; + + StringUtil::url_decode(userinfo); + + if (__encode_auth(userinfo.c_str(), http_auth) < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + return false; + } + + client_req->set_header_pair("Authorization", http_auth.c_str()); + } + return true; } -bool ComplexHttpTask::redirect_url(HttpResponse *client_resp, ParsedURI& uri) +bool ComplexHttpTask::redirect_url(HttpResponse *client_resp, + const ParsedURI& uri, ParsedURI& new_uri) { if (redirect_count_ < redirect_max_) { @@ -284,14 +324,14 @@ bool ComplexHttpTask::redirect_url(HttpResponse *client_resp, ParsedURI& uri) url = uri.scheme + (':' + url); } - URIParser::parse(url, uri); + URIParser::parse(url, new_uri); return true; } return false; } -bool ComplexHttpTask::need_redirect(ParsedURI& uri) +bool ComplexHttpTask::need_redirect(const ParsedURI& uri, ParsedURI& new_uri) { HttpRequest *client_req = this->get_req(); HttpResponse *client_resp = this->get_resp(); @@ -308,7 +348,7 @@ bool ComplexHttpTask::need_redirect(ParsedURI& uri) case 301: case 302: case 303: - if (redirect_url(client_resp, uri)) + if (redirect_url(client_resp, uri, new_uri)) { if (strcasecmp(method, HttpMethodGet) != 0 && strcasecmp(method, HttpMethodHead) != 0) @@ -323,7 +363,7 @@ bool ComplexHttpTask::need_redirect(ParsedURI& uri) case 307: case 308: - if (redirect_url(client_resp, uri)) + if (redirect_url(client_resp, uri, new_uri)) return true; else break; @@ -359,8 +399,31 @@ bool ComplexHttpTask::finish_once() if (this->state == WFT_STATE_SUCCESS) { - if (this->need_redirect(uri_)) - this->set_redirect(uri_); + ParsedURI new_uri; + if (this->need_redirect(uri_, new_uri)) + { + if (uri_.userinfo && strcasecmp(uri_.host, new_uri.host) == 0) + { + if (!new_uri.userinfo) + { + new_uri.userinfo = uri_.userinfo; + uri_.userinfo = NULL; + } + } + else if (uri_.userinfo) + { + HttpRequest *client_req = this->get_req(); + HttpHeaderCursor cursor(client_req); + struct HttpMessageHeader header = { + .name = "Authorization", + .name_len = strlen("Authorization") + }; + + cursor.find_and_erase(&header); + } + + this->set_redirect(new_uri); + } else if (this->state != WFT_STATE_SUCCESS) this->disable_retry(); } @@ -370,23 +433,6 @@ bool ComplexHttpTask::finish_once() /*******Proxy Client*******/ -static int __encode_auth(const char *p, std::string& auth) -{ - size_t len = strlen(p); - size_t base64_len = (len + 2) / 3 * 4; - char *base64 = (char *)malloc(base64_len + 1); - - if (!base64) - return -1; - - EVP_EncodeBlock((unsigned char *)base64, (const unsigned char *)p, len); - auth.append("Basic "); - auth.append(base64, base64_len); - - free(base64); - return 0; -} - static SSL *__create_ssl(SSL_CTX *ssl_ctx) { BIO *wbio; @@ -635,7 +681,6 @@ bool ComplexHttpProxyTask::init_success() { this->state = WFT_STATE_TASK_ERROR; this->error = WFT_ERR_URI_SCHEME_INVALID; - this->set_empty_request(); return false; } @@ -653,17 +698,6 @@ bool ComplexHttpProxyTask::init_success() else user_port = is_ssl_ ? 443 : 80; - if (uri_.userinfo && uri_.userinfo[0]) - { - proxy_auth_.clear(); - if (__encode_auth(uri_.userinfo, proxy_auth_) < 0) - { - this->state = WFT_STATE_SYS_ERROR; - this->error = errno; - return false; - } - } - std::string info("http-proxy|remote:"); info += is_ssl_ ? "https://" : "http://"; info += user_uri_.host; @@ -672,8 +706,24 @@ bool ComplexHttpProxyTask::init_success() info += user_uri_.port; else info += is_ssl_ ? "443" : "80"; - info += "|auth:"; - info += proxy_auth_; + + if (uri_.userinfo && uri_.userinfo[0]) + { + std::string userinfo(uri_.userinfo); + + StringUtil::url_decode(userinfo); + proxy_auth_.clear(); + + if (__encode_auth(userinfo.c_str(), proxy_auth_) < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + return false; + } + + info += "|auth:"; + info += proxy_auth_; + } this->WFComplexClientTask::set_info(info); @@ -704,6 +754,24 @@ bool ComplexHttpProxyTask::init_success() client_req->set_request_uri(request_uri.c_str()); client_req->set_header_pair("Host", header_host.c_str()); this->WFComplexClientTask::set_transport_type(TT_TCP); + + if (user_uri_.userinfo && user_uri_.userinfo[0]) + { + std::string userinfo(user_uri_.userinfo); + std::string http_auth; + + StringUtil::url_decode(userinfo); + + if (__encode_auth(userinfo.c_str(), http_auth) < 0) + { + this->state = WFT_STATE_SYS_ERROR; + this->error = errno; + return false; + } + + client_req->set_header_pair("Authorization", http_auth.c_str()); + } + return true; } @@ -732,8 +800,33 @@ bool ComplexHttpProxyTask::finish_once() if (this->state == WFT_STATE_SUCCESS) { - if (this->need_redirect(user_uri_)) + ParsedURI new_uri; + if (this->need_redirect(user_uri_, new_uri)) + { + if (user_uri_.userinfo && + strcasecmp(user_uri_.host, new_uri.host) == 0) + { + if (!new_uri.userinfo) + { + new_uri.userinfo = user_uri_.userinfo; + user_uri_.userinfo = NULL; + } + } + else if (user_uri_.userinfo) + { + HttpRequest *client_req = this->get_req(); + HttpHeaderCursor cursor(client_req); + struct HttpMessageHeader header = { + .name = "Authorization", + .name_len = strlen("Authorization") + }; + + cursor.find_and_erase(&header); + } + + user_uri_ = std::move(new_uri); this->set_redirect(uri_); + } else if (this->state != WFT_STATE_SUCCESS) this->disable_retry(); } diff --git a/src/util/StringUtil.cc b/src/util/StringUtil.cc index edaf6df8..7fa084d1 100644 --- a/src/util/StringUtil.cc +++ b/src/util/StringUtil.cc @@ -49,20 +49,17 @@ static inline char __itoh(int n) return n + '0'; } -size_t StringUtil::url_decode(char *str, size_t len) +static size_t __url_decode(char *str) { char *dest = str; char *data = str; - while (len--) + while (*data) { - if (*data == '%' && len >= 2 - && isxdigit(*(data + 1)) - && isxdigit(*(data + 2))) + if (*data == '%' && isxdigit(data[1]) && isxdigit(data[2])) { *dest = __htoi((unsigned char *)data + 1); data += 2; - len -= 2; } else if (*data == '+') *dest = ' '; @@ -82,25 +79,25 @@ void StringUtil::url_decode(std::string& str) if (str.empty()) return; - size_t sz = url_decode(const_cast(str.c_str()), str.size()); + size_t sz = __url_decode(const_cast(str.c_str())); str.resize(sz); } std::string StringUtil::url_encode(const std::string& str) { - std::string res; const char *cur = str.c_str(); const char *ed = cur + str.size(); + std::string res; while (cur < ed) { if (*cur == ' ') res += '+'; - else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' - || *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' - || *cur == '(' || *cur == ')' || *cur == ':' || *cur == '/' - || *cur == '@' || *cur == '?' || *cur == '#' || *cur == '&') + else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' || + *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' || + *cur == '(' || *cur == ')' || *cur == ':' || *cur == '/' || + *cur == '@' || *cur == '?' || *cur == '#' || *cur == '&') res += *cur; else { @@ -117,17 +114,17 @@ std::string StringUtil::url_encode(const std::string& str) std::string StringUtil::url_encode_component(const std::string& str) { - std::string res; const char *cur = str.c_str(); const char *ed = cur + str.size(); + std::string res; while (cur < ed) { if (*cur == ' ') res += '+'; - else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' - || *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' - || *cur == '(' || *cur == ')') + else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' || + *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' || + *cur == '(' || *cur == ')') res += *cur; else { @@ -144,10 +141,10 @@ std::string StringUtil::url_encode_component(const std::string& str) std::vector StringUtil::split(const std::string& str, char sep) { - std::vector res; std::string::const_iterator start = str.begin(); std::string::const_iterator end = str.end(); std::string::const_iterator next = find(start, end, sep); + std::vector res; while (next != end) { diff --git a/src/util/StringUtil.h b/src/util/StringUtil.h index 2a0d5a39..e40e87a9 100644 --- a/src/util/StringUtil.h +++ b/src/util/StringUtil.h @@ -31,7 +31,6 @@ class StringUtil { public: - static size_t url_decode(char *str, size_t len); static void url_decode(std::string& str); static std::string url_encode(const std::string& str); static std::string url_encode_component(const std::string& str);