github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/transport/recv.go (about)

     1  // Package transport provides long-lived http/tcp connections for
     2  // intra-cluster communications (see README for details and usage example).
     3  /*
     4   * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
     5   */
     6  package transport
     7  
     8  import (
     9  	"fmt"
    10  	"io"
    11  	"math"
    12  	"net/http"
    13  	"path"
    14  	"runtime"
    15  	"strconv"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/NVIDIA/aistore/api/apc"
    20  	"github.com/NVIDIA/aistore/cmn"
    21  	"github.com/NVIDIA/aistore/cmn/atomic"
    22  	"github.com/NVIDIA/aistore/cmn/cos"
    23  	"github.com/NVIDIA/aistore/cmn/debug"
    24  	"github.com/NVIDIA/aistore/cmn/mono"
    25  	"github.com/NVIDIA/aistore/cmn/nlog"
    26  	"github.com/NVIDIA/aistore/hk"
    27  	"github.com/NVIDIA/aistore/memsys"
    28  	"github.com/OneOfOne/xxhash"
    29  	"github.com/pierrec/lz4/v3"
    30  )
    31  
    32  const sessionIsOld = time.Hour
    33  
    34  // private types
    35  type (
    36  	rxStats interface {
    37  		addOff(int64)
    38  		incNum()
    39  	}
    40  	iterator struct {
    41  		body    io.Reader
    42  		handler handler
    43  		pdu     *rpdu
    44  		stats   rxStats
    45  		hbuf    []byte
    46  	}
    47  	objReader struct {
    48  		body   io.Reader
    49  		pdu    *rpdu
    50  		loghdr string
    51  		hdr    ObjHdr
    52  		off    int64
    53  	}
    54  
    55  	handler interface {
    56  		recv(hdr *ObjHdr, objReader io.Reader, err error) error // RecvObj
    57  		stats(*http.Request, string) (rxStats, uint64, string)
    58  		unreg()
    59  		addOld(uint64)
    60  		getStats() RxStats
    61  	}
    62  	hdl struct {
    63  		rxObj  RecvObj
    64  		trname string
    65  		now    int64
    66  	}
    67  	hdlExtra struct {
    68  		hdl
    69  		hkName      string
    70  		sessions    sync.Map
    71  		oldSessions sync.Map
    72  	}
    73  )
    74  
    75  // interface guard
    76  var (
    77  	_ handler = (*hdl)(nil)
    78  	_ handler = (*hdlExtra)(nil)
    79  )
    80  
    81  // global
    82  var (
    83  	nextSessionID atomic.Int64 // next unique session ID
    84  )
    85  
    86  // main Rx objects
    87  func RxAnyStream(w http.ResponseWriter, r *http.Request) {
    88  	var (
    89  		reader    io.Reader = r.Body
    90  		lz4Reader *lz4.Reader
    91  		trname    = path.Base(r.URL.Path)
    92  		mm        = memsys.PageMM()
    93  	)
    94  	// Rx handler
    95  	h, err := oget(trname)
    96  	if err != nil {
    97  		//
    98  		//  Try reading `nextProtoHdr` containing transport.ObjHdr -
    99  		//  that's because low-level `Stream.Fin` (sending graceful `opcFin`)
   100  		//  could be the cause for `errUnknownTrname` and `errAlreadyClosedTrname`.
   101  		//  Secondly, `errAlreadyClosedTrname` is considered benign, attributed
   102  		//  to xaction abort and such - the fact that'd be difficult to confirm
   103  		//  at the lowest level (and with no handler and its rxObj cb).
   104  		//
   105  		if _, ok := err.(*errAlreadyClosedTrname); ok {
   106  			if verbose {
   107  				nlog.Errorln(err)
   108  			}
   109  		} else {
   110  			cmn.WriteErr(w, r, err, 0)
   111  		}
   112  		return
   113  	}
   114  	// compression
   115  	if compressionType := r.Header.Get(apc.HdrCompress); compressionType != "" {
   116  		debug.Assert(compressionType == apc.LZ4Compression)
   117  		lz4Reader = lz4.NewReader(r.Body)
   118  		reader = lz4Reader
   119  	}
   120  
   121  	stats, uid, loghdr := h.stats(r, trname)
   122  	it := &iterator{handler: h, body: reader, stats: stats}
   123  	it.hbuf, _ = mm.AllocSize(dfltMaxHdr)
   124  
   125  	// receive loop
   126  	err = it.rxloop(uid, loghdr, mm)
   127  
   128  	// cleanup
   129  	if lz4Reader != nil {
   130  		lz4Reader.Reset(nil)
   131  	}
   132  	if it.pdu != nil {
   133  		it.pdu.free(mm)
   134  	}
   135  	mm.Free(it.hbuf)
   136  
   137  	// if err != io.EOF {
   138  	if !cos.IsEOF(err) {
   139  		cmn.WriteErr(w, r, err)
   140  	}
   141  }
   142  
   143  ////////////////
   144  // Rx handler //
   145  ////////////////
   146  
   147  // begin t2t session
   148  func (h *hdl) stats(r *http.Request, trname string) (rxStats, uint64, string) {
   149  	debug.Assertf(h.trname == trname, "%q vs %q", h.trname, trname)
   150  	sid := r.Header.Get(apc.HdrSessID)
   151  	loghdr := h.trname + "[" + r.RemoteAddr + ":" + sid + "]"
   152  	return nopRxStats{}, 0, loghdr
   153  }
   154  
   155  // ditto, with Rx stats
   156  func (h *hdlExtra) stats(r *http.Request, trname string) (rxStats, uint64, string) {
   157  	debug.Assertf(h.trname == trname, "%q vs %q", h.trname, trname)
   158  	sid := r.Header.Get(apc.HdrSessID)
   159  
   160  	sessID, err := strconv.ParseInt(sid, 10, 64)
   161  	if err != nil || sessID == 0 {
   162  		err = fmt.Errorf("%s[:%q]: invalid session ID, err %v", h.trname, sid, err)
   163  		cos.AssertNoErr(err)
   164  	}
   165  
   166  	// yet another id to index optional h.sessions & h.oldSessions sync.Maps
   167  	uid := uniqueID(r, sessID)
   168  	statsif, _ := h.sessions.LoadOrStore(uid, &Stats{})
   169  
   170  	xxh, _ := UID2SessID(uid)
   171  	loghdr := fmt.Sprintf("%s[%d:%d]", h.trname, xxh, sessID)
   172  	if verbose {
   173  		nlog.Infof("%s: start-of-stream from %s", loghdr, r.RemoteAddr)
   174  	}
   175  	return statsif.(rxStats), uid, loghdr
   176  }
   177  
   178  func (*hdl) unreg()        {}
   179  func (h *hdlExtra) unreg() { hk.Unreg(h.hkName + hk.NameSuffix) }
   180  
   181  func (*hdl) addOld(uint64)            {}
   182  func (h *hdlExtra) addOld(uid uint64) { h.oldSessions.Store(uid, mono.NanoTime()) }
   183  
   184  func (h *hdlExtra) cleanup() time.Duration {
   185  	h.now = mono.NanoTime()
   186  	h.oldSessions.Range(h.cl)
   187  	return sessionIsOld
   188  }
   189  
   190  func (h *hdlExtra) cl(key, value any) bool {
   191  	timeClosed := value.(int64)
   192  	if time.Duration(h.now-timeClosed) > sessionIsOld {
   193  		uid := key.(uint64)
   194  		h.oldSessions.Delete(uid)
   195  		h.sessions.Delete(uid)
   196  	}
   197  	return true
   198  }
   199  
   200  func (h *hdl) recv(hdr *ObjHdr, objReader io.Reader, err error) error {
   201  	return h.rxObj(hdr, objReader, err)
   202  }
   203  
   204  func (*hdl) getStats() RxStats { return nil }
   205  
   206  func (h *hdlExtra) getStats() (s RxStats) {
   207  	s = make(RxStats, 4)
   208  	h.sessions.Range(s.f)
   209  	return
   210  }
   211  
   212  func (s RxStats) f(key, value any) bool {
   213  	out := &Stats{}
   214  	uid := key.(uint64)
   215  	in := value.(*Stats)
   216  	out.Num.Store(in.Num.Load())       // via rxStats.incNum
   217  	out.Offset.Store(in.Offset.Load()) // via rxStats.addOff
   218  	s[uid] = out
   219  	return true
   220  }
   221  
   222  //////////////////////////////////
   223  // next(obj, msg, pdu) iterator //
   224  //////////////////////////////////
   225  
   226  func (it *iterator) Read(p []byte) (n int, err error) { return it.body.Read(p) }
   227  
   228  func (it *iterator) rxloop(uid uint64, loghdr string, mm *memsys.MMSA) (err error) {
   229  	for err == nil {
   230  		var (
   231  			flags uint64
   232  			hlen  int
   233  		)
   234  		hlen, flags, err = it.nextProtoHdr(loghdr)
   235  		if err != nil {
   236  			break
   237  		}
   238  		if hlen > cap(it.hbuf) {
   239  			if hlen > maxSizeHeader {
   240  				err = fmt.Errorf("sbr1 %s: hlen %d exceeds maximum %d", loghdr, hlen, maxSizeHeader)
   241  				break
   242  			}
   243  			// grow
   244  			nlog.Warningf("%s: header length %d exceeds the current buffer %d", loghdr, hlen, cap(it.hbuf))
   245  			mm.Free(it.hbuf)
   246  			it.hbuf, _ = mm.AllocSize(min(int64(hlen)<<1, maxSizeHeader))
   247  		}
   248  
   249  		it.stats.addOff(int64(hlen + sizeProtoHdr))
   250  		debug.Assert(flags&msgFl == 0) //  messaging: not used, removed
   251  		if flags&pduStreamFl != 0 {
   252  			if it.pdu == nil {
   253  				pbuf, _ := mm.AllocSize(maxSizePDU)
   254  				it.pdu = newRecvPDU(it.body, pbuf)
   255  			} else {
   256  				it.pdu.reset()
   257  			}
   258  		}
   259  		err = it.rxObj(loghdr, hlen)
   260  	}
   261  
   262  	it.handler.addOld(uid)
   263  	return
   264  }
   265  
   266  func (it *iterator) rxObj(loghdr string, hlen int) (err error) {
   267  	var (
   268  		obj *objReader
   269  		h   = it.handler
   270  	)
   271  	obj, err = it.nextObj(loghdr, hlen)
   272  	if obj != nil {
   273  		if !obj.hdr.IsHeaderOnly() {
   274  			obj.pdu = it.pdu
   275  		}
   276  		err = eofOK(err)
   277  		size, off := obj.hdr.ObjAttrs.Size, obj.off
   278  		if errCb := h.recv(&obj.hdr, obj, err); errCb != nil {
   279  			err = errCb
   280  		}
   281  		// stats
   282  		if err == nil {
   283  			it.stats.incNum()        // 1. this stream stats
   284  			g.tstats.Inc(InObjCount) // 2. stats/target_stats.go
   285  
   286  			if size >= 0 {
   287  				g.tstats.Add(InObjSize, size)
   288  			} else {
   289  				debug.Assert(size == SizeUnknown)
   290  				g.tstats.Add(InObjSize, obj.off-off)
   291  			}
   292  		}
   293  	} else if err != nil && err != io.EOF {
   294  		if errCb := h.recv(&ObjHdr{}, nil, err); errCb != nil {
   295  			err = errCb
   296  		}
   297  	}
   298  	return
   299  }
   300  
   301  func eofOK(err error) error {
   302  	if err == io.EOF {
   303  		err = nil
   304  	}
   305  	return err
   306  }
   307  
   308  // nextProtoHdr receives and handles 16 bytes of the protocol header (not to confuse with transport.Obj.Hdr)
   309  // returns hlen, which is header length - for transport.Obj (and formerly, message length for transport.Msg)
   310  func (it *iterator) nextProtoHdr(loghdr string) (hlen int, flags uint64, err error) {
   311  	var n int
   312  	n, err = it.Read(it.hbuf[:sizeProtoHdr])
   313  	if n < sizeProtoHdr {
   314  		if err == nil {
   315  			err = fmt.Errorf("sbr3 %s: failed to receive proto hdr (n=%d)", loghdr, n)
   316  		}
   317  		return
   318  	}
   319  	// extract and validate hlen
   320  	hlen, flags, err = extProtoHdr(it.hbuf, loghdr)
   321  	return
   322  }
   323  
   324  func (it *iterator) nextObj(loghdr string, hlen int) (obj *objReader, err error) {
   325  	var n int
   326  	n, err = it.Read(it.hbuf[:hlen])
   327  	if n < hlen {
   328  		if err == nil {
   329  			// [retry] insist on receiving the full length
   330  			var m int
   331  			for range maxInReadRetries {
   332  				runtime.Gosched()
   333  				m, err = it.Read(it.hbuf[n:hlen])
   334  				if err != nil {
   335  					break
   336  				}
   337  				n += m
   338  				if n == hlen {
   339  					break
   340  				}
   341  			}
   342  		}
   343  		if n < hlen {
   344  			err = fmt.Errorf("sbr4 %s: failed to receive obj hdr (%d < %d)", loghdr, n, hlen)
   345  			return
   346  		}
   347  	}
   348  	hdr := ExtObjHeader(it.hbuf, hlen)
   349  	if hdr.isFin() {
   350  		err = io.EOF
   351  		return
   352  	}
   353  	obj = allocRecv()
   354  	obj.body, obj.hdr, obj.loghdr = it.body, hdr, loghdr
   355  	return
   356  }
   357  
   358  ///////////////
   359  // objReader //
   360  ///////////////
   361  
   362  func (obj *objReader) Read(b []byte) (n int, err error) {
   363  	if obj.pdu != nil {
   364  		return obj.readPDU(b)
   365  	}
   366  	debug.Assert(obj.Size() >= 0)
   367  	rem := obj.Size() - obj.off
   368  	if rem < int64(len(b)) {
   369  		b = b[:int(rem)]
   370  	}
   371  	n, err = obj.body.Read(b)
   372  	obj.off += int64(n) // NOTE: `GORACE` complaining here can be safely ignored
   373  	switch err {
   374  	case nil:
   375  		if obj.off >= obj.Size() {
   376  			err = io.EOF
   377  		}
   378  	case io.EOF:
   379  		if obj.off != obj.Size() {
   380  			err = fmt.Errorf("sbr6 %s: premature eof %d != %s, err %w", obj.loghdr, obj.off, obj, err)
   381  		}
   382  	default:
   383  		err = fmt.Errorf("sbr7 %s: off %d, obj %s, err %w", obj.loghdr, obj.off, obj, err)
   384  	}
   385  	return
   386  }
   387  
   388  func (obj *objReader) String() string {
   389  	return fmt.Sprintf("%s(size=%d)", obj.hdr.Cname(), obj.Size())
   390  }
   391  
   392  func (obj *objReader) Size() int64     { return obj.hdr.ObjSize() }
   393  func (obj *objReader) IsUnsized() bool { return obj.hdr.IsUnsized() }
   394  
   395  //
   396  // pduReader
   397  //
   398  
   399  func (obj *objReader) readPDU(b []byte) (n int, err error) {
   400  	pdu := obj.pdu
   401  	if pdu.woff == 0 {
   402  		err = pdu.readHdr(obj.loghdr)
   403  		if err != nil {
   404  			return
   405  		}
   406  	}
   407  	for !pdu.done {
   408  		if _, err = pdu.readFrom(); err != nil && err != io.EOF {
   409  			err = fmt.Errorf("sbr8 %s: failed to receive PDU, err %w, obj %s", obj.loghdr, err, obj)
   410  			break
   411  		}
   412  		debug.Assert(err == nil || (err == io.EOF && pdu.done))
   413  		if !pdu.done {
   414  			runtime.Gosched()
   415  		}
   416  	}
   417  	n = pdu.read(b)
   418  	obj.off += int64(n)
   419  
   420  	if err != nil {
   421  		return
   422  	}
   423  	if pdu.rlength() == 0 {
   424  		if pdu.last {
   425  			err = io.EOF
   426  			if obj.IsUnsized() {
   427  				obj.hdr.ObjAttrs.Size = obj.off
   428  			} else if obj.Size() != obj.off {
   429  				nlog.Errorf("sbr9 %s: off %d != %s", obj.loghdr, obj.off, obj)
   430  			}
   431  		} else {
   432  			pdu.reset()
   433  		}
   434  	}
   435  	return
   436  }
   437  
   438  //
   439  // session ID <=> unique ID
   440  //
   441  
   442  func uniqueID(r *http.Request, sessID int64) uint64 {
   443  	hash := xxhash.Checksum64S(cos.UnsafeB(r.RemoteAddr), cos.MLCG32)
   444  	return (hash&math.MaxUint32)<<32 | uint64(sessID)
   445  }
   446  
   447  func UID2SessID(uid uint64) (xxh, sessID uint64) {
   448  	xxh, sessID = uid>>32, uid&math.MaxUint32
   449  	return
   450  }
   451  
   452  // DrainAndFreeReader:
   453  // 1) reads and discards all the data from `r` - the `objReader`;
   454  // 2) frees this objReader back to the `recvPool`.
   455  // As such, this function is intended for usage only and exclusively by
   456  // `transport.RecvObj` implementations.
   457  func DrainAndFreeReader(r io.Reader) {
   458  	if r == nil {
   459  		return
   460  	}
   461  	obj, ok := r.(*objReader)
   462  	debug.Assert(ok)
   463  	if obj.body != nil && !obj.hdr.IsHeaderOnly() {
   464  		cos.DrainReader(obj)
   465  	}
   466  	FreeRecv(obj)
   467  }