github.com/Uhtred009/v2ray-core-1@v4.31.2+incompatible/transport/internet/quic/dialer.go (about)

     1  // +build !confonly
     2  
     3  package quic
     4  
     5  import (
     6  	"context"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/lucas-clemente/quic-go"
    11  	"v2ray.com/core/common"
    12  	"v2ray.com/core/common/net"
    13  	"v2ray.com/core/common/task"
    14  	"v2ray.com/core/transport/internet"
    15  	"v2ray.com/core/transport/internet/tls"
    16  )
    17  
    18  type sessionContext struct {
    19  	rawConn *sysConn
    20  	session quic.Session
    21  }
    22  
    23  var errSessionClosed = newError("session closed")
    24  
    25  func (c *sessionContext) openStream(destAddr net.Addr) (*interConn, error) {
    26  	if !isActive(c.session) {
    27  		return nil, errSessionClosed
    28  	}
    29  
    30  	stream, err := c.session.OpenStream()
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  
    35  	conn := &interConn{
    36  		stream: stream,
    37  		local:  c.session.LocalAddr(),
    38  		remote: destAddr,
    39  	}
    40  
    41  	return conn, nil
    42  }
    43  
    44  type clientSessions struct {
    45  	access   sync.Mutex
    46  	sessions map[net.Destination][]*sessionContext
    47  	cleanup  *task.Periodic
    48  }
    49  
    50  func isActive(s quic.Session) bool {
    51  	select {
    52  	case <-s.Context().Done():
    53  		return false
    54  	default:
    55  		return true
    56  	}
    57  }
    58  
    59  func removeInactiveSessions(sessions []*sessionContext) []*sessionContext {
    60  	activeSessions := make([]*sessionContext, 0, len(sessions))
    61  	for _, s := range sessions {
    62  		if isActive(s.session) {
    63  			activeSessions = append(activeSessions, s)
    64  			continue
    65  		}
    66  		if err := s.session.CloseWithError(0, ""); err != nil {
    67  			newError("failed to close session").Base(err).WriteToLog()
    68  		}
    69  		if err := s.rawConn.Close(); err != nil {
    70  			newError("failed to close raw connection").Base(err).WriteToLog()
    71  		}
    72  	}
    73  
    74  	if len(activeSessions) < len(sessions) {
    75  		return activeSessions
    76  	}
    77  
    78  	return sessions
    79  }
    80  
    81  func openStream(sessions []*sessionContext, destAddr net.Addr) *interConn {
    82  	for _, s := range sessions {
    83  		if !isActive(s.session) {
    84  			continue
    85  		}
    86  
    87  		conn, err := s.openStream(destAddr)
    88  		if err != nil {
    89  			continue
    90  		}
    91  
    92  		return conn
    93  	}
    94  
    95  	return nil
    96  }
    97  
    98  func (s *clientSessions) cleanSessions() error {
    99  	s.access.Lock()
   100  	defer s.access.Unlock()
   101  
   102  	if len(s.sessions) == 0 {
   103  		return nil
   104  	}
   105  
   106  	newSessionMap := make(map[net.Destination][]*sessionContext)
   107  
   108  	for dest, sessions := range s.sessions {
   109  		sessions = removeInactiveSessions(sessions)
   110  		if len(sessions) > 0 {
   111  			newSessionMap[dest] = sessions
   112  		}
   113  	}
   114  
   115  	s.sessions = newSessionMap
   116  	return nil
   117  }
   118  
   119  func (s *clientSessions) openConnection(destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (internet.Connection, error) {
   120  	s.access.Lock()
   121  	defer s.access.Unlock()
   122  
   123  	if s.sessions == nil {
   124  		s.sessions = make(map[net.Destination][]*sessionContext)
   125  	}
   126  
   127  	dest := net.DestinationFromAddr(destAddr)
   128  
   129  	var sessions []*sessionContext
   130  	if s, found := s.sessions[dest]; found {
   131  		sessions = s
   132  	}
   133  
   134  	if true {
   135  		conn := openStream(sessions, destAddr)
   136  		if conn != nil {
   137  			return conn, nil
   138  		}
   139  	}
   140  
   141  	sessions = removeInactiveSessions(sessions)
   142  
   143  	rawConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
   144  		IP:   []byte{0, 0, 0, 0},
   145  		Port: 0,
   146  	}, sockopt)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  
   151  	quicConfig := &quic.Config{
   152  		ConnectionIDLength: 12,
   153  		HandshakeTimeout:   time.Second * 8,
   154  		MaxIdleTimeout:     time.Second * 30,
   155  	}
   156  
   157  	conn, err := wrapSysConn(rawConn, config)
   158  	if err != nil {
   159  		rawConn.Close()
   160  		return nil, err
   161  	}
   162  
   163  	session, err := quic.DialContext(context.Background(), conn, destAddr, "", tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig)
   164  	if err != nil {
   165  		conn.Close()
   166  		return nil, err
   167  	}
   168  
   169  	context := &sessionContext{
   170  		session: session,
   171  		rawConn: conn,
   172  	}
   173  	s.sessions[dest] = append(sessions, context)
   174  	return context.openStream(destAddr)
   175  }
   176  
   177  var client clientSessions
   178  
   179  func init() {
   180  	client.sessions = make(map[net.Destination][]*sessionContext)
   181  	client.cleanup = &task.Periodic{
   182  		Interval: time.Minute,
   183  		Execute:  client.cleanSessions,
   184  	}
   185  	common.Must(client.cleanup.Start())
   186  }
   187  
   188  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
   189  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
   190  	if tlsConfig == nil {
   191  		tlsConfig = &tls.Config{
   192  			ServerName:    internalDomain,
   193  			AllowInsecure: true,
   194  		}
   195  	}
   196  
   197  	var destAddr *net.UDPAddr
   198  	if dest.Address.Family().IsIP() {
   199  		destAddr = &net.UDPAddr{
   200  			IP:   dest.Address.IP(),
   201  			Port: int(dest.Port),
   202  		}
   203  	} else {
   204  		addr, err := net.ResolveUDPAddr("udp", dest.NetAddr())
   205  		if err != nil {
   206  			return nil, err
   207  		}
   208  		destAddr = addr
   209  	}
   210  
   211  	config := streamSettings.ProtocolSettings.(*Config)
   212  
   213  	return client.openConnection(destAddr, config, tlsConfig, streamSettings.SocketSettings)
   214  }
   215  
   216  func init() {
   217  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   218  }