github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/quic/quic.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 "net" 9 "sync" 10 "time" 11 12 "github.com/Asutorufa/yuhaiin/pkg/net/deadline" 13 "github.com/Asutorufa/yuhaiin/pkg/net/dialer" 14 "github.com/Asutorufa/yuhaiin/pkg/net/nat" 15 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 16 "github.com/Asutorufa/yuhaiin/pkg/protos/node/point" 17 "github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol" 18 "github.com/Asutorufa/yuhaiin/pkg/protos/statistic" 19 "github.com/Asutorufa/yuhaiin/pkg/utils/id" 20 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 21 "github.com/Asutorufa/yuhaiin/pkg/utils/syncmap" 22 "github.com/quic-go/quic-go" 23 ) 24 25 type Client struct { 26 netapi.EmptyDispatch 27 28 tlsConfig *tls.Config 29 dialer netapi.Proxy 30 31 session quic.Connection 32 underlying net.PacketConn 33 sessionMu sync.Mutex 34 sessionUnix int64 35 36 packetConn *ConnectionPacketConn 37 natMap syncmap.SyncMap[uint64, *clientPacketConn] 38 39 idg id.IDGenerator 40 41 host *net.UDPAddr 42 } 43 44 func init() { 45 point.RegisterProtocol(NewClient) 46 } 47 48 func NewClient(config *protocol.Protocol_Quic) point.WrapProxy { 49 return func(dialer netapi.Proxy) (netapi.Proxy, error) { 50 51 var host *net.UDPAddr = &net.UDPAddr{IP: net.IPv4zero} 52 53 if config.Quic.Host != "" { 54 addr, err := netapi.ParseAddress(statistic.Type_udp, config.Quic.Host) 55 if err == nil { 56 if ur := addr.UDPAddr(context.TODO()); ur.Err == nil { 57 host = ur.V 58 } 59 } 60 } 61 62 tlsConfig := point.ParseTLSConfig(config.Quic.Tls) 63 if tlsConfig == nil { 64 tlsConfig = &tls.Config{} 65 } 66 67 if point.IsBootstrap(dialer) { 68 dialer = nil 69 } 70 71 c := &Client{ 72 dialer: dialer, 73 tlsConfig: tlsConfig, 74 host: host, 75 } 76 77 return c, nil 78 } 79 } 80 81 func (c *Client) initSession(ctx context.Context) (quic.Connection, error) { 82 session := c.session 83 84 if session != nil { 85 select { 86 case <-session.Context().Done(): 87 default: 88 return session, nil 89 } 90 } 91 92 c.sessionMu.Lock() 93 defer c.sessionMu.Unlock() 94 95 if c.session != nil { 96 select { 97 case <-c.session.Context().Done(): 98 default: 99 return c.session, nil 100 } 101 } 102 103 if c.session != nil { 104 _ = c.session.CloseWithError(0, "") 105 } 106 107 if c.underlying != nil { 108 _ = c.underlying.Close() 109 } 110 111 var conn net.PacketConn 112 var err error 113 114 if c.dialer == nil { 115 conn, err = dialer.ListenPacket("udp", "") 116 } else { 117 conn, err = c.dialer.PacketConn(ctx, netapi.EmptyAddr) 118 } 119 if err != nil { 120 return nil, err 121 } 122 123 tr := quic.Transport{ 124 Conn: conn, 125 ConnectionIDLength: 12, 126 } 127 128 config := &quic.Config{ 129 KeepAlivePeriod: 15 * time.Second, 130 MaxIdleTimeout: nat.IdleTimeout, 131 EnableDatagrams: true, 132 } 133 134 session, err = tr.Dial(ctx, c.host, c.tlsConfig, config) 135 if err != nil { 136 _ = conn.Close() 137 return nil, err 138 } 139 140 pconn := NewConnectionPacketConn(session) 141 142 c.underlying = conn 143 c.session = session 144 c.sessionUnix = time.Now().Unix() 145 146 // Datagram 147 c.packetConn = pconn 148 149 go func() { 150 defer session.CloseWithError(0, "") 151 for { 152 id, data, err := pconn.Receive(context.TODO()) 153 if err != nil { 154 return 155 } 156 157 cchan, ok := c.natMap.Load(id) 158 if !ok { 159 continue 160 } 161 162 select { 163 case <-session.Context().Done(): 164 return 165 case <-cchan.ctx.Done(): 166 case cchan.msg <- data: 167 } 168 } 169 }() 170 return session, nil 171 } 172 173 func (c *Client) Conn(ctx context.Context, s netapi.Address) (net.Conn, error) { 174 session, err := c.initSession(ctx) 175 if err != nil { 176 return nil, err 177 } 178 179 stream, err := session.OpenStream() 180 if err != nil { 181 _ = session.CloseWithError(0, "") 182 return nil, err 183 } 184 185 return &interConn{ 186 Stream: stream, 187 session: session, 188 time: c.sessionUnix, 189 }, nil 190 } 191 192 func (c *Client) PacketConn(ctx context.Context, host netapi.Address) (net.PacketConn, error) { 193 _, err := c.initSession(ctx) 194 if err != nil { 195 return nil, err 196 } 197 198 ctx, cancel := context.WithCancel(context.TODO()) 199 200 cp := &clientPacketConn{ 201 c: c, 202 ctx: ctx, 203 cancel: cancel, 204 session: c.packetConn, 205 id: c.idg.Generate(), 206 msg: make(chan *pool.Buffer, 64), 207 deadline: deadline.NewPipe(), 208 } 209 c.natMap.Store(cp.id, cp) 210 211 return cp, nil 212 } 213 214 var _ net.Conn = (*interConn)(nil) 215 216 type interConn struct { 217 quic.Stream 218 session quic.Connection 219 time int64 220 } 221 222 func (c *interConn) Read(p []byte) (n int, err error) { 223 n, err = c.Stream.Read(p) 224 225 if err != nil && err != io.EOF { 226 qe, ok := err.(*quic.StreamError) 227 if ok && qe.ErrorCode == quic.StreamErrorCode(quic.NoError) { 228 err = io.EOF 229 } 230 } 231 return 232 } 233 234 func (c *interConn) Write(p []byte) (n int, err error) { 235 n, err = c.Stream.Write(p) 236 if err != nil && err != io.EOF { 237 qe, ok := err.(*quic.StreamError) 238 if ok && qe.ErrorCode == quic.StreamErrorCode(quic.NoError) { 239 err = io.EOF 240 } 241 } 242 return 243 } 244 245 func (c *interConn) Close() error { 246 c.Stream.CancelRead(0) 247 return c.Stream.Close() 248 } 249 250 func (c *interConn) LocalAddr() net.Addr { 251 return &QuicAddr{ 252 Addr: c.session.LocalAddr(), 253 ID: c.Stream.StreamID(), 254 time: c.time, 255 } 256 } 257 258 func (c *interConn) RemoteAddr() net.Addr { 259 return &QuicAddr{ 260 Addr: c.session.RemoteAddr(), 261 ID: c.Stream.StreamID(), 262 time: c.time, 263 } 264 } 265 266 type QuicAddr struct { 267 Addr net.Addr 268 ID quic.StreamID 269 time int64 270 } 271 272 func (q *QuicAddr) String() string { 273 if q.time == 0 { 274 return fmt.Sprintf("quic://%d@%v", q.ID, q.Addr) 275 } 276 return fmt.Sprintf("quic://%d-%d@%v", q.time, q.ID, q.Addr) 277 } 278 279 func (q *QuicAddr) Network() string { return "udp" } 280 281 type clientPacketConn struct { 282 c *Client 283 session *ConnectionPacketConn 284 id uint64 285 286 ctx context.Context 287 cancel context.CancelFunc 288 289 msg chan *pool.Buffer 290 291 deadline *deadline.PipeDeadline 292 } 293 294 func (x *clientPacketConn) ReadFrom(p []byte) (n int, _ net.Addr, err error) { 295 select { 296 case <-x.session.Context().Done(): 297 return x.read(p, func() error { 298 x.Close() 299 return x.session.Context().Err() 300 }) 301 case <-x.deadline.ReadContext().Done(): 302 return x.read(p, x.deadline.ReadContext().Err) 303 case <-x.ctx.Done(): 304 return x.read(p, x.ctx.Err) 305 case msg := <-x.msg: 306 defer msg.Free() 307 308 n = copy(p, msg.Bytes()) 309 return n, x.session.conn.RemoteAddr(), nil 310 } 311 } 312 313 func (x *clientPacketConn) read(p []byte, err func() error) (n int, _ net.Addr, _ error) { 314 if len(x.msg) > 0 { 315 select { 316 case msg := <-x.msg: 317 defer msg.Free() 318 319 n = copy(p, msg.Bytes()) 320 return n, x.session.conn.RemoteAddr(), nil 321 default: 322 } 323 } 324 325 return 0, nil, err() 326 } 327 328 func (x *clientPacketConn) WriteTo(p []byte, _ net.Addr) (n int, err error) { 329 select { 330 case <-x.ctx.Done(): 331 return 0, x.ctx.Err() 332 case <-x.deadline.WriteContext().Done(): 333 return 0, x.deadline.WriteContext().Err() 334 case <-x.session.Context().Done(): 335 return 0, x.session.Context().Err() 336 default: 337 } 338 339 err = x.session.Write(p, x.id) 340 if err != nil { 341 return 0, err 342 } 343 return len(p), nil 344 } 345 346 func (x *clientPacketConn) Close() error { 347 x.cancel() 348 x.deadline.Close() 349 x.c.natMap.Delete(x.id) 350 return nil 351 } 352 353 func (x *clientPacketConn) LocalAddr() net.Addr { 354 return &QuicAddr{ 355 Addr: x.session.conn.LocalAddr(), 356 ID: quic.StreamID(x.id), 357 } 358 } 359 360 func (x *clientPacketConn) SetDeadline(t time.Time) error { 361 select { 362 case <-x.ctx.Done(): 363 return io.EOF 364 default: 365 } 366 367 x.deadline.SetDeadline(t) 368 return nil 369 } 370 371 func (x *clientPacketConn) SetReadDeadline(t time.Time) error { 372 x.deadline.SetReadDeadline(t) 373 return nil 374 } 375 376 func (x *clientPacketConn) SetWriteDeadline(t time.Time) error { 377 x.deadline.SetWriteDeadline(t) 378 return nil 379 }