github.com/swiftstack/ProxyFS@v0.0.0-20210203235616-4017c267d62f/retryrpc/api.go (about)

     1  // Copyright (c) 2015-2021, NVIDIA CORPORATION.
     2  // SPDX-License-Identifier: Apache-2.0
     3  
     4  package retryrpc
     5  
     6  // Package retryrpc provides a client and server RPC model which survives
     7  // lost connections on either the client or the server.
     8  //
     9  // NOTE: This package does handle cases where the server process dies.  There
    10  // are still gaps where a server may complete an RPC and die before returning
    11  // a response.
    12  
    13  import (
    14  	"container/list"
    15  	"context"
    16  	"crypto/tls"
    17  	"crypto/x509"
    18  	"fmt"
    19  	"net"
    20  	"reflect"
    21  	"sync"
    22  	"time"
    23  
    24  	"github.com/google/btree"
    25  	"github.com/swiftstack/ProxyFS/bucketstats"
    26  	"github.com/swiftstack/ProxyFS/logger"
    27  )
    28  
    29  // ServerCreds tracks the root CA and the
    30  // server CA
    31  type ServerCreds struct {
    32  	RootCAx509CertificatePEM []byte
    33  	serverTLSCertificate     tls.Certificate
    34  }
    35  
    36  // Server tracks the state of the server
    37  type Server struct {
    38  	sync.Mutex
    39  	completedLongTTL time.Duration          // How long a completed request stays on queue
    40  	completedAckTrim time.Duration          // How frequently trim requests acked by client
    41  	svrMap           map[string]*methodArgs // Key: Method name
    42  	ipaddr           string                 // IP address server listens too
    43  	port             int                    // Port of server
    44  	netListener      net.Listener
    45  	tlsListener      net.Listener
    46  
    47  	halting              bool
    48  	goroutineWG          sync.WaitGroup // Used to track outstanding goroutines
    49  	connLock             sync.Mutex
    50  	connections          *list.List
    51  	connWG               sync.WaitGroup
    52  	Creds                *ServerCreds
    53  	listenersWG          sync.WaitGroup
    54  	receiver             reflect.Value          // Package receiver being served
    55  	perClientInfo        map[string]*clientInfo // Key: "clientID".  Tracks clients
    56  	completedTickerDone  chan bool
    57  	completedLongTicker  *time.Ticker // Longer ~10 minute timer to trim
    58  	completedShortTicker *time.Ticker // Shorter ~100ms timer to trim known completed
    59  	deadlineIO           time.Duration
    60  	keepAlivePeriod      time.Duration
    61  	completedDoneWG      sync.WaitGroup
    62  	dontStartTrimmers    bool // Used for testing
    63  }
    64  
    65  // ServerConfig is used to configure a retryrpc Server
    66  type ServerConfig struct {
    67  	LongTrim          time.Duration // How long the results of an RPC are stored on a Server before removed
    68  	ShortTrim         time.Duration // How frequently completed and ACKed RPCs results are removed from Server
    69  	IPAddr            string        // IP Address that Server uses to listen
    70  	Port              int           // Port that Server uses to listen
    71  	DeadlineIO        time.Duration // How long I/Os on sockets wait even if idle
    72  	KeepAlivePeriod   time.Duration // How frequently a KEEPALIVE is sent
    73  	dontStartTrimmers bool          // Used for testing
    74  }
    75  
    76  // NewServer creates the Server object
    77  func NewServer(config *ServerConfig) *Server {
    78  	var (
    79  		err error
    80  	)
    81  	server := &Server{ipaddr: config.IPAddr, port: config.Port, completedLongTTL: config.LongTrim,
    82  		completedAckTrim: config.ShortTrim, deadlineIO: config.DeadlineIO,
    83  		keepAlivePeriod: config.KeepAlivePeriod, dontStartTrimmers: config.dontStartTrimmers}
    84  	server.svrMap = make(map[string]*methodArgs)
    85  	server.perClientInfo = make(map[string]*clientInfo)
    86  	server.completedTickerDone = make(chan bool)
    87  	server.connections = list.New()
    88  
    89  	server.Creds, err = constructServerCreds(server.ipaddr)
    90  	if err != nil {
    91  		logger.Errorf("Construction of server credentials failed with err: %v", err)
    92  		panic(err)
    93  	}
    94  
    95  	return server
    96  }
    97  
    98  // Register creates the map of server methods
    99  func (server *Server) Register(retrySvr interface{}) (err error) {
   100  
   101  	// Find all the methods associated with retrySvr and put into serviceMap
   102  	server.receiver = reflect.ValueOf(retrySvr)
   103  	return server.register(retrySvr)
   104  }
   105  
   106  // Start listener
   107  func (server *Server) Start() (err error) {
   108  	portStr := fmt.Sprintf("%d", server.port)
   109  	hostPortStr := net.JoinHostPort(server.ipaddr, portStr)
   110  
   111  	tlsConfig := &tls.Config{
   112  		Certificates: []tls.Certificate{server.Creds.serverTLSCertificate},
   113  	}
   114  
   115  	listenConfig := &net.ListenConfig{KeepAlive: server.keepAlivePeriod}
   116  	server.netListener, err = listenConfig.Listen(context.Background(), "tcp", hostPortStr)
   117  	if nil != err {
   118  		err = fmt.Errorf("tls.Listen() failed: %v", err)
   119  		return
   120  	}
   121  
   122  	server.tlsListener = tls.NewListener(server.netListener, tlsConfig)
   123  
   124  	server.listenersWG.Add(1)
   125  
   126  	// Some of the unit tests disable starting trimmers
   127  	if !server.dontStartTrimmers {
   128  		// Start ticker which removes older completedRequests
   129  		server.completedLongTicker = time.NewTicker(server.completedLongTTL)
   130  		// Start ticker which removes requests already ACKed by client
   131  		server.completedShortTicker = time.NewTicker(server.completedAckTrim)
   132  	}
   133  	server.completedDoneWG.Add(1)
   134  	if !server.dontStartTrimmers {
   135  		go func() {
   136  			for {
   137  				select {
   138  				case <-server.completedTickerDone:
   139  					server.completedDoneWG.Done()
   140  					return
   141  				case tl := <-server.completedLongTicker.C:
   142  					server.trimCompleted(tl, true)
   143  				case ts := <-server.completedShortTicker.C:
   144  					server.trimCompleted(ts, false)
   145  				}
   146  			}
   147  		}()
   148  	} else {
   149  		go func() {
   150  			for {
   151  				select {
   152  				case <-server.completedTickerDone:
   153  					server.completedDoneWG.Done()
   154  					return
   155  				}
   156  			}
   157  		}()
   158  	}
   159  
   160  	return err
   161  }
   162  
   163  // Run server loop, accept connections, read request, run RPC method and
   164  // return the results.
   165  func (server *Server) Run() {
   166  	server.goroutineWG.Add(1)
   167  	go server.run()
   168  }
   169  
   170  // SendCallback sends a message to clientID so that clientID contacts
   171  // the RPC server.
   172  //
   173  // The assumption is that this callback only gets called when the server has
   174  // an async message for the client
   175  //
   176  // The message is "best effort" - if we fail to write on socket then the
   177  // message is silently dropped on floor.
   178  func (server *Server) SendCallback(clientID string, msg []byte) {
   179  
   180  	// TODO - what if client no longer in list of current clients?
   181  	var (
   182  		localIOR ioReply
   183  	)
   184  	server.Lock()
   185  	lci, ok := server.perClientInfo[clientID]
   186  	if !ok {
   187  		fmt.Printf("SERVER: SendCallback() - unable to find client UniqueID: %v\n", clientID)
   188  		server.Unlock()
   189  		return
   190  	}
   191  	server.Unlock()
   192  
   193  	lci.Lock()
   194  	currentCtx := lci.cCtx
   195  	lci.Unlock()
   196  
   197  	localIOR.JResult = msg
   198  	setupHdrReply(&localIOR, Upcall)
   199  
   200  	server.returnResults(&localIOR, currentCtx)
   201  }
   202  
   203  // Close stops the server
   204  func (server *Server) Close() {
   205  	server.Lock()
   206  	server.halting = true
   207  	server.Unlock()
   208  
   209  	err := server.tlsListener.Close()
   210  	if err != nil {
   211  		logger.Errorf("server.tlsListener.Close() returned err: %v", err)
   212  	}
   213  
   214  	server.listenersWG.Wait()
   215  
   216  	server.goroutineWG.Wait()
   217  
   218  	// Now close the client sockets to wakeup them up
   219  	server.closeClientConn()
   220  
   221  	if !server.dontStartTrimmers {
   222  		server.completedLongTicker.Stop()
   223  		server.completedShortTicker.Stop()
   224  	}
   225  	server.completedTickerDone <- true
   226  	server.completedDoneWG.Wait()
   227  
   228  	// Cleanup bucketstats so that unit tests can run
   229  	for _, ci := range server.perClientInfo {
   230  		ci.Lock()
   231  		bucketstats.UnRegister("proxyfs.retryrpc", ci.myUniqueID)
   232  		ci.Unlock()
   233  
   234  	}
   235  }
   236  
   237  // CloseClientConn - This is debug code to cause some connections to be closed
   238  // It is called from a stress test case to cause retransmits
   239  func (server *Server) CloseClientConn() {
   240  	if server == nil {
   241  		return
   242  	}
   243  	server.connLock.Lock()
   244  	for c := server.connections.Front(); c != nil; c = c.Next() {
   245  		conn := c.Value.(net.Conn)
   246  		/* DEBUG code
   247  		fmt.Printf("SERVER - closing localaddr conn: %v remoteaddr: %v\n", conn.LocalAddr().String(), conn.RemoteAddr().String())
   248  		*/
   249  		conn.Close()
   250  	}
   251  	server.connLock.Unlock()
   252  }
   253  
   254  // CompletedCnt returns count of pendingRequests
   255  //
   256  // This is only useful for testing.
   257  func (server *Server) CompletedCnt() (totalCnt int) {
   258  	for _, v := range server.perClientInfo {
   259  		totalCnt += v.completedCnt()
   260  	}
   261  	return
   262  }
   263  
   264  // Client methods
   265  type clientState int
   266  
   267  const (
   268  	// INITIAL means the Client struct has just been created
   269  	INITIAL clientState = iota + 1
   270  	// DISCONNECTED means the Client has lost the connection to the server
   271  	DISCONNECTED
   272  	// CONNECTED means the Client is connected to the server
   273  	CONNECTED
   274  	// RETRANSMITTING means a goroutine is in the middle of recovering
   275  	// from a loss of a connection with the server
   276  	RETRANSMITTING
   277  )
   278  
   279  type connectionTracker struct {
   280  	state                    clientState
   281  	genNum                   uint64 // Generation number of tlsConn - avoid racing recoveries
   282  	tlsConfig                *tls.Config
   283  	tlsConn                  *tls.Conn // Our connection to the server
   284  	x509CertPool             *x509.CertPool
   285  	rootCAx509CertificatePEM []byte
   286  	hostPortStr              string
   287  }
   288  
   289  // Client tracking structure
   290  type Client struct {
   291  	sync.Mutex
   292  	halting          bool
   293  	currentRequestID requestID // Last request ID - start from clock
   294  	// tick at mount and increment from there?
   295  	// Handle reset of time?
   296  	connection         connectionTracker
   297  	myUniqueID         string      // Unique ID across all clients
   298  	cb                 interface{} // Callbacks to client
   299  	deadlineIO         time.Duration
   300  	keepAlivePeriod    time.Duration
   301  	outstandingRequest map[requestID]*reqCtx // Map of outstanding requests sent
   302  	// or to be sent to server.  Key is assigned from currentRequestID
   303  	highestConsecutive requestID // Highest requestID that can be
   304  	// trimmed
   305  	bt          *btree.BTree   // btree of requestID's acked
   306  	goroutineWG sync.WaitGroup // Used to track outstanding goroutines
   307  	stats       clientSideStatsInfo
   308  }
   309  
   310  // ClientCallbacks contains the methods required when supporting
   311  // callbacks from the Server.
   312  type ClientCallbacks interface {
   313  	Interrupt(payload []byte)
   314  }
   315  
   316  // ClientConfig is used to configure a retryrpc Client
   317  type ClientConfig struct {
   318  	MyUniqueID               string
   319  	IPAddr                   string        // IP Address of Server
   320  	Port                     int           // Port of Server
   321  	RootCAx509CertificatePEM []byte        // Root certificate
   322  	Callbacks                interface{}   // Structure implementing ClientCallbacks
   323  	DeadlineIO               time.Duration // How long I/Os on sockets wait even if idle
   324  	KeepAlivePeriod          time.Duration // How frequently a KEEPALIVE is sent
   325  }
   326  
   327  // TODO - pass loggers to Cient and Server objects
   328  
   329  // NewClient returns a Client structure
   330  //
   331  // If the server wants to send an async message to the client
   332  // it uses the Interrupt method defined in cb
   333  //
   334  // NOTE: It is assumed that if a client calls NewClient(), it will
   335  // always use a unique myUniqueID.   Otherwise, the server may have
   336  // old entries.
   337  //
   338  // TODO - purge cache of old entries on server and/or use different
   339  // starting point for requestID.
   340  func NewClient(config *ClientConfig) (client *Client, err error) {
   341  
   342  	client = &Client{myUniqueID: config.MyUniqueID, cb: config.Callbacks,
   343  		keepAlivePeriod: config.KeepAlivePeriod, deadlineIO: config.DeadlineIO}
   344  	portStr := fmt.Sprintf("%d", config.Port)
   345  	client.connection.state = INITIAL
   346  	client.connection.hostPortStr = net.JoinHostPort(config.IPAddr, portStr)
   347  	client.outstandingRequest = make(map[requestID]*reqCtx)
   348  	client.connection.x509CertPool = x509.NewCertPool()
   349  	client.bt = btree.New(2)
   350  
   351  	// Add cert for root CA to our pool
   352  	ok := client.connection.x509CertPool.AppendCertsFromPEM(config.RootCAx509CertificatePEM)
   353  	if !ok {
   354  		err = fmt.Errorf("x509CertPool.AppendCertsFromPEM() returned !ok")
   355  		return nil, err
   356  	}
   357  
   358  	bucketstats.Register("proxyfs.retryrpc", client.GetStatsGroupName(), &client.stats)
   359  
   360  	return client, err
   361  }
   362  
   363  // Send the request and block until it has completed
   364  func (client *Client) Send(method string, request interface{}, reply interface{}) (err error) {
   365  
   366  	return client.send(method, request, reply)
   367  }
   368  
   369  // GetStatsGroupName returns the bucketstats GroupName for this client
   370  func (client *Client) GetStatsGroupName() (s string) {
   371  
   372  	return clientSideGroupPrefix + client.myUniqueID
   373  }
   374  
   375  // Close gracefully shuts down the client
   376  func (client *Client) Close() {
   377  	// Set halting flag and then close our socket to server.
   378  	// This will cause the blocked getIO() in readReplies() to return.
   379  	client.Lock()
   380  	client.halting = true
   381  	if client.connection.state == CONNECTED {
   382  		client.connection.state = INITIAL
   383  		client.connection.tlsConn.Close()
   384  	}
   385  	client.Unlock()
   386  
   387  	// Wait for the goroutines to return
   388  	client.goroutineWG.Wait()
   389  	bucketstats.UnRegister("proxyfs.retryrpc", client.GetStatsGroupName())
   390  
   391  }