github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/internet/quic/dialer.go (about) 1 package quic 2 3 import ( 4 "context" 5 "sync" 6 "time" 7 8 "github.com/quic-go/quic-go" 9 "github.com/quic-go/quic-go/logging" 10 "github.com/quic-go/quic-go/qlog" 11 "github.com/xtls/xray-core/common" 12 "github.com/xtls/xray-core/common/net" 13 "github.com/xtls/xray-core/common/task" 14 "github.com/xtls/xray-core/transport/internet" 15 "github.com/xtls/xray-core/transport/internet/stat" 16 "github.com/xtls/xray-core/transport/internet/tls" 17 ) 18 19 type connectionContext struct { 20 rawConn *sysConn 21 conn quic.Connection 22 } 23 24 var errConnectionClosed = newError("connection closed") 25 26 func (c *connectionContext) openStream(destAddr net.Addr) (*interConn, error) { 27 if !isActive(c.conn) { 28 return nil, errConnectionClosed 29 } 30 31 stream, err := c.conn.OpenStream() 32 if err != nil { 33 return nil, err 34 } 35 36 conn := &interConn{ 37 stream: stream, 38 local: c.conn.LocalAddr(), 39 remote: destAddr, 40 } 41 42 return conn, nil 43 } 44 45 type clientConnections struct { 46 access sync.Mutex 47 conns map[net.Destination][]*connectionContext 48 cleanup *task.Periodic 49 } 50 51 func isActive(s quic.Connection) bool { 52 select { 53 case <-s.Context().Done(): 54 return false 55 default: 56 return true 57 } 58 } 59 60 func removeInactiveConnections(conns []*connectionContext) []*connectionContext { 61 activeConnections := make([]*connectionContext, 0, len(conns)) 62 for i, s := range conns { 63 if isActive(s.conn) { 64 activeConnections = append(activeConnections, s) 65 continue 66 } 67 68 newError("closing quic connection at index: ", i).WriteToLog() 69 if err := s.conn.CloseWithError(0, ""); err != nil { 70 newError("failed to close connection").Base(err).WriteToLog() 71 } 72 if err := s.rawConn.Close(); err != nil { 73 newError("failed to close raw connection").Base(err).WriteToLog() 74 } 75 } 76 77 if len(activeConnections) < len(conns) { 78 newError("active quic connection reduced from ", len(conns), " to ", len(activeConnections)).WriteToLog() 79 return activeConnections 80 } 81 82 return conns 83 } 84 85 func (s *clientConnections) cleanConnections() error { 86 s.access.Lock() 87 defer s.access.Unlock() 88 89 if len(s.conns) == 0 { 90 return nil 91 } 92 93 newConnMap := make(map[net.Destination][]*connectionContext) 94 95 for dest, conns := range s.conns { 96 conns = removeInactiveConnections(conns) 97 if len(conns) > 0 { 98 newConnMap[dest] = conns 99 } 100 } 101 102 s.conns = newConnMap 103 return nil 104 } 105 106 func (s *clientConnections) openConnection(ctx context.Context, destAddr net.Addr, config *Config, tlsConfig *tls.Config, sockopt *internet.SocketConfig) (stat.Connection, error) { 107 s.access.Lock() 108 defer s.access.Unlock() 109 110 if s.conns == nil { 111 s.conns = make(map[net.Destination][]*connectionContext) 112 } 113 114 dest := net.DestinationFromAddr(destAddr) 115 116 var conns []*connectionContext 117 if s, found := s.conns[dest]; found { 118 conns = s 119 } 120 121 if len(conns) > 0 { 122 s := conns[len(conns)-1] 123 if isActive(s.conn) { 124 conn, err := s.openStream(destAddr) 125 if err == nil { 126 return conn, nil 127 } 128 newError("failed to openStream: ").Base(err).WriteToLog() 129 } else { 130 newError("current quic connection is not active!").WriteToLog() 131 } 132 } 133 134 conns = removeInactiveConnections(conns) 135 newError("dialing quic to ", dest).WriteToLog() 136 rawConn, err := internet.DialSystem(ctx, dest, sockopt) 137 if err != nil { 138 return nil, newError("failed to dial to dest: ", err).AtWarning().Base(err) 139 } 140 141 quicConfig := &quic.Config{ 142 KeepAlivePeriod: 0, 143 HandshakeIdleTimeout: time.Second * 8, 144 MaxIdleTimeout: time.Second * 300, 145 Tracer: func(ctx context.Context, p logging.Perspective, ci quic.ConnectionID) *logging.ConnectionTracer { 146 return qlog.NewConnectionTracer(&QlogWriter{connID: ci}, p, ci) 147 }, 148 } 149 150 var udpConn *net.UDPConn 151 switch conn := rawConn.(type) { 152 case *net.UDPConn: 153 udpConn = conn 154 case *internet.PacketConnWrapper: 155 udpConn = conn.Conn.(*net.UDPConn) 156 default: 157 // TODO: Support sockopt for QUIC 158 rawConn.Close() 159 return nil, newError("QUIC with sockopt is unsupported").AtWarning() 160 } 161 162 sysConn, err := wrapSysConn(udpConn, config) 163 if err != nil { 164 rawConn.Close() 165 return nil, err 166 } 167 tr := quic.Transport{ 168 ConnectionIDLength: 12, 169 Conn: sysConn, 170 } 171 conn, err := tr.Dial(context.Background(), destAddr, tlsConfig.GetTLSConfig(tls.WithDestination(dest)), quicConfig) 172 if err != nil { 173 sysConn.Close() 174 return nil, err 175 } 176 177 context := &connectionContext{ 178 conn: conn, 179 rawConn: sysConn, 180 } 181 s.conns[dest] = append(conns, context) 182 return context.openStream(destAddr) 183 } 184 185 var client clientConnections 186 187 func init() { 188 client.conns = make(map[net.Destination][]*connectionContext) 189 client.cleanup = &task.Periodic{ 190 Interval: time.Minute, 191 Execute: client.cleanConnections, 192 } 193 common.Must(client.cleanup.Start()) 194 } 195 196 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { 197 tlsConfig := tls.ConfigFromStreamSettings(streamSettings) 198 if tlsConfig == nil { 199 tlsConfig = &tls.Config{ 200 ServerName: internalDomain, 201 AllowInsecure: true, 202 } 203 } 204 205 var destAddr *net.UDPAddr 206 if dest.Address.Family().IsIP() { 207 destAddr = &net.UDPAddr{ 208 IP: dest.Address.IP(), 209 Port: int(dest.Port), 210 } 211 } else { 212 dialerIp := internet.DestIpAddress() 213 if dialerIp != nil { 214 destAddr = &net.UDPAddr{ 215 IP: dialerIp, 216 Port: int(dest.Port), 217 } 218 newError("quic Dial use dialer dest addr: ", destAddr).WriteToLog() 219 } else { 220 addr, err := net.ResolveUDPAddr("udp", dest.NetAddr()) 221 if err != nil { 222 return nil, err 223 } 224 destAddr = addr 225 } 226 } 227 228 config := streamSettings.ProtocolSettings.(*Config) 229 230 return client.openConnection(ctx, destAddr, config, tlsConfig, streamSettings.SocketSettings) 231 } 232 233 func init() { 234 common.Must(internet.RegisterTransportDialer(protocolName, Dial)) 235 }