github.com/imannamdari/v2ray-core/v5@v5.0.5/transport/internet/quic/dialer.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/imannamdari/quic-go"
     9  
    10  	"github.com/imannamdari/v2ray-core/v5/common"
    11  	"github.com/imannamdari/v2ray-core/v5/common/net"
    12  	"github.com/imannamdari/v2ray-core/v5/common/task"
    13  	"github.com/imannamdari/v2ray-core/v5/transport/internet"
    14  	"github.com/imannamdari/v2ray-core/v5/transport/internet/tls"
    15  )
    16  
    17  type connectionContext struct {
    18  	rawConn *sysConn
    19  	conn    quic.Connection
    20  }
    21  
    22  var errConnectionClosed = newError("connection closed")
    23  
    24  func (c *connectionContext) openStream(destAddr net.Addr) (*interConn, error) {
    25  	if !isActive(c.conn) {
    26  		return nil, errConnectionClosed
    27  	}
    28  
    29  	stream, err := c.conn.OpenStream()
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  
    34  	conn := &interConn{
    35  		stream: stream,
    36  		local:  c.conn.LocalAddr(),
    37  		remote: destAddr,
    38  	}
    39  
    40  	return conn, nil
    41  }
    42  
    43  type clientConnections struct {
    44  	access  sync.Mutex
    45  	conns   map[net.Destination][]*connectionContext
    46  	cleanup *task.Periodic
    47  }
    48  
    49  func isActive(s quic.Connection) bool {
    50  	select {
    51  	case <-s.Context().Done():
    52  		return false
    53  	default:
    54  		return true
    55  	}
    56  }
    57  
    58  func removeInactiveConnections(conns []*connectionContext) []*connectionContext {
    59  	activeConnections := make([]*connectionContext, 0, len(conns))
    60  	for _, s := range conns {
    61  		if isActive(s.conn) {
    62  			activeConnections = append(activeConnections, s)
    63  			continue
    64  		}
    65  		if err := s.conn.CloseWithError(0, ""); err != nil {
    66  			newError("failed to close connection").Base(err).WriteToLog()
    67  		}
    68  		if err := s.rawConn.Close(); err != nil {
    69  			newError("failed to close raw connection").Base(err).WriteToLog()
    70  		}
    71  	}
    72  
    73  	if len(activeConnections) < len(conns) {
    74  		return activeConnections
    75  	}
    76  
    77  	return conns
    78  }
    79  
    80  func openStream(conns []*connectionContext, destAddr net.Addr) *interConn {
    81  	for _, s := range conns {
    82  		if !isActive(s.conn) {
    83  			continue
    84  		}
    85  
    86  		conn, err := s.openStream(destAddr)
    87  		if err != nil {
    88  			continue
    89  		}
    90  
    91  		return conn
    92  	}
    93  
    94  	return nil
    95  }
    96  
    97  func (s *clientConnections) cleanConnections() error {
    98  	s.access.Lock()
    99  	defer s.access.Unlock()
   100  
   101  	if len(s.conns) == 0 {
   102  		return nil
   103  	}
   104  
   105  	newConnMap := make(map[net.Destination][]*connectionContext)
   106  
   107  	for dest, conns := range s.conns {
   108  		conns = removeInactiveConnections(conns)
   109  		if len(conns) > 0 {
   110  			newConnMap[dest] = conns
   111  		}
   112  	}
   113  
   114  	s.conns = newConnMap
   115  	return nil
   116  }
   117  
   118  func (s *clientConnections) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
   119  	s.access.Lock()
   120  	defer s.access.Unlock()
   121  
   122  	if s.conns == nil {
   123  		s.conns = make(map[net.Destination][]*connectionContext)
   124  	}
   125  
   126  	dest := net.DestinationFromAddr(destAddr)
   127  
   128  	var conns []*connectionContext
   129  	if s, found := s.conns[dest]; found {
   130  		conns = s
   131  	}
   132  
   133  	{
   134  		conn := openStream(conns, destAddr)
   135  		if conn != nil {
   136  			return conn, nil
   137  		}
   138  	}
   139  
   140  	conns = removeInactiveConnections(conns)
   141  
   142  	newError("dialing QUIC to ", dest).WriteToLog()
   143  
   144  	rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
   145  		IP:   []byte{0, 0, 0, 0},
   146  		Port: 0,
   147  	}, sockopt)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  
   152  	quicConfig := &quic.Config{
   153  		ConnectionIDLength:   12,
   154  		HandshakeIdleTimeout: time.Second * 8,
   155  		MaxIdleTimeout:       time.Second * 30,
   156  		KeepAlivePeriod:      time.Second * 15,
   157  	}
   158  
   159  	sysConn, err := wrapSysConn(rawConn.(*net.UDPConn), config)
   160  	if err != nil {
   161  		rawConn.Close()
   162  		return nil, err
   163  	}
   164  
   165  	conn, err := quic.DialContext(context.Background(), sysConn, destAddr, "", tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
   166  	if err != nil {
   167  		sysConn.Close()
   168  		return nil, err
   169  	}
   170  
   171  	context := &connectionContext{
   172  		conn:    conn,
   173  		rawConn: sysConn,
   174  	}
   175  	s.conns[dest] = append(conns, context)
   176  	return context.openStream(destAddr)
   177  }
   178  
   179  var client clientConnections
   180  
   181  func init() {
   182  	client.conns = make(map[net.Destination][]*connectionContext)
   183  	client.cleanup = &task.Periodic{
   184  		Interval: time.Minute,
   185  		Execute:  client.cleanConnections,
   186  	}
   187  	common.Must(client.cleanup.Start())
   188  }
   189  
   190  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
   191  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
   192  	if tlsConfig == nil {
   193  		tlsConfig = &tls.Config{
   194  			ServerName:    internalDomain,
   195  			AllowInsecure: true,
   196  		}
   197  	}
   198  
   199  	var destAddr *net.UDPAddr
   200  	if dest.Address.Family().IsIP() {
   201  		destAddr = &net.UDPAddr{
   202  			IP:   dest.Address.IP(),
   203  			Port: int(dest.Port),
   204  		}
   205  	} else {
   206  		addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
   207  		if err != nil {
   208  			return nil, err
   209  		}
   210  		destAddr = addr
   211  	}
   212  
   213  	config := streamSettings.ProtocolSettings.(*Config)
   214  
   215  	return client.openConnection(destAddr, config, tlsConfig, streamSettings.SocketSettings)
   216  }
   217  
   218  func init() {
   219  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   220  }