github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/grid/muxclient.go (about)

     1  // Copyright (c) 2015-2023 MinIO, Inc.
     2  //
     3  // This file is part of MinIO Object Storage stack
     4  //
     5  // This program is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Affero General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // This program is distributed in the hope that it will be useful
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU Affero General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Affero General Public License
    16  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17  
    18  package grid
    19  
    20  import (
    21  	"context"
    22  	"encoding/binary"
    23  	"errors"
    24  	"fmt"
    25  	"sync"
    26  	"sync/atomic"
    27  	"time"
    28  
    29  	xioutil "github.com/minio/minio/internal/ioutil"
    30  	"github.com/minio/minio/internal/logger"
    31  	"github.com/zeebo/xxh3"
    32  )
    33  
    34  // muxClient is a stateful connection to a remote.
    35  type muxClient struct {
    36  	MuxID            uint64
    37  	SendSeq, RecvSeq uint32
    38  	LastPong         int64
    39  	BaseFlags        Flags
    40  	ctx              context.Context
    41  	cancelFn         context.CancelCauseFunc
    42  	parent           *Connection
    43  	respWait         chan<- Response
    44  	respMu           sync.Mutex
    45  	singleResp       bool
    46  	closed           bool
    47  	stateless        bool
    48  	acked            bool
    49  	init             bool
    50  	deadline         time.Duration
    51  	outBlock         chan struct{}
    52  	subroute         *subHandlerID
    53  	respErr          atomic.Pointer[error]
    54  }
    55  
    56  // Response is a response from the server.
    57  type Response struct {
    58  	Msg []byte
    59  	Err error
    60  }
    61  
    62  func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient {
    63  	ctx, cancelFn := context.WithCancelCause(ctx)
    64  	return &muxClient{
    65  		MuxID:     muxID,
    66  		ctx:       ctx,
    67  		cancelFn:  cancelFn,
    68  		parent:    parent,
    69  		LastPong:  time.Now().Unix(),
    70  		BaseFlags: parent.baseFlags,
    71  	}
    72  }
    73  
    74  // roundtrip performs a roundtrip, returning the first response.
    75  // This cannot be used concurrently.
    76  func (m *muxClient) roundtrip(h HandlerID, req []byte) ([]byte, error) {
    77  	if m.init {
    78  		return nil, errors.New("mux client already used")
    79  	}
    80  	m.init = true
    81  	m.singleResp = true
    82  	msg := message{
    83  		Op:         OpRequest,
    84  		MuxID:      m.MuxID,
    85  		Handler:    h,
    86  		Flags:      m.BaseFlags | FlagEOF,
    87  		Payload:    req,
    88  		DeadlineMS: uint32(m.deadline.Milliseconds()),
    89  	}
    90  	if m.subroute != nil {
    91  		msg.Flags |= FlagSubroute
    92  	}
    93  	ch := make(chan Response, 1)
    94  	m.respMu.Lock()
    95  	if m.closed {
    96  		m.respMu.Unlock()
    97  		return nil, ErrDisconnected
    98  	}
    99  	m.respWait = ch
   100  	m.respMu.Unlock()
   101  	ctx := m.ctx
   102  
   103  	// Add deadline if none.
   104  	if msg.DeadlineMS == 0 {
   105  		msg.DeadlineMS = uint32(defaultSingleRequestTimeout / time.Millisecond)
   106  		var cancel context.CancelFunc
   107  		ctx, cancel = context.WithTimeout(ctx, defaultSingleRequestTimeout)
   108  		defer cancel()
   109  	}
   110  	// Send request
   111  	if err := m.send(msg); err != nil {
   112  		return nil, err
   113  	}
   114  	if debugReqs {
   115  		fmt.Println(m.MuxID, m.parent.String(), "SEND")
   116  	}
   117  	// Wait for response or context.
   118  	select {
   119  	case v, ok := <-ch:
   120  		if !ok {
   121  			return nil, ErrDisconnected
   122  		}
   123  		if debugReqs && v.Err != nil {
   124  			v.Err = fmt.Errorf("%d %s RESP ERR: %w", m.MuxID, m.parent.String(), v.Err)
   125  		}
   126  		return v.Msg, v.Err
   127  	case <-ctx.Done():
   128  		if debugReqs {
   129  			return nil, fmt.Errorf("%d %s ERR: %w", m.MuxID, m.parent.String(), context.Cause(ctx))
   130  		}
   131  		return nil, context.Cause(ctx)
   132  	}
   133  }
   134  
   135  // send the message. msg.Seq and msg.MuxID will be set
   136  func (m *muxClient) send(msg message) error {
   137  	m.respMu.Lock()
   138  	defer m.respMu.Unlock()
   139  	if m.closed {
   140  		return errors.New("mux client closed")
   141  	}
   142  	return m.sendLocked(msg)
   143  }
   144  
   145  // sendLocked the message. msg.Seq and msg.MuxID will be set.
   146  // m.respMu must be held.
   147  func (m *muxClient) sendLocked(msg message) error {
   148  	dst := GetByteBuffer()[:0]
   149  	msg.Seq = m.SendSeq
   150  	msg.MuxID = m.MuxID
   151  	msg.Flags |= m.BaseFlags
   152  	if debugPrint {
   153  		fmt.Println("Client sending", &msg, "to", m.parent.Remote)
   154  	}
   155  	m.SendSeq++
   156  
   157  	dst, err := msg.MarshalMsg(dst)
   158  	if err != nil {
   159  		return err
   160  	}
   161  	if msg.Flags&FlagSubroute != 0 {
   162  		if m.subroute == nil {
   163  			return fmt.Errorf("internal error: subroute not defined on client")
   164  		}
   165  		hid := m.subroute.withHandler(msg.Handler)
   166  		before := len(dst)
   167  		dst = append(dst, hid[:]...)
   168  		if debugPrint {
   169  			fmt.Println("Added subroute", hid.String(), "to message", msg, "len", len(dst)-before)
   170  		}
   171  	}
   172  	if msg.Flags&FlagCRCxxh3 != 0 {
   173  		h := xxh3.Hash(dst)
   174  		dst = binary.LittleEndian.AppendUint32(dst, uint32(h))
   175  	}
   176  	return m.parent.send(m.ctx, dst)
   177  }
   178  
   179  // RequestStateless will send a single payload request and stream back results.
   180  // req may not be read/written to after calling.
   181  // TODO: Not implemented
   182  func (m *muxClient) RequestStateless(h HandlerID, req []byte, out chan<- Response) {
   183  	if m.init {
   184  		out <- Response{Err: errors.New("mux client already used")}
   185  	}
   186  	m.init = true
   187  
   188  	// Try to grab an initial block.
   189  	m.singleResp = false
   190  	msg := message{
   191  		Op:         OpConnectMux,
   192  		Handler:    h,
   193  		Flags:      FlagEOF,
   194  		Payload:    req,
   195  		DeadlineMS: uint32(m.deadline.Milliseconds()),
   196  	}
   197  	msg.setZeroPayloadFlag()
   198  	if m.subroute != nil {
   199  		msg.Flags |= FlagSubroute
   200  	}
   201  
   202  	// Send...
   203  	err := m.send(msg)
   204  	if err != nil {
   205  		out <- Response{Err: err}
   206  		return
   207  	}
   208  
   209  	// Route directly to output.
   210  	m.respWait = out
   211  }
   212  
   213  // RequestStream will send a single payload request and stream back results.
   214  // 'requests' can be nil, in which case only req is sent as input.
   215  // It will however take less resources.
   216  func (m *muxClient) RequestStream(h HandlerID, payload []byte, requests chan []byte, responses chan Response) (*Stream, error) {
   217  	if m.init {
   218  		return nil, errors.New("mux client already used")
   219  	}
   220  	if responses == nil {
   221  		return nil, errors.New("RequestStream: responses channel is nil")
   222  	}
   223  	m.init = true
   224  	m.respMu.Lock()
   225  	if m.closed {
   226  		m.respMu.Unlock()
   227  		return nil, ErrDisconnected
   228  	}
   229  	m.respWait = responses // Route directly to output.
   230  	m.respMu.Unlock()
   231  
   232  	// Try to grab an initial block.
   233  	m.singleResp = false
   234  	m.RecvSeq = m.SendSeq // Sync
   235  	if cap(requests) > 0 {
   236  		m.outBlock = make(chan struct{}, cap(requests))
   237  	}
   238  	msg := message{
   239  		Op:         OpConnectMux,
   240  		Handler:    h,
   241  		Payload:    payload,
   242  		DeadlineMS: uint32(m.deadline.Milliseconds()),
   243  	}
   244  	msg.setZeroPayloadFlag()
   245  	if requests == nil {
   246  		msg.Flags |= FlagEOF
   247  	}
   248  	if m.subroute != nil {
   249  		msg.Flags |= FlagSubroute
   250  	}
   251  
   252  	// Send...
   253  	err := m.send(msg)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  	if debugPrint {
   258  		fmt.Println("Connecting Mux", m.MuxID, ",to", m.parent.Remote)
   259  	}
   260  
   261  	// Space for one message and an error.
   262  	responseCh := make(chan Response, 1)
   263  
   264  	// Spawn simple disconnect
   265  	if requests == nil {
   266  		go m.handleOneWayStream(responseCh, responses)
   267  		return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil
   268  	}
   269  
   270  	// Deliver responses and send unblocks back to the server.
   271  	go m.handleTwowayResponses(responseCh, responses)
   272  	go m.handleTwowayRequests(responses, requests)
   273  
   274  	return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil
   275  }
   276  
   277  func (m *muxClient) addErrorNonBlockingClose(respHandler chan<- Response, err error) {
   278  	m.respMu.Lock()
   279  	defer m.respMu.Unlock()
   280  	if !m.closed {
   281  		m.respErr.Store(&err)
   282  		// Do not block.
   283  		select {
   284  		case respHandler <- Response{Err: err}:
   285  			xioutil.SafeClose(respHandler)
   286  		default:
   287  			go func() {
   288  				respHandler <- Response{Err: err}
   289  				xioutil.SafeClose(respHandler)
   290  			}()
   291  		}
   292  		logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID}))
   293  		m.closed = true
   294  	}
   295  }
   296  
   297  // respHandler
   298  func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <-chan Response) {
   299  	if debugPrint {
   300  		start := time.Now()
   301  		defer func() {
   302  			fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond))
   303  		}()
   304  	}
   305  	defer func() {
   306  		// addErrorNonBlockingClose will close the response channel
   307  		// - maybe async, so we shouldn't do it here.
   308  		if m.respErr.Load() == nil {
   309  			xioutil.SafeClose(respHandler)
   310  		}
   311  	}()
   312  	var pingTimer <-chan time.Time
   313  	if m.deadline == 0 || m.deadline > clientPingInterval {
   314  		ticker := time.NewTicker(clientPingInterval)
   315  		defer ticker.Stop()
   316  		pingTimer = ticker.C
   317  		atomic.StoreInt64(&m.LastPong, time.Now().Unix())
   318  	}
   319  	defer m.parent.deleteMux(false, m.MuxID)
   320  	for {
   321  		select {
   322  		case <-m.ctx.Done():
   323  			if debugPrint {
   324  				fmt.Println("Client sending disconnect to mux", m.MuxID)
   325  			}
   326  			err := context.Cause(m.ctx)
   327  			if !errors.Is(err, errStreamEOF) {
   328  				m.addErrorNonBlockingClose(respHandler, err)
   329  			}
   330  			return
   331  		case resp, ok := <-respServer:
   332  			if !ok {
   333  				return
   334  			}
   335  			select {
   336  			case respHandler <- resp:
   337  				m.respMu.Lock()
   338  				if !m.closed {
   339  					logger.LogIf(m.ctx, m.sendLocked(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}))
   340  				}
   341  				m.respMu.Unlock()
   342  			case <-m.ctx.Done():
   343  				// Client canceled. Don't block.
   344  				// Next loop will catch it.
   345  			}
   346  		case <-pingTimer:
   347  			if time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)) > clientPingInterval*2 {
   348  				m.addErrorNonBlockingClose(respHandler, ErrDisconnected)
   349  				return
   350  			}
   351  			// Send new ping.
   352  			logger.LogIf(m.ctx, m.send(message{Op: OpPing, MuxID: m.MuxID}))
   353  		}
   354  	}
   355  }
   356  
   357  // responseCh is the channel to that goes to the requester.
   358  // internalResp is the channel that comes from the server.
   359  func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
   360  	defer m.parent.deleteMux(false, m.MuxID)
   361  	defer xioutil.SafeClose(responseCh)
   362  	for resp := range internalResp {
   363  		responseCh <- resp
   364  		m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
   365  	}
   366  }
   367  
   368  func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) {
   369  	var errState bool
   370  	if debugPrint {
   371  		start := time.Now()
   372  		defer func() {
   373  			fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond))
   374  		}()
   375  	}
   376  
   377  	// Listen for client messages.
   378  	for {
   379  		if errState {
   380  			go func() {
   381  				// Drain requests.
   382  				for range requests {
   383  				}
   384  			}()
   385  			return
   386  		}
   387  		select {
   388  		case <-m.ctx.Done():
   389  			if debugPrint {
   390  				fmt.Println("Client sending disconnect to mux", m.MuxID)
   391  			}
   392  			m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
   393  			errState = true
   394  			continue
   395  		case req, ok := <-requests:
   396  			if !ok {
   397  				// Done send EOF
   398  				if debugPrint {
   399  					fmt.Println("Client done, sending EOF to mux", m.MuxID)
   400  				}
   401  				msg := message{
   402  					Op:    OpMuxClientMsg,
   403  					MuxID: m.MuxID,
   404  					Seq:   1,
   405  					Flags: FlagEOF,
   406  				}
   407  				msg.setZeroPayloadFlag()
   408  				err := m.send(msg)
   409  				if err != nil {
   410  					m.addErrorNonBlockingClose(internalResp, err)
   411  				}
   412  				return
   413  			}
   414  			// Grab a send token.
   415  			select {
   416  			case <-m.ctx.Done():
   417  				m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
   418  				errState = true
   419  				continue
   420  			case <-m.outBlock:
   421  			}
   422  			msg := message{
   423  				Op:      OpMuxClientMsg,
   424  				MuxID:   m.MuxID,
   425  				Seq:     1,
   426  				Payload: req,
   427  			}
   428  			msg.setZeroPayloadFlag()
   429  			err := m.send(msg)
   430  			PutByteBuffer(req)
   431  			if err != nil {
   432  				m.addErrorNonBlockingClose(internalResp, err)
   433  				errState = true
   434  				continue
   435  			}
   436  			msg.Seq++
   437  		}
   438  	}
   439  }
   440  
   441  // checkSeq will check if sequence number is correct and increment it by 1.
   442  func (m *muxClient) checkSeq(seq uint32) (ok bool) {
   443  	if seq != m.RecvSeq {
   444  		if debugPrint {
   445  			fmt.Printf("MuxID: %d client, expected sequence %d, got %d\n", m.MuxID, m.RecvSeq, seq)
   446  		}
   447  		m.addResponse(Response{Err: ErrIncorrectSequence})
   448  		return false
   449  	}
   450  	m.RecvSeq++
   451  	return true
   452  }
   453  
   454  // response will send handleIncoming response to client.
   455  // may never block.
   456  // Should return whether the next call would block.
   457  func (m *muxClient) response(seq uint32, r Response) {
   458  	if debugReqs {
   459  		fmt.Println(m.MuxID, m.parent.String(), "RESP")
   460  	}
   461  	if debugPrint {
   462  		fmt.Printf("mux %d: got msg seqid %d, payload length: %d, err:%v\n", m.MuxID, seq, len(r.Msg), r.Err)
   463  	}
   464  	if !m.checkSeq(seq) {
   465  		if debugReqs {
   466  			fmt.Println(m.MuxID, m.parent.String(), "CHECKSEQ FAIL", m.RecvSeq, seq)
   467  		}
   468  		PutByteBuffer(r.Msg)
   469  		r.Msg = nil
   470  		r.Err = ErrIncorrectSequence
   471  		m.addResponse(r)
   472  		return
   473  	}
   474  	atomic.StoreInt64(&m.LastPong, time.Now().Unix())
   475  	ok := m.addResponse(r)
   476  	if !ok {
   477  		PutByteBuffer(r.Msg)
   478  	}
   479  }
   480  
   481  var errStreamEOF = errors.New("stream EOF")
   482  
   483  // error is a message from the server to disconnect.
   484  func (m *muxClient) error(err RemoteErr) {
   485  	if debugPrint {
   486  		fmt.Printf("mux %d: got remote err:%v\n", m.MuxID, string(err))
   487  	}
   488  	m.addResponse(Response{Err: &err})
   489  }
   490  
   491  func (m *muxClient) ack(seq uint32) {
   492  	if !m.checkSeq(seq) {
   493  		return
   494  	}
   495  	if m.acked || m.outBlock == nil {
   496  		return
   497  	}
   498  	available := cap(m.outBlock)
   499  	for i := 0; i < available; i++ {
   500  		m.outBlock <- struct{}{}
   501  	}
   502  	m.acked = true
   503  }
   504  
   505  func (m *muxClient) unblockSend(seq uint32) {
   506  	if !m.checkSeq(seq) {
   507  		return
   508  	}
   509  	select {
   510  	case m.outBlock <- struct{}{}:
   511  	default:
   512  		logger.LogIf(m.ctx, errors.New("output unblocked overflow"))
   513  	}
   514  }
   515  
   516  func (m *muxClient) pong(msg pongMsg) {
   517  	if msg.NotFound || msg.Err != nil {
   518  		err := errors.New("remote terminated call")
   519  		if msg.Err != nil {
   520  			err = fmt.Errorf("remove pong failed: %v", &msg.Err)
   521  		}
   522  		m.addResponse(Response{Err: err})
   523  		return
   524  	}
   525  	atomic.StoreInt64(&m.LastPong, time.Now().Unix())
   526  }
   527  
   528  // addResponse will add a response to the response channel.
   529  // This function will never block
   530  func (m *muxClient) addResponse(r Response) (ok bool) {
   531  	m.respMu.Lock()
   532  	defer m.respMu.Unlock()
   533  	if m.closed {
   534  		return false
   535  	}
   536  	select {
   537  	case m.respWait <- r:
   538  		if r.Err != nil {
   539  			if debugPrint {
   540  				fmt.Println("Closing mux", m.MuxID, "due to error:", r.Err)
   541  			}
   542  			m.closeLocked()
   543  		}
   544  		return true
   545  	default:
   546  		if m.stateless {
   547  			// Drop message if not stateful.
   548  			return
   549  		}
   550  		err := errors.New("INTERNAL ERROR: Response was blocked")
   551  		logger.LogIf(m.ctx, err)
   552  		m.closeLocked()
   553  		return false
   554  	}
   555  }
   556  
   557  func (m *muxClient) close() {
   558  	if debugPrint {
   559  		fmt.Println("closing outgoing mux", m.MuxID)
   560  	}
   561  	if !m.respMu.TryLock() {
   562  		// Cancel before locking - will unblock any pending sends.
   563  		if m.cancelFn != nil {
   564  			m.cancelFn(context.Canceled)
   565  		}
   566  		// Wait for senders to release.
   567  		m.respMu.Lock()
   568  	}
   569  
   570  	defer m.respMu.Unlock()
   571  	m.closeLocked()
   572  }
   573  
   574  func (m *muxClient) closeLocked() {
   575  	if m.closed {
   576  		return
   577  	}
   578  	// We hold the lock, so nobody can modify m.respWait while we're closing.
   579  	if m.respWait != nil {
   580  		xioutil.SafeClose(m.respWait)
   581  		m.respWait = nil
   582  	}
   583  	m.closed = true
   584  }