github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/grid/muxserver.go (about)

     1  // Copyright (c) 2015-2023 MinIO, Inc.
     2  //
     3  // This file is part of MinIO Object Storage stack
     4  //
     5  // This program is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Affero General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // This program is distributed in the hope that it will be useful
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU Affero General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Affero General Public License
    16  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17  
    18  package grid
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"sync"
    25  	"sync/atomic"
    26  	"time"
    27  
    28  	xioutil "github.com/minio/minio/internal/ioutil"
    29  	"github.com/minio/minio/internal/logger"
    30  )
    31  
    32  const lastPingThreshold = 4 * clientPingInterval
    33  
    34  type muxServer struct {
    35  	ID               uint64
    36  	LastPing         int64
    37  	SendSeq, RecvSeq uint32
    38  	Resp             chan []byte
    39  	BaseFlags        Flags
    40  	ctx              context.Context
    41  	cancel           context.CancelFunc
    42  	inbound          chan []byte
    43  	parent           *Connection
    44  	sendMu           sync.Mutex
    45  	recvMu           sync.Mutex
    46  	outBlock         chan struct{}
    47  }
    48  
    49  func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer {
    50  	var cancel context.CancelFunc
    51  	ctx = setCaller(ctx, c.remote)
    52  	if msg.DeadlineMS > 0 {
    53  		ctx, cancel = context.WithTimeout(ctx, time.Duration(msg.DeadlineMS)*time.Millisecond)
    54  	} else {
    55  		ctx, cancel = context.WithCancel(ctx)
    56  	}
    57  	m := muxServer{
    58  		ID:        msg.MuxID,
    59  		RecvSeq:   msg.Seq + 1,
    60  		SendSeq:   msg.Seq,
    61  		ctx:       ctx,
    62  		cancel:    cancel,
    63  		parent:    c,
    64  		LastPing:  time.Now().Unix(),
    65  		BaseFlags: c.baseFlags,
    66  	}
    67  	go func() {
    68  		// TODO: Handle
    69  	}()
    70  
    71  	return &m
    72  }
    73  
    74  func newMuxStream(ctx context.Context, msg message, c *Connection, handler StreamHandler) *muxServer {
    75  	var cancel context.CancelFunc
    76  	ctx = setCaller(ctx, c.remote)
    77  	if len(handler.Subroute) > 0 {
    78  		ctx = setSubroute(ctx, handler.Subroute)
    79  	}
    80  	if msg.DeadlineMS > 0 {
    81  		ctx, cancel = context.WithTimeout(ctx, time.Duration(msg.DeadlineMS)*time.Millisecond+c.addDeadline)
    82  	} else {
    83  		ctx, cancel = context.WithCancel(ctx)
    84  	}
    85  
    86  	send := make(chan []byte)
    87  	inboundCap, outboundCap := handler.InCapacity, handler.OutCapacity
    88  	if outboundCap <= 0 {
    89  		outboundCap = 1
    90  	}
    91  
    92  	m := muxServer{
    93  		ID:        msg.MuxID,
    94  		RecvSeq:   msg.Seq + 1,
    95  		SendSeq:   msg.Seq,
    96  		ctx:       ctx,
    97  		cancel:    cancel,
    98  		parent:    c,
    99  		inbound:   nil,
   100  		outBlock:  make(chan struct{}, outboundCap),
   101  		LastPing:  time.Now().Unix(),
   102  		BaseFlags: c.baseFlags,
   103  	}
   104  	// Acknowledge Mux created.
   105  	// Send async.
   106  	var wg sync.WaitGroup
   107  	wg.Add(1)
   108  	go func() {
   109  		defer wg.Done()
   110  		var ack message
   111  		ack.Op = OpAckMux
   112  		ack.Flags = m.BaseFlags
   113  		ack.MuxID = m.ID
   114  		m.send(ack)
   115  		if debugPrint {
   116  			fmt.Println("connected stream mux:", ack.MuxID)
   117  		}
   118  	}()
   119  
   120  	// Data inbound to the handler
   121  	var handlerIn chan []byte
   122  	if inboundCap > 0 {
   123  		m.inbound = make(chan []byte, inboundCap)
   124  		handlerIn = make(chan []byte, 1)
   125  		go func(inbound chan []byte) {
   126  			wg.Wait()
   127  			defer xioutil.SafeClose(handlerIn)
   128  			m.handleInbound(c, inbound, handlerIn)
   129  		}(m.inbound)
   130  	}
   131  	// Fill outbound block.
   132  	// Each token represents a message that can be sent to the client without blocking.
   133  	// The client will refill the tokens as they confirm delivery of the messages.
   134  	for i := 0; i < outboundCap; i++ {
   135  		m.outBlock <- struct{}{}
   136  	}
   137  
   138  	// Handler goroutine.
   139  	var handlerErr atomic.Pointer[RemoteErr]
   140  	go func() {
   141  		wg.Wait()
   142  		defer xioutil.SafeClose(send)
   143  		err := m.handleRequests(ctx, msg, send, handler, handlerIn)
   144  		if err != nil {
   145  			handlerErr.Store(err)
   146  		}
   147  	}()
   148  
   149  	// Response sender goroutine...
   150  	go func(outBlock <-chan struct{}) {
   151  		wg.Wait()
   152  		defer m.parent.deleteMux(true, m.ID)
   153  		m.sendResponses(ctx, send, c, &handlerErr, outBlock)
   154  	}(m.outBlock)
   155  
   156  	// Remote aliveness check if needed.
   157  	if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) {
   158  		go func() {
   159  			wg.Wait()
   160  			m.checkRemoteAlive()
   161  		}()
   162  	}
   163  	return &m
   164  }
   165  
   166  // handleInbound sends unblocks when we have delivered the message to the handler.
   167  func (m *muxServer) handleInbound(c *Connection, inbound <-chan []byte, handlerIn chan<- []byte) {
   168  	for in := range inbound {
   169  		handlerIn <- in
   170  		m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags})
   171  	}
   172  }
   173  
   174  // sendResponses will send responses to the client.
   175  func (m *muxServer) sendResponses(ctx context.Context, toSend <-chan []byte, c *Connection, handlerErr *atomic.Pointer[RemoteErr], outBlock <-chan struct{}) {
   176  	for {
   177  		// Process outgoing message.
   178  		var payload []byte
   179  		var ok bool
   180  		select {
   181  		case payload, ok = <-toSend:
   182  		case <-ctx.Done():
   183  			return
   184  		}
   185  		select {
   186  		case <-ctx.Done():
   187  			return
   188  		case <-outBlock:
   189  		}
   190  		msg := message{
   191  			MuxID: m.ID,
   192  			Op:    OpMuxServerMsg,
   193  			Flags: c.baseFlags,
   194  		}
   195  		if !ok {
   196  			hErr := handlerErr.Load()
   197  			if debugPrint {
   198  				fmt.Println("muxServer: Mux", m.ID, "send EOF", hErr)
   199  			}
   200  			msg.Flags |= FlagEOF
   201  			if hErr != nil {
   202  				msg.Flags |= FlagPayloadIsErr
   203  				msg.Payload = []byte(*hErr)
   204  			}
   205  			msg.setZeroPayloadFlag()
   206  			m.send(msg)
   207  			return
   208  		}
   209  		msg.Payload = payload
   210  		msg.setZeroPayloadFlag()
   211  		m.send(msg)
   212  	}
   213  }
   214  
   215  // handleRequests will handle the requests from the client and call the handler function.
   216  func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<- []byte, handler StreamHandler, handlerIn <-chan []byte) (handlerErr *RemoteErr) {
   217  	start := time.Now()
   218  	defer func() {
   219  		if debugPrint {
   220  			fmt.Println("Mux", m.ID, "Handler took", time.Since(start).Round(time.Millisecond))
   221  		}
   222  		if r := recover(); r != nil {
   223  			logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r))
   224  			err := RemoteErr(fmt.Sprintf("handler panic: %v", r))
   225  			handlerErr = &err
   226  		}
   227  		if debugPrint {
   228  			fmt.Println("muxServer: Mux", m.ID, "Returned with", handlerErr)
   229  		}
   230  	}()
   231  	// handlerErr is guarded by 'send' channel.
   232  	handlerErr = handler.Handle(ctx, msg.Payload, handlerIn, send)
   233  	return handlerErr
   234  }
   235  
   236  // checkRemoteAlive will check if the remote is alive.
   237  func (m *muxServer) checkRemoteAlive() {
   238  	t := time.NewTicker(lastPingThreshold / 4)
   239  	defer t.Stop()
   240  	for {
   241  		select {
   242  		case <-m.ctx.Done():
   243  			return
   244  		case <-t.C:
   245  			last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
   246  			if last > lastPingThreshold {
   247  				logger.LogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
   248  				m.close()
   249  				return
   250  			}
   251  		}
   252  	}
   253  }
   254  
   255  // checkSeq will check if sequence number is correct and increment it by 1.
   256  func (m *muxServer) checkSeq(seq uint32) (ok bool) {
   257  	if seq != m.RecvSeq {
   258  		if debugPrint {
   259  			fmt.Printf("expected sequence %d, got %d\n", m.RecvSeq, seq)
   260  		}
   261  		m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq))
   262  		return false
   263  	}
   264  	m.RecvSeq++
   265  	return true
   266  }
   267  
   268  func (m *muxServer) message(msg message) {
   269  	if debugPrint {
   270  		fmt.Printf("muxServer: received message %d, length %d\n", msg.Seq, len(msg.Payload))
   271  	}
   272  	m.recvMu.Lock()
   273  	defer m.recvMu.Unlock()
   274  	if cap(m.inbound) == 0 {
   275  		m.disconnect("did not expect inbound message")
   276  		return
   277  	}
   278  	if !m.checkSeq(msg.Seq) {
   279  		return
   280  	}
   281  	// Note, on EOF no value can be sent.
   282  	if msg.Flags&FlagEOF != 0 {
   283  		if len(msg.Payload) > 0 {
   284  			logger.LogIf(m.ctx, fmt.Errorf("muxServer: EOF message with payload"))
   285  		}
   286  		if m.inbound != nil {
   287  			xioutil.SafeClose(m.inbound)
   288  			m.inbound = nil
   289  		}
   290  		return
   291  	}
   292  
   293  	select {
   294  	case <-m.ctx.Done():
   295  	case m.inbound <- msg.Payload:
   296  		if debugPrint {
   297  			fmt.Printf("muxServer: Sent seq %d to handler\n", msg.Seq)
   298  		}
   299  	default:
   300  		m.disconnect("handler blocked")
   301  	}
   302  }
   303  
   304  func (m *muxServer) unblockSend(seq uint32) {
   305  	if !m.checkSeq(seq) {
   306  		return
   307  	}
   308  	m.recvMu.Lock()
   309  	defer m.recvMu.Unlock()
   310  	if m.outBlock == nil {
   311  		// Closed
   312  		return
   313  	}
   314  	select {
   315  	case m.outBlock <- struct{}{}:
   316  	default:
   317  		logger.LogIf(m.ctx, errors.New("output unblocked overflow"))
   318  	}
   319  }
   320  
   321  func (m *muxServer) ping(seq uint32) pongMsg {
   322  	if !m.checkSeq(seq) {
   323  		msg := fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq)
   324  		return pongMsg{Err: &msg}
   325  	}
   326  	select {
   327  	case <-m.ctx.Done():
   328  		err := context.Cause(m.ctx).Error()
   329  		return pongMsg{Err: &err}
   330  	default:
   331  		atomic.StoreInt64(&m.LastPing, time.Now().Unix())
   332  		return pongMsg{}
   333  	}
   334  }
   335  
   336  func (m *muxServer) disconnect(msg string) {
   337  	if debugPrint {
   338  		fmt.Println("Mux", m.ID, "disconnecting. Reason:", msg)
   339  	}
   340  	if msg != "" {
   341  		m.send(message{Op: OpMuxServerMsg, MuxID: m.ID, Flags: FlagPayloadIsErr | FlagEOF, Payload: []byte(msg)})
   342  	} else {
   343  		m.send(message{Op: OpDisconnectClientMux, MuxID: m.ID})
   344  	}
   345  	m.parent.deleteMux(true, m.ID)
   346  }
   347  
   348  func (m *muxServer) send(msg message) {
   349  	m.sendMu.Lock()
   350  	defer m.sendMu.Unlock()
   351  	msg.MuxID = m.ID
   352  	msg.Seq = m.SendSeq
   353  	m.SendSeq++
   354  	if debugPrint {
   355  		fmt.Printf("Mux %d, Sending %+v\n", m.ID, msg)
   356  	}
   357  	logger.LogIf(m.ctx, m.parent.queueMsg(msg, nil))
   358  }
   359  
   360  func (m *muxServer) close() {
   361  	m.cancel()
   362  	m.recvMu.Lock()
   363  	defer m.recvMu.Unlock()
   364  
   365  	if m.inbound != nil {
   366  		xioutil.SafeClose(m.inbound)
   367  		m.inbound = nil
   368  	}
   369  
   370  	if m.outBlock != nil {
   371  		xioutil.SafeClose(m.outBlock)
   372  		m.outBlock = nil
   373  
   374  	}
   375  }