Files
ZLMediaKit/webrtc/IceTransport.hpp

759 lines
31 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved.
*
* This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit).
*
* Use of this source code is governed by MIT-like license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLMEDIAKIT_WEBRTC_ICE_TRANSPORT_HPP
#define ZLMEDIAKIT_WEBRTC_ICE_TRANSPORT_HPP
#include <map>
#include <list>
#include <string>
#include <memory>
#include <algorithm>
#include <functional>
#include <unordered_map>
#include "json/json.h"
#include "Util/Byte.hpp"
#include "Poller/Timer.h"
#include "Poller/EventPoller.h"
#include "Network/Socket.h"
#include "Network/UdpClient.h"
#include "Network/Session.h"
#include "logger.h"
#include "StunPacket.hpp"
namespace RTC {
uint64_t calCandidatePairPriority(uint32_t G, uint32_t D);
class CandidateAddr {
public:
bool operator==(const CandidateAddr& rhs) const {
return ((_host == rhs._host) && (_port == rhs._port));
}
bool operator!=(const CandidateAddr& rhs) const {
return !(*this == rhs);
}
std::string dumpString() const {
return _host + ":" + std::to_string(_port);
}
public:
std::string _host;
uint16_t _port = 0;
};
class CandidateTuple {
public:
using Ptr = std::shared_ptr<CandidateTuple>;
CandidateTuple() = default;
virtual ~CandidateTuple() = default;
enum class AddressType {
HOST = 1,
SRFLX, //server reflexive
PRFLX, //peer reflexive
RELAY,
};
enum class SecureType {
NOT_SECURE = 1,
SECURE,
};
enum class TransportType {
UDP = 1,
TCP,
};
bool operator<(const CandidateTuple& rhs) const {
return (_priority < rhs._priority);
}
bool operator==(const CandidateTuple& rhs) const {
return ((_addr == rhs._addr)
&& (_priority == rhs._priority)
&& (_transport == rhs._transport) && (_secure == rhs._secure));
}
struct ClassHash {
std::size_t operator()(const CandidateTuple& t) const {
std::string str = t._addr._host + std::to_string(t._addr._port) +
std::to_string((uint32_t)t._transport) + std::to_string((uint32_t)t._secure);
return std::hash<std::string>()(str);
}
};
struct ClassEqual {
bool operator()(const CandidateTuple& a, const CandidateTuple& b) const {
return a == b;
}
};
public:
CandidateAddr _addr;
uint32_t _priority = 0;
TransportType _transport = TransportType::UDP;
SecureType _secure = SecureType::NOT_SECURE;
std::string _ufrag;
std::string _pwd;
};
class CandidateInfo : public CandidateTuple {
public:
using Ptr = std::shared_ptr<CandidateInfo>;
CandidateInfo() = default;
virtual ~CandidateInfo() = default;
enum class AddressType {
INVALID = 0,
HOST = 1,
SRFLX, // server reflx
PRFLX, // peer reflx
RELAY,
};
enum class State {
Frozen = 1, //尚未check,并还不需要check
Waiting, //尚未发送check,但也不是Frozen
InProgress, //已经发起check,但是仍在进行中
Succeeded, //check success
Failed, //check failed
};
bool operator==(const CandidateInfo& rhs) const {
return CandidateTuple::operator==(rhs) && (_type == rhs._type);
}
std::string getAddressTypeStr() const {
return getAddressTypeStr(_type);
}
// 获取候选者地址类型字符串的静态函数
static std::string getAddressTypeStr(CandidateInfo::AddressType type) {
switch (type) {
case CandidateInfo::AddressType::HOST: return "host";
case CandidateInfo::AddressType::SRFLX: return "srflx";
case CandidateInfo::AddressType::PRFLX: return "reflx";
case CandidateInfo::AddressType::RELAY: return "relay";
default: return "invalid";
}
}
static std::string getStateStr(State state) {
switch (state) {
case State::Frozen: return "frozen";
case State::Waiting: return "waiting";
case State::InProgress: return "in_progress";
case State::Succeeded: return "succeeded";
case State::Failed: return "failed";
default: break;
}
return "unknown";
}
std::string dumpString() const {
return getAddressTypeStr() + " " + _addr.dumpString();
}
public:
AddressType _type = AddressType::HOST;
CandidateAddr _base_addr;
};
// ice stun/turn服务器配置
// 格式为: (stun/turn)[s]:host:port[?transport=(tcp/udp)], 默认udp模式
// 例如:
// stun:stun.l.google.com:19302 → 谷歌的 STUN 服务器UDP
// turn:turn.example.com:3478?transport=tcp → 使用 TCP 的 TURN 服务器。
// turns:turn.example.com:5349 → 使用 TLS 的 TURN 服务器。
class IceServerInfo : public CandidateTuple {
public:
using Ptr = std::shared_ptr<IceServerInfo>;
IceServerInfo() = default;
virtual ~IceServerInfo() = default;
IceServerInfo(const std::string &url) { parse(url); }
void parse(const std::string &url);
enum class SchemaType {
TURN = 1,
STUN,
};
public:
std::string _full_url;
std::string _param_strs;
SchemaType _schema = SchemaType::TURN;
};
class IceTransport : public std::enable_shared_from_this<IceTransport> {
public:
using Ptr = std::shared_ptr<IceTransport>;
class Pair {
public:
using Ptr = std::shared_ptr<Pair>;
Pair() = default;
Pair(toolkit::SocketHelper::Ptr socket) : _socket(std::move(socket)) {}
Pair(toolkit::SocketHelper::Ptr socket, std::string peer_host, uint16_t peer_port,
std::shared_ptr<sockaddr_storage> relayed_addr = nullptr) :
_socket(std::move(socket)), _peer_host(std::move(peer_host)), _peer_port(peer_port), _relayed_addr(std::move(relayed_addr)) {
}
Pair(Pair &that) {
_socket = that._socket;
_peer_host = that._peer_host;
_peer_port = that._peer_port;
_relayed_addr = nullptr;
if (that._relayed_addr) {
_relayed_addr = std::make_shared<sockaddr_storage>();
memcpy(_relayed_addr.get(), that._relayed_addr.get(), sizeof(sockaddr_storage));
}
}
virtual ~Pair() = default;
void get_peer_addr(sockaddr_storage &peer_addr) const {
if (!_peer_host.empty()) {
peer_addr = toolkit::SockUtil::make_sockaddr(_peer_host.data(), _peer_port);
} else {
auto addr = _socket->get_peer_addr();
if (addr->sa_family == AF_INET6 && IN6_IS_ADDR_V4MAPPED(&((struct sockaddr_in6 *)addr)->sin6_addr)) {
memset(&peer_addr, 0, sizeof(peer_addr));
// 转换IPv6 v4mapped地址为IPv4地址
struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)addr;
struct sockaddr_in *addr4 = (struct sockaddr_in *)&peer_addr;
addr4->sin_family = AF_INET;
addr4->sin_port = addr6->sin6_port;
memcpy(&addr4->sin_addr, &addr6->sin6_addr.s6_addr[12], 4);
} else {
memcpy(&peer_addr, addr, toolkit::SockUtil::get_sock_len(addr));
}
}
}
bool get_relayed_addr(sockaddr_storage &peerAddr) const {
if (!_relayed_addr) {
return false;
}
memcpy(&peerAddr, _relayed_addr.get(), sizeof(peerAddr));
return true;
}
std::string get_local_ip() const { return _socket->get_local_ip(); }
uint16_t get_local_port() const { return _socket->get_local_port(); }
std::string get_peer_ip() const { return !_peer_host.empty() ? _peer_host : _socket->get_peer_ip(); }
uint16_t get_peer_port() const { return !_peer_host.empty() ? _peer_port : _socket->get_peer_port(); }
std::string get_relayed_ip() const { return _relayed_addr ? toolkit::SockUtil::inet_ntoa((const struct sockaddr *)_relayed_addr.get()) : ""; }
uint16_t get_relayed_port() const { return _relayed_addr ? toolkit::SockUtil::inet_port((const struct sockaddr *)_relayed_addr.get()) : 0; }
static bool is_same_relayed_addr(Pair *a, Pair *b) {
if (a->_relayed_addr && b->_relayed_addr) {
return toolkit::SockUtil::is_same_addr(
reinterpret_cast<const struct sockaddr *>(a->_relayed_addr.get()), reinterpret_cast<const struct sockaddr *>(b->_relayed_addr.get()));
}
return (a->_relayed_addr == b->_relayed_addr);
}
static bool is_same(Pair* a, Pair* b) {
// FIXME: a->_socket == b->_socket条件成立后后面get_peer_ip和get_peer_port一定相同
if ((a->_socket == b->_socket)
&& (a->get_peer_ip() == b->get_peer_ip())
&& (a->get_peer_port() == b->get_peer_port())
&& (is_same_relayed_addr(a, b))) {
return true;
}
return false;
}
std::string dumpString(uint8_t flag) const {
toolkit::_StrPrinter sp;
static const char* fStr[] = { "<-", "->", "<->" };
sp << (_socket ? (_socket->getSock()->sockType() == toolkit::SockNum::Sock_TCP ? "tcp " : "udp ") : "")
<< get_local_ip() << ":" << get_local_port() << fStr[flag] << get_peer_ip() << ":" << get_peer_port();
if (_relayed_addr && flag == 2) {
sp << " relay " << get_relayed_ip() << ":" << get_relayed_port();
}
return sp;
}
public:
toolkit::SocketHelper::Ptr _socket;
//对端host:port 地址因为多个pair会复用一个socket对象因此可能会和_socket的创建bind信息不一致
std::string _peer_host;
uint16_t _peer_port;
//中继后地址用于实现TURN转发地址当该地址不为空时该地址为真正的peer地址,_peer_host和_peer_port表示中继地址
std::shared_ptr<sockaddr_storage> _relayed_addr;
};
class Listener {
public:
virtual ~Listener() = default;
public:
virtual void onIceTransportRecvData(const toolkit::Buffer::Ptr& buffer, const Pair::Ptr& pair) = 0;
virtual void onIceTransportGatheringCandidate(const Pair::Ptr&, const CandidateInfo&) = 0;
virtual void onIceTransportDisconnected() = 0;
virtual void onIceTransportCompleted() = 0;
};
public:
using MsgHandler = std::function<void(const StunPacket::Ptr&, const Pair::Ptr&)>;
struct RequestInfo {
StunPacket::Ptr _request; // 原始请求包
MsgHandler _handler; // 响应处理函数
Pair::Ptr _pair; // 发送对
uint64_t _send_time; // 首次发送时间(毫秒)
uint64_t _next_timeout; // 下次超时时间(毫秒)
uint32_t _retry_count; // 当前重传次数
uint32_t _rto = 500; // 当前RTO值(毫秒) 初始RTO 500ms
RequestInfo(StunPacket::Ptr req, MsgHandler h, Pair::Ptr p)
: _request(std::move(req))
, _handler(std::move(h))
, _pair(std::move(p))
, _retry_count(0) {
_send_time = toolkit::getCurrentMillisecond();
_next_timeout = _send_time + _rto;
}
};
IceTransport(Listener* listener, std::string ufrag, std::string password, toolkit::EventPoller::Ptr poller);
virtual ~IceTransport() {}
virtual void initialize();
const toolkit::EventPoller::Ptr& getPoller() const { return _poller; }
const std::string& getIdentifier() const { return _identifier; }
const std::string& getUfrag() const { return _ufrag; }
const std::string& getPassword() const { return _password; }
void setUfrag(std::string ufrag) { _ufrag = std::move(ufrag); }
void setPassword(std::string password) { _password = std::move(password); }
virtual bool processSocketData(const uint8_t* data, size_t len, const Pair::Ptr& pair);
virtual void sendSocketData(const toolkit::Buffer::Ptr& buf, const Pair::Ptr& pair, bool flush = true);
void sendSocketData_l(const toolkit::Buffer::Ptr& buf, const Pair::Ptr& pair, bool flush = true);
protected:
virtual void processStunPacket(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
virtual void processRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
virtual void processResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
virtual bool processChannelData(const uint8_t* data, size_t len, const Pair::Ptr& pair);
virtual StunPacket::Authentication checkRequestAuthentication(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
StunPacket::Authentication checkResponseAuthentication(const StunPacket::Ptr& request, const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void processUnauthorizedResponse(const StunPacket::Ptr& response, const StunPacket::Ptr& request, const Pair::Ptr& pair, MsgHandler handler);
virtual void handleBindingRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
virtual void handleChannelData(uint16_t channel_number, const char* data, size_t len, const Pair::Ptr& pair) {};
void sendChannelData(uint16_t channel_number, const toolkit::Buffer::Ptr &buffer, const Pair::Ptr& pair);
virtual void sendUnauthorizedResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void sendErrorResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, StunAttrErrorCode::Code errorCode);
void sendRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair, MsgHandler handler);
void sendPacket(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
// For permissions
bool hasPermission(const sockaddr_storage& addr);
void addPermission(const sockaddr_storage& addr);
// For Channel Bind
bool hasChannelBind(uint16_t channel_number);
bool hasChannelBind(const sockaddr_storage& addr, uint16_t& channel_number);
void addChannelBind(uint16_t channel_number, const sockaddr_storage& addr);
toolkit::SocketHelper::Ptr createSocket(CandidateTuple::TransportType type, const std::string &peer_host, uint16_t peer_port, const std::string &local_ip, uint16_t local_port = 0);
toolkit::SocketHelper::Ptr createUdpSocket(const std::string &target_host, uint16_t peer_port, const std::string &local_ip, uint16_t local_port);
void checkRequestTimeouts();
void retransmitRequest(const std::string& transaction_id, RequestInfo& req_info);
protected:
std::string _identifier;
toolkit::EventPoller::Ptr _poller;
Listener* _listener = nullptr;
std::unordered_map<std::string /*transcation ID*/, RequestInfo> _response_handlers;
std::unordered_map<std::pair<StunPacket::Class, StunPacket::Method>, MsgHandler, StunPacket::ClassMethodHash> _request_handlers;
// for local
std::string _ufrag;
std::string _password;
// For permissions
std::unordered_map<sockaddr_storage /*peer ip:port*/, uint64_t /* create or fresh time*/,
toolkit::SockUtil::SockAddrHash, toolkit::SockUtil::SockAddrEqual> _permissions;
// For Channel Bind
std::unordered_map<uint16_t /*channel number*/, sockaddr_storage /*peer ip:port*/> _channel_bindings;
std::unordered_map<uint16_t /*channel number*/, uint64_t /*bind or fresh time*/> _channel_binding_times;
// For STUN request retry
std::shared_ptr<toolkit::Timer> _retry_timer;
};
class IceServer : public IceTransport {
public:
using Ptr = std::shared_ptr<IceServer>;
using WeakPtr = std::weak_ptr<IceServer>;
IceServer(Listener* listener, std::string ufrag, std::string password, toolkit::EventPoller::Ptr poller);
virtual ~IceServer() {}
bool processSocketData(const uint8_t* data, size_t len, const Pair::Ptr& pair) override;
void relayForwordingData(const toolkit::Buffer::Ptr& buffer, const sockaddr_storage& peer_addr);
void relayBackingData(const toolkit::Buffer::Ptr& buffer, const Pair::Ptr& pair, const sockaddr_storage& peer_addr);
protected:
void processRelayPacket(const toolkit::Buffer::Ptr &buffer, const Pair::Ptr& pair);
void handleAllocateRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void handleRefreshRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void handleCreatePermissionRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void handleChannelBindRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void handleSendIndication(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void handleChannelData(uint16_t channel_number, const char* data, size_t len, const Pair::Ptr& pair) override;
StunPacket::Authentication checkRequestAuthentication(const StunPacket::Ptr& packet, const Pair::Ptr& pair) override;
void sendDataIndication(const sockaddr_storage& peer_addr, const toolkit::Buffer::Ptr &buffer, const Pair::Ptr& pair);
void sendUnauthorizedResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair) override;
toolkit::SocketHelper::Ptr allocateRelayed(const Pair::Ptr& pair);
toolkit::SocketHelper::Ptr createRelayedUdpSocket(const std::string &peer_host, uint16_t peer_port, const std::string &local_ip, uint16_t local_port);
protected:
std::vector<toolkit::BufferLikeString> _nonce_list;
std::unordered_map<sockaddr_storage /*peer ip:port*/, std::pair<std::shared_ptr<uint16_t> /* port */, Pair::Ptr /*relayed_pairs*/>,
toolkit::SockUtil::SockAddrHash, toolkit::SockUtil::SockAddrEqual> _relayed_pairs;
Pair::Ptr _session_pair;
};
class IceAgent : public IceTransport {
public:
using Ptr = std::shared_ptr<IceAgent>;
// 候选者对信息结构
struct CandidatePair {
Pair::Ptr _local_pair; // 本地候选者对
CandidateInfo _remote_candidate; // 远程候选者信息
CandidateInfo _local_candidate; // 本地候选者信息
uint64_t _priority; // 候选者对优先级64位符合RFC 8445
CandidateInfo::State _state; // 连通性检查状态
bool _nominated = false;
CandidatePair(Pair::Ptr local_pair, CandidateInfo remote, CandidateInfo local)
: _local_pair(std::move(local_pair))
, _remote_candidate(std::move(remote))
, _local_candidate(std::move(local))
, _state(CandidateInfo::State::Frozen) {
_priority = calCandidatePairPriority(local._priority, remote._priority);
}
std::string dumpString() const {
return "local " + _local_candidate.dumpString() + " <-> remote " + _remote_candidate.dumpString();
}
// 比较操作符,用于优先级排序(高优先级在前)
bool operator<(const CandidatePair& other) const {
return _priority > other._priority;
}
};
enum class State {
//checklist state and ice session state
Running = 1, //正在进行候选地址的连通性检测
Nominated, //发起提名,等待应答
Completed, //所有候选地址完成验证,且至少有一路连接检测成功
Failed, //所有候选地址检测失败,连接不可用
};
static const char* stateToString(State state) {
switch (state) {
case State::Running: return "Running";
case State::Completed: return "Completed";
case State::Failed: return "Failed";
default: return "Unknown";
}
}
enum class Role {
Controlling = 1,
Controlled,
};
enum class Implementation {
Lite = 1,
Full,
};
IceAgent(Listener* listener, Implementation implementation, Role role,
std::string ufrag, std::string password, toolkit::EventPoller::Ptr poller);
virtual ~IceAgent() {}
void setIceServer(IceServerInfo::Ptr ice_server) {
_ice_server = std::move(ice_server);
}
void gatheringCandidate(const CandidateTuple::Ptr& candidate_tuple, bool gathering_rflx, bool gathering_relay);
void connectivityCheck(CandidateInfo& candidate);
void nominated(const Pair::Ptr& pair, CandidateTuple& candidate);
void sendSocketData(const toolkit::Buffer::Ptr& buf, const Pair::Ptr& pair, bool flush = true) override;
IceAgent::Implementation getImplementation() const {
return _implementation;
}
void setgetImplementation(IceAgent::Implementation implementation) {
InfoL << (uint32_t)implementation;
_implementation = implementation;
}
IceAgent::Role getRole() const {
return _role;
}
void setRole(IceAgent::Role role) {
InfoL << (uint32_t)role;
_role = role;
}
IceAgent::State getState() const {
return _state;
}
void setState(IceAgent::State state) {
InfoL << stateToString(state);
_state = state;
}
Pair::Ptr getSelectedPair(bool try_last = false) const {
return try_last ? _last_selected_pair.lock() : _selected_pair;
}
bool setSelectedPair(const Pair::Ptr& pair);
void removePair(const toolkit::SocketHelper *socket);
std::vector<Pair::Ptr> getPairs() const;
// 获取checklist信息用于API查询
Json::Value getChecklistInfo() const;
size_t getRecvSpeed();
size_t getRecvTotalBytes();
size_t getSendSpeed();
size_t getSendTotalBytes();
protected:
void gatheringSrflxCandidate(const Pair::Ptr& pair);
void gatheringRelayCandidate(const Pair::Ptr& pair);
void localRelayedConnectivityCheck(CandidateInfo& candidate);
void connectivityCheck(const Pair::Ptr& pair, CandidateTuple& candidate);
void tryTriggerredCheck(const Pair::Ptr& pair);
void sendBindRequest(const Pair::Ptr& pair, MsgHandler handler);
void sendBindRequest(const Pair::Ptr& pair, CandidateTuple& candidate, bool use_candidate, MsgHandler handler);
void sendAllocateRequest(const Pair::Ptr& pair);
void sendCreatePermissionRequest(const Pair::Ptr& pair, const sockaddr_storage& peer_addr);
void sendChannelBindRequest(const Pair::Ptr& pair, uint16_t channel_number, const sockaddr_storage& peer_addr);
void processRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) override;
void handleBindingRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) override;
void handleGatheringCandidateResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void handleConnectivityCheckResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, CandidateTuple& candidate);
void handleNominatedResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, CandidateTuple& candidate);
void handleAllocateResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void handleCreatePermissionResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, const sockaddr_storage& peer_addr);
void handleChannelBindResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, uint16_t channel_number, const sockaddr_storage& peer_addr);
void handleDataIndication(const StunPacket::Ptr& packet, const Pair::Ptr& pair);
void handleChannelData(uint16_t channel_number, const char* data, size_t len, const Pair::Ptr& pair) override;
void onGatheringCandidate(const Pair::Ptr& pair, CandidateInfo& candidate);
void onConnected(const Pair::Ptr& pair);
void onCompleted(const Pair::Ptr& pair);
void refreshPermissions();
void refreshChannelBindings();
void sendSendIndication(const sockaddr_storage& peer_addr, const toolkit::Buffer::Ptr& buffer, const Pair::Ptr& pair);
void sendRelayPacket(const toolkit::Buffer::Ptr& buffer, const Pair::Ptr& pair, bool flush);
private:
CandidateInfo getLocalCandidateInfo(const Pair::Ptr& local_pair);
void addToChecklist(const Pair::Ptr& local_pair, CandidateInfo& remote_candidate);
protected:
IceServerInfo::Ptr _ice_server;
std::shared_ptr<toolkit::Timer> _refresh_timer;
// for candidate
Implementation _implementation = Implementation::Full;
Role _role = Role::Controlling; //ice role
uint64_t _tiebreaker = 0; // 8 bytes unsigned integer.
State _state = IceAgent::State::Running; //ice session state
Pair::Ptr _selected_pair;
Pair::Ptr _nominated_pair;
StunPacket::Ptr _nominated_response;
std::weak_ptr<Pair> _last_selected_pair;
// 双向索引的候选地址管理结构
struct SocketCandidateManager {
// socket -> candidates 的一对多映射
std::unordered_map<toolkit::SocketHelper::Ptr, std::vector<CandidateInfo>> socket_to_candidates;
// candidate -> socket 的映射(用于快速查找)
std::unordered_map<CandidateInfo, toolkit::SocketHelper::Ptr, CandidateTuple::ClassHash, CandidateTuple::ClassEqual> candidate_to_socket;
// 按类型分组的socket列表方便遍历
std::vector<toolkit::SocketHelper::Ptr> _host_sockets; // HOST类型socket
std::vector<toolkit::SocketHelper::Ptr> _relay_sockets; // RELAY类型socket
bool _has_relayed_candidate = false;
// 添加映射关系带5元组重复检查
bool addMapping(toolkit::SocketHelper::Ptr socket, const CandidateInfo& candidate) {
// 检查5元组是否已存在
if (candidate_to_socket.find(candidate) != candidate_to_socket.end()) {
return false; // 已存在相同的5元组
}
socket_to_candidates[socket].push_back(candidate);
candidate_to_socket[candidate] = socket;
// 按类型分组
if (candidate._type != CandidateInfo::AddressType::RELAY) {
addHostSocket(std::move(socket));
} else if (candidate._type == CandidateInfo::AddressType::RELAY) {
addRelaySocket(std::move(socket));
}
return true;
}
// 获取socket对应的所有candidates
std::vector<CandidateInfo> getCandidates(const toolkit::SocketHelper::Ptr& socket) const {
auto it = socket_to_candidates.find(socket);
return (it != socket_to_candidates.end()) ? it->second : std::vector<CandidateInfo>();
}
// 获取candidate对应的socket
toolkit::SocketHelper::Ptr getSocket(const CandidateInfo& candidate) const {
auto it = candidate_to_socket.find(candidate);
return (it != candidate_to_socket.end()) ? it->second : nullptr;
}
// 获取所有socket便于遍历
std::vector<toolkit::SocketHelper::Ptr> getAllSockets() const {
std::vector<toolkit::SocketHelper::Ptr> result;
result.reserve(_host_sockets.size() + _relay_sockets.size());
result.insert(result.end(), _host_sockets.begin(), _host_sockets.end());
result.insert(result.end(), _relay_sockets.begin(), _relay_sockets.end());
return result;
}
// 获取所有candidates便于遍历
std::vector<CandidateInfo> getAllCandidates() const {
std::vector<CandidateInfo> result;
for (auto& pair : candidate_to_socket) {
result.push_back(pair.first);
}
return result;
}
// 直接添加host socket
void addHostSocket(toolkit::SocketHelper::Ptr socket) {
if (std::find(_host_sockets.begin(), _host_sockets.end(), socket) == _host_sockets.end()) {
_host_sockets.emplace_back(std::move(socket));
}
}
// 直接添加relay socket
void addRelaySocket(toolkit::SocketHelper::Ptr socket) {
if (std::find(_relay_sockets.begin(), _relay_sockets.end(), socket) == _relay_sockets.end()) {
_relay_sockets.emplace_back(std::move(socket));
}
}
// 获取host sockets
const std::vector<toolkit::SocketHelper::Ptr>& getHostSockets() const {
return _host_sockets;
}
// 获取relay sockets
const std::vector<toolkit::SocketHelper::Ptr>& getRelaySockets() const {
return _relay_sockets;
}
// 移除host socket
void removeHostSocket(const toolkit::SocketHelper::Ptr& socket) {
auto it = std::find(_host_sockets.begin(), _host_sockets.end(), socket);
if (it != _host_sockets.end()) {
_host_sockets.erase(it);
}
}
// 移除relay socket
void removeRelaySocket(const toolkit::SocketHelper::Ptr& socket) {
auto it = std::find(_relay_sockets.begin(), _relay_sockets.end(), socket);
if (it != _relay_sockets.end()) {
_relay_sockets.erase(it);
}
}
// 清空host sockets
void clearHostSockets() {
_host_sockets.clear();
}
// 清空relay sockets
void clearRelaySockets() {
_relay_sockets.clear();
}
// 获取host socket数量
size_t getHostSocketCount() const {
return _host_sockets.size();
}
// 获取relay socket数量
size_t getRelaySocketCount() const {
return _relay_sockets.size();
}
};
//for GATHERING_CANDIDATE
SocketCandidateManager _socket_candidate_manager; //local candidates
//for CONNECTIVITY_CHECK
using CandidateSet = std::unordered_set<CandidateInfo, CandidateTuple::ClassHash, CandidateTuple::ClassEqual>;
CandidateSet _remote_candidates;
//TODO:当前仅支持多数据流复用一个checklist
std::vector<std::shared_ptr<CandidatePair>> _check_list;
std::vector<std::shared_ptr<CandidatePair>> _valid_list;
std::shared_ptr<CandidatePair> _select_candidate_pair;
};
} // namespace RTC
#endif //ZLMEDIAKIT_WEBRTC_ICE_TRANSPORT_HPP