github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/grid/muxserver.go (about) 1 // Copyright (c) 2015-2023 MinIO, Inc. 2 // 3 // This file is part of MinIO Object Storage stack 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package grid 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 "sync" 25 "sync/atomic" 26 "time" 27 28 xioutil "github.com/minio/minio/internal/ioutil" 29 "github.com/minio/minio/internal/logger" 30 ) 31 32 const lastPingThreshold = 4 * clientPingInterval 33 34 type muxServer struct { 35 ID uint64 36 LastPing int64 37 SendSeq, RecvSeq uint32 38 Resp chan []byte 39 BaseFlags Flags 40 ctx context.Context 41 cancel context.CancelFunc 42 inbound chan []byte 43 parent *Connection 44 sendMu sync.Mutex 45 recvMu sync.Mutex 46 outBlock chan struct{} 47 } 48 49 func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer { 50 var cancel context.CancelFunc 51 ctx = setCaller(ctx, c.remote) 52 if msg.DeadlineMS > 0 { 53 ctx, cancel = context.WithTimeout(ctx, time.Duration(msg.DeadlineMS)*time.Millisecond) 54 } else { 55 ctx, cancel = context.WithCancel(ctx) 56 } 57 m := muxServer{ 58 ID: msg.MuxID, 59 RecvSeq: msg.Seq + 1, 60 SendSeq: msg.Seq, 61 ctx: ctx, 62 cancel: cancel, 63 parent: c, 64 LastPing: time.Now().Unix(), 65 BaseFlags: c.baseFlags, 66 } 67 go func() { 68 // TODO: Handle 69 }() 70 71 return &m 72 } 73 74 func newMuxStream(ctx context.Context, msg message, c *Connection, handler StreamHandler) *muxServer { 75 var cancel context.CancelFunc 76 ctx = setCaller(ctx, c.remote) 77 if len(handler.Subroute) > 0 { 78 ctx = setSubroute(ctx, handler.Subroute) 79 } 80 if msg.DeadlineMS > 0 { 81 ctx, cancel = context.WithTimeout(ctx, time.Duration(msg.DeadlineMS)*time.Millisecond+c.addDeadline) 82 } else { 83 ctx, cancel = context.WithCancel(ctx) 84 } 85 86 send := make(chan []byte) 87 inboundCap, outboundCap := handler.InCapacity, handler.OutCapacity 88 if outboundCap <= 0 { 89 outboundCap = 1 90 } 91 92 m := muxServer{ 93 ID: msg.MuxID, 94 RecvSeq: msg.Seq + 1, 95 SendSeq: msg.Seq, 96 ctx: ctx, 97 cancel: cancel, 98 parent: c, 99 inbound: nil, 100 outBlock: make(chan struct{}, outboundCap), 101 LastPing: time.Now().Unix(), 102 BaseFlags: c.baseFlags, 103 } 104 // Acknowledge Mux created. 105 // Send async. 106 var wg sync.WaitGroup 107 wg.Add(1) 108 go func() { 109 defer wg.Done() 110 var ack message 111 ack.Op = OpAckMux 112 ack.Flags = m.BaseFlags 113 ack.MuxID = m.ID 114 m.send(ack) 115 if debugPrint { 116 fmt.Println("connected stream mux:", ack.MuxID) 117 } 118 }() 119 120 // Data inbound to the handler 121 var handlerIn chan []byte 122 if inboundCap > 0 { 123 m.inbound = make(chan []byte, inboundCap) 124 handlerIn = make(chan []byte, 1) 125 go func(inbound chan []byte) { 126 wg.Wait() 127 defer xioutil.SafeClose(handlerIn) 128 m.handleInbound(c, inbound, handlerIn) 129 }(m.inbound) 130 } 131 // Fill outbound block. 132 // Each token represents a message that can be sent to the client without blocking. 133 // The client will refill the tokens as they confirm delivery of the messages. 134 for i := 0; i < outboundCap; i++ { 135 m.outBlock <- struct{}{} 136 } 137 138 // Handler goroutine. 139 var handlerErr atomic.Pointer[RemoteErr] 140 go func() { 141 wg.Wait() 142 defer xioutil.SafeClose(send) 143 err := m.handleRequests(ctx, msg, send, handler, handlerIn) 144 if err != nil { 145 handlerErr.Store(err) 146 } 147 }() 148 149 // Response sender goroutine... 150 go func(outBlock <-chan struct{}) { 151 wg.Wait() 152 defer m.parent.deleteMux(true, m.ID) 153 m.sendResponses(ctx, send, c, &handlerErr, outBlock) 154 }(m.outBlock) 155 156 // Remote aliveness check if needed. 157 if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) { 158 go func() { 159 wg.Wait() 160 m.checkRemoteAlive() 161 }() 162 } 163 return &m 164 } 165 166 // handleInbound sends unblocks when we have delivered the message to the handler. 167 func (m *muxServer) handleInbound(c *Connection, inbound <-chan []byte, handlerIn chan<- []byte) { 168 for in := range inbound { 169 handlerIn <- in 170 m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags}) 171 } 172 } 173 174 // sendResponses will send responses to the client. 175 func (m *muxServer) sendResponses(ctx context.Context, toSend <-chan []byte, c *Connection, handlerErr *atomic.Pointer[RemoteErr], outBlock <-chan struct{}) { 176 for { 177 // Process outgoing message. 178 var payload []byte 179 var ok bool 180 select { 181 case payload, ok = <-toSend: 182 case <-ctx.Done(): 183 return 184 } 185 select { 186 case <-ctx.Done(): 187 return 188 case <-outBlock: 189 } 190 msg := message{ 191 MuxID: m.ID, 192 Op: OpMuxServerMsg, 193 Flags: c.baseFlags, 194 } 195 if !ok { 196 hErr := handlerErr.Load() 197 if debugPrint { 198 fmt.Println("muxServer: Mux", m.ID, "send EOF", hErr) 199 } 200 msg.Flags |= FlagEOF 201 if hErr != nil { 202 msg.Flags |= FlagPayloadIsErr 203 msg.Payload = []byte(*hErr) 204 } 205 msg.setZeroPayloadFlag() 206 m.send(msg) 207 return 208 } 209 msg.Payload = payload 210 msg.setZeroPayloadFlag() 211 m.send(msg) 212 } 213 } 214 215 // handleRequests will handle the requests from the client and call the handler function. 216 func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<- []byte, handler StreamHandler, handlerIn <-chan []byte) (handlerErr *RemoteErr) { 217 start := time.Now() 218 defer func() { 219 if debugPrint { 220 fmt.Println("Mux", m.ID, "Handler took", time.Since(start).Round(time.Millisecond)) 221 } 222 if r := recover(); r != nil { 223 logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r)) 224 err := RemoteErr(fmt.Sprintf("handler panic: %v", r)) 225 handlerErr = &err 226 } 227 if debugPrint { 228 fmt.Println("muxServer: Mux", m.ID, "Returned with", handlerErr) 229 } 230 }() 231 // handlerErr is guarded by 'send' channel. 232 handlerErr = handler.Handle(ctx, msg.Payload, handlerIn, send) 233 return handlerErr 234 } 235 236 // checkRemoteAlive will check if the remote is alive. 237 func (m *muxServer) checkRemoteAlive() { 238 t := time.NewTicker(lastPingThreshold / 4) 239 defer t.Stop() 240 for { 241 select { 242 case <-m.ctx.Done(): 243 return 244 case <-t.C: 245 last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0)) 246 if last > lastPingThreshold { 247 logger.LogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last)) 248 m.close() 249 return 250 } 251 } 252 } 253 } 254 255 // checkSeq will check if sequence number is correct and increment it by 1. 256 func (m *muxServer) checkSeq(seq uint32) (ok bool) { 257 if seq != m.RecvSeq { 258 if debugPrint { 259 fmt.Printf("expected sequence %d, got %d\n", m.RecvSeq, seq) 260 } 261 m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq)) 262 return false 263 } 264 m.RecvSeq++ 265 return true 266 } 267 268 func (m *muxServer) message(msg message) { 269 if debugPrint { 270 fmt.Printf("muxServer: received message %d, length %d\n", msg.Seq, len(msg.Payload)) 271 } 272 m.recvMu.Lock() 273 defer m.recvMu.Unlock() 274 if cap(m.inbound) == 0 { 275 m.disconnect("did not expect inbound message") 276 return 277 } 278 if !m.checkSeq(msg.Seq) { 279 return 280 } 281 // Note, on EOF no value can be sent. 282 if msg.Flags&FlagEOF != 0 { 283 if len(msg.Payload) > 0 { 284 logger.LogIf(m.ctx, fmt.Errorf("muxServer: EOF message with payload")) 285 } 286 if m.inbound != nil { 287 xioutil.SafeClose(m.inbound) 288 m.inbound = nil 289 } 290 return 291 } 292 293 select { 294 case <-m.ctx.Done(): 295 case m.inbound <- msg.Payload: 296 if debugPrint { 297 fmt.Printf("muxServer: Sent seq %d to handler\n", msg.Seq) 298 } 299 default: 300 m.disconnect("handler blocked") 301 } 302 } 303 304 func (m *muxServer) unblockSend(seq uint32) { 305 if !m.checkSeq(seq) { 306 return 307 } 308 m.recvMu.Lock() 309 defer m.recvMu.Unlock() 310 if m.outBlock == nil { 311 // Closed 312 return 313 } 314 select { 315 case m.outBlock <- struct{}{}: 316 default: 317 logger.LogIf(m.ctx, errors.New("output unblocked overflow")) 318 } 319 } 320 321 func (m *muxServer) ping(seq uint32) pongMsg { 322 if !m.checkSeq(seq) { 323 msg := fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq) 324 return pongMsg{Err: &msg} 325 } 326 select { 327 case <-m.ctx.Done(): 328 err := context.Cause(m.ctx).Error() 329 return pongMsg{Err: &err} 330 default: 331 atomic.StoreInt64(&m.LastPing, time.Now().Unix()) 332 return pongMsg{} 333 } 334 } 335 336 func (m *muxServer) disconnect(msg string) { 337 if debugPrint { 338 fmt.Println("Mux", m.ID, "disconnecting. Reason:", msg) 339 } 340 if msg != "" { 341 m.send(message{Op: OpMuxServerMsg, MuxID: m.ID, Flags: FlagPayloadIsErr | FlagEOF, Payload: []byte(msg)}) 342 } else { 343 m.send(message{Op: OpDisconnectClientMux, MuxID: m.ID}) 344 } 345 m.parent.deleteMux(true, m.ID) 346 } 347 348 func (m *muxServer) send(msg message) { 349 m.sendMu.Lock() 350 defer m.sendMu.Unlock() 351 msg.MuxID = m.ID 352 msg.Seq = m.SendSeq 353 m.SendSeq++ 354 if debugPrint { 355 fmt.Printf("Mux %d, Sending %+v\n", m.ID, msg) 356 } 357 logger.LogIf(m.ctx, m.parent.queueMsg(msg, nil)) 358 } 359 360 func (m *muxServer) close() { 361 m.cancel() 362 m.recvMu.Lock() 363 defer m.recvMu.Unlock() 364 365 if m.inbound != nil { 366 xioutil.SafeClose(m.inbound) 367 m.inbound = nil 368 } 369 370 if m.outBlock != nil { 371 xioutil.SafeClose(m.outBlock) 372 m.outBlock = nil 373 374 } 375 }