github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/transport/recv.go (about) 1 // Package transport provides long-lived http/tcp connections for 2 // intra-cluster communications (see README for details and usage example). 3 /* 4 * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved. 5 */ 6 package transport 7 8 import ( 9 "fmt" 10 "io" 11 "math" 12 "net/http" 13 "path" 14 "runtime" 15 "strconv" 16 "sync" 17 "time" 18 19 "github.com/NVIDIA/aistore/api/apc" 20 "github.com/NVIDIA/aistore/cmn" 21 "github.com/NVIDIA/aistore/cmn/atomic" 22 "github.com/NVIDIA/aistore/cmn/cos" 23 "github.com/NVIDIA/aistore/cmn/debug" 24 "github.com/NVIDIA/aistore/cmn/mono" 25 "github.com/NVIDIA/aistore/cmn/nlog" 26 "github.com/NVIDIA/aistore/hk" 27 "github.com/NVIDIA/aistore/memsys" 28 "github.com/OneOfOne/xxhash" 29 "github.com/pierrec/lz4/v3" 30 ) 31 32 const sessionIsOld = time.Hour 33 34 // private types 35 type ( 36 rxStats interface { 37 addOff(int64) 38 incNum() 39 } 40 iterator struct { 41 body io.Reader 42 handler handler 43 pdu *rpdu 44 stats rxStats 45 hbuf []byte 46 } 47 objReader struct { 48 body io.Reader 49 pdu *rpdu 50 loghdr string 51 hdr ObjHdr 52 off int64 53 } 54 55 handler interface { 56 recv(hdr *ObjHdr, objReader io.Reader, err error) error // RecvObj 57 stats(*http.Request, string) (rxStats, uint64, string) 58 unreg() 59 addOld(uint64) 60 getStats() RxStats 61 } 62 hdl struct { 63 rxObj RecvObj 64 trname string 65 now int64 66 } 67 hdlExtra struct { 68 hdl 69 hkName string 70 sessions sync.Map 71 oldSessions sync.Map 72 } 73 ) 74 75 // interface guard 76 var ( 77 _ handler = (*hdl)(nil) 78 _ handler = (*hdlExtra)(nil) 79 ) 80 81 // global 82 var ( 83 nextSessionID atomic.Int64 // next unique session ID 84 ) 85 86 // main Rx objects 87 func RxAnyStream(w http.ResponseWriter, r *http.Request) { 88 var ( 89 reader io.Reader = r.Body 90 lz4Reader *lz4.Reader 91 trname = path.Base(r.URL.Path) 92 mm = memsys.PageMM() 93 ) 94 // Rx handler 95 h, err := oget(trname) 96 if err != nil { 97 // 98 // Try reading `nextProtoHdr` containing transport.ObjHdr - 99 // that's because low-level `Stream.Fin` (sending graceful `opcFin`) 100 // could be the cause for `errUnknownTrname` and `errAlreadyClosedTrname`. 101 // Secondly, `errAlreadyClosedTrname` is considered benign, attributed 102 // to xaction abort and such - the fact that'd be difficult to confirm 103 // at the lowest level (and with no handler and its rxObj cb). 104 // 105 if _, ok := err.(*errAlreadyClosedTrname); ok { 106 if verbose { 107 nlog.Errorln(err) 108 } 109 } else { 110 cmn.WriteErr(w, r, err, 0) 111 } 112 return 113 } 114 // compression 115 if compressionType := r.Header.Get(apc.HdrCompress); compressionType != "" { 116 debug.Assert(compressionType == apc.LZ4Compression) 117 lz4Reader = lz4.NewReader(r.Body) 118 reader = lz4Reader 119 } 120 121 stats, uid, loghdr := h.stats(r, trname) 122 it := &iterator{handler: h, body: reader, stats: stats} 123 it.hbuf, _ = mm.AllocSize(dfltMaxHdr) 124 125 // receive loop 126 err = it.rxloop(uid, loghdr, mm) 127 128 // cleanup 129 if lz4Reader != nil { 130 lz4Reader.Reset(nil) 131 } 132 if it.pdu != nil { 133 it.pdu.free(mm) 134 } 135 mm.Free(it.hbuf) 136 137 // if err != io.EOF { 138 if !cos.IsEOF(err) { 139 cmn.WriteErr(w, r, err) 140 } 141 } 142 143 //////////////// 144 // Rx handler // 145 //////////////// 146 147 // begin t2t session 148 func (h *hdl) stats(r *http.Request, trname string) (rxStats, uint64, string) { 149 debug.Assertf(h.trname == trname, "%q vs %q", h.trname, trname) 150 sid := r.Header.Get(apc.HdrSessID) 151 loghdr := h.trname + "[" + r.RemoteAddr + ":" + sid + "]" 152 return nopRxStats{}, 0, loghdr 153 } 154 155 // ditto, with Rx stats 156 func (h *hdlExtra) stats(r *http.Request, trname string) (rxStats, uint64, string) { 157 debug.Assertf(h.trname == trname, "%q vs %q", h.trname, trname) 158 sid := r.Header.Get(apc.HdrSessID) 159 160 sessID, err := strconv.ParseInt(sid, 10, 64) 161 if err != nil || sessID == 0 { 162 err = fmt.Errorf("%s[:%q]: invalid session ID, err %v", h.trname, sid, err) 163 cos.AssertNoErr(err) 164 } 165 166 // yet another id to index optional h.sessions & h.oldSessions sync.Maps 167 uid := uniqueID(r, sessID) 168 statsif, _ := h.sessions.LoadOrStore(uid, &Stats{}) 169 170 xxh, _ := UID2SessID(uid) 171 loghdr := fmt.Sprintf("%s[%d:%d]", h.trname, xxh, sessID) 172 if verbose { 173 nlog.Infof("%s: start-of-stream from %s", loghdr, r.RemoteAddr) 174 } 175 return statsif.(rxStats), uid, loghdr 176 } 177 178 func (*hdl) unreg() {} 179 func (h *hdlExtra) unreg() { hk.Unreg(h.hkName + hk.NameSuffix) } 180 181 func (*hdl) addOld(uint64) {} 182 func (h *hdlExtra) addOld(uid uint64) { h.oldSessions.Store(uid, mono.NanoTime()) } 183 184 func (h *hdlExtra) cleanup() time.Duration { 185 h.now = mono.NanoTime() 186 h.oldSessions.Range(h.cl) 187 return sessionIsOld 188 } 189 190 func (h *hdlExtra) cl(key, value any) bool { 191 timeClosed := value.(int64) 192 if time.Duration(h.now-timeClosed) > sessionIsOld { 193 uid := key.(uint64) 194 h.oldSessions.Delete(uid) 195 h.sessions.Delete(uid) 196 } 197 return true 198 } 199 200 func (h *hdl) recv(hdr *ObjHdr, objReader io.Reader, err error) error { 201 return h.rxObj(hdr, objReader, err) 202 } 203 204 func (*hdl) getStats() RxStats { return nil } 205 206 func (h *hdlExtra) getStats() (s RxStats) { 207 s = make(RxStats, 4) 208 h.sessions.Range(s.f) 209 return 210 } 211 212 func (s RxStats) f(key, value any) bool { 213 out := &Stats{} 214 uid := key.(uint64) 215 in := value.(*Stats) 216 out.Num.Store(in.Num.Load()) // via rxStats.incNum 217 out.Offset.Store(in.Offset.Load()) // via rxStats.addOff 218 s[uid] = out 219 return true 220 } 221 222 ////////////////////////////////// 223 // next(obj, msg, pdu) iterator // 224 ////////////////////////////////// 225 226 func (it *iterator) Read(p []byte) (n int, err error) { return it.body.Read(p) } 227 228 func (it *iterator) rxloop(uid uint64, loghdr string, mm *memsys.MMSA) (err error) { 229 for err == nil { 230 var ( 231 flags uint64 232 hlen int 233 ) 234 hlen, flags, err = it.nextProtoHdr(loghdr) 235 if err != nil { 236 break 237 } 238 if hlen > cap(it.hbuf) { 239 if hlen > maxSizeHeader { 240 err = fmt.Errorf("sbr1 %s: hlen %d exceeds maximum %d", loghdr, hlen, maxSizeHeader) 241 break 242 } 243 // grow 244 nlog.Warningf("%s: header length %d exceeds the current buffer %d", loghdr, hlen, cap(it.hbuf)) 245 mm.Free(it.hbuf) 246 it.hbuf, _ = mm.AllocSize(min(int64(hlen)<<1, maxSizeHeader)) 247 } 248 249 it.stats.addOff(int64(hlen + sizeProtoHdr)) 250 debug.Assert(flags&msgFl == 0) // messaging: not used, removed 251 if flags&pduStreamFl != 0 { 252 if it.pdu == nil { 253 pbuf, _ := mm.AllocSize(maxSizePDU) 254 it.pdu = newRecvPDU(it.body, pbuf) 255 } else { 256 it.pdu.reset() 257 } 258 } 259 err = it.rxObj(loghdr, hlen) 260 } 261 262 it.handler.addOld(uid) 263 return 264 } 265 266 func (it *iterator) rxObj(loghdr string, hlen int) (err error) { 267 var ( 268 obj *objReader 269 h = it.handler 270 ) 271 obj, err = it.nextObj(loghdr, hlen) 272 if obj != nil { 273 if !obj.hdr.IsHeaderOnly() { 274 obj.pdu = it.pdu 275 } 276 err = eofOK(err) 277 size, off := obj.hdr.ObjAttrs.Size, obj.off 278 if errCb := h.recv(&obj.hdr, obj, err); errCb != nil { 279 err = errCb 280 } 281 // stats 282 if err == nil { 283 it.stats.incNum() // 1. this stream stats 284 g.tstats.Inc(InObjCount) // 2. stats/target_stats.go 285 286 if size >= 0 { 287 g.tstats.Add(InObjSize, size) 288 } else { 289 debug.Assert(size == SizeUnknown) 290 g.tstats.Add(InObjSize, obj.off-off) 291 } 292 } 293 } else if err != nil && err != io.EOF { 294 if errCb := h.recv(&ObjHdr{}, nil, err); errCb != nil { 295 err = errCb 296 } 297 } 298 return 299 } 300 301 func eofOK(err error) error { 302 if err == io.EOF { 303 err = nil 304 } 305 return err 306 } 307 308 // nextProtoHdr receives and handles 16 bytes of the protocol header (not to confuse with transport.Obj.Hdr) 309 // returns hlen, which is header length - for transport.Obj (and formerly, message length for transport.Msg) 310 func (it *iterator) nextProtoHdr(loghdr string) (hlen int, flags uint64, err error) { 311 var n int 312 n, err = it.Read(it.hbuf[:sizeProtoHdr]) 313 if n < sizeProtoHdr { 314 if err == nil { 315 err = fmt.Errorf("sbr3 %s: failed to receive proto hdr (n=%d)", loghdr, n) 316 } 317 return 318 } 319 // extract and validate hlen 320 hlen, flags, err = extProtoHdr(it.hbuf, loghdr) 321 return 322 } 323 324 func (it *iterator) nextObj(loghdr string, hlen int) (obj *objReader, err error) { 325 var n int 326 n, err = it.Read(it.hbuf[:hlen]) 327 if n < hlen { 328 if err == nil { 329 // [retry] insist on receiving the full length 330 var m int 331 for range maxInReadRetries { 332 runtime.Gosched() 333 m, err = it.Read(it.hbuf[n:hlen]) 334 if err != nil { 335 break 336 } 337 n += m 338 if n == hlen { 339 break 340 } 341 } 342 } 343 if n < hlen { 344 err = fmt.Errorf("sbr4 %s: failed to receive obj hdr (%d < %d)", loghdr, n, hlen) 345 return 346 } 347 } 348 hdr := ExtObjHeader(it.hbuf, hlen) 349 if hdr.isFin() { 350 err = io.EOF 351 return 352 } 353 obj = allocRecv() 354 obj.body, obj.hdr, obj.loghdr = it.body, hdr, loghdr 355 return 356 } 357 358 /////////////// 359 // objReader // 360 /////////////// 361 362 func (obj *objReader) Read(b []byte) (n int, err error) { 363 if obj.pdu != nil { 364 return obj.readPDU(b) 365 } 366 debug.Assert(obj.Size() >= 0) 367 rem := obj.Size() - obj.off 368 if rem < int64(len(b)) { 369 b = b[:int(rem)] 370 } 371 n, err = obj.body.Read(b) 372 obj.off += int64(n) // NOTE: `GORACE` complaining here can be safely ignored 373 switch err { 374 case nil: 375 if obj.off >= obj.Size() { 376 err = io.EOF 377 } 378 case io.EOF: 379 if obj.off != obj.Size() { 380 err = fmt.Errorf("sbr6 %s: premature eof %d != %s, err %w", obj.loghdr, obj.off, obj, err) 381 } 382 default: 383 err = fmt.Errorf("sbr7 %s: off %d, obj %s, err %w", obj.loghdr, obj.off, obj, err) 384 } 385 return 386 } 387 388 func (obj *objReader) String() string { 389 return fmt.Sprintf("%s(size=%d)", obj.hdr.Cname(), obj.Size()) 390 } 391 392 func (obj *objReader) Size() int64 { return obj.hdr.ObjSize() } 393 func (obj *objReader) IsUnsized() bool { return obj.hdr.IsUnsized() } 394 395 // 396 // pduReader 397 // 398 399 func (obj *objReader) readPDU(b []byte) (n int, err error) { 400 pdu := obj.pdu 401 if pdu.woff == 0 { 402 err = pdu.readHdr(obj.loghdr) 403 if err != nil { 404 return 405 } 406 } 407 for !pdu.done { 408 if _, err = pdu.readFrom(); err != nil && err != io.EOF { 409 err = fmt.Errorf("sbr8 %s: failed to receive PDU, err %w, obj %s", obj.loghdr, err, obj) 410 break 411 } 412 debug.Assert(err == nil || (err == io.EOF && pdu.done)) 413 if !pdu.done { 414 runtime.Gosched() 415 } 416 } 417 n = pdu.read(b) 418 obj.off += int64(n) 419 420 if err != nil { 421 return 422 } 423 if pdu.rlength() == 0 { 424 if pdu.last { 425 err = io.EOF 426 if obj.IsUnsized() { 427 obj.hdr.ObjAttrs.Size = obj.off 428 } else if obj.Size() != obj.off { 429 nlog.Errorf("sbr9 %s: off %d != %s", obj.loghdr, obj.off, obj) 430 } 431 } else { 432 pdu.reset() 433 } 434 } 435 return 436 } 437 438 // 439 // session ID <=> unique ID 440 // 441 442 func uniqueID(r *http.Request, sessID int64) uint64 { 443 hash := xxhash.Checksum64S(cos.UnsafeB(r.RemoteAddr), cos.MLCG32) 444 return (hash&math.MaxUint32)<<32 | uint64(sessID) 445 } 446 447 func UID2SessID(uid uint64) (xxh, sessID uint64) { 448 xxh, sessID = uid>>32, uid&math.MaxUint32 449 return 450 } 451 452 // DrainAndFreeReader: 453 // 1) reads and discards all the data from `r` - the `objReader`; 454 // 2) frees this objReader back to the `recvPool`. 455 // As such, this function is intended for usage only and exclusively by 456 // `transport.RecvObj` implementations. 457 func DrainAndFreeReader(r io.Reader) { 458 if r == nil { 459 return 460 } 461 obj, ok := r.(*objReader) 462 debug.Assert(ok) 463 if obj.body != nil && !obj.hdr.IsHeaderOnly() { 464 cos.DrainReader(obj) 465 } 466 FreeRecv(obj) 467 }