github.com/eagleql/xray-core@v1.4.4/common/mux/client.go (about) 1 package mux 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/eagleql/xray-core/common" 10 "github.com/eagleql/xray-core/common/buf" 11 "github.com/eagleql/xray-core/common/errors" 12 "github.com/eagleql/xray-core/common/net" 13 "github.com/eagleql/xray-core/common/protocol" 14 "github.com/eagleql/xray-core/common/session" 15 "github.com/eagleql/xray-core/common/signal/done" 16 "github.com/eagleql/xray-core/common/task" 17 "github.com/eagleql/xray-core/proxy" 18 "github.com/eagleql/xray-core/transport" 19 "github.com/eagleql/xray-core/transport/internet" 20 "github.com/eagleql/xray-core/transport/pipe" 21 ) 22 23 type ClientManager struct { 24 Enabled bool // wheather mux is enabled from user config 25 Picker WorkerPicker 26 } 27 28 func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error { 29 for i := 0; i < 16; i++ { 30 worker, err := m.Picker.PickAvailable() 31 if err != nil { 32 return err 33 } 34 if worker.Dispatch(ctx, link) { 35 return nil 36 } 37 } 38 39 return newError("unable to find an available mux client").AtWarning() 40 } 41 42 type WorkerPicker interface { 43 PickAvailable() (*ClientWorker, error) 44 } 45 46 type IncrementalWorkerPicker struct { 47 Factory ClientWorkerFactory 48 49 access sync.Mutex 50 workers []*ClientWorker 51 cleanupTask *task.Periodic 52 } 53 54 func (p *IncrementalWorkerPicker) cleanupFunc() error { 55 p.access.Lock() 56 defer p.access.Unlock() 57 58 if len(p.workers) == 0 { 59 return newError("no worker") 60 } 61 62 p.cleanup() 63 return nil 64 } 65 66 func (p *IncrementalWorkerPicker) cleanup() { 67 var activeWorkers []*ClientWorker 68 for _, w := range p.workers { 69 if !w.Closed() { 70 activeWorkers = append(activeWorkers, w) 71 } 72 } 73 p.workers = activeWorkers 74 } 75 76 func (p *IncrementalWorkerPicker) findAvailable() int { 77 for idx, w := range p.workers { 78 if !w.IsFull() { 79 return idx 80 } 81 } 82 83 return -1 84 } 85 86 func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) { 87 p.access.Lock() 88 defer p.access.Unlock() 89 90 idx := p.findAvailable() 91 if idx >= 0 { 92 n := len(p.workers) 93 if n > 1 && idx != n-1 { 94 p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1] 95 } 96 return p.workers[idx], false, nil 97 } 98 99 p.cleanup() 100 101 worker, err := p.Factory.Create() 102 if err != nil { 103 return nil, false, err 104 } 105 p.workers = append(p.workers, worker) 106 107 if p.cleanupTask == nil { 108 p.cleanupTask = &task.Periodic{ 109 Interval: time.Second * 30, 110 Execute: p.cleanupFunc, 111 } 112 } 113 114 return worker, true, nil 115 } 116 117 func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) { 118 worker, start, err := p.pickInternal() 119 if start { 120 common.Must(p.cleanupTask.Start()) 121 } 122 123 return worker, err 124 } 125 126 type ClientWorkerFactory interface { 127 Create() (*ClientWorker, error) 128 } 129 130 type DialingWorkerFactory struct { 131 Proxy proxy.Outbound 132 Dialer internet.Dialer 133 Strategy ClientStrategy 134 } 135 136 func (f *DialingWorkerFactory) Create() (*ClientWorker, error) { 137 opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)} 138 uplinkReader, upLinkWriter := pipe.New(opts...) 139 downlinkReader, downlinkWriter := pipe.New(opts...) 140 141 c, err := NewClientWorker(transport.Link{ 142 Reader: downlinkReader, 143 Writer: upLinkWriter, 144 }, f.Strategy) 145 146 if err != nil { 147 return nil, err 148 } 149 150 go func(p proxy.Outbound, d internet.Dialer, c common.Closable) { 151 ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{ 152 Target: net.TCPDestination(muxCoolAddress, muxCoolPort), 153 }) 154 ctx, cancel := context.WithCancel(ctx) 155 156 if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { 157 errors.New("failed to handler mux client connection").Base(err).WriteToLog() 158 } 159 common.Must(c.Close()) 160 cancel() 161 }(f.Proxy, f.Dialer, c.done) 162 163 return c, nil 164 } 165 166 type ClientStrategy struct { 167 MaxConcurrency uint32 168 MaxConnection uint32 169 } 170 171 type ClientWorker struct { 172 sessionManager *SessionManager 173 link transport.Link 174 done *done.Instance 175 strategy ClientStrategy 176 } 177 178 var muxCoolAddress = net.DomainAddress("v1.mux.cool") 179 var muxCoolPort = net.Port(9527) 180 181 // NewClientWorker creates a new mux.Client. 182 func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) { 183 c := &ClientWorker{ 184 sessionManager: NewSessionManager(), 185 link: stream, 186 done: done.New(), 187 strategy: s, 188 } 189 190 go c.fetchOutput() 191 go c.monitor() 192 193 return c, nil 194 } 195 196 func (m *ClientWorker) TotalConnections() uint32 { 197 return uint32(m.sessionManager.Count()) 198 } 199 200 func (m *ClientWorker) ActiveConnections() uint32 { 201 return uint32(m.sessionManager.Size()) 202 } 203 204 // Closed returns true if this Client is closed. 205 func (m *ClientWorker) Closed() bool { 206 return m.done.Done() 207 } 208 209 func (m *ClientWorker) monitor() { 210 timer := time.NewTicker(time.Second * 16) 211 defer timer.Stop() 212 213 for { 214 select { 215 case <-m.done.Wait(): 216 m.sessionManager.Close() 217 common.Close(m.link.Writer) 218 common.Interrupt(m.link.Reader) 219 return 220 case <-timer.C: 221 size := m.sessionManager.Size() 222 if size == 0 && m.sessionManager.CloseIfNoSession() { 223 common.Must(m.done.Close()) 224 } 225 } 226 } 227 } 228 229 func writeFirstPayload(reader buf.Reader, writer *Writer) error { 230 err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100) 231 if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout { 232 return writer.WriteMultiBuffer(buf.MultiBuffer{}) 233 } 234 235 if err != nil { 236 return err 237 } 238 239 return nil 240 } 241 242 func fetchInput(ctx context.Context, s *Session, output buf.Writer) { 243 dest := session.OutboundFromContext(ctx).Target 244 transferType := protocol.TransferTypeStream 245 if dest.Network == net.Network_UDP { 246 transferType = protocol.TransferTypePacket 247 } 248 s.transferType = transferType 249 writer := NewWriter(s.ID, dest, output, transferType) 250 defer s.Close() 251 defer writer.Close() 252 253 newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx)) 254 if err := writeFirstPayload(s.input, writer); err != nil { 255 newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) 256 writer.hasError = true 257 common.Interrupt(s.input) 258 return 259 } 260 261 if err := buf.Copy(s.input, writer); err != nil { 262 newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx)) 263 writer.hasError = true 264 common.Interrupt(s.input) 265 return 266 } 267 } 268 269 func (m *ClientWorker) IsClosing() bool { 270 sm := m.sessionManager 271 if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) { 272 return true 273 } 274 return false 275 } 276 277 func (m *ClientWorker) IsFull() bool { 278 if m.IsClosing() || m.Closed() { 279 return true 280 } 281 282 sm := m.sessionManager 283 if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) { 284 return true 285 } 286 return false 287 } 288 289 func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool { 290 if m.IsFull() || m.Closed() { 291 return false 292 } 293 294 sm := m.sessionManager 295 s := sm.Allocate() 296 if s == nil { 297 return false 298 } 299 s.input = link.Reader 300 s.output = link.Writer 301 go fetchInput(ctx, s, m.link.Writer) 302 return true 303 } 304 305 func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error { 306 if meta.Option.Has(OptionData) { 307 return buf.Copy(NewStreamReader(reader), buf.Discard) 308 } 309 return nil 310 } 311 312 func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error { 313 if meta.Option.Has(OptionData) { 314 return buf.Copy(NewStreamReader(reader), buf.Discard) 315 } 316 return nil 317 } 318 319 func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error { 320 if !meta.Option.Has(OptionData) { 321 return nil 322 } 323 324 s, found := m.sessionManager.Get(meta.SessionID) 325 if !found { 326 // Notify remote peer to close this session. 327 closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream) 328 closingWriter.Close() 329 330 return buf.Copy(NewStreamReader(reader), buf.Discard) 331 } 332 333 rr := s.NewReader(reader, &meta.Target) 334 err := buf.Copy(rr, s.output) 335 if err != nil && buf.IsWriteError(err) { 336 newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog() 337 338 // Notify remote peer to close this session. 339 closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream) 340 closingWriter.Close() 341 342 drainErr := buf.Copy(rr, buf.Discard) 343 common.Interrupt(s.input) 344 s.Close() 345 return drainErr 346 } 347 348 return err 349 } 350 351 func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error { 352 if s, found := m.sessionManager.Get(meta.SessionID); found { 353 if meta.Option.Has(OptionError) { 354 common.Interrupt(s.input) 355 common.Interrupt(s.output) 356 } 357 s.Close() 358 } 359 if meta.Option.Has(OptionData) { 360 return buf.Copy(NewStreamReader(reader), buf.Discard) 361 } 362 return nil 363 } 364 365 func (m *ClientWorker) fetchOutput() { 366 defer func() { 367 common.Must(m.done.Close()) 368 }() 369 370 reader := &buf.BufferedReader{Reader: m.link.Reader} 371 372 var meta FrameMetadata 373 for { 374 err := meta.Unmarshal(reader) 375 if err != nil { 376 if errors.Cause(err) != io.EOF { 377 newError("failed to read metadata").Base(err).WriteToLog() 378 } 379 break 380 } 381 382 switch meta.SessionStatus { 383 case SessionStatusKeepAlive: 384 err = m.handleStatueKeepAlive(&meta, reader) 385 case SessionStatusEnd: 386 err = m.handleStatusEnd(&meta, reader) 387 case SessionStatusNew: 388 err = m.handleStatusNew(&meta, reader) 389 case SessionStatusKeep: 390 err = m.handleStatusKeep(&meta, reader) 391 default: 392 status := meta.SessionStatus 393 newError("unknown status: ", status).AtError().WriteToLog() 394 return 395 } 396 397 if err != nil { 398 newError("failed to process data").Base(err).WriteToLog() 399 return 400 } 401 } 402 }