github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/quic/dialer.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/quic-go/quic-go"
     9  
    10  	"github.com/v2fly/v2ray-core/v5/common"
    11  	"github.com/v2fly/v2ray-core/v5/common/net"
    12  	"github.com/v2fly/v2ray-core/v5/common/task"
    13  	"github.com/v2fly/v2ray-core/v5/transport/internet"
    14  	"github.com/v2fly/v2ray-core/v5/transport/internet/tls"
    15  )
    16  
    17  type connectionContext struct {
    18  	rawConn *sysConn
    19  	conn    quic.Connection
    20  }
    21  
    22  var errConnectionClosed = newError("connection closed")
    23  
    24  func (c *connectionContext) openStream(destAddr net.Addr) (*interConn, error) {
    25  	if !isActive(c.conn) {
    26  		return nil, errConnectionClosed
    27  	}
    28  
    29  	stream, err := c.conn.OpenStream()
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  
    34  	conn := &interConn{
    35  		stream: stream,
    36  		local:  c.conn.LocalAddr(),
    37  		remote: destAddr,
    38  	}
    39  
    40  	return conn, nil
    41  }
    42  
    43  type clientConnections struct {
    44  	access  sync.Mutex
    45  	conns   map[net.Destination][]*connectionContext
    46  	cleanup *task.Periodic
    47  }
    48  
    49  func isActive(s quic.Connection) bool {
    50  	select {
    51  	case <-s.Context().Done():
    52  		return false
    53  	default:
    54  		return true
    55  	}
    56  }
    57  
    58  func removeInactiveConnections(conns []*connectionContext) []*connectionContext {
    59  	activeConnections := make([]*connectionContext, 0, len(conns))
    60  	for _, s := range conns {
    61  		if isActive(s.conn) {
    62  			activeConnections = append(activeConnections, s)
    63  			continue
    64  		}
    65  		if err := s.conn.CloseWithError(0, ""); err != nil {
    66  			newError("failed to close connection").Base(err).WriteToLog()
    67  		}
    68  		if err := s.rawConn.Close(); err != nil {
    69  			newError("failed to close raw connection").Base(err).WriteToLog()
    70  		}
    71  	}
    72  
    73  	if len(activeConnections) < len(conns) {
    74  		return activeConnections
    75  	}
    76  
    77  	return conns
    78  }
    79  
    80  func openStream(conns []*connectionContext, destAddr net.Addr) *interConn {
    81  	for _, s := range conns {
    82  		if !isActive(s.conn) {
    83  			continue
    84  		}
    85  
    86  		conn, err := s.openStream(destAddr)
    87  		if err != nil {
    88  			continue
    89  		}
    90  
    91  		return conn
    92  	}
    93  
    94  	return nil
    95  }
    96  
    97  func (s *clientConnections) cleanConnections() error {
    98  	s.access.Lock()
    99  	defer s.access.Unlock()
   100  
   101  	if len(s.conns) == 0 {
   102  		return nil
   103  	}
   104  
   105  	newConnMap := make(map[net.Destination][]*connectionContext)
   106  
   107  	for dest, conns := range s.conns {
   108  		conns = removeInactiveConnections(conns)
   109  		if len(conns) > 0 {
   110  			newConnMap[dest] = conns
   111  		}
   112  	}
   113  
   114  	s.conns = newConnMap
   115  	return nil
   116  }
   117  
   118  func (s *clientConnections) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
   119  	s.access.Lock()
   120  	defer s.access.Unlock()
   121  
   122  	if s.conns == nil {
   123  		s.conns = make(map[net.Destination][]*connectionContext)
   124  	}
   125  
   126  	dest := net.DestinationFromAddr(destAddr)
   127  
   128  	var conns []*connectionContext
   129  	if s, found := s.conns[dest]; found {
   130  		conns = s
   131  	}
   132  
   133  	{
   134  		conn := openStream(conns, destAddr)
   135  		if conn != nil {
   136  			return conn, nil
   137  		}
   138  	}
   139  
   140  	conns = removeInactiveConnections(conns)
   141  
   142  	newError("dialing QUIC to ", dest).WriteToLog()
   143  
   144  	rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
   145  		IP:   []byte{0, 0, 0, 0},
   146  		Port: 0,
   147  	}, sockopt)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  
   152  	quicConfig := &quic.Config{
   153  		HandshakeIdleTimeout: time.Second * 8,
   154  		MaxIdleTimeout:       time.Second * 30,
   155  		KeepAlivePeriod:      time.Second * 15,
   156  	}
   157  
   158  	sysConn, err := wrapSysConn(rawConn.(*net.UDPConn), config)
   159  	if err != nil {
   160  		rawConn.Close()
   161  		return nil, err
   162  	}
   163  
   164  	tr := quic.Transport{
   165  		Conn:               sysConn,
   166  		ConnectionIDLength: 12,
   167  	}
   168  
   169  	conn, err := tr.Dial(context.Background(), destAddr, tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
   170  	if err != nil {
   171  		sysConn.Close()
   172  		return nil, err
   173  	}
   174  
   175  	context := &connectionContext{
   176  		conn:    conn,
   177  		rawConn: sysConn,
   178  	}
   179  	s.conns[dest] = append(conns, context)
   180  	return context.openStream(destAddr)
   181  }
   182  
   183  var client clientConnections
   184  
   185  func init() {
   186  	client.conns = make(map[net.Destination][]*connectionContext)
   187  	client.cleanup = &task.Periodic{
   188  		Interval: time.Minute,
   189  		Execute:  client.cleanConnections,
   190  	}
   191  	common.Must(client.cleanup.Start())
   192  }
   193  
   194  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
   195  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
   196  	if tlsConfig == nil {
   197  		tlsConfig = &tls.Config{
   198  			ServerName:    internalDomain,
   199  			AllowInsecure: true,
   200  		}
   201  	}
   202  
   203  	var destAddr *net.UDPAddr
   204  	if dest.Address.Family().IsIP() {
   205  		destAddr = &net.UDPAddr{
   206  			IP:   dest.Address.IP(),
   207  			Port: int(dest.Port),
   208  		}
   209  	} else {
   210  		addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
   211  		if err != nil {
   212  			return nil, err
   213  		}
   214  		destAddr = addr
   215  	}
   216  
   217  	config := streamSettings.ProtocolSettings.(*Config)
   218  
   219  	return client.openConnection(destAddr, config, tlsConfig, streamSettings.SocketSettings)
   220  }
   221  
   222  func init() {
   223  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   224  }