github.com/swiftstack/ProxyFS@v0.0.0-20210203235616-4017c267d62f/retryrpc/client.go (about)

     1  // Copyright (c) 2015-2021, NVIDIA CORPORATION.
     2  // SPDX-License-Identifier: Apache-2.0
     3  
     4  package retryrpc
     5  
     6  import (
     7  	"crypto/tls"
     8  	"encoding/binary"
     9  	"encoding/json"
    10  	"fmt"
    11  	"net"
    12  	"os"
    13  	"time"
    14  
    15  	"github.com/google/btree"
    16  	"github.com/swiftstack/ProxyFS/bucketstats"
    17  	"github.com/swiftstack/ProxyFS/logger"
    18  )
    19  
    20  const (
    21  	ConnectionRetryDelayMultiplier = 2
    22  	ConnectionRetryInitialDelay    = 100 * time.Millisecond
    23  	ConnectionRetryLimit           = 8
    24  )
    25  
    26  const (
    27  	// Prefix used for bucketstats of client
    28  	clientSideGroupPrefix = "ClientSide-"
    29  )
    30  
    31  // Useful stats for the client side
    32  type clientSideStatsInfo struct {
    33  	RetransmitsStarted bucketstats.Total // Number of retransmits attempted
    34  	SendCalled         bucketstats.Total // Number of times Send called
    35  	ReplyCalled        bucketstats.Total // Number of times receive Reply to RPC
    36  	UpcallCalled       bucketstats.Total // Number of times received an Upcall
    37  }
    38  
    39  // TODO - what if RPC was completed on Server1 and before response,
    40  // proxyfsd fails over to Server2?   Client will resend - not idempotent
    41  // This is outside of our initial requirements but something we should
    42  // review.
    43  
    44  //
    45  // Send algorithm is:
    46  // 1. Build ctx including channel for reply struct
    47  // 2. Call goroutine to do marshalling and sending of
    48  //    request to server
    49  // 3. Wait on channel in reply struct for result
    50  // 4. readResponses goroutine will read response on socket
    51  //    and call a goroutine to do unmarshalling and notification
    52  func (client *Client) send(method string, rpcRequest interface{}, rpcReply interface{}) (err error) {
    53  	var (
    54  		connectionRetryCount int
    55  		connectionRetryDelay time.Duration
    56  		crID                 requestID
    57  	)
    58  	client.stats.SendCalled.Add(1)
    59  
    60  	client.Lock()
    61  	if client.connection.state == INITIAL {
    62  
    63  		connectionRetryCount = 0
    64  		connectionRetryDelay = ConnectionRetryInitialDelay
    65  
    66  		for {
    67  			err = client.dial()
    68  			if err == nil {
    69  				break
    70  			}
    71  			client.Unlock()
    72  			connectionRetryCount++
    73  			if connectionRetryCount > ConnectionRetryLimit {
    74  				err = fmt.Errorf("In send(), ConnectionRetryLimit (%v) on calling dial() exceeded", ConnectionRetryLimit)
    75  				logger.PanicfWithError(err, "")
    76  			}
    77  			time.Sleep(connectionRetryDelay)
    78  			connectionRetryDelay *= ConnectionRetryDelayMultiplier
    79  			client.Lock()
    80  			if client.connection.state != INITIAL {
    81  				break
    82  			}
    83  		}
    84  	}
    85  
    86  	// Put request data into structure to be be marshaled into JSON
    87  	jreq := jsonRequest{Method: method, HighestReplySeen: client.highestConsecutive}
    88  	jreq.Params[0] = rpcRequest
    89  	jreq.MyUniqueID = client.myUniqueID
    90  
    91  	if client.halting == true {
    92  		client.Unlock()
    93  		e := fmt.Errorf("Calling retryrpc.Send() without dialing")
    94  		logger.PanicfWithError(e, "")
    95  		return
    96  	}
    97  	client.currentRequestID++
    98  	crID = client.currentRequestID
    99  	jreq.RequestID = crID
   100  	client.Unlock()
   101  
   102  	// Setup ioreq to write structure on socket to server
   103  	ioreq, err := buildIoRequest(jreq)
   104  	if err != nil {
   105  		e := fmt.Errorf("Client buildIoRequest returned err: %v", err)
   106  		logger.PanicfWithError(e, "")
   107  		return err
   108  	}
   109  
   110  	// Create context to wait result and to handle retransmits
   111  	ctx := &reqCtx{ioreq: *ioreq, rpcReply: rpcReply}
   112  	ctx.answer = make(chan replyCtx)
   113  
   114  	client.goroutineWG.Add(1)
   115  	go client.sendToServer(crID, ctx, true)
   116  
   117  	// Now wait for response
   118  	answer := <-ctx.answer
   119  
   120  	return answer.err
   121  }
   122  
   123  // sendToServer packages the request and marshals it before
   124  // sending to server.
   125  //
   126  // At this point, the client will retry the request until either it
   127  // completes OR the client is shutdown.
   128  func (client *Client) sendToServer(crID requestID, ctx *reqCtx, queue bool) {
   129  
   130  	defer client.goroutineWG.Done()
   131  
   132  	// Now send the request to the server.
   133  	// We need to grab the mutex here to serialize writes on socket
   134  	client.Lock()
   135  
   136  	// Keep track of requests we are sending so we can resend them later
   137  	// as needed.   We queue the request first since we may get an error
   138  	// we can just return.
   139  	//
   140  	// That should be okay since the restransmit goroutine will walk the
   141  	// outstandingRequests queue and resend the request.
   142  	//
   143  	// Don't queue the request if we are retransmitting....
   144  	if queue == true {
   145  		client.outstandingRequest[crID] = ctx
   146  	}
   147  
   148  	// Record generation number of connection.  It is used during
   149  	// retransmit to prevent multiple goroutines from closing the
   150  	// connection and opening a new socket when only one is needed.
   151  	ctx.genNum = client.connection.genNum
   152  
   153  	// The connection state may have changed between when this goroutine
   154  	// was scheduled and when it grabbed the client lock.
   155  	//
   156  	// After we have queued the request, verify the state again before
   157  	// attempting to use the connection.  If we are not CONNECTED, return
   158  	// since we must already be in RETRANSMITTING. Since the request is
   159  	// on the queue, it will be retried automatically.
   160  	if client.connection.state != CONNECTED {
   161  		client.Unlock()
   162  		return
   163  	}
   164  
   165  	// Send header
   166  	client.connection.tlsConn.SetDeadline(time.Now().Add(client.deadlineIO))
   167  	err := binary.Write(client.connection.tlsConn, binary.BigEndian, ctx.ioreq.Hdr)
   168  	if err != nil {
   169  		genNum := ctx.genNum
   170  		client.Unlock()
   171  
   172  		// Just return - the retransmit code will start another
   173  		// sendToServer() goroutine
   174  		client.retransmit(genNum)
   175  		return
   176  	}
   177  
   178  	// Send JSON request
   179  	client.connection.tlsConn.SetDeadline(time.Now().Add(client.deadlineIO))
   180  	bytesWritten, writeErr := client.connection.tlsConn.Write(ctx.ioreq.JReq)
   181  
   182  	if (bytesWritten != len(ctx.ioreq.JReq)) || (writeErr != nil) {
   183  		/* TODO - log message?
   184  		fmt.Printf("CLIENT: PARTIAL Write! bytesWritten is: %v len(ctx.ioreq.JReq): %v writeErr: %v\n",
   185  			bytesWritten, len(ctx.ioreq.JReq), writeErr)
   186  		*/
   187  		client.Unlock()
   188  
   189  		// Just return - the retransmit code will start another
   190  		// sendToServer() goroutine
   191  		client.retransmit(ctx.genNum)
   192  		return
   193  	}
   194  
   195  	client.Unlock()
   196  	return
   197  }
   198  
   199  func (client *Client) notifyReply(buf []byte, genNum uint64) {
   200  	defer client.goroutineWG.Done()
   201  
   202  	// Unmarshal once to get the header fields
   203  	jReply := jsonReply{}
   204  	err := json.Unmarshal(buf, &jReply)
   205  	if err != nil {
   206  		// Don't have ctx to reply.  Assume read garbage on socket and
   207  		// reconnect.
   208  
   209  		// TODO - make log message
   210  		e := fmt.Errorf("notifyReply failed to unmarshal buf: %+v err: %v", string(buf), err)
   211  		fmt.Printf("%v\n", e)
   212  
   213  		client.retransmit(genNum)
   214  		return
   215  	}
   216  
   217  	// Remove request from client.outstandingRequest
   218  	//
   219  	// We do it here since we need to retrieve the RequestID from the
   220  	// original request anyway.
   221  	crID := jReply.RequestID
   222  	client.Lock()
   223  
   224  	// If this message is from an old socket - throw it away
   225  	// since the request was resent.
   226  	if client.connection.genNum != genNum {
   227  		client.Unlock()
   228  		return
   229  	}
   230  	ctx, ok := client.outstandingRequest[crID]
   231  
   232  	if !ok {
   233  		// Saw reply for request which is no longer on outstandingRequest list
   234  		// Can happen if handling retransmit
   235  		client.Unlock()
   236  		return
   237  	}
   238  
   239  	// Unmarshal the buf into the original reply structure
   240  	m := svrResponse{Result: ctx.rpcReply}
   241  	unmarshalErr := json.Unmarshal(buf, &m)
   242  	if unmarshalErr != nil {
   243  		e := fmt.Errorf("notifyReply failed to unmarshal buf: %v err: %v ctx: %v", string(buf), unmarshalErr, ctx)
   244  		fmt.Printf("%v\n", e)
   245  
   246  		// Assume read garbage on socket - close the socket and reconnect
   247  		client.retransmit(genNum)
   248  		client.Unlock()
   249  		return
   250  	}
   251  
   252  	delete(client.outstandingRequest, crID)
   253  
   254  	// Give reply to blocked send() - most developers test for nil err so
   255  	// only set it if there is an error
   256  	r := replyCtx{}
   257  	if jReply.ErrStr != "" {
   258  		r.err = fmt.Errorf("%v", jReply.ErrStr)
   259  	}
   260  	client.Unlock()
   261  	ctx.answer <- r
   262  
   263  	// Fork off a goroutine to update highestConsecutiveNum
   264  	go client.updateHighestConsecutiveNum(crID)
   265  }
   266  
   267  // readReplies is a goroutine dedicated to reading responses from the server.
   268  //
   269  // As soon as it reads a complete response, it launches a goroutine to process
   270  // the response and notify the blocked Send().
   271  func (client *Client) readReplies(callingGenNum uint64, tlsConn *tls.Conn) {
   272  	defer client.goroutineWG.Done()
   273  	var localCnt int
   274  
   275  	for {
   276  
   277  		// Wait reply from server
   278  		buf, msgType, getErr := getIO(callingGenNum, client.deadlineIO, tlsConn)
   279  
   280  		// This must happen before checking error
   281  		client.Lock()
   282  		if client.halting {
   283  			client.Unlock()
   284  			return
   285  		}
   286  		localCnt = len(client.outstandingRequest)
   287  		client.Unlock()
   288  
   289  		// Ignore timeouts on idle connections while reading header
   290  		//
   291  		// We consider a connection to be idle if we have no outstanding requests when
   292  		// we get the timeout.   Otherwise, we call retransmit.
   293  		if os.IsTimeout(getErr) == true && localCnt == 0 {
   294  			continue
   295  		}
   296  
   297  		if getErr != nil {
   298  
   299  			// If we had an error reading socket - call retransmit() and exit
   300  			// the goroutine.  retransmit()/dial() will start another
   301  			// readReplies() goroutine.
   302  			client.retransmit(callingGenNum)
   303  			return
   304  		}
   305  
   306  		// Figure out what type of message it is
   307  		switch msgType {
   308  		case RPC:
   309  			// We have a reply to an RPC - let a goroutine do the unmarshalling
   310  			// and sending the reply to blocked Send() so that this routine
   311  			// can read the next response.
   312  			client.goroutineWG.Add(1)
   313  			go client.notifyReply(buf, callingGenNum)
   314  			client.stats.ReplyCalled.Add(1)
   315  
   316  		case Upcall:
   317  
   318  			// Spawn off goroutine to call callback
   319  			client.goroutineWG.Add(1)
   320  			go func(buf []byte) {
   321  				client.cb.(ClientCallbacks).Interrupt(buf)
   322  				client.goroutineWG.Done()
   323  			}(buf)
   324  			client.stats.UpcallCalled.Add(1)
   325  
   326  		default:
   327  			fmt.Printf("CLIENT - invalid msgType: %v\n", msgType)
   328  		}
   329  	}
   330  }
   331  
   332  // retransmit is called when a socket related error occurs on the
   333  // connection to the server.
   334  func (client *Client) retransmit(genNum uint64) {
   335  	var (
   336  		connectionRetryCount int
   337  		connectionRetryDelay time.Duration
   338  	)
   339  
   340  	client.Lock()
   341  
   342  	// Check if we are already processing the socket error via
   343  	// another goroutine.  If it is - return now.
   344  	//
   345  	// Since the original request is on client.outstandingRequest it will
   346  	// have been resent by the first goroutine to encounter the error.
   347  	if (genNum != client.connection.genNum) || (client.connection.state == RETRANSMITTING) {
   348  		client.Unlock()
   349  		return
   350  	}
   351  
   352  	if client.halting == true {
   353  		client.Unlock()
   354  		return
   355  	}
   356  
   357  	// We are the first goroutine to notice the error on the
   358  	// socket - close the connection and start trying to reconnect.
   359  	_ = client.connection.tlsConn.Close()
   360  	client.connection.state = RETRANSMITTING
   361  	client.stats.RetransmitsStarted.Add(1)
   362  
   363  	connectionRetryCount = 0
   364  	connectionRetryDelay = ConnectionRetryInitialDelay
   365  
   366  	for {
   367  		err := client.dial()
   368  		// If we were able to connect then break - otherwise retry
   369  		// after a delay
   370  		if err == nil {
   371  			break
   372  		}
   373  		client.Unlock()
   374  		connectionRetryCount++
   375  		if connectionRetryCount > ConnectionRetryLimit {
   376  			err = fmt.Errorf("In retransmit(), ConnectionRetryLimit (%v) on calling dial() exceeded", ConnectionRetryLimit)
   377  			logger.PanicfWithError(err, "")
   378  		}
   379  		time.Sleep(connectionRetryDelay)
   380  		connectionRetryDelay *= ConnectionRetryDelayMultiplier
   381  		client.Lock()
   382  		// While the lock was dropped we may be halting....
   383  		if client.halting == true {
   384  			client.Unlock()
   385  			return
   386  		}
   387  	}
   388  
   389  	for crID, ctx := range client.outstandingRequest {
   390  		// Note that we are holding the lock so these
   391  		// goroutines will block until we release it.
   392  		client.goroutineWG.Add(1)
   393  		go client.sendToServer(crID, ctx, false)
   394  	}
   395  	client.Unlock()
   396  }
   397  
   398  // Send myUniqueID to server
   399  //
   400  // NOTE: Client lock is already held during this call.
   401  func (client *Client) sendMyInfo(tlsConn *tls.Conn) (err error) {
   402  
   403  	// Setup ioreq to write structure on socket to server
   404  	isreq, err := buildSetIDRequest(client.myUniqueID)
   405  	if err != nil {
   406  		e := fmt.Errorf("Client buildSetIDRequest returned err: %v", err)
   407  		logger.PanicfWithError(e, "")
   408  		return err
   409  	}
   410  
   411  	// Send header
   412  	client.connection.tlsConn.SetDeadline(time.Now().Add(client.deadlineIO))
   413  	err = binary.Write(tlsConn, binary.BigEndian, isreq.Hdr)
   414  	if err != nil {
   415  		return
   416  	}
   417  
   418  	// Send MyUniqueID
   419  	client.connection.tlsConn.SetDeadline(time.Now().Add(client.deadlineIO))
   420  	bytesWritten, writeErr := tlsConn.Write(isreq.MyUniqueID)
   421  
   422  	if uint32(bytesWritten) != isreq.Hdr.Len {
   423  		e := fmt.Errorf("sendMyInfo length incorrect")
   424  		err = e
   425  		return
   426  	}
   427  
   428  	if writeErr != nil {
   429  		err = writeErr
   430  		return
   431  	}
   432  
   433  	// Nothing is sent back from server
   434  
   435  	return
   436  }
   437  
   438  // dial sets up connection to server
   439  // It is assumed that the client lock is held.
   440  //
   441  // NOTE: Client lock is held
   442  func (client *Client) dial() (err error) {
   443  	var entryState = client.connection.state
   444  
   445  	client.connection.tlsConfig = &tls.Config{
   446  		RootCAs: client.connection.x509CertPool,
   447  	}
   448  
   449  	// Now dial the server
   450  	d := &net.Dialer{KeepAlive: client.keepAlivePeriod}
   451  	tlsConn, dialErr := tls.DialWithDialer(d, "tcp", client.connection.hostPortStr, client.connection.tlsConfig)
   452  	if dialErr != nil {
   453  		err = fmt.Errorf("tls.Dial() failed: %v", dialErr)
   454  		return
   455  	}
   456  
   457  	if client.connection.tlsConn != nil {
   458  		client.connection.tlsConn.Close()
   459  		client.connection.tlsConn = nil
   460  	}
   461  
   462  	client.connection.tlsConn = tlsConn
   463  	client.connection.state = CONNECTED
   464  	client.connection.genNum++
   465  
   466  	// Send myUniqueID to server.   If this fails the dial will
   467  	// be retried.
   468  	err = client.sendMyInfo(tlsConn)
   469  	if err != nil {
   470  		_ = client.connection.tlsConn.Close()
   471  		client.connection.tlsConn = nil
   472  		client.connection.state = entryState
   473  		return
   474  	}
   475  
   476  	// Start readResponse goroutine to read responses from server
   477  	client.goroutineWG.Add(1)
   478  	go client.readReplies(client.connection.genNum, tlsConn)
   479  
   480  	return
   481  }
   482  
   483  // Less tests whether the current item is less than the given argument.
   484  //
   485  // This must provide a strict weak ordering.
   486  // If !a.Less(b) && !b.Less(a), we treat this to mean a == b (i.e. we can only
   487  // hold one of either a or b in the tree).
   488  //
   489  // NOTE: It is assumed client lock is held when this is called.
   490  func (a requestID) Less(b btree.Item) bool {
   491  	return a < b.(requestID)
   492  }
   493  
   494  // printBTree prints the btree contents and is only for debugging
   495  //
   496  // NOTE: It is assumed client lock is held when this is called.
   497  func printBTree(tr *btree.BTree, msg string) {
   498  	tr.Ascend(func(a btree.Item) bool {
   499  		r := a.(requestID)
   500  		fmt.Printf("%v =========== - r is: %v\n", msg, r)
   501  		return true
   502  	})
   503  
   504  }
   505  
   506  // It is assumed the client lock is already held
   507  func (client *Client) setHighestConsecutive() {
   508  	client.bt.AscendGreaterOrEqual(client.highestConsecutive, func(a btree.Item) bool {
   509  		r := a.(requestID)
   510  		c := client.highestConsecutive
   511  
   512  		// If this item is a consecutive number then keep going.
   513  		// Otherwise stop the Ascend now
   514  		c++
   515  		if r == c {
   516  			client.highestConsecutive = r
   517  		} else {
   518  			// If we are past the first leaf and we do not have
   519  			// consecutive numbers than break now instead of going
   520  			// through rest of tree
   521  			if r != client.bt.Min() {
   522  				return false
   523  			}
   524  		}
   525  		return true
   526  	})
   527  
   528  	// Now trim the btree up to highestConsecutiveNum
   529  	m := client.bt.Min()
   530  	if m != nil {
   531  		i := m.(requestID)
   532  		for ; i < client.highestConsecutive; i++ {
   533  			client.bt.Delete(i)
   534  		}
   535  	}
   536  }
   537  
   538  // updateHighestConsecutiveNum takes the requestID and calculates the
   539  // highestConsective request ID we have seen.  This is done by putting
   540  // the requestID into a btree of completed requestIDs.  Then calculating
   541  // the highest consective number seen and updating Client.
   542  func (client *Client) updateHighestConsecutiveNum(crID requestID) {
   543  	client.Lock()
   544  	client.bt.ReplaceOrInsert(crID)
   545  	client.setHighestConsecutive()
   546  	client.Unlock()
   547  }