github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/transport/internet/quic/dialer.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/quic-go/quic-go"
     9  	"github.com/quic-go/quic-go/logging"
    10  	"github.com/quic-go/quic-go/qlog"
    11  	"github.com/xmplusdev/xmcore/common"
    12  	"github.com/xmplusdev/xmcore/common/net"
    13  	"github.com/xmplusdev/xmcore/common/task"
    14  	"github.com/xmplusdev/xmcore/transport/internet"
    15  	"github.com/xmplusdev/xmcore/transport/internet/stat"
    16  	"github.com/xmplusdev/xmcore/transport/internet/tls"
    17  )
    18  
    19  type connectionContext struct {
    20  	rawConn *sysConn
    21  	conn    quic.Connection
    22  }
    23  
    24  var errConnectionClosed = newError("connection closed")
    25  
    26  func (c *connectionContext) openStream(destAddr net.Addr) (*interConn, error) {
    27  	if !isActive(c.conn) {
    28  		return nil, errConnectionClosed
    29  	}
    30  
    31  	stream, err := c.conn.OpenStream()
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	conn := &interConn{
    37  		stream: stream,
    38  		local:  c.conn.LocalAddr(),
    39  		remote: destAddr,
    40  	}
    41  
    42  	return conn, nil
    43  }
    44  
    45  type clientConnections struct {
    46  	access  sync.Mutex
    47  	conns   map[net.Destination][]*connectionContext
    48  	cleanup *task.Periodic
    49  }
    50  
    51  func isActive(s quic.Connection) bool {
    52  	select {
    53  	case <-s.Context().Done():
    54  		return false
    55  	default:
    56  		return true
    57  	}
    58  }
    59  
    60  func removeInactiveConnections(conns []*connectionContext) []*connectionContext {
    61  	activeConnections := make([]*connectionContext, 0, len(conns))
    62  	for i, s := range conns {
    63  		if isActive(s.conn) {
    64  			activeConnections = append(activeConnections, s)
    65  			continue
    66  		}
    67  
    68  		newError("closing quic connection at index: ", i).WriteToLog()
    69  		if err := s.conn.CloseWithError(0, ""); err != nil {
    70  			newError("failed to close connection").Base(err).WriteToLog()
    71  		}
    72  		if err := s.rawConn.Close(); err != nil {
    73  			newError("failed to close raw connection").Base(err).WriteToLog()
    74  		}
    75  	}
    76  
    77  	if len(activeConnections) < len(conns) {
    78  		newError("active quic connection reduced from ", len(conns), " to ", len(activeConnections)).WriteToLog()
    79  		return activeConnections
    80  	}
    81  
    82  	return conns
    83  }
    84  
    85  func (s *clientConnections) cleanConnections() error {
    86  	s.access.Lock()
    87  	defer s.access.Unlock()
    88  
    89  	if len(s.conns) == 0 {
    90  		return nil
    91  	}
    92  
    93  	newConnMap := make(map[net.Destination][]*connectionContext)
    94  
    95  	for dest, conns := range s.conns {
    96  		conns = removeInactiveConnections(conns)
    97  		if len(conns) > 0 {
    98  			newConnMap[dest] = conns
    99  		}
   100  	}
   101  
   102  	s.conns = newConnMap
   103  	return nil
   104  }
   105  
   106  func (s *clientConnections) openConnection(ctx context.Context, destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (stat.Connection, error) {
   107  	s.access.Lock()
   108  	defer s.access.Unlock()
   109  
   110  	if s.conns == nil {
   111  		s.conns = make(map[net.Destination][]*connectionContext)
   112  	}
   113  
   114  	dest := net.DestinationFromAddr(destAddr)
   115  
   116  	var conns []*connectionContext
   117  	if s, found := s.conns[dest]; found {
   118  		conns = s
   119  	}
   120  
   121  	if len(conns) > 0 {
   122  		s := conns[len(conns)-1]
   123  		if isActive(s.conn) {
   124  			conn, err := s.openStream(destAddr)
   125  			if err == nil {
   126  				return conn, nil
   127  			}
   128  			newError("failed to openStream: ").Base(err).WriteToLog()
   129  		} else {
   130  			newError("current quic connection is not active!").WriteToLog()
   131  		}
   132  	}
   133  
   134  	conns = removeInactiveConnections(conns)
   135  	newError("dialing quic to ", dest).WriteToLog()
   136  	rawConn, err := internet.DialSystem(ctx, dest, sockopt)
   137  	if err != nil {
   138  		return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err)
   139  	}
   140  
   141  	quicConfig := &quic.Config{
   142  		KeepAlivePeriod:      0,
   143  		HandshakeIdleTimeout: time.Second * 8,
   144  		MaxIdleTimeout:       time.Second * 300,
   145  		Tracer: func(ctx context.Context, p logging.Perspective, ci quic.ConnectionID) *logging.ConnectionTracer {
   146  			return qlog.NewConnectionTracer(&QlogWriter{connID: ci}, p, ci)
   147  		},
   148  	}
   149  
   150  	var udpConn *net.UDPConn
   151  	switch conn := rawConn.(type) {
   152  	case *net.UDPConn:
   153  		udpConn = conn
   154  	case *internet.PacketConnWrapper:
   155  		udpConn = conn.Conn.(*net.UDPConn)
   156  	default:
   157  		// TODO: Support sockopt for QUIC
   158  		rawConn.Close()
   159  		return nil, newError("QUIC with sockopt is unsupported").AtWarning()
   160  	}
   161  
   162  	sysConn, err := wrapSysConn(udpConn, config)
   163  	if err != nil {
   164  		rawConn.Close()
   165  		return nil, err
   166  	}
   167  	tr := quic.Transport{
   168  		ConnectionIDLength: 12,
   169  		Conn:               sysConn,
   170  	}
   171  	conn, err := tr.Dial(context.Background(), destAddr, tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
   172  	if err != nil {
   173  		sysConn.Close()
   174  		return nil, err
   175  	}
   176  
   177  	context := &connectionContext{
   178  		conn:    conn,
   179  		rawConn: sysConn,
   180  	}
   181  	s.conns[dest] = append(conns, context)
   182  	return context.openStream(destAddr)
   183  }
   184  
   185  var client clientConnections
   186  
   187  func init() {
   188  	client.conns = make(map[net.Destination][]*connectionContext)
   189  	client.cleanup = &task.Periodic{
   190  		Interval: time.Minute,
   191  		Execute:  client.cleanConnections,
   192  	}
   193  	common.Must(client.cleanup.Start())
   194  }
   195  
   196  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
   197  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
   198  	if tlsConfig == nil {
   199  		tlsConfig = &tls.Config{
   200  			ServerName:    internalDomain,
   201  			AllowInsecure: true,
   202  		}
   203  	}
   204  
   205  	var destAddr *net.UDPAddr
   206  	if dest.Address.Family().IsIP() {
   207  		destAddr = &net.UDPAddr{
   208  			IP:   dest.Address.IP(),
   209  			Port: int(dest.Port),
   210  		}
   211  	}  else {
   212  		dialerIp := internet.DestIpAddress()
   213  		if dialerIp != nil {
   214  			destAddr = &net.UDPAddr{
   215  				IP:   dialerIp,
   216  				Port: int(dest.Port),
   217  			}
   218  			newError("quic Dial use dialer dest addr: ", destAddr).WriteToLog()
   219  		} else {
   220  			addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
   221  			if err != nil {
   222  				return nil, err
   223  			}
   224  			destAddr = addr
   225  		}
   226  	}
   227  
   228  	config := streamSettings.ProtocolSettings.(*Config)
   229  
   230  	return client.openConnection(ctx, destAddr, config, tlsConfig, streamSettings.SocketSettings)
   231  }
   232  
   233  func init() {
   234  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   235  }