Skip to content

Commit 55fa377

Browse files
committed
P2P: hole punching
1 parent 75f03ee commit 55fa377

File tree

1 file changed

+60
-10
lines changed

1 file changed

+60
-10
lines changed

core/src/p2p.rs

+60-10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use std::{error::Error, net::SocketAddr, fmt::Display};
1+
use std::{error::Error, net::SocketAddr, fmt::Display, str::FromStr};
22

33
use quinn::{Endpoint, SendStream, RecvStream};
44
use tokio::net::{tcp::{ReadHalf, WriteHalf}, TcpStream};
55

6-
use crate::{p2p_utils::{make_client_endpoint, make_server_endpoint}, unsafe_quic_client, i};
6+
use crate::{p2p_utils::{make_client_endpoint, make_server_endpoint}, unsafe_quic_client, *};
77

88
pub fn get_client_endpoint(bind: Option<&str>) -> Result<Endpoint, Box<dyn Error>> {
99
let client_addr = bind.unwrap_or("0.0.0.0:0").parse().unwrap();
@@ -34,7 +34,7 @@ pub enum NatType {
3434
/// 可以作为quic服务器
3535
Server,
3636
/// 通过端口增量猜测公网地址进行侦听
37-
Nat4Increment(u16),
37+
Nat4Increment(i32),
3838
/// 无法主动侦听
3939
Nat4Random,
4040
}
@@ -76,7 +76,7 @@ pub async fn question_stun(udp: &Endpoint, server_addr: &str) -> (NatType, Strin
7676
let port = port.parse::<u16>().unwrap() + 1;
7777
let my_udp_addr2 = get_stun_addr(udp, &format!("{host}:{port}")).await;
7878

79-
let nat_type = match (my_udp_addr1.ip().eq(&my_udp_addr2.ip()), my_udp_addr2.port() - my_udp_addr1.port()) {
79+
let nat_type = match (my_udp_addr1.ip().eq(&my_udp_addr2.ip()), my_udp_addr2.port() as i32 - my_udp_addr1.port() as i32) {
8080
(true, 0) => NatType::Server,
8181
(true, increment) if increment < 10 => NatType::Nat4Increment(increment),
8282
_ => NatType::Nat4Random,
@@ -95,6 +95,12 @@ pub async fn get_stun_addr(udp: &Endpoint, server_addr: &str) -> SocketAddr {
9595
my_udp_addr.parse().unwrap()
9696
}
9797

98+
pub fn increment_port(peer_udp_addr: &SocketAddr, increment: i32) -> SocketAddr {
99+
let mut peer_udp_addr = peer_udp_addr.clone();
100+
peer_udp_addr.set_port((peer_udp_addr.port() as i32 + increment) as u16);
101+
peer_udp_addr
102+
}
103+
98104
pub async fn bridge(udp: Endpoint, my_nat_type: NatType, my_udp_addr: &str, peer_nat_type: NatType, peer_udp_addr: &str, mut tcp: TcpStream,
99105
be_server_if_both_can: bool) {
100106
let hole_addr = udp.local_addr().unwrap();
@@ -108,13 +114,36 @@ pub async fn bridge(udp: Endpoint, my_nat_type: NatType, my_udp_addr: &str, peer
108114
(matches!(my_nat_type, NatType::Nat4Increment(_)) && peer_nat_type == NatType::Nat4Random) || // 对方绝无可能作为服务器,我可猜测端口
109115
(matches!(my_nat_type, NatType::Nat4Increment(_)) && matches!(peer_nat_type, NatType::Nat4Increment(_)) && be_server_if_both_can) // 双方都需要猜测端口,由函数参数决断
110116
;
117+
let peer_udp_addr = SocketAddr::from_str(peer_udp_addr).unwrap();
111118
if should_be_server {
112119
udp.rebind(std::net::UdpSocket::bind("0.0.0.0:0").unwrap()).unwrap(); // drop old client port
113120
// Make sure the server has a chance to clean up
114121
udp.wait_idle().await;
122+
// 非开放型NAT需要打洞,这种情况下peer不能是随机型NAT
123+
let hole = std::net::UdpSocket::bind(hole_addr).unwrap();
124+
match peer_nat_type {
125+
NatType::Server => {
126+
let _hole = hole.send_to(b"Hello", peer_udp_addr);
127+
i!("send_to {peer_udp_addr}");
128+
},
129+
NatType::Nat4Increment(increment) => {
130+
for i in 0..5 {
131+
let peer_udp_addr = increment_port(&peer_udp_addr, i * increment);
132+
let _hole = hole.send_to(b"Hello", peer_udp_addr);
133+
i!("send_to {peer_udp_addr}");
134+
}
135+
},
136+
// 非开放型NAT碰到随机型NAT,束手无策
137+
NatType::Nat4Random => {
138+
let _hole = hole.send_to(b"Hello", peer_udp_addr);
139+
i!("send_to {peer_udp_addr}");
140+
},
141+
}
142+
drop(hole);
143+
// quic server
115144
let udp = get_server_endpoint(Some(&hole_addr.to_string())).unwrap();
116145
i!("UDP({my_udp_addr}) -> await connect");
117-
let incoming_conn = udp.accept().await.unwrap(); // 非开放型NAT可能堵塞
146+
let incoming_conn = udp.accept().await.unwrap();
118147
let visitor = incoming_conn.remote_address().to_string();
119148
i!("UDP({my_udp_addr}) -> {visitor} incoming");
120149
// assert_eq!(visitor, udp_addr);
@@ -124,15 +153,36 @@ pub async fn bridge(udp: Endpoint, my_nat_type: NatType, my_udp_addr: &str, peer
124153
let a = tcp.split();
125154
let b = (s, r);
126155
tcp2udp(a, b).await;
127-
}
128-
if !should_be_server {
129-
let udp_conn = udp.connect(peer_udp_addr.parse().unwrap(), "localhost").unwrap()
130-
.await.expect("无法连接UDP服务器");
156+
} else { // To be client
157+
tokio::time::sleep(std::time::Duration::from_millis(500)).await; // 等待打洞
158+
let udp_conn = match peer_nat_type {
159+
NatType::Nat4Increment(increment) => {
160+
let mut i = -1;
161+
let mut f = || {
162+
let peer_udp_addr = increment_port(&peer_udp_addr, { i+=1; i } * increment);
163+
i!("try connecting: {peer_udp_addr}");
164+
udp.connect(peer_udp_addr, "localhost").unwrap()
165+
};
166+
tokio::select! {
167+
Ok(conn) = f() => conn,
168+
Ok(conn) = f() => conn,
169+
Ok(conn) = f() => conn,
170+
Ok(conn) = f() => conn,
171+
Ok(conn) = f() => conn,
172+
else => {
173+
return e!("无法连接UDP服务器");
174+
}
175+
}
176+
},
177+
_ => {
178+
i!("try connecting: {peer_udp_addr}");
179+
udp.connect(peer_udp_addr, "localhost").unwrap().await.unwrap()
180+
},
181+
};
131182
let (s, mut r) = udp_conn.accept_bi().await.expect("无法读取UDP数据");
132183
let mut buf = vec![0; 5];
133184
r.read_exact(&mut buf).await.unwrap();
134185
let _hello = String::from_utf8_lossy(&buf).to_string();
135-
// wtf!(_hello);
136186
// assert_eq!(_hello, "Hello");
137187
let a = tcp.split();
138188
let b = (s, r);

0 commit comments

Comments
 (0)