mirror of
https://github.com/sogou/workflow.git
synced 2026-02-08 01:33:17 +08:00
355 lines
7.6 KiB
C++
355 lines
7.6 KiB
C++
/*
|
|
Copyright (c) 2019 Sogou, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
|
|
Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
|
|
*/
|
|
|
|
#include <stdint.h>
|
|
#include <string.h>
|
|
#include <errno.h>
|
|
#include <sys/uio.h>
|
|
#include <string>
|
|
#include <openssl/sha.h>
|
|
#include "MySQLMessage.h"
|
|
#include "mysql_types.h"
|
|
|
|
namespace protocol
|
|
{
|
|
|
|
#define MYSQL_PAYLOAD_MAX ((1 << 24) - 1)
|
|
|
|
MySQLMessage::~MySQLMessage()
|
|
{
|
|
mysql_parser_deinit(parser_);
|
|
mysql_stream_deinit(stream_);
|
|
delete parser_;
|
|
delete stream_;
|
|
}
|
|
|
|
MySQLMessage::MySQLMessage(MySQLMessage&& move)
|
|
{
|
|
this->size_limit = move.size_limit;
|
|
move.size_limit = (size_t)-1;
|
|
|
|
parser_ = move.parser_;
|
|
stream_ = move.stream_;
|
|
seqid_ = move.seqid_;
|
|
cur_size_ = move.cur_size_;
|
|
|
|
move.parser_ = new mysql_parser_t;
|
|
move.stream_ = new mysql_stream_t;
|
|
move.seqid_ = 0;
|
|
move.cur_size_ = 0;
|
|
mysql_parser_init(move.parser_);
|
|
mysql_stream_init(move.stream_);
|
|
}
|
|
|
|
MySQLMessage& MySQLMessage::operator= (MySQLMessage&& move)
|
|
{
|
|
if (this != &move)
|
|
{
|
|
this->size_limit = move.size_limit;
|
|
move.size_limit = (size_t)-1;
|
|
|
|
mysql_parser_deinit(parser_);
|
|
mysql_stream_deinit(stream_);
|
|
delete parser_;
|
|
delete stream_;
|
|
|
|
parser_ = move.parser_;
|
|
stream_ = move.stream_;
|
|
seqid_ = move.seqid_;
|
|
cur_size_ = move.cur_size_;
|
|
|
|
move.parser_ = new mysql_parser_t;
|
|
move.stream_ = new mysql_stream_t;
|
|
move.seqid_ = 0;
|
|
move.cur_size_ = 0;
|
|
mysql_parser_init(move.parser_);
|
|
mysql_stream_init(move.stream_);
|
|
}
|
|
|
|
return *this;
|
|
}
|
|
|
|
int MySQLMessage::append(const void *buf, size_t *size)
|
|
{
|
|
const void *stream_buf;
|
|
size_t stream_len;
|
|
int ret;
|
|
|
|
cur_size_ += *size;
|
|
if (cur_size_ > this->size_limit)
|
|
{
|
|
errno = EMSGSIZE;
|
|
return -1;
|
|
}
|
|
|
|
ret = mysql_stream_write(buf, *size, stream_);
|
|
if (ret > 0)
|
|
{
|
|
seqid_ = mysql_stream_get_seq(stream_);
|
|
mysql_stream_get_buf(&stream_buf, &stream_len, stream_);
|
|
ret = decode_packet((const char *)stream_buf, stream_len);
|
|
if (ret == -2)
|
|
{
|
|
errno = EBADMSG;
|
|
ret = -1;
|
|
}
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
int MySQLMessage::encode(struct iovec vectors[], int max)
|
|
{
|
|
const char *p = buf_.c_str();
|
|
size_t nleft = buf_.size();
|
|
uint8_t seqid_start = seqid_;
|
|
char *head;
|
|
uint32_t length;
|
|
int i = 0;
|
|
|
|
if (nleft == 0)
|
|
return 0;
|
|
|
|
do
|
|
{
|
|
length = (nleft >= MYSQL_PAYLOAD_MAX ? MYSQL_PAYLOAD_MAX
|
|
: (uint32_t)nleft);
|
|
head = heads_[seqid_];
|
|
int3store(head, length);
|
|
head[3] = seqid_++;
|
|
vectors[i].iov_base = head;
|
|
vectors[i].iov_len = 4;
|
|
i++;
|
|
vectors[i].iov_base = const_cast<char *>(p);
|
|
vectors[i].iov_len = length;
|
|
i++;
|
|
|
|
if (i > max)//overflow
|
|
break;
|
|
|
|
if (nleft < MYSQL_PAYLOAD_MAX)
|
|
return i;
|
|
|
|
nleft -= MYSQL_PAYLOAD_MAX;
|
|
p += length;
|
|
} while (seqid_ != seqid_start);
|
|
|
|
errno = EOVERFLOW;
|
|
return -1;
|
|
}
|
|
|
|
void MySQLRequest::set_query(const char *query, size_t length)
|
|
{
|
|
set_command(MYSQL_COM_QUERY);
|
|
buf_.resize(length + 1);
|
|
char *buffer = const_cast<char *>(buf_.c_str());
|
|
|
|
buffer[0] = MYSQL_COM_QUERY;
|
|
if (length > 0)
|
|
memcpy(buffer + 1, query, length);
|
|
}
|
|
|
|
std::string MySQLRequest::get_query() const
|
|
{
|
|
size_t len = buf_.size();
|
|
if (len <= 1 || buf_[0] != MYSQL_COM_QUERY)
|
|
return "";
|
|
|
|
return std::string(buf_.c_str() + 1);
|
|
}
|
|
|
|
int MySQLHandshakeResponse::encode(struct iovec vectors[], int max)
|
|
{
|
|
const char empty13[13] = {0};
|
|
|
|
buf_.clear();
|
|
buf_.append((const char *)&protocol_version_, 1);
|
|
buf_.append(server_version_.c_str(), server_version_.size() + 1);
|
|
buf_.append((const char *)&connection_id_, 4);
|
|
buf_.append((const char *)auth_plugin_data_part_1_, 8);
|
|
buf_.append(empty13, 1);
|
|
buf_.append(empty13, 2);
|
|
buf_.append((const char *)&character_set_, 1);
|
|
buf_.append((const char *)&status_flags_, 2);
|
|
buf_.append(empty13, 13);
|
|
buf_.append((const char *)auth_plugin_data_part_2_, 12);
|
|
return this->MySQLMessage::encode(vectors, max);
|
|
}
|
|
|
|
int MySQLHandshakeResponse::decode_packet(const char *buf, size_t buflen)
|
|
{
|
|
const char *end = buf + buflen;
|
|
const char *pos;
|
|
|
|
if (buflen == 0)
|
|
return -2;
|
|
|
|
protocol_version_ = *buf;
|
|
if (protocol_version_ == 255)
|
|
{
|
|
if (buflen >= 4)
|
|
{
|
|
const_cast<char *>(buf)[3] = '#';
|
|
if (mysql_parser_parse(buf, buflen, parser_) == 1)
|
|
{
|
|
disallowed_ = true;
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
errno = EBADMSG;
|
|
return -1;
|
|
}
|
|
|
|
pos = ++buf;
|
|
while (pos < end && *pos)
|
|
pos++;
|
|
|
|
if (pos >= end || end - pos < 43)
|
|
return -2;
|
|
|
|
server_version_.assign(buf, pos - buf);
|
|
buf = pos + 1;
|
|
connection_id_ = uint4korr(buf);
|
|
buf += 4;
|
|
memcpy(auth_plugin_data_part_1_, buf, 8);
|
|
buf += 9;
|
|
buf += 2;
|
|
character_set_ = *buf++;
|
|
status_flags_ = uint2korr(buf);
|
|
buf += 2;
|
|
buf += 13;
|
|
memcpy(auth_plugin_data_part_2_, buf, 12);
|
|
return 1;
|
|
}
|
|
|
|
static inline void __sha1(const std::string& str, unsigned char *md)
|
|
{
|
|
SHA_CTX ctx;
|
|
SHA1_Init(&ctx);
|
|
SHA1_Update(&ctx, str.c_str(), str.size());
|
|
SHA1_Final(md, &ctx);
|
|
}
|
|
|
|
static inline std::string __sha1_bin(const std::string& str)
|
|
{
|
|
unsigned char md[20];
|
|
|
|
__sha1(str, md);
|
|
return std::string((const char *)md, 20);
|
|
}
|
|
|
|
#define MYSQL_CAPFLAG_CLIENT_PROTOCOL_41 0x00000200
|
|
#define MYSQL_CAPFLAG_CLIENT_SECURE_CONNECTION 0x00008000
|
|
#define MYSQL_CAPFLAG_CLIENT_CONNECT_WITH_DB 0x00000008
|
|
#define MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH 0x00080000
|
|
#define MYSQL_CAPFLAG_CLIENT_MULTI_STATEMENTS 0x00010000
|
|
#define MYSQL_CAPFLAG_CLIENT_MULTI_RESULTS 0x00020000
|
|
#define MYSQL_CAPFLAG_CLIENT_PS_MULTI_RESULTS 0x00040000
|
|
#define MYSQL_CAPFLAG_CLIENT_LOCAL_FILES 0x00000080
|
|
|
|
int MySQLAuthRequest::encode(struct iovec vectors[], int max)
|
|
{
|
|
std::string native;
|
|
char header[32] = {0};
|
|
char *pos = header;
|
|
|
|
int4store(pos, MYSQL_CAPFLAG_CLIENT_PROTOCOL_41 |
|
|
MYSQL_CAPFLAG_CLIENT_SECURE_CONNECTION |
|
|
MYSQL_CAPFLAG_CLIENT_CONNECT_WITH_DB |
|
|
MYSQL_CAPFLAG_CLIENT_PLUGIN_AUTH |
|
|
MYSQL_CAPFLAG_CLIENT_MULTI_RESULTS|
|
|
MYSQL_CAPFLAG_CLIENT_LOCAL_FILES |
|
|
MYSQL_CAPFLAG_CLIENT_MULTI_STATEMENTS |
|
|
MYSQL_CAPFLAG_CLIENT_PS_MULTI_RESULTS);
|
|
pos += 4;
|
|
int4store(pos, 0);
|
|
pos += 4;
|
|
*pos = (uint8_t)character_set_;
|
|
|
|
if (password_.empty())
|
|
native.push_back((char)0);
|
|
else
|
|
{
|
|
native.push_back((char)20);
|
|
std::string first = __sha1_bin(password_);
|
|
std::string second = __sha1_bin(challenge_ + __sha1_bin(first));
|
|
|
|
for (int i = 0; i < 20; i++)
|
|
native.push_back(first[i] ^ second[i]);
|
|
}
|
|
|
|
buf_.clear();
|
|
buf_.append(header, 32);
|
|
buf_.append(username_.c_str(), username_.size() + 1);
|
|
buf_.append(native);
|
|
buf_.append(db_.c_str(), db_.size() + 1);
|
|
buf_.append("mysql_native_password", 22);
|
|
return this->MySQLMessage::encode(vectors, max);
|
|
}
|
|
|
|
int MySQLAuthRequest::decode_packet(const char *buf, size_t buflen)
|
|
{
|
|
const char *end = buf + buflen;
|
|
const char *pos;
|
|
|
|
if (buflen < 32)
|
|
return -2;
|
|
|
|
uint32_t flags = uint4korr(buf);
|
|
|
|
if (!(flags & MYSQL_CAPFLAG_CLIENT_PROTOCOL_41))
|
|
return -2;
|
|
|
|
buf += 8;
|
|
character_set_ = *buf++;
|
|
buf += 23;
|
|
|
|
pos = buf;
|
|
while (pos < end && *pos)
|
|
pos++;
|
|
|
|
if (pos >= end)
|
|
return -2;
|
|
|
|
username_.assign(buf, pos - buf);
|
|
buf = pos + 1;
|
|
|
|
return 1;
|
|
}
|
|
|
|
void MySQLResponse::set_ok_packet()
|
|
{
|
|
uint16_t zero16 = 0;
|
|
buf_.clear();
|
|
buf_.push_back(0x00);
|
|
buf_.push_back(0x00);
|
|
buf_.push_back(0x00);
|
|
buf_.append((const char *)&zero16, 2);
|
|
buf_.append((const char *)&zero16, 2);
|
|
buf_.push_back(0x00);
|
|
}
|
|
|
|
int MySQLResponse::decode_packet(const char *buf, size_t buflen)
|
|
{
|
|
return mysql_parser_parse(buf, buflen, parser_);
|
|
}
|
|
|
|
}
|
|
|