github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/common/mux/client.go (about) 1 package mux 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/xmplusdev/xmcore/common" 10 "github.com/xmplusdev/xmcore/common/buf" 11 "github.com/xmplusdev/xmcore/common/errors" 12 "github.com/xmplusdev/xmcore/common/net" 13 "github.com/xmplusdev/xmcore/common/protocol" 14 "github.com/xmplusdev/xmcore/common/session" 15 "github.com/xmplusdev/xmcore/common/signal/done" 16 "github.com/xmplusdev/xmcore/common/task" 17 "github.com/xmplusdev/xmcore/common/xudp" 18 "github.com/xmplusdev/xmcore/proxy" 19 "github.com/xmplusdev/xmcore/transport" 20 "github.com/xmplusdev/xmcore/transport/internet" 21 "github.com/xmplusdev/xmcore/transport/pipe" 22 ) 23 24 type ClientManager struct { 25 Enabled bool // wheather mux is enabled from user config 26 Picker WorkerPicker 27 } 28 29 func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error { 30 for i := 0; i < 16; i++ { 31 worker, err := m.Picker.PickAvailable() 32 if err != nil { 33 return err 34 } 35 if worker.Dispatch(ctx, link) { 36 return nil 37 } 38 } 39 40 return newError("unable to find an available mux client").AtWarning() 41 } 42 43 type WorkerPicker interface { 44 PickAvailable() (*ClientWorker, error) 45 } 46 47 type IncrementalWorkerPicker struct { 48 Factory ClientWorkerFactory 49 50 access sync.Mutex 51 workers []*ClientWorker 52 cleanupTask *task.Periodic 53 } 54 55 func (p *IncrementalWorkerPicker) cleanupFunc() error { 56 p.access.Lock() 57 defer p.access.Unlock() 58 59 if len(p.workers) == 0 { 60 return newError("no worker") 61 } 62 63 p.cleanup() 64 return nil 65 } 66 67 func (p *IncrementalWorkerPicker) cleanup() { 68 var activeWorkers []*ClientWorker 69 for _, w := range p.workers { 70 if !w.Closed() { 71 activeWorkers = append(activeWorkers, w) 72 } 73 } 74 p.workers = activeWorkers 75 } 76 77 func (p *IncrementalWorkerPicker) findAvailable() int { 78 for idx, w := range p.workers { 79 if !w.IsFull() { 80 return idx 81 } 82 } 83 84 return -1 85 } 86 87 func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) { 88 p.access.Lock() 89 defer p.access.Unlock() 90 91 idx := p.findAvailable() 92 if idx >= 0 { 93 n := len(p.workers) 94 if n > 1 && idx != n-1 { 95 p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1] 96 } 97 return p.workers[idx], false, nil 98 } 99 100 p.cleanup() 101 102 worker, err := p.Factory.Create() 103 if err != nil { 104 return nil, false, err 105 } 106 p.workers = append(p.workers, worker) 107 108 if p.cleanupTask == nil { 109 p.cleanupTask = &task.Periodic{ 110 Interval: time.Second * 30, 111 Execute: p.cleanupFunc, 112 } 113 } 114 115 return worker, true, nil 116 } 117 118 func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) { 119 worker, start, err := p.pickInternal() 120 if start { 121 common.Must(p.cleanupTask.Start()) 122 } 123 124 return worker, err 125 } 126 127 type ClientWorkerFactory interface { 128 Create() (*ClientWorker, error) 129 } 130 131 type DialingWorkerFactory struct { 132 Proxy proxy.Outbound 133 Dialer internet.Dialer 134 Strategy ClientStrategy 135 } 136 137 func (f *DialingWorkerFactory) Create() (*ClientWorker, error) { 138 opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)} 139 uplinkReader, upLinkWriter := pipe.New(opts...) 140 downlinkReader, downlinkWriter := pipe.New(opts...) 141 142 c, err := NewClientWorker(transport.Link{ 143 Reader: downlinkReader, 144 Writer: upLinkWriter, 145 }, f.Strategy) 146 if err != nil { 147 return nil, err 148 } 149 150 go func(p proxy.Outbound, d internet.Dialer, c common.Closable) { 151 ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{ 152 Target: net.TCPDestination(muxCoolAddress, muxCoolPort), 153 }) 154 ctx, cancel := context.WithCancel(ctx) 155 156 if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { 157 errors.New("failed to handler mux client connection").Base(err).WriteToLog() 158 } 159 common.Must(c.Close()) 160 cancel() 161 }(f.Proxy, f.Dialer, c.done) 162 163 return c, nil 164 } 165 166 type ClientStrategy struct { 167 MaxConcurrency uint32 168 MaxConnection uint32 169 } 170 171 type ClientWorker struct { 172 sessionManager *SessionManager 173 link transport.Link 174 done *done.Instance 175 strategy ClientStrategy 176 } 177 178 var ( 179 muxCoolAddress = net.DomainAddress("v1.mux.cool") 180 muxCoolPort = net.Port(9527) 181 ) 182 183 // NewClientWorker creates a new mux.Client. 184 func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) { 185 c := &ClientWorker{ 186 sessionManager: NewSessionManager(), 187 link: stream, 188 done: done.New(), 189 strategy: s, 190 } 191 192 go c.fetchOutput() 193 go c.monitor() 194 195 return c, nil 196 } 197 198 func (m *ClientWorker) TotalConnections() uint32 { 199 return uint32(m.sessionManager.Count()) 200 } 201 202 func (m *ClientWorker) ActiveConnections() uint32 { 203 return uint32(m.sessionManager.Size()) 204 } 205 206 // Closed returns true if this Client is closed. 207 func (m *ClientWorker) Closed() bool { 208 return m.done.Done() 209 } 210 211 func (m *ClientWorker) monitor() { 212 timer := time.NewTicker(time.Second * 16) 213 defer timer.Stop() 214 215 for { 216 select { 217 case <-m.done.Wait(): 218 m.sessionManager.Close() 219 common.Close(m.link.Writer) 220 common.Interrupt(m.link.Reader) 221 return 222 case <-timer.C: 223 size := m.sessionManager.Size() 224 if size == 0 && m.sessionManager.CloseIfNoSession() { 225 common.Must(m.done.Close()) 226 } 227 } 228 } 229 } 230 231 func writeFirstPayload(reader buf.Reader, writer *Writer) error { 232 err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100) 233 if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout { 234 return writer.WriteMultiBuffer(buf.MultiBuffer{}) 235 } 236 237 if err != nil { 238 return err 239 } 240 241 return nil 242 } 243 244 func fetchInput(ctx context.Context, s *Session, output buf.Writer) { 245 dest := session.OutboundFromContext(ctx).Target 246 transferType := protocol.TransferTypeStream 247 if dest.Network == net.Network_UDP { 248 transferType = protocol.TransferTypePacket 249 } 250 s.transferType = transferType 251 writer := NewWriter(s.ID, dest, output, transferType, xudp.GetGlobalID(ctx)) 252 defer s.Close(false) 253 defer writer.Close() 254 255 newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx)) 256 if err := writeFirstPayload(s.input, writer); err != nil { 257 newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) 258 writer.hasError = true 259 return 260 } 261 262 if err := buf.Copy(s.input, writer); err != nil { 263 newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx)) 264 writer.hasError = true 265 return 266 } 267 } 268 269 func (m *ClientWorker) IsClosing() bool { 270 sm := m.sessionManager 271 if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) { 272 return true 273 } 274 return false 275 } 276 277 func (m *ClientWorker) IsFull() bool { 278 if m.IsClosing() || m.Closed() { 279 return true 280 } 281 282 sm := m.sessionManager 283 if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) { 284 return true 285 } 286 return false 287 } 288 289 func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool { 290 if m.IsFull() || m.Closed() { 291 return false 292 } 293 294 sm := m.sessionManager 295 s := sm.Allocate() 296 if s == nil { 297 return false 298 } 299 s.input = link.Reader 300 s.output = link.Writer 301 go fetchInput(ctx, s, m.link.Writer) 302 return true 303 } 304 305 func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error { 306 if meta.Option.Has(OptionData) { 307 return buf.Copy(NewStreamReader(reader), buf.Discard) 308 } 309 return nil 310 } 311 312 func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error { 313 if meta.Option.Has(OptionData) { 314 return buf.Copy(NewStreamReader(reader), buf.Discard) 315 } 316 return nil 317 } 318 319 func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error { 320 if !meta.Option.Has(OptionData) { 321 return nil 322 } 323 324 s, found := m.sessionManager.Get(meta.SessionID) 325 if !found { 326 // Notify remote peer to close this session. 327 closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream) 328 closingWriter.Close() 329 330 return buf.Copy(NewStreamReader(reader), buf.Discard) 331 } 332 333 rr := s.NewReader(reader, &meta.Target) 334 err := buf.Copy(rr, s.output) 335 if err != nil && buf.IsWriteError(err) { 336 newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog() 337 s.Close(false) 338 return buf.Copy(rr, buf.Discard) 339 } 340 341 return err 342 } 343 344 func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error { 345 if s, found := m.sessionManager.Get(meta.SessionID); found { 346 s.Close(false) 347 } 348 if meta.Option.Has(OptionData) { 349 return buf.Copy(NewStreamReader(reader), buf.Discard) 350 } 351 return nil 352 } 353 354 func (m *ClientWorker) fetchOutput() { 355 defer func() { 356 common.Must(m.done.Close()) 357 }() 358 359 reader := &buf.BufferedReader{Reader: m.link.Reader} 360 361 var meta FrameMetadata 362 for { 363 err := meta.Unmarshal(reader) 364 if err != nil { 365 if errors.Cause(err) != io.EOF { 366 newError("failed to read metadata").Base(err).WriteToLog() 367 } 368 break 369 } 370 371 switch meta.SessionStatus { 372 case SessionStatusKeepAlive: 373 err = m.handleStatueKeepAlive(&meta, reader) 374 case SessionStatusEnd: 375 err = m.handleStatusEnd(&meta, reader) 376 case SessionStatusNew: 377 err = m.handleStatusNew(&meta, reader) 378 case SessionStatusKeep: 379 err = m.handleStatusKeep(&meta, reader) 380 default: 381 status := meta.SessionStatus 382 newError("unknown status: ", status).AtError().WriteToLog() 383 return 384 } 385 386 if err != nil { 387 newError("failed to process data").Base(err).WriteToLog() 388 return 389 } 390 } 391 }