github.com/eagleql/xray-core@v1.4.4/transport/internet/quic/dialer.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/eagleql/xray-core/common"
     9  	"github.com/eagleql/xray-core/common/net"
    10  	"github.com/eagleql/xray-core/common/task"
    11  	"github.com/eagleql/xray-core/transport/internet"
    12  	"github.com/eagleql/xray-core/transport/internet/tls"
    13  	"github.com/lucas-clemente/quic-go"
    14  )
    15  
    16  type sessionContext struct {
    17  	rawConn *sysConn
    18  	session quic.Session
    19  }
    20  
    21  var errSessionClosed = newError("session closed")
    22  
    23  func (c *sessionContext) openStream(destAddr net.Addr) (*interConn, error) {
    24  	if !isActive(c.session) {
    25  		return nil, errSessionClosed
    26  	}
    27  
    28  	stream, err := c.session.OpenStream()
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  
    33  	conn := &interConn{
    34  		stream: stream,
    35  		local:  c.session.LocalAddr(),
    36  		remote: destAddr,
    37  	}
    38  
    39  	return conn, nil
    40  }
    41  
    42  type clientSessions struct {
    43  	access   sync.Mutex
    44  	sessions map[net.Destination][]*sessionContext
    45  	cleanup  *task.Periodic
    46  }
    47  
    48  func isActive(s quic.Session) bool {
    49  	select {
    50  	case <-s.Context().Done():
    51  		return false
    52  	default:
    53  		return true
    54  	}
    55  }
    56  
    57  func removeInactiveSessions(sessions []*sessionContext) []*sessionContext {
    58  	activeSessions := make([]*sessionContext, 0, len(sessions))
    59  	for _, s := range sessions {
    60  		if isActive(s.session) {
    61  			activeSessions = append(activeSessions, s)
    62  			continue
    63  		}
    64  		if err := s.session.CloseWithError(0, ""); err != nil {
    65  			newError("failed to close session").Base(err).WriteToLog()
    66  		}
    67  		if err := s.rawConn.Close(); err != nil {
    68  			newError("failed to close raw connection").Base(err).WriteToLog()
    69  		}
    70  	}
    71  
    72  	if len(activeSessions) < len(sessions) {
    73  		return activeSessions
    74  	}
    75  
    76  	return sessions
    77  }
    78  
    79  func openStream(sessions []*sessionContext, destAddr net.Addr) *interConn {
    80  	for _, s := range sessions {
    81  		if !isActive(s.session) {
    82  			continue
    83  		}
    84  
    85  		conn, err := s.openStream(destAddr)
    86  		if err != nil {
    87  			continue
    88  		}
    89  
    90  		return conn
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  func (s *clientSessions) cleanSessions() error {
    97  	s.access.Lock()
    98  	defer s.access.Unlock()
    99  
   100  	if len(s.sessions) == 0 {
   101  		return nil
   102  	}
   103  
   104  	newSessionMap := make(map[net.Destination][]*sessionContext)
   105  
   106  	for dest, sessions := range s.sessions {
   107  		sessions = removeInactiveSessions(sessions)
   108  		if len(sessions) > 0 {
   109  			newSessionMap[dest] = sessions
   110  		}
   111  	}
   112  
   113  	s.sessions = newSessionMap
   114  	return nil
   115  }
   116  
   117  func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
   118  	s.access.Lock()
   119  	defer s.access.Unlock()
   120  
   121  	if s.sessions == nil {
   122  		s.sessions = make(map[net.Destination][]*sessionContext)
   123  	}
   124  
   125  	dest := net.DestinationFromAddr(destAddr)
   126  
   127  	var sessions []*sessionContext
   128  	if s, found := s.sessions[dest]; found {
   129  		sessions = s
   130  	}
   131  
   132  	if true {
   133  		conn := openStream(sessions, destAddr)
   134  		if conn != nil {
   135  			return conn, nil
   136  		}
   137  	}
   138  
   139  	sessions = removeInactiveSessions(sessions)
   140  
   141  	rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
   142  		IP:   []byte{0, 0, 0, 0},
   143  		Port: 0,
   144  	}, sockopt)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	quicConfig := &quic.Config{
   150  		ConnectionIDLength: 12,
   151  		KeepAlive:          true,
   152  	}
   153  
   154  	conn, err := wrapSysConn(rawConn, config)
   155  	if err != nil {
   156  		rawConn.Close()
   157  		return nil, err
   158  	}
   159  
   160  	session, err := quic.DialContext(context.Background(), conn, destAddr, "", tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
   161  	if err != nil {
   162  		conn.Close()
   163  		return nil, err
   164  	}
   165  
   166  	context := &sessionContext{
   167  		session: session,
   168  		rawConn: conn,
   169  	}
   170  	s.sessions[dest] = append(sessions, context)
   171  	return context.openStream(destAddr)
   172  }
   173  
   174  var client clientSessions
   175  
   176  func init() {
   177  	client.sessions = make(map[net.Destination][]*sessionContext)
   178  	client.cleanup = &task.Periodic{
   179  		Interval: time.Minute,
   180  		Execute:  client.cleanSessions,
   181  	}
   182  	common.Must(client.cleanup.Start())
   183  }
   184  
   185  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
   186  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
   187  	if tlsConfig == nil {
   188  		tlsConfig = &tls.Config{
   189  			ServerName:    internalDomain,
   190  			AllowInsecure: true,
   191  		}
   192  	}
   193  
   194  	var destAddr *net.UDPAddr
   195  	if dest.Address.Family().IsIP() {
   196  		destAddr = &net.UDPAddr{
   197  			IP:   dest.Address.IP(),
   198  			Port: int(dest.Port),
   199  		}
   200  	} else {
   201  		addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
   202  		if err != nil {
   203  			return nil, err
   204  		}
   205  		destAddr = addr
   206  	}
   207  
   208  	config := streamSettings.ProtocolSettings.(*Config)
   209  
   210  	return client.openConnection(destAddr, config, tlsConfig, streamSettings.SocketSettings)
   211  }
   212  
   213  func init() {
   214  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   215  }