github.com/moqsien/xraycore@v1.8.5/transport/internet/quic/dialer.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/quic-go/quic-go"
     9  	"github.com/quic-go/quic-go/logging"
    10  	"github.com/quic-go/quic-go/qlog"
    11  	"github.com/moqsien/xraycore/common"
    12  	"github.com/moqsien/xraycore/common/net"
    13  	"github.com/moqsien/xraycore/common/task"
    14  	"github.com/moqsien/xraycore/transport/internet"
    15  	"github.com/moqsien/xraycore/transport/internet/stat"
    16  	"github.com/moqsien/xraycore/transport/internet/tls"
    17  )
    18  
    19  type connectionContext struct {
    20  	rawConn *sysConn
    21  	conn    quic.Connection
    22  }
    23  
    24  var errConnectionClosed = newError("connection closed")
    25  
    26  func (c *connectionContext) openStream(destAddr net.Addr) (*interConn, error) {
    27  	if !isActive(c.conn) {
    28  		return nil, errConnectionClosed
    29  	}
    30  
    31  	stream, err := c.conn.OpenStream()
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	conn := &interConn{
    37  		stream: stream,
    38  		local:  c.conn.LocalAddr(),
    39  		remote: destAddr,
    40  	}
    41  
    42  	return conn, nil
    43  }
    44  
    45  type clientConnections struct {
    46  	access  sync.Mutex
    47  	conns   map[net.Destination][]*connectionContext
    48  	cleanup *task.Periodic
    49  }
    50  
    51  func isActive(s quic.Connection) bool {
    52  	select {
    53  	case <-s.Context().Done():
    54  		return false
    55  	default:
    56  		return true
    57  	}
    58  }
    59  
    60  func removeInactiveConnections(conns []*connectionContext) []*connectionContext {
    61  	activeConnections := make([]*connectionContext, 0, len(conns))
    62  	for i, s := range conns {
    63  		if isActive(s.conn) {
    64  			activeConnections = append(activeConnections, s)
    65  			continue
    66  		}
    67  
    68  		newError("closing quic connection at index: ", i).WriteToLog()
    69  		if err := s.conn.CloseWithError(0, ""); err != nil {
    70  			newError("failed to close connection").Base(err).WriteToLog()
    71  		}
    72  		if err := s.rawConn.Close(); err != nil {
    73  			newError("failed to close raw connection").Base(err).WriteToLog()
    74  		}
    75  	}
    76  
    77  	if len(activeConnections) < len(conns) {
    78  		newError("active quic connection reduced from ", len(conns), " to ", len(activeConnections)).WriteToLog()
    79  		return activeConnections
    80  	}
    81  
    82  	return conns
    83  }
    84  
    85  func (s *clientConnections) cleanConnections() error {
    86  	s.access.Lock()
    87  	defer s.access.Unlock()
    88  
    89  	if len(s.conns) == 0 {
    90  		return nil
    91  	}
    92  
    93  	newConnMap := make(map[net.Destination][]*connectionContext)
    94  
    95  	for dest, conns := range s.conns {
    96  		conns = removeInactiveConnections(conns)
    97  		if len(conns) > 0 {
    98  			newConnMap[dest] = conns
    99  		}
   100  	}
   101  
   102  	s.conns = newConnMap
   103  	return nil
   104  }
   105  
   106  func (s *clientConnections) openConnection(ctx context.Context, destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (stat.Connection, error) {
   107  	s.access.Lock()
   108  	defer s.access.Unlock()
   109  
   110  	if s.conns == nil {
   111  		s.conns = make(map[net.Destination][]*connectionContext)
   112  	}
   113  
   114  	dest := net.DestinationFromAddr(destAddr)
   115  
   116  	var conns []*connectionContext
   117  	if s, found := s.conns[dest]; found {
   118  		conns = s
   119  	}
   120  
   121  	if len(conns) > 0 {
   122  		s := conns[len(conns)-1]
   123  		if isActive(s.conn) {
   124  			conn, err := s.openStream(destAddr)
   125  			if err == nil {
   126  				return conn, nil
   127  			}
   128  			newError("failed to openStream: ").Base(err).WriteToLog()
   129  		} else {
   130  			newError("current quic connection is not active!").WriteToLog()
   131  		}
   132  	}
   133  
   134  	conns = removeInactiveConnections(conns)
   135  	newError("dialing quic to ", dest).WriteToLog()
   136  	rawConn, err := internet.DialSystem(ctx, dest, sockopt)
   137  	if err != nil {
   138  		return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err)
   139  	}
   140  
   141  	quicConfig := &quic.Config{
   142  		KeepAlivePeriod:      0,
   143  		HandshakeIdleTimeout: time.Second * 8,
   144  		MaxIdleTimeout:       time.Second * 300,
   145  		Tracer: func(ctx context.Context, p logging.Perspective, ci quic.ConnectionID) logging.ConnectionTracer {
   146  			return qlog.NewConnectionTracer(&QlogWriter{connID: ci}, p, ci)
   147  		},
   148  	}
   149  	udpConn, _ := rawConn.(*net.UDPConn)
   150  	if udpConn == nil {
   151  		udpConn = rawConn.(*internet.PacketConnWrapper).Conn.(*net.UDPConn)
   152  	}
   153  	sysConn, err := wrapSysConn(udpConn, config)
   154  	if err != nil {
   155  		rawConn.Close()
   156  		return nil, err
   157  	}
   158  	tr := quic.Transport{
   159  		ConnectionIDLength: 12,
   160  		Conn:               sysConn,
   161  	}
   162  	conn, err := tr.Dial(context.Background(), destAddr, tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
   163  	if err != nil {
   164  		sysConn.Close()
   165  		return nil, err
   166  	}
   167  
   168  	context := &connectionContext{
   169  		conn:    conn,
   170  		rawConn: sysConn,
   171  	}
   172  	s.conns[dest] = append(conns, context)
   173  	return context.openStream(destAddr)
   174  }
   175  
   176  var client clientConnections
   177  
   178  func init() {
   179  	client.conns = make(map[net.Destination][]*connectionContext)
   180  	client.cleanup = &task.Periodic{
   181  		Interval: time.Minute,
   182  		Execute:  client.cleanConnections,
   183  	}
   184  	common.Must(client.cleanup.Start())
   185  }
   186  
   187  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
   188  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
   189  	if tlsConfig == nil {
   190  		tlsConfig = &tls.Config{
   191  			ServerName:    internalDomain,
   192  			AllowInsecure: true,
   193  		}
   194  	}
   195  
   196  	var destAddr *net.UDPAddr
   197  	if dest.Address.Family().IsIP() {
   198  		destAddr = &net.UDPAddr{
   199  			IP:   dest.Address.IP(),
   200  			Port: int(dest.Port),
   201  		}
   202  	} else {
   203  		addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
   204  		if err != nil {
   205  			return nil, err
   206  		}
   207  		destAddr = addr
   208  	}
   209  
   210  	config := streamSettings.ProtocolSettings.(*Config)
   211  
   212  	return client.openConnection(ctx, destAddr, config, tlsConfig, streamSettings.SocketSettings)
   213  }
   214  
   215  func init() {
   216  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   217  }