github.com/swiftstack/ProxyFS@v0.0.0-20210203235616-4017c267d62f/retryrpc/api_internal.go (about)

     1  // Copyright (c) 2015-2021, NVIDIA CORPORATION.
     2  // SPDX-License-Identifier: Apache-2.0
     3  
     4  // Package retryrpc provides a client and server RPC model which survives
     5  // lost connections on either the client or the server.
     6  package retryrpc
     7  
     8  import (
     9  	"container/list"
    10  	"crypto/ed25519"
    11  	"crypto/rand"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"crypto/x509/pkix"
    15  	"encoding/binary"
    16  	"encoding/json"
    17  	"encoding/pem"
    18  	"fmt"
    19  	"io"
    20  	"math/big"
    21  	"net"
    22  	"reflect"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/swiftstack/ProxyFS/bucketstats"
    27  	"github.com/swiftstack/ProxyFS/logger"
    28  )
    29  
    30  // PayloadProtocols defines the supported protocols for the payload
    31  type PayloadProtocols int
    32  
    33  // Support payload protocols
    34  const (
    35  	JSON PayloadProtocols = 1
    36  )
    37  
    38  const (
    39  	currentRetryVersion = 1
    40  )
    41  
    42  type requestID uint64
    43  
    44  // Useful stats for the clientInfo instance
    45  type statsInfo struct {
    46  	AddCompleted           bucketstats.Total           // Number added to completed list
    47  	RmCompleted            bucketstats.Total           // Number removed from completed list
    48  	RPCLenUsec             bucketstats.BucketLog2Round // Tracks length of RPCs
    49  	ReplySize              bucketstats.BucketLog2Round // Tracks completed RPC reply size
    50  	longestRPC             time.Duration               // Time of longest RPC
    51  	longestRPCMethod       string                      // Method of longest RPC
    52  	largestReplySize       uint64                      // Tracks largest RPC reply size
    53  	largestReplySizeMethod string                      // Method of largest RPC reply size completed
    54  	RPCattempted           bucketstats.Total           // Number of RPCs attempted - may be completed or in process
    55  	RPCcompleted           bucketstats.Total           // Number of RPCs which completed - incremented after call returns
    56  	RPCretried             bucketstats.Total           // Number of RPCs which were just pulled from completed list
    57  }
    58  
    59  // Server side data structure storing per client information
    60  // such as completed requests, etc
    61  type clientInfo struct {
    62  	sync.Mutex
    63  	cCtx                     *connCtx                      // Current connCtx for client
    64  	myUniqueID               string                        // Unique ID of this client
    65  	completedRequest         map[requestID]*completedEntry // Key: "RequestID"
    66  	completedRequestLRU      *list.List                    // LRU used to remove completed request in ticker
    67  	highestReplySeen         requestID                     // Highest consectutive requestID client has seen
    68  	previousHighestReplySeen requestID                     // Previous highest consectutive requestID client has seen
    69  	stats                    statsInfo
    70  }
    71  
    72  type completedEntry struct {
    73  	reply   *ioReply
    74  	lruElem *list.Element
    75  }
    76  
    77  // connCtx tracks a conn which has been accepted.
    78  //
    79  // It also contains the lock used for serialization when
    80  // reading or writing on the socket.
    81  type connCtx struct {
    82  	sync.Mutex
    83  	conn                net.Conn
    84  	activeRPCsWG        sync.WaitGroup // WaitGroup tracking active RPCs from this client on this connection
    85  	cond                *sync.Cond     // Signal waiting goroutines that serviceClient() has exited
    86  	serviceClientExited bool
    87  	ci                  *clientInfo // Back pointer to the CI
    88  }
    89  
    90  // pendingCtx tracks an individual request from a client
    91  type pendingCtx struct {
    92  	lock sync.Mutex
    93  	buf  []byte   // Request
    94  	cCtx *connCtx // Most recent connection to return results
    95  }
    96  
    97  // methodArgs defines the method provided by the RPC server
    98  // as well as the request type and reply type arguments
    99  type methodArgs struct {
   100  	methodPtr *reflect.Method
   101  	request   reflect.Type
   102  	reply     reflect.Type
   103  }
   104  
   105  // completedLRUEntry tracks time entry was completed for
   106  // expiration from cache
   107  type completedLRUEntry struct {
   108  	requestID     requestID
   109  	timeCompleted time.Time
   110  }
   111  
   112  // Magic number written at the end of the ioHeader.   Used
   113  // to detect if the complete header has been read.
   114  const headerMagic uint32 = 0xCAFEFEED
   115  
   116  // MsgType is the type of message being sent
   117  type MsgType uint16
   118  
   119  const (
   120  	// RPC represents an RPC from client to server
   121  	RPC MsgType = iota + 1
   122  	// Upcall represents an upcall from server to client
   123  	Upcall
   124  	// PassID is the message sent by the client to identify itself to server
   125  	PassID
   126  )
   127  
   128  // ioHeader is the header sent on the socket
   129  type ioHeader struct {
   130  	Len      uint32 // Number of bytes following header
   131  	Protocol uint16
   132  	Version  uint16
   133  	Type     MsgType
   134  	Magic    uint32 // Magic number - if invalid means have not read complete header
   135  }
   136  
   137  // ioRequest tracks fields written on wire
   138  type ioRequest struct {
   139  	Hdr  ioHeader
   140  	JReq []byte // JSON containing request
   141  }
   142  
   143  // ioReply is the structure returned over the wire
   144  type ioReply struct {
   145  	Hdr     ioHeader
   146  	JResult []byte // JSON containing response
   147  }
   148  
   149  // internalSetIDRequest is the structure sent over the wire
   150  // when the connection is first made.   This is how the server
   151  // learns the client ID
   152  type internalSetIDRequest struct {
   153  	Hdr        ioHeader
   154  	MyUniqueID []byte // Client unique ID as byte
   155  }
   156  
   157  type replyCtx struct {
   158  	err error
   159  }
   160  
   161  // reqCtx exists on the client and tracks a request passed to Send()
   162  type reqCtx struct {
   163  	ioreq    ioRequest // Wrapped request passed to Send()
   164  	rpcReply interface{}
   165  	answer   chan replyCtx
   166  	genNum   uint64 // Generation number of socket when request sent
   167  }
   168  
   169  // jsonRequest is used to marshal an RPC request in/out of JSON
   170  type jsonRequest struct {
   171  	MyUniqueID       string         `json:"myuniqueid"`       // ID of client
   172  	RequestID        requestID      `json:"requestid"`        // ID of this request
   173  	HighestReplySeen requestID      `json:"highestReplySeen"` // Used to trim completedRequests on server
   174  	Method           string         `json:"method"`
   175  	Params           [1]interface{} `json:"params"`
   176  }
   177  
   178  // jsonReply is used to marshal an RPC response in/out of JSON
   179  type jsonReply struct {
   180  	MyUniqueID string      `json:"myuniqueid"` // ID of client
   181  	RequestID  requestID   `json:"requestid"`  // ID of this request
   182  	ErrStr     string      `json:"errstr"`
   183  	Result     interface{} `json:"result"`
   184  }
   185  
   186  // svrRequest is used with jsonRequest when we unmarshal the
   187  // parameters passed in an RPC.  This is how we get the rpcReply
   188  // structure specific to the RPC
   189  type svrRequest struct {
   190  	Params [1]interface{} `json:"params"`
   191  }
   192  
   193  // svrReply is used with jsonReply when we marshal the reply
   194  type svrResponse struct {
   195  	Result interface{} `json:"result"`
   196  }
   197  
   198  func buildIoRequest(jReq jsonRequest) (ioreq *ioRequest, err error) {
   199  	ioreq = &ioRequest{}
   200  	ioreq.JReq, err = json.Marshal(jReq)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	ioreq.Hdr.Len = uint32(len(ioreq.JReq))
   205  	ioreq.Hdr.Protocol = uint16(JSON)
   206  	ioreq.Hdr.Version = currentRetryVersion
   207  	ioreq.Hdr.Type = RPC
   208  	ioreq.Hdr.Magic = headerMagic
   209  	return
   210  }
   211  
   212  func setupHdrReply(ioreply *ioReply, t MsgType) {
   213  	ioreply.Hdr.Len = uint32(len(ioreply.JResult))
   214  	ioreply.Hdr.Protocol = uint16(JSON)
   215  	ioreply.Hdr.Version = currentRetryVersion
   216  	ioreply.Hdr.Type = t
   217  	ioreply.Hdr.Magic = headerMagic
   218  	return
   219  }
   220  
   221  func buildSetIDRequest(myUniqueID string) (isreq *internalSetIDRequest, err error) {
   222  	isreq = &internalSetIDRequest{}
   223  	isreq.MyUniqueID, err = json.Marshal(myUniqueID)
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	isreq.Hdr.Len = uint32(len(isreq.MyUniqueID))
   228  	isreq.Hdr.Protocol = uint16(JSON)
   229  	isreq.Hdr.Version = currentRetryVersion
   230  	isreq.Hdr.Type = PassID
   231  	isreq.Hdr.Magic = headerMagic
   232  	return
   233  }
   234  
   235  func getIO(genNum uint64, deadlineIO time.Duration, conn net.Conn) (buf []byte, msgType MsgType, err error) {
   236  	if printDebugLogs {
   237  		logger.Infof("conn: %v", conn)
   238  	}
   239  
   240  	// Read in the header of the request first
   241  	var hdr ioHeader
   242  
   243  	conn.SetDeadline(time.Now().Add(deadlineIO))
   244  	err = binary.Read(conn, binary.BigEndian, &hdr)
   245  	if err != nil {
   246  		return
   247  	}
   248  
   249  	if hdr.Magic != headerMagic {
   250  		err = fmt.Errorf("Incomplete read of header")
   251  		return
   252  	}
   253  
   254  	if hdr.Len == 0 {
   255  		err = fmt.Errorf("hdr.Len == 0")
   256  		return
   257  	}
   258  	msgType = hdr.Type
   259  
   260  	// Now read the rest of the structure off the wire.
   261  	var numBytes int
   262  	buf = make([]byte, hdr.Len)
   263  	conn.SetDeadline(time.Now().Add(deadlineIO))
   264  	numBytes, err = io.ReadFull(conn, buf)
   265  	if err != nil {
   266  		err = fmt.Errorf("Incomplete read of body")
   267  		return
   268  	}
   269  
   270  	if hdr.Len != uint32(numBytes) {
   271  		err = fmt.Errorf("Incomplete read of body")
   272  		return
   273  	}
   274  
   275  	return
   276  }
   277  
   278  // constructServerCreds will generate root CA cert and server cert
   279  //
   280  // It is assumed that this is called on the "server" process and
   281  // the caller will provide a mechanism to pass
   282  // serverCreds.rootCAx509CertificatePEMkeys to the "clients".
   283  func constructServerCreds(serverIPAddrAsString string) (serverCreds *ServerCreds, err error) {
   284  	var (
   285  		commonX509NotAfter            time.Time
   286  		commonX509NotBefore           time.Time
   287  		rootCAEd25519PrivateKey       ed25519.PrivateKey
   288  		rootCAEd25519PublicKey        ed25519.PublicKey
   289  		rootCAx509CertificateDER      []byte
   290  		rootCAx509CertificateTemplate *x509.Certificate
   291  		rootCAx509SerialNumber        *big.Int
   292  		serverEd25519PrivateKey       ed25519.PrivateKey
   293  		serverEd25519PrivateKeyDER    []byte
   294  		serverEd25519PrivateKeyPEM    []byte
   295  		serverEd25519PublicKey        ed25519.PublicKey
   296  		serverX509CertificateDER      []byte
   297  		serverX509CertificatePEM      []byte
   298  		serverX509CertificateTemplate *x509.Certificate
   299  		serverX509SerialNumber        *big.Int
   300  		timeNow                       time.Time
   301  	)
   302  
   303  	serverCreds = &ServerCreds{}
   304  
   305  	timeNow = time.Now()
   306  
   307  	// TODO - what should the length of this be?  What if we want to eject a client
   308  	// from the server?  How would that work?
   309  	//
   310  	// Do we even want the root CA at all?
   311  	commonX509NotBefore = time.Date(timeNow.Year()-1, time.January, 1, 0, 0, 0, 0, timeNow.Location())
   312  	commonX509NotAfter = time.Date(timeNow.Year()+99, time.January, 1, 0, 0, 0, 0, timeNow.Location())
   313  
   314  	rootCAx509SerialNumber, err = rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
   315  	if err != nil {
   316  		err = fmt.Errorf("rand.Int() [1] failed: %v", err)
   317  		return
   318  	}
   319  
   320  	rootCAx509CertificateTemplate = &x509.Certificate{
   321  		SerialNumber: rootCAx509SerialNumber,
   322  		Subject: pkix.Name{
   323  			Organization: []string{"CA Organization"},
   324  			CommonName:   "Root CA",
   325  		},
   326  		NotBefore:             commonX509NotBefore,
   327  		NotAfter:              commonX509NotAfter,
   328  		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   329  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   330  		BasicConstraintsValid: true,
   331  		IsCA:                  true,
   332  	}
   333  
   334  	// Generate public and private key
   335  	rootCAEd25519PublicKey, rootCAEd25519PrivateKey, err = ed25519.GenerateKey(nil)
   336  	if err != nil {
   337  		err = fmt.Errorf("ed25519.GenerateKey() [1] failed: %v", err)
   338  		return
   339  	}
   340  
   341  	// Create the certificate with the keys
   342  	rootCAx509CertificateDER, err = x509.CreateCertificate(rand.Reader,
   343  		rootCAx509CertificateTemplate, rootCAx509CertificateTemplate, rootCAEd25519PublicKey, rootCAEd25519PrivateKey)
   344  	if err != nil {
   345  		err = fmt.Errorf("x509.CreateCertificate() [1] failed: %v", err)
   346  		return
   347  	}
   348  
   349  	serverCreds.RootCAx509CertificatePEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: rootCAx509CertificateDER})
   350  
   351  	serverX509SerialNumber, err = rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
   352  	if err != nil {
   353  		err = fmt.Errorf("rand.Int() [2] failed: %v", err)
   354  		return
   355  	}
   356  
   357  	serverX509CertificateTemplate = &x509.Certificate{
   358  		SerialNumber: serverX509SerialNumber,
   359  		Subject: pkix.Name{
   360  			Organization: []string{"Server Organization"},
   361  			CommonName:   "Server",
   362  		},
   363  		NotBefore:   commonX509NotBefore,
   364  		NotAfter:    commonX509NotAfter,
   365  		KeyUsage:    x509.KeyUsageDigitalSignature,
   366  		ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   367  		IPAddresses: []net.IP{net.ParseIP(serverIPAddrAsString)},
   368  	}
   369  
   370  	// Generate the server public/private keys
   371  	serverEd25519PublicKey, serverEd25519PrivateKey, err = ed25519.GenerateKey(nil)
   372  	if err != nil {
   373  		err = fmt.Errorf("ed25519.GenerateKey() [2] failed: %v", err)
   374  		return
   375  	}
   376  
   377  	// Create the server certificate with the server public/private keys
   378  	serverX509CertificateDER, err = x509.CreateCertificate(rand.Reader, serverX509CertificateTemplate, rootCAx509CertificateTemplate, serverEd25519PublicKey, rootCAEd25519PrivateKey)
   379  	if err != nil {
   380  		err = fmt.Errorf("x509.CreateCertificate() [2] failed: %v", err)
   381  		return
   382  	}
   383  
   384  	serverX509CertificatePEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: serverX509CertificateDER})
   385  
   386  	serverEd25519PrivateKeyDER, err = x509.MarshalPKCS8PrivateKey(serverEd25519PrivateKey)
   387  	if err != nil {
   388  		err = fmt.Errorf("x509.MarshalPKCS8PrivateKey() failed: %v", err)
   389  		return
   390  	}
   391  
   392  	serverEd25519PrivateKeyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: serverEd25519PrivateKeyDER})
   393  
   394  	serverCreds.serverTLSCertificate, err = tls.X509KeyPair(serverX509CertificatePEM, serverEd25519PrivateKeyPEM)
   395  
   396  	return
   397  }