github.com/xraypb/Xray-core@v1.8.1/transport/internet/quic/dialer.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/quic-go/quic-go"
    10  	"github.com/quic-go/quic-go/logging"
    11  	"github.com/quic-go/quic-go/qlog"
    12  	"github.com/xraypb/Xray-core/common"
    13  	"github.com/xraypb/Xray-core/common/net"
    14  	"github.com/xraypb/Xray-core/common/task"
    15  	"github.com/xraypb/Xray-core/transport/internet"
    16  	"github.com/xraypb/Xray-core/transport/internet/stat"
    17  	"github.com/xraypb/Xray-core/transport/internet/tls"
    18  )
    19  
    20  type connectionContext struct {
    21  	rawConn *sysConn
    22  	conn    quic.Connection
    23  }
    24  
    25  var errConnectionClosed = newError("connection closed")
    26  
    27  func (c *connectionContext) openStream(destAddr net.Addr) (*interConn, error) {
    28  	if !isActive(c.conn) {
    29  		return nil, errConnectionClosed
    30  	}
    31  
    32  	stream, err := c.conn.OpenStream()
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  
    37  	conn := &interConn{
    38  		stream: stream,
    39  		local:  c.conn.LocalAddr(),
    40  		remote: destAddr,
    41  	}
    42  
    43  	return conn, nil
    44  }
    45  
    46  type clientConnections struct {
    47  	access  sync.Mutex
    48  	conns   map[net.Destination][]*connectionContext
    49  	cleanup *task.Periodic
    50  }
    51  
    52  func isActive(s quic.Connection) bool {
    53  	select {
    54  	case <-s.Context().Done():
    55  		return false
    56  	default:
    57  		return true
    58  	}
    59  }
    60  
    61  func removeInactiveConnections(conns []*connectionContext) []*connectionContext {
    62  	activeConnections := make([]*connectionContext, 0, len(conns))
    63  	for i, s := range conns {
    64  		if isActive(s.conn) {
    65  			activeConnections = append(activeConnections, s)
    66  			continue
    67  		}
    68  
    69  		newError("closing quic connection at index: ", i).WriteToLog()
    70  		if err := s.conn.CloseWithError(0, ""); err != nil {
    71  			newError("failed to close connection").Base(err).WriteToLog()
    72  		}
    73  		if err := s.rawConn.Close(); err != nil {
    74  			newError("failed to close raw connection").Base(err).WriteToLog()
    75  		}
    76  	}
    77  
    78  	if len(activeConnections) < len(conns) {
    79  		newError("active quic connection reduced from ", len(conns), " to ", len(activeConnections)).WriteToLog()
    80  		return activeConnections
    81  	}
    82  
    83  	return conns
    84  }
    85  
    86  func (s *clientConnections) cleanConnections() error {
    87  	s.access.Lock()
    88  	defer s.access.Unlock()
    89  
    90  	if len(s.conns) == 0 {
    91  		return nil
    92  	}
    93  
    94  	newConnMap := make(map[net.Destination][]*connectionContext)
    95  
    96  	for dest, conns := range s.conns {
    97  		conns = removeInactiveConnections(conns)
    98  		if len(conns) > 0 {
    99  			newConnMap[dest] = conns
   100  		}
   101  	}
   102  
   103  	s.conns = newConnMap
   104  	return nil
   105  }
   106  
   107  func (s *clientConnections) openConnection(ctx context.Context, destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (stat.Connection, error) {
   108  	s.access.Lock()
   109  	defer s.access.Unlock()
   110  
   111  	if s.conns == nil {
   112  		s.conns = make(map[net.Destination][]*connectionContext)
   113  	}
   114  
   115  	dest := net.DestinationFromAddr(destAddr)
   116  
   117  	var conns []*connectionContext
   118  	if s, found := s.conns[dest]; found {
   119  		conns = s
   120  	}
   121  
   122  	if len(conns) > 0 {
   123  		s := conns[len(conns)-1]
   124  		if isActive(s.conn) {
   125  			conn, err := s.openStream(destAddr)
   126  			if err == nil {
   127  				return conn, nil
   128  			}
   129  			newError("failed to openStream: ").Base(err).WriteToLog()
   130  		} else {
   131  			newError("current quic connection is not active!").WriteToLog()
   132  		}
   133  	}
   134  
   135  	conns = removeInactiveConnections(conns)
   136  	newError("dialing quic to ", dest).WriteToLog()
   137  	rawConn, err := internet.DialSystem(ctx, dest, sockopt)
   138  	if err != nil {
   139  		return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err)
   140  	}
   141  
   142  	quicConfig := &quic.Config{
   143  		ConnectionIDLength:   12,
   144  		KeepAlivePeriod:      0,
   145  		HandshakeIdleTimeout: time.Second * 8,
   146  		MaxIdleTimeout:       time.Second * 300,
   147  		Tracer: qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
   148  			return &QlogWriter{connID: connID}
   149  		}),
   150  	}
   151  
   152  	udpConn, _ := rawConn.(*net.UDPConn)
   153  	if udpConn == nil {
   154  		udpConn = rawConn.(*internet.PacketConnWrapper).Conn.(*net.UDPConn)
   155  	}
   156  	sysConn, err := wrapSysConn(udpConn, config)
   157  	if err != nil {
   158  		rawConn.Close()
   159  		return nil, err
   160  	}
   161  
   162  	conn, err := quic.DialContext(context.Background(), sysConn, destAddr, "", tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
   163  	if err != nil {
   164  		sysConn.Close()
   165  		return nil, err
   166  	}
   167  
   168  	context := &connectionContext{
   169  		conn:    conn,
   170  		rawConn: sysConn,
   171  	}
   172  	s.conns[dest] = append(conns, context)
   173  	return context.openStream(destAddr)
   174  }
   175  
   176  var client clientConnections
   177  
   178  func init() {
   179  	client.conns = make(map[net.Destination][]*connectionContext)
   180  	client.cleanup = &task.Periodic{
   181  		Interval: time.Minute,
   182  		Execute:  client.cleanConnections,
   183  	}
   184  	common.Must(client.cleanup.Start())
   185  }
   186  
   187  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
   188  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
   189  	if tlsConfig == nil {
   190  		tlsConfig = &tls.Config{
   191  			ServerName:    internalDomain,
   192  			AllowInsecure: true,
   193  		}
   194  	}
   195  
   196  	var destAddr *net.UDPAddr
   197  	if dest.Address.Family().IsIP() {
   198  		destAddr = &net.UDPAddr{
   199  			IP:   dest.Address.IP(),
   200  			Port: int(dest.Port),
   201  		}
   202  	} else {
   203  		addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
   204  		if err != nil {
   205  			return nil, err
   206  		}
   207  		destAddr = addr
   208  	}
   209  
   210  	config := streamSettings.ProtocolSettings.(*Config)
   211  
   212  	return client.openConnection(ctx, destAddr, config, tlsConfig, streamSettings.SocketSettings)
   213  }
   214  
   215  func init() {
   216  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   217  }