github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/common/mux/client.go (about) 1 package mux 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/xtls/xray-core/common" 10 "github.com/xtls/xray-core/common/buf" 11 "github.com/xtls/xray-core/common/errors" 12 "github.com/xtls/xray-core/common/net" 13 "github.com/xtls/xray-core/common/protocol" 14 "github.com/xtls/xray-core/common/session" 15 "github.com/xtls/xray-core/common/signal/done" 16 "github.com/xtls/xray-core/common/task" 17 "github.com/xtls/xray-core/common/xudp" 18 "github.com/xtls/xray-core/proxy" 19 "github.com/xtls/xray-core/transport" 20 "github.com/xtls/xray-core/transport/internet" 21 "github.com/xtls/xray-core/transport/pipe" 22 ) 23 24 type ClientManager struct { 25 Enabled bool // wheather mux is enabled from user config 26 Picker WorkerPicker 27 } 28 29 func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error { 30 for i := 0; i < 16; i++ { 31 worker, err := m.Picker.PickAvailable() 32 if err != nil { 33 return err 34 } 35 if worker.Dispatch(ctx, link) { 36 return nil 37 } 38 } 39 40 return newError("unable to find an available mux client").AtWarning() 41 } 42 43 type WorkerPicker interface { 44 PickAvailable() (*ClientWorker, error) 45 } 46 47 type IncrementalWorkerPicker struct { 48 Factory ClientWorkerFactory 49 50 access sync.Mutex 51 workers []*ClientWorker 52 cleanupTask *task.Periodic 53 } 54 55 func (p *IncrementalWorkerPicker) cleanupFunc() error { 56 p.access.Lock() 57 defer p.access.Unlock() 58 59 if len(p.workers) == 0 { 60 return newError("no worker") 61 } 62 63 p.cleanup() 64 return nil 65 } 66 67 func (p *IncrementalWorkerPicker) cleanup() { 68 var activeWorkers []*ClientWorker 69 for _, w := range p.workers { 70 if !w.Closed() { 71 activeWorkers = append(activeWorkers, w) 72 } 73 } 74 p.workers = activeWorkers 75 } 76 77 func (p *IncrementalWorkerPicker) findAvailable() int { 78 for idx, w := range p.workers { 79 if !w.IsFull() { 80 return idx 81 } 82 } 83 84 return -1 85 } 86 87 func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) { 88 p.access.Lock() 89 defer p.access.Unlock() 90 91 idx := p.findAvailable() 92 if idx >= 0 { 93 n := len(p.workers) 94 if n > 1 && idx != n-1 { 95 p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1] 96 } 97 return p.workers[idx], false, nil 98 } 99 100 p.cleanup() 101 102 worker, err := p.Factory.Create() 103 if err != nil { 104 return nil, false, err 105 } 106 p.workers = append(p.workers, worker) 107 108 if p.cleanupTask == nil { 109 p.cleanupTask = &task.Periodic{ 110 Interval: time.Second * 30, 111 Execute: p.cleanupFunc, 112 } 113 } 114 115 return worker, true, nil 116 } 117 118 func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) { 119 worker, start, err := p.pickInternal() 120 if start { 121 common.Must(p.cleanupTask.Start()) 122 } 123 124 return worker, err 125 } 126 127 type ClientWorkerFactory interface { 128 Create() (*ClientWorker, error) 129 } 130 131 type DialingWorkerFactory struct { 132 Proxy proxy.Outbound 133 Dialer internet.Dialer 134 Strategy ClientStrategy 135 } 136 137 func (f *DialingWorkerFactory) Create() (*ClientWorker, error) { 138 opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)} 139 uplinkReader, upLinkWriter := pipe.New(opts...) 140 downlinkReader, downlinkWriter := pipe.New(opts...) 141 142 c, err := NewClientWorker(transport.Link{ 143 Reader: downlinkReader, 144 Writer: upLinkWriter, 145 }, f.Strategy) 146 if err != nil { 147 return nil, err 148 } 149 150 go func(p proxy.Outbound, d internet.Dialer, c common.Closable) { 151 outbounds := []*session.Outbound{{ 152 Target: net.TCPDestination(muxCoolAddress, muxCoolPort), 153 }} 154 ctx := session.ContextWithOutbounds(context.Background(), outbounds) 155 ctx, cancel := context.WithCancel(ctx) 156 157 if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { 158 errors.New("failed to handler mux client connection").Base(err).WriteToLog() 159 } 160 common.Must(c.Close()) 161 cancel() 162 }(f.Proxy, f.Dialer, c.done) 163 164 return c, nil 165 } 166 167 type ClientStrategy struct { 168 MaxConcurrency uint32 169 MaxConnection uint32 170 } 171 172 type ClientWorker struct { 173 sessionManager *SessionManager 174 link transport.Link 175 done *done.Instance 176 strategy ClientStrategy 177 } 178 179 var ( 180 muxCoolAddress = net.DomainAddress("v1.mux.cool") 181 muxCoolPort = net.Port(9527) 182 ) 183 184 // NewClientWorker creates a new mux.Client. 185 func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) { 186 c := &ClientWorker{ 187 sessionManager: NewSessionManager(), 188 link: stream, 189 done: done.New(), 190 strategy: s, 191 } 192 193 go c.fetchOutput() 194 go c.monitor() 195 196 return c, nil 197 } 198 199 func (m *ClientWorker) TotalConnections() uint32 { 200 return uint32(m.sessionManager.Count()) 201 } 202 203 func (m *ClientWorker) ActiveConnections() uint32 { 204 return uint32(m.sessionManager.Size()) 205 } 206 207 // Closed returns true if this Client is closed. 208 func (m *ClientWorker) Closed() bool { 209 return m.done.Done() 210 } 211 212 func (m *ClientWorker) monitor() { 213 timer := time.NewTicker(time.Second * 16) 214 defer timer.Stop() 215 216 for { 217 select { 218 case <-m.done.Wait(): 219 m.sessionManager.Close() 220 common.Close(m.link.Writer) 221 common.Interrupt(m.link.Reader) 222 return 223 case <-timer.C: 224 size := m.sessionManager.Size() 225 if size == 0 && m.sessionManager.CloseIfNoSession() { 226 common.Must(m.done.Close()) 227 } 228 } 229 } 230 } 231 232 func writeFirstPayload(reader buf.Reader, writer *Writer) error { 233 err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100) 234 if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout { 235 return writer.WriteMultiBuffer(buf.MultiBuffer{}) 236 } 237 238 if err != nil { 239 return err 240 } 241 242 return nil 243 } 244 245 func fetchInput(ctx context.Context, s *Session, output buf.Writer) { 246 outbounds := session.OutboundsFromContext(ctx) 247 ob := outbounds[len(outbounds) - 1] 248 transferType := protocol.TransferTypeStream 249 if ob.Target.Network == net.Network_UDP { 250 transferType = protocol.TransferTypePacket 251 } 252 s.transferType = transferType 253 writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx)) 254 defer s.Close(false) 255 defer writer.Close() 256 257 newError("dispatching request to ", ob.Target).WriteToLog(session.ExportIDToError(ctx)) 258 if err := writeFirstPayload(s.input, writer); err != nil { 259 newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) 260 writer.hasError = true 261 return 262 } 263 264 if err := buf.Copy(s.input, writer); err != nil { 265 newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx)) 266 writer.hasError = true 267 return 268 } 269 } 270 271 func (m *ClientWorker) IsClosing() bool { 272 sm := m.sessionManager 273 if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) { 274 return true 275 } 276 return false 277 } 278 279 func (m *ClientWorker) IsFull() bool { 280 if m.IsClosing() || m.Closed() { 281 return true 282 } 283 284 sm := m.sessionManager 285 if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) { 286 return true 287 } 288 return false 289 } 290 291 func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool { 292 if m.IsFull() || m.Closed() { 293 return false 294 } 295 296 sm := m.sessionManager 297 s := sm.Allocate() 298 if s == nil { 299 return false 300 } 301 s.input = link.Reader 302 s.output = link.Writer 303 go fetchInput(ctx, s, m.link.Writer) 304 return true 305 } 306 307 func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error { 308 if meta.Option.Has(OptionData) { 309 return buf.Copy(NewStreamReader(reader), buf.Discard) 310 } 311 return nil 312 } 313 314 func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error { 315 if meta.Option.Has(OptionData) { 316 return buf.Copy(NewStreamReader(reader), buf.Discard) 317 } 318 return nil 319 } 320 321 func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error { 322 if !meta.Option.Has(OptionData) { 323 return nil 324 } 325 326 s, found := m.sessionManager.Get(meta.SessionID) 327 if !found { 328 // Notify remote peer to close this session. 329 closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream) 330 closingWriter.Close() 331 332 return buf.Copy(NewStreamReader(reader), buf.Discard) 333 } 334 335 rr := s.NewReader(reader, &meta.Target) 336 err := buf.Copy(rr, s.output) 337 if err != nil && buf.IsWriteError(err) { 338 newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog() 339 s.Close(false) 340 return buf.Copy(rr, buf.Discard) 341 } 342 343 return err 344 } 345 346 func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error { 347 if s, found := m.sessionManager.Get(meta.SessionID); found { 348 s.Close(false) 349 } 350 if meta.Option.Has(OptionData) { 351 return buf.Copy(NewStreamReader(reader), buf.Discard) 352 } 353 return nil 354 } 355 356 func (m *ClientWorker) fetchOutput() { 357 defer func() { 358 common.Must(m.done.Close()) 359 }() 360 361 reader := &buf.BufferedReader{Reader: m.link.Reader} 362 363 var meta FrameMetadata 364 for { 365 err := meta.Unmarshal(reader) 366 if err != nil { 367 if errors.Cause(err) != io.EOF { 368 newError("failed to read metadata").Base(err).WriteToLog() 369 } 370 break 371 } 372 373 switch meta.SessionStatus { 374 case SessionStatusKeepAlive: 375 err = m.handleStatueKeepAlive(&meta, reader) 376 case SessionStatusEnd: 377 err = m.handleStatusEnd(&meta, reader) 378 case SessionStatusNew: 379 err = m.handleStatusNew(&meta, reader) 380 case SessionStatusKeep: 381 err = m.handleStatusKeep(&meta, reader) 382 default: 383 status := meta.SessionStatus 384 newError("unknown status: ", status).AtError().WriteToLog() 385 return 386 } 387 388 if err != nil { 389 newError("failed to process data").Base(err).WriteToLog() 390 return 391 } 392 } 393 }