websocket :

1. check payload_length in close;
2. add get_status_code() and set_close_message();
3. optimize the malloc codes;
This commit is contained in:
liyingxin
2023-08-16 23:27:54 +08:00
parent 59144544b3
commit f19516f147
6 changed files with 119 additions and 55 deletions

View File

@@ -93,9 +93,9 @@ public:
- **文本**,通过``set_text_data()``这类接口设置;
- **二进制**,通过``set_binary_data()``这类接口设置;
注意这些均为**拷贝接口**,消息在发出之前需要用户来保证data在内存的生命周期
注意这些均为**拷贝接口**,消息会在task里拷贝一份
这两类接口都有一个带``bool fin``参数的接口表示本消息是否finish。因为**WebSocket**协议的数据包允许分段传输,如果你要发送一个完整的消息想分多次发送,则可以使用带``bool fin``的接口,并且把``fin``值设置为``false``。
这两类接口都有一个带``bool fin``参数的接口表示本消息是否finish。因为**WebSocket**协议的数据包允许分段传输,如果你要发送一个完整的消息想分多次发送,则可以使用带``bool fin``的接口,并且把``fin``值设置为``false``,默认值是``true``
#### 4. callback
@@ -169,6 +169,12 @@ wait_group.wait();
这里发起了一个close任务由于close是异步的因此在``task->start()``之后当前线程会退出,我们在当前线程结合一个了``wait_group``进行不占线程的阻塞并在close任务的回调函数里唤醒然后当前线程就可以安全调用``client.deinit()``、删除client实例以及退出了。
开发者可以对close任务设置status_code和close_reason以表示主动关闭的原因。默认status_code为`WSStatusCodeNormal`,如需设置,接口参考:
```cpp
bool WebSocketFrame::set_close_message(uint16_t status_code, const char *data, size_t size);
```
需要注意的是如果不主动发起close任务直接删除client实例那么底层使用的那个网络连接还会存在直到超时或其他原因断开
而``client.deinit()``是个等待内部网络资源完全释放的同步接口需要手动调用以保证程序退出前client的所有资源安全释放。

View File

@@ -153,6 +153,7 @@ WFWebSocketTask *WebSocketClient::create_close_task(websocket_callback_t cb)
protocol::WebSocketFrame *msg = close_task->get_msg();
msg->set_opcode(WebSocketFrameConnectionClose);
msg->set_masking_key(this->channel->gen_masking_key());
msg->set_close_message(WSStatusCodeNormal, "");
return close_task;
}

View File

@@ -118,7 +118,6 @@ int WebSocketFrame::encode(struct iovec vectors[], int max)
uint16_t tmp = htons(this->parser->payload_length);
memcpy(p, &tmp, sizeof(tmp));
p += 2;
}
else
{
@@ -189,96 +188,133 @@ bool WebSocketFrame::set_binary_data(const char *data, size_t size)
bool WebSocketFrame::set_binary_data(const char *data, size_t size, bool fin)
{
bool ret = true;
// -1/0/text/bin. Cannot set into close/ping/pong.
if (this->parser->opcode > WebSocketFrameBinary)
return false;
void *payload_data = this->parser->payload_data;
if (!payload_data)
payload_data = malloc(size);
else if (this->parser->payload_length < size)
payload_data = realloc(payload_data, size);
if (!payload_data)
return false;
this->parser->payload_data = payload_data;
memcpy(this->parser->payload_data, data, size);
this->parser->payload_length = size;
this->parser->opcode = WebSocketFrameBinary;
this->parser->fin = fin;
if (this->parser->payload_length && this->parser->payload_data)
{
ret = false;
free(this->parser->payload_data);
}
this->parser->payload_data = (char *)malloc(size);
memcpy(this->parser->payload_data, data, size);
this->parser->payload_length = size;
return ret;
return true;
}
bool WebSocketFrame::set_text_data(const char *data)
{
return set_text_data(data, strlen(data), true);
return this->set_text_data(data, strlen(data), true);
}
bool WebSocketFrame::set_text_data(const char *data, size_t size, bool fin)
{
bool ret = true;
if (this->parser->opcode > WebSocketFrameBinary)
return false;
void *payload_data = this->parser->payload_data;
if (!payload_data)
payload_data = malloc(size);
else if (this->parser->payload_length < size)
payload_data = realloc(payload_data, size);
if (!payload_data)
return false;
this->parser->payload_data = payload_data;
memcpy(this->parser->payload_data, data, size);
this->parser->payload_length = size;
this->parser->opcode = WebSocketFrameText;
this->parser->fin = fin;
if (this->parser->payload_length && this->parser->payload_data)
{
ret = false;
free(this->parser->payload_data);
}
return true;
}
this->parser->payload_data = (char *)malloc(size);
memcpy(this->parser->payload_data, data, size);
this->parser->payload_length = size;
bool WebSocketFrame::set_close_message(uint16_t status_code, const char *data)
{
return this->set_close_message(status_code, data, strlen(data));
}
return ret;
bool WebSocketFrame::set_close_message(uint16_t status_code,
const char *data, size_t size)
{
if (this->parser->opcode != WebSocketFrameConnectionClose)
return false;
size_t payload_length = size + sizeof(uint16_t);
void *payload_data = this->parser->payload_data;
if (!payload_data)
payload_data = malloc(payload_length);
else if (this->parser->payload_length < payload_length)
payload_data = realloc(payload_data, payload_length);
if (!payload_data)
return false;
this->parser->payload_data = payload_data;
this->parser->payload_length = payload_length;
// this->parser->status_code = status_code;
uint16_t tmp = htons(status_code);
memcpy(this->parser->payload_data, &tmp, sizeof(uint16_t));
memcpy((char *)this->parser->payload_data + sizeof(uint16_t), data, size);
return true;
}
bool WebSocketFrame::set_data(const websocket_parser_t *parser)
{
bool ret = true;
unsigned char *p;
if (this->parser->payload_length && this->parser->payload_data)
{
ret = false;
free(this->parser->payload_data);
}
// this->parser->status_code = parser->status_code;
this->parser->payload_length = parser->payload_length;
void *payload_data = this->parser->payload_data;
size_t payload_length = parser->payload_length;
if (this->parser->opcode == WebSocketFrameConnectionClose &&
parser->status_code != WSStatusCodeUndefined)
{
this->parser->payload_length += 2;
payload_length += sizeof(uint16_t);
}
this->parser->payload_data = malloc(this->parser->payload_length);
p = (unsigned char *)this->parser->payload_data;
if (!payload_data)
payload_data = malloc(payload_length);
else if (this->parser->payload_length < payload_length)
payload_data = realloc(payload_data, payload_length);
if (!payload_data)
return false;
this->parser->payload_data = payload_data;
this->parser->payload_length = payload_length;
this->parser->status_code = parser->status_code;
if (this->parser->opcode == WebSocketFrameConnectionClose &&
parser->status_code != WSStatusCodeUndefined)
{
uint16_t tmp = htons(parser->status_code);
memcpy(p, &tmp, sizeof(tmp));
p += 2;
memcpy(this->parser->payload_data, &tmp, sizeof(uint16_t));
payload_data = (char *)payload_data + sizeof(uint16_t);
}
memcpy(p, parser->payload_data, parser->payload_length);
memcpy(payload_data, parser->payload_data, parser->payload_length);
return ret;
return true;
}
bool WebSocketFrame::get_data(const char **data, size_t *size) const
void WebSocketFrame::get_data(const char **data, size_t *size) const
{
if (this->parser->status_code == WSStatusCodeUnsupportedData ||
this->parser->status_code == WSStatusCodeProtocolError)
{
return false;
}
*data = (char *)this->parser->payload_data;
*size = this->parser->payload_length;
return true;
}
bool WebSocketFrame::finished() const
@@ -286,5 +322,10 @@ bool WebSocketFrame::finished() const
return this->parser->fin;
}
uint16_t WebSocketFrame::get_status_code() const
{
return this->parser->status_code;
}
} // end namespace protocol

View File

@@ -44,7 +44,12 @@ public:
bool set_binary_data(const char *data, size_t size);
bool set_binary_data(const char *data, size_t size, bool fin);
bool get_data(const char **data, size_t *size) const;
bool set_close_message(uint16_t code, const char *data);
bool set_close_message(uint16_t code, const char *data, size_t size);
void get_data(const char **data, size_t *size) const;
uint16_t get_status_code() const;
bool finished() const;

View File

@@ -171,10 +171,13 @@ int websocket_parser_parse(websocket_parser_t *parser)
p = (unsigned char *)parser->payload_data;
if (parser->opcode == WebSocketFrameConnectionClose)
if (parser->opcode == WebSocketFrameConnectionClose &&
parser->payload_length >= 2)
{
parser->status_code = ntohs(*((uint16_t*)p));
p = malloc(parser->payload_length - 2);
if (!p)
return -1;
memcpy(p, (unsigned char *)parser->payload_data + 2,
parser->payload_length - 2);
free(parser->payload_data);

View File

@@ -38,9 +38,15 @@ void process(WFWebSocketTask *task)
task->get_msg()->get_data(&data, &size);
fprintf(stderr, "get text message: [%.*s]\n", (int)size, data);
}
else if (task->get_msg()->get_opcode() == WebSocketFrameConnectionClose)
{
task->get_msg()->get_data(&data, &size);
fprintf(stderr, "close message: [%.*s] status code: %u\n",
(int)size, data, task->get_msg()->get_status_code());
}
else
{
fprintf(stderr, "process opcode=%d\n", task->get_msg()->get_opcode());
fprintf(stderr, "process opcode: %d\n", task->get_msg()->get_opcode());
}
}
@@ -73,12 +79,14 @@ int main(int argc, char *argv[])
wg.done();
return;
}
auto *ping_task = client.create_ping_task(nullptr);
auto *timer_task = WFTaskFactory::create_timer_task(3000000 /* 3s */, nullptr);
auto *close_task = client.create_close_task([&wg] (WFWebSocketTask *task) {
wg.done();
});
close_task->get_msg()->set_close_message(WSStatusCodeNormal, "close after 3 seconds");
series_of(task)->push_back(ping_task);
series_of(task)->push_back(timer_task);
series_of(task)->push_back(close_task);