github.com/metacubex/mihomo@v1.18.5/transport/tuic/v4/client.go (about) 1 package v4 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "crypto/tls" 8 "errors" 9 "net" 10 "runtime" 11 "sync" 12 "sync/atomic" 13 "time" 14 15 atomic2 "github.com/metacubex/mihomo/common/atomic" 16 N "github.com/metacubex/mihomo/common/net" 17 "github.com/metacubex/mihomo/common/pool" 18 C "github.com/metacubex/mihomo/constant" 19 "github.com/metacubex/mihomo/log" 20 "github.com/metacubex/mihomo/transport/tuic/common" 21 22 "github.com/metacubex/quic-go" 23 "github.com/puzpuzpuz/xsync/v3" 24 "github.com/zhangyunhao116/fastrand" 25 ) 26 27 type ClientOption struct { 28 TlsConfig *tls.Config 29 QuicConfig *quic.Config 30 Token [32]byte 31 UdpRelayMode common.UdpRelayMode 32 CongestionController string 33 ReduceRtt bool 34 RequestTimeout time.Duration 35 MaxUdpRelayPacketSize int 36 FastOpen bool 37 MaxOpenStreams int64 38 CWND int 39 } 40 41 type clientImpl struct { 42 *ClientOption 43 udp bool 44 45 quicConn quic.Connection 46 connMutex sync.Mutex 47 48 openStreams atomic.Int64 49 closed atomic.Bool 50 51 udpInputMap *xsync.MapOf[uint32, net.Conn] 52 53 // only ready for PoolClient 54 dialerRef C.Dialer 55 lastVisited atomic2.TypedValue[time.Time] 56 } 57 58 func (t *clientImpl) OpenStreams() int64 { 59 return t.openStreams.Load() 60 } 61 62 func (t *clientImpl) DialerRef() C.Dialer { 63 return t.dialerRef 64 } 65 66 func (t *clientImpl) LastVisited() time.Time { 67 return t.lastVisited.Load() 68 } 69 70 func (t *clientImpl) SetLastVisited(last time.Time) { 71 t.lastVisited.Store(last) 72 } 73 74 func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn common.DialFunc) (quic.Connection, error) { 75 t.connMutex.Lock() 76 defer t.connMutex.Unlock() 77 if t.quicConn != nil { 78 return t.quicConn, nil 79 } 80 transport, addr, err := dialFn(ctx, dialer) 81 if err != nil { 82 return nil, err 83 } 84 var quicConn quic.Connection 85 if t.ReduceRtt { 86 quicConn, err = transport.DialEarly(ctx, addr, t.TlsConfig, t.QuicConfig) 87 } else { 88 quicConn, err = transport.Dial(ctx, addr, t.TlsConfig, t.QuicConfig) 89 } 90 if err != nil { 91 return nil, err 92 } 93 94 common.SetCongestionController(quicConn, t.CongestionController, t.CWND) 95 96 go func() { 97 _ = t.sendAuthentication(quicConn) 98 }() 99 100 if t.udp { 101 go func() { 102 switch t.UdpRelayMode { 103 case common.QUIC: 104 _ = t.handleUniStream(quicConn) 105 default: // native 106 _ = t.handleMessage(quicConn) 107 } 108 }() 109 } 110 111 t.quicConn = quicConn 112 t.openStreams.Store(0) 113 return quicConn, nil 114 } 115 116 func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) { 117 defer func() { 118 t.deferQuicConn(quicConn, err) 119 }() 120 stream, err := quicConn.OpenUniStream() 121 if err != nil { 122 return err 123 } 124 buf := pool.GetBuffer() 125 defer pool.PutBuffer(buf) 126 err = NewAuthenticate(t.Token).WriteTo(buf) 127 if err != nil { 128 return err 129 } 130 _, err = buf.WriteTo(stream) 131 if err != nil { 132 return err 133 } 134 err = stream.Close() 135 if err != nil { 136 return 137 } 138 return nil 139 } 140 141 func (t *clientImpl) handleUniStream(quicConn quic.Connection) (err error) { 142 defer func() { 143 t.deferQuicConn(quicConn, err) 144 }() 145 for { 146 var stream quic.ReceiveStream 147 stream, err = quicConn.AcceptUniStream(context.Background()) 148 if err != nil { 149 return err 150 } 151 go func() (err error) { 152 var assocId uint32 153 defer func() { 154 t.deferQuicConn(quicConn, err) 155 if err != nil && assocId != 0 { 156 if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { 157 if conn, ok := val.(net.Conn); ok { 158 _ = conn.Close() 159 } 160 } 161 } 162 stream.CancelRead(0) 163 }() 164 reader := bufio.NewReader(stream) 165 commandHead, err := ReadCommandHead(reader) 166 if err != nil { 167 return 168 } 169 switch commandHead.TYPE { 170 case PacketType: 171 var packet Packet 172 packet, err = ReadPacketWithHead(commandHead, reader) 173 if err != nil { 174 return 175 } 176 if t.udp && t.UdpRelayMode == common.QUIC { 177 assocId = packet.ASSOC_ID 178 if val, ok := t.udpInputMap.Load(assocId); ok { 179 if conn, ok := val.(net.Conn); ok { 180 writer := bufio.NewWriterSize(conn, packet.BytesLen()) 181 _ = packet.WriteTo(writer) 182 _ = writer.Flush() 183 } 184 } 185 } 186 } 187 return 188 }() 189 } 190 } 191 192 func (t *clientImpl) handleMessage(quicConn quic.Connection) (err error) { 193 defer func() { 194 t.deferQuicConn(quicConn, err) 195 }() 196 for { 197 var message []byte 198 message, err = quicConn.ReceiveDatagram(context.Background()) 199 if err != nil { 200 return err 201 } 202 go func() (err error) { 203 var assocId uint32 204 defer func() { 205 t.deferQuicConn(quicConn, err) 206 if err != nil && assocId != 0 { 207 if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { 208 if conn, ok := val.(net.Conn); ok { 209 _ = conn.Close() 210 } 211 } 212 } 213 }() 214 reader := bytes.NewBuffer(message) 215 commandHead, err := ReadCommandHead(reader) 216 if err != nil { 217 return 218 } 219 switch commandHead.TYPE { 220 case PacketType: 221 var packet Packet 222 packet, err = ReadPacketWithHead(commandHead, reader) 223 if err != nil { 224 return 225 } 226 if t.udp && t.UdpRelayMode == common.NATIVE { 227 assocId = packet.ASSOC_ID 228 if val, ok := t.udpInputMap.Load(assocId); ok { 229 if conn, ok := val.(net.Conn); ok { 230 _, _ = conn.Write(message) 231 } 232 } 233 } 234 } 235 return 236 }() 237 } 238 } 239 240 func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) { 241 var netError net.Error 242 if err != nil && errors.As(err, &netError) { 243 t.forceClose(quicConn, err) 244 } 245 } 246 247 func (t *clientImpl) forceClose(quicConn quic.Connection, err error) { 248 t.connMutex.Lock() 249 defer t.connMutex.Unlock() 250 if quicConn == nil { 251 quicConn = t.quicConn 252 } 253 if quicConn != nil { 254 if quicConn == t.quicConn { 255 t.quicConn = nil 256 } 257 } 258 errStr := "" 259 if err != nil { 260 errStr = err.Error() 261 } 262 if quicConn != nil { 263 _ = quicConn.CloseWithError(ProtocolError, errStr) 264 } 265 udpInputMap := t.udpInputMap 266 udpInputMap.Range(func(key uint32, value net.Conn) bool { 267 conn := value 268 _ = conn.Close() 269 udpInputMap.Delete(key) 270 return true 271 }) 272 } 273 274 func (t *clientImpl) Close() { 275 t.closed.Store(true) 276 if t.openStreams.Load() == 0 { 277 t.forceClose(nil, common.ClientClosed) 278 } 279 } 280 281 func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { 282 quicConn, err := t.getQuicConn(ctx, dialer, dialFn) 283 if err != nil { 284 return nil, err 285 } 286 openStreams := t.openStreams.Add(1) 287 if openStreams >= t.MaxOpenStreams { 288 t.openStreams.Add(-1) 289 return nil, common.TooManyOpenStreams 290 } 291 stream, err := func() (stream net.Conn, err error) { 292 defer func() { 293 t.deferQuicConn(quicConn, err) 294 }() 295 buf := pool.GetBuffer() 296 defer pool.PutBuffer(buf) 297 err = NewConnect(NewAddress(metadata)).WriteTo(buf) 298 if err != nil { 299 return nil, err 300 } 301 quicStream, err := quicConn.OpenStream() 302 if err != nil { 303 return nil, err 304 } 305 stream = common.NewQuicStreamConn( 306 quicStream, 307 quicConn.LocalAddr(), 308 quicConn.RemoteAddr(), 309 func() { 310 time.AfterFunc(C.DefaultTCPTimeout, func() { 311 openStreams := t.openStreams.Add(-1) 312 if openStreams == 0 && t.closed.Load() { 313 t.forceClose(quicConn, common.ClientClosed) 314 } 315 }) 316 }, 317 ) 318 _, err = buf.WriteTo(stream) 319 if err != nil { 320 _ = stream.Close() 321 return nil, err 322 } 323 return stream, err 324 }() 325 if err != nil { 326 return nil, err 327 } 328 329 bufConn := N.NewBufferedConn(stream) 330 response := func() error { 331 if t.RequestTimeout > 0 { 332 _ = bufConn.SetReadDeadline(time.Now().Add(t.RequestTimeout)) 333 } 334 response, err := ReadResponse(bufConn) 335 if err != nil { 336 _ = bufConn.Close() 337 return err 338 } 339 if response.IsFailed() { 340 _ = bufConn.Close() 341 return errors.New("connect failed") 342 } 343 _ = bufConn.SetReadDeadline(time.Time{}) 344 return nil 345 } 346 if t.FastOpen { 347 return N.NewEarlyConn(bufConn, response), nil 348 } 349 err = response() 350 if err != nil { 351 return nil, err 352 } 353 return bufConn, nil 354 } 355 356 func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { 357 quicConn, err := t.getQuicConn(ctx, dialer, dialFn) 358 if err != nil { 359 return nil, err 360 } 361 openStreams := t.openStreams.Add(1) 362 if openStreams >= t.MaxOpenStreams { 363 t.openStreams.Add(-1) 364 return nil, common.TooManyOpenStreams 365 } 366 367 pipe1, pipe2 := N.Pipe() 368 var connId uint32 369 for { 370 connId = fastrand.Uint32() 371 _, loaded := t.udpInputMap.LoadOrStore(connId, pipe1) 372 if !loaded { 373 break 374 } 375 } 376 pc := &quicStreamPacketConn{ 377 connId: connId, 378 quicConn: quicConn, 379 inputConn: N.NewBufferedConn(pipe2), 380 udpRelayMode: t.UdpRelayMode, 381 maxUdpRelayPacketSize: t.MaxUdpRelayPacketSize, 382 deferQuicConnFn: t.deferQuicConn, 383 closeDeferFn: func() { 384 t.udpInputMap.Delete(connId) 385 time.AfterFunc(C.DefaultUDPTimeout, func() { 386 openStreams := t.openStreams.Add(-1) 387 if openStreams == 0 && t.closed.Load() { 388 t.forceClose(quicConn, common.ClientClosed) 389 } 390 }) 391 }, 392 } 393 return pc, nil 394 } 395 396 type Client struct { 397 *clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner 398 } 399 400 func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { 401 conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn) 402 if err != nil { 403 return nil, err 404 } 405 return N.NewRefConn(conn, t), err 406 } 407 408 func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { 409 pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn) 410 if err != nil { 411 return nil, err 412 } 413 return N.NewRefPacketConn(pc, t), nil 414 } 415 416 func (t *Client) forceClose() { 417 t.clientImpl.forceClose(nil, common.ClientClosed) 418 } 419 420 func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client { 421 ci := &clientImpl{ 422 ClientOption: clientOption, 423 udp: udp, 424 dialerRef: dialerRef, 425 udpInputMap: xsync.NewMapOf[uint32, net.Conn](), 426 } 427 c := &Client{ci} 428 runtime.SetFinalizer(c, closeClient) 429 log.Debugln("New TuicV4 Client at %p", c) 430 return c 431 } 432 433 func closeClient(client *Client) { 434 log.Debugln("Close TuicV4 Client at %p", client) 435 client.forceClose() 436 }