github.com/v2fly/v2ray-core/v4@v4.45.2/transport/internet/quic/dialer.go (about)

     1  //go:build !confonly
     2  // +build !confonly
     3  
     4  package quic
     5  
     6  import (
     7  	"context"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/lucas-clemente/quic-go"
    12  
    13  	"github.com/v2fly/v2ray-core/v4/common"
    14  	"github.com/v2fly/v2ray-core/v4/common/net"
    15  	"github.com/v2fly/v2ray-core/v4/common/task"
    16  	"github.com/v2fly/v2ray-core/v4/transport/internet"
    17  	"github.com/v2fly/v2ray-core/v4/transport/internet/tls"
    18  )
    19  
    20  type connectionContext struct {
    21  	rawConn *sysConn
    22  	conn    quic.Connection
    23  }
    24  
    25  var errConnectionClosed = newError("connection closed")
    26  
    27  func (c *connectionContext) openStream(destAddr net.Addr) (*interConn, error) {
    28  	if !isActive(c.conn) {
    29  		return nil, errConnectionClosed
    30  	}
    31  
    32  	stream, err := c.conn.OpenStream()
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  
    37  	conn := &interConn{
    38  		stream: stream,
    39  		local:  c.conn.LocalAddr(),
    40  		remote: destAddr,
    41  	}
    42  
    43  	return conn, nil
    44  }
    45  
    46  type clientConnections struct {
    47  	access  sync.Mutex
    48  	conns   map[net.Destination][]*connectionContext
    49  	cleanup *task.Periodic
    50  }
    51  
    52  func isActive(s quic.Connection) bool {
    53  	select {
    54  	case <-s.Context().Done():
    55  		return false
    56  	default:
    57  		return true
    58  	}
    59  }
    60  
    61  func removeInactiveConnections(conns []*connectionContext) []*connectionContext {
    62  	activeConnections := make([]*connectionContext, 0, len(conns))
    63  	for _, s := range conns {
    64  		if isActive(s.conn) {
    65  			activeConnections = append(activeConnections, s)
    66  			continue
    67  		}
    68  		if err := s.conn.CloseWithError(0, ""); err != nil {
    69  			newError("failed to close connection").Base(err).WriteToLog()
    70  		}
    71  		if err := s.rawConn.Close(); err != nil {
    72  			newError("failed to close raw connection").Base(err).WriteToLog()
    73  		}
    74  	}
    75  
    76  	if len(activeConnections) < len(conns) {
    77  		return activeConnections
    78  	}
    79  
    80  	return conns
    81  }
    82  
    83  func openStream(conns []*connectionContext, destAddr net.Addr) *interConn {
    84  	for _, s := range conns {
    85  		if !isActive(s.conn) {
    86  			continue
    87  		}
    88  
    89  		conn, err := s.openStream(destAddr)
    90  		if err != nil {
    91  			continue
    92  		}
    93  
    94  		return conn
    95  	}
    96  
    97  	return nil
    98  }
    99  
   100  func (s *clientConnections) cleanConnections() error {
   101  	s.access.Lock()
   102  	defer s.access.Unlock()
   103  
   104  	if len(s.conns) == 0 {
   105  		return nil
   106  	}
   107  
   108  	newConnMap := make(map[net.Destination][]*connectionContext)
   109  
   110  	for dest, conns := range s.conns {
   111  		conns = removeInactiveConnections(conns)
   112  		if len(conns) > 0 {
   113  			newConnMap[dest] = conns
   114  		}
   115  	}
   116  
   117  	s.conns = newConnMap
   118  	return nil
   119  }
   120  
   121  func (s *clientConnections) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
   122  	s.access.Lock()
   123  	defer s.access.Unlock()
   124  
   125  	if s.conns == nil {
   126  		s.conns = make(map[net.Destination][]*connectionContext)
   127  	}
   128  
   129  	dest := net.DestinationFromAddr(destAddr)
   130  
   131  	var conns []*connectionContext
   132  	if s, found := s.conns[dest]; found {
   133  		conns = s
   134  	}
   135  
   136  	{
   137  		conn := openStream(conns, destAddr)
   138  		if conn != nil {
   139  			return conn, nil
   140  		}
   141  	}
   142  
   143  	conns = removeInactiveConnections(conns)
   144  
   145  	rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
   146  		IP:   []byte{0, 0, 0, 0},
   147  		Port: 0,
   148  	}, sockopt)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	quicConfig := &quic.Config{
   154  		ConnectionIDLength:   12,
   155  		HandshakeIdleTimeout: time.Second * 8,
   156  		MaxIdleTimeout:       time.Second * 30,
   157  		KeepAlive:            true,
   158  	}
   159  
   160  	sysConn, err := wrapSysConn(rawConn.(*net.UDPConn), config)
   161  	if err != nil {
   162  		rawConn.Close()
   163  		return nil, err
   164  	}
   165  
   166  	conn, err := quic.DialContext(context.Background(), sysConn, destAddr, "", tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
   167  	if err != nil {
   168  		sysConn.Close()
   169  		return nil, err
   170  	}
   171  
   172  	context := &connectionContext{
   173  		conn:    conn,
   174  		rawConn: sysConn,
   175  	}
   176  	s.conns[dest] = append(conns, context)
   177  	return context.openStream(destAddr)
   178  }
   179  
   180  var client clientConnections
   181  
   182  func init() {
   183  	client.conns = make(map[net.Destination][]*connectionContext)
   184  	client.cleanup = &task.Periodic{
   185  		Interval: time.Minute,
   186  		Execute:  client.cleanConnections,
   187  	}
   188  	common.Must(client.cleanup.Start())
   189  }
   190  
   191  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
   192  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
   193  	if tlsConfig == nil {
   194  		tlsConfig = &tls.Config{
   195  			ServerName:    internalDomain,
   196  			AllowInsecure: true,
   197  		}
   198  	}
   199  
   200  	var destAddr *net.UDPAddr
   201  	if dest.Address.Family().IsIP() {
   202  		destAddr = &net.UDPAddr{
   203  			IP:   dest.Address.IP(),
   204  			Port: int(dest.Port),
   205  		}
   206  	} else {
   207  		addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
   208  		if err != nil {
   209  			return nil, err
   210  		}
   211  		destAddr = addr
   212  	}
   213  
   214  	config := streamSettings.ProtocolSettings.(*Config)
   215  
   216  	return client.openConnection(destAddr, config, tlsConfig, streamSettings.SocketSettings)
   217  }
   218  
   219  func init() {
   220  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   221  }