github.com/xraypb/xray-core@v1.6.6/common/mux/client.go (about) 1 package mux 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/xraypb/xray-core/common" 10 "github.com/xraypb/xray-core/common/buf" 11 "github.com/xraypb/xray-core/common/errors" 12 "github.com/xraypb/xray-core/common/net" 13 "github.com/xraypb/xray-core/common/protocol" 14 "github.com/xraypb/xray-core/common/session" 15 "github.com/xraypb/xray-core/common/signal/done" 16 "github.com/xraypb/xray-core/common/task" 17 "github.com/xraypb/xray-core/proxy" 18 "github.com/xraypb/xray-core/transport" 19 "github.com/xraypb/xray-core/transport/internet" 20 "github.com/xraypb/xray-core/transport/pipe" 21 ) 22 23 type ClientManager struct { 24 Enabled bool // wheather mux is enabled from user config 25 Picker WorkerPicker 26 } 27 28 func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error { 29 for i := 0; i < 16; i++ { 30 worker, err := m.Picker.PickAvailable() 31 if err != nil { 32 return err 33 } 34 if worker.Dispatch(ctx, link) { 35 return nil 36 } 37 } 38 39 return newError("unable to find an available mux client").AtWarning() 40 } 41 42 type WorkerPicker interface { 43 PickAvailable() (*ClientWorker, error) 44 } 45 46 type IncrementalWorkerPicker struct { 47 Factory ClientWorkerFactory 48 49 access sync.Mutex 50 workers []*ClientWorker 51 cleanupTask *task.Periodic 52 } 53 54 func (p *IncrementalWorkerPicker) cleanupFunc() error { 55 p.access.Lock() 56 defer p.access.Unlock() 57 58 if len(p.workers) == 0 { 59 return newError("no worker") 60 } 61 62 p.cleanup() 63 return nil 64 } 65 66 func (p *IncrementalWorkerPicker) cleanup() { 67 var activeWorkers []*ClientWorker 68 for _, w := range p.workers { 69 if !w.Closed() { 70 activeWorkers = append(activeWorkers, w) 71 } 72 } 73 p.workers = activeWorkers 74 } 75 76 func (p *IncrementalWorkerPicker) findAvailable() int { 77 for idx, w := range p.workers { 78 if !w.IsFull() { 79 return idx 80 } 81 } 82 83 return -1 84 } 85 86 func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) { 87 p.access.Lock() 88 defer p.access.Unlock() 89 90 idx := p.findAvailable() 91 if idx >= 0 { 92 n := len(p.workers) 93 if n > 1 && idx != n-1 { 94 p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1] 95 } 96 return p.workers[idx], false, nil 97 } 98 99 p.cleanup() 100 101 worker, err := p.Factory.Create() 102 if err != nil { 103 return nil, false, err 104 } 105 p.workers = append(p.workers, worker) 106 107 if p.cleanupTask == nil { 108 p.cleanupTask = &task.Periodic{ 109 Interval: time.Second * 30, 110 Execute: p.cleanupFunc, 111 } 112 } 113 114 return worker, true, nil 115 } 116 117 func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) { 118 worker, start, err := p.pickInternal() 119 if start { 120 common.Must(p.cleanupTask.Start()) 121 } 122 123 return worker, err 124 } 125 126 type ClientWorkerFactory interface { 127 Create() (*ClientWorker, error) 128 } 129 130 type DialingWorkerFactory struct { 131 Proxy proxy.Outbound 132 Dialer internet.Dialer 133 Strategy ClientStrategy 134 } 135 136 func (f *DialingWorkerFactory) Create() (*ClientWorker, error) { 137 opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)} 138 uplinkReader, upLinkWriter := pipe.New(opts...) 139 downlinkReader, downlinkWriter := pipe.New(opts...) 140 141 c, err := NewClientWorker(transport.Link{ 142 Reader: downlinkReader, 143 Writer: upLinkWriter, 144 }, f.Strategy) 145 if err != nil { 146 return nil, err 147 } 148 149 go func(p proxy.Outbound, d internet.Dialer, c common.Closable) { 150 ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{ 151 Target: net.TCPDestination(muxCoolAddress, muxCoolPort), 152 }) 153 ctx, cancel := context.WithCancel(ctx) 154 155 if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { 156 errors.New("failed to handler mux client connection").Base(err).WriteToLog() 157 } 158 common.Must(c.Close()) 159 cancel() 160 }(f.Proxy, f.Dialer, c.done) 161 162 return c, nil 163 } 164 165 type ClientStrategy struct { 166 MaxConcurrency uint32 167 MaxConnection uint32 168 } 169 170 type ClientWorker struct { 171 sessionManager *SessionManager 172 link transport.Link 173 done *done.Instance 174 strategy ClientStrategy 175 } 176 177 var ( 178 muxCoolAddress = net.DomainAddress("v1.mux.cool") 179 muxCoolPort = net.Port(9527) 180 ) 181 182 // NewClientWorker creates a new mux.Client. 183 func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) { 184 c := &ClientWorker{ 185 sessionManager: NewSessionManager(), 186 link: stream, 187 done: done.New(), 188 strategy: s, 189 } 190 191 go c.fetchOutput() 192 go c.monitor() 193 194 return c, nil 195 } 196 197 func (m *ClientWorker) TotalConnections() uint32 { 198 return uint32(m.sessionManager.Count()) 199 } 200 201 func (m *ClientWorker) ActiveConnections() uint32 { 202 return uint32(m.sessionManager.Size()) 203 } 204 205 // Closed returns true if this Client is closed. 206 func (m *ClientWorker) Closed() bool { 207 return m.done.Done() 208 } 209 210 func (m *ClientWorker) monitor() { 211 timer := time.NewTicker(time.Second * 16) 212 defer timer.Stop() 213 214 for { 215 select { 216 case <-m.done.Wait(): 217 m.sessionManager.Close() 218 common.Close(m.link.Writer) 219 common.Interrupt(m.link.Reader) 220 return 221 case <-timer.C: 222 size := m.sessionManager.Size() 223 if size == 0 && m.sessionManager.CloseIfNoSession() { 224 common.Must(m.done.Close()) 225 } 226 } 227 } 228 } 229 230 func writeFirstPayload(reader buf.Reader, writer *Writer) error { 231 err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100) 232 if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout { 233 return writer.WriteMultiBuffer(buf.MultiBuffer{}) 234 } 235 236 if err != nil { 237 return err 238 } 239 240 return nil 241 } 242 243 func fetchInput(ctx context.Context, s *Session, output buf.Writer) { 244 dest := session.OutboundFromContext(ctx).Target 245 transferType := protocol.TransferTypeStream 246 if dest.Network == net.Network_UDP { 247 transferType = protocol.TransferTypePacket 248 } 249 s.transferType = transferType 250 writer := NewWriter(s.ID, dest, output, transferType) 251 defer s.Close() 252 defer writer.Close() 253 254 newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx)) 255 if err := writeFirstPayload(s.input, writer); err != nil { 256 newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) 257 writer.hasError = true 258 common.Interrupt(s.input) 259 return 260 } 261 262 if err := buf.Copy(s.input, writer); err != nil { 263 newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx)) 264 writer.hasError = true 265 common.Interrupt(s.input) 266 return 267 } 268 } 269 270 func (m *ClientWorker) IsClosing() bool { 271 sm := m.sessionManager 272 if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) { 273 return true 274 } 275 return false 276 } 277 278 func (m *ClientWorker) IsFull() bool { 279 if m.IsClosing() || m.Closed() { 280 return true 281 } 282 283 sm := m.sessionManager 284 if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) { 285 return true 286 } 287 return false 288 } 289 290 func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool { 291 if m.IsFull() || m.Closed() { 292 return false 293 } 294 295 sm := m.sessionManager 296 s := sm.Allocate() 297 if s == nil { 298 return false 299 } 300 s.input = link.Reader 301 s.output = link.Writer 302 go fetchInput(ctx, s, m.link.Writer) 303 return true 304 } 305 306 func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error { 307 if meta.Option.Has(OptionData) { 308 return buf.Copy(NewStreamReader(reader), buf.Discard) 309 } 310 return nil 311 } 312 313 func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error { 314 if meta.Option.Has(OptionData) { 315 return buf.Copy(NewStreamReader(reader), buf.Discard) 316 } 317 return nil 318 } 319 320 func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error { 321 if !meta.Option.Has(OptionData) { 322 return nil 323 } 324 325 s, found := m.sessionManager.Get(meta.SessionID) 326 if !found { 327 // Notify remote peer to close this session. 328 closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream) 329 closingWriter.Close() 330 331 return buf.Copy(NewStreamReader(reader), buf.Discard) 332 } 333 334 rr := s.NewReader(reader, &meta.Target) 335 err := buf.Copy(rr, s.output) 336 if err != nil && buf.IsWriteError(err) { 337 newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog() 338 339 // Notify remote peer to close this session. 340 closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream) 341 closingWriter.Close() 342 343 drainErr := buf.Copy(rr, buf.Discard) 344 common.Interrupt(s.input) 345 s.Close() 346 return drainErr 347 } 348 349 return err 350 } 351 352 func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error { 353 if s, found := m.sessionManager.Get(meta.SessionID); found { 354 if meta.Option.Has(OptionError) { 355 common.Interrupt(s.input) 356 common.Interrupt(s.output) 357 } 358 s.Close() 359 } 360 if meta.Option.Has(OptionData) { 361 return buf.Copy(NewStreamReader(reader), buf.Discard) 362 } 363 return nil 364 } 365 366 func (m *ClientWorker) fetchOutput() { 367 defer func() { 368 common.Must(m.done.Close()) 369 }() 370 371 reader := &buf.BufferedReader{Reader: m.link.Reader} 372 373 var meta FrameMetadata 374 for { 375 err := meta.Unmarshal(reader) 376 if err != nil { 377 if errors.Cause(err) != io.EOF { 378 newError("failed to read metadata").Base(err).WriteToLog() 379 } 380 break 381 } 382 383 switch meta.SessionStatus { 384 case SessionStatusKeepAlive: 385 err = m.handleStatueKeepAlive(&meta, reader) 386 case SessionStatusEnd: 387 err = m.handleStatusEnd(&meta, reader) 388 case SessionStatusNew: 389 err = m.handleStatusNew(&meta, reader) 390 case SessionStatusKeep: 391 err = m.handleStatusKeep(&meta, reader) 392 default: 393 status := meta.SessionStatus 394 newError("unknown status: ", status).AtError().WriteToLog() 395 return 396 } 397 398 if err != nil { 399 newError("failed to process data").Base(err).WriteToLog() 400 return 401 } 402 } 403 }