github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/remotestorage/internal/reliable/grpc.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package reliable
    16  
    17  import (
    18  	"context"
    19  	"io"
    20  	"time"
    21  
    22  	"github.com/cenkalti/backoff/v4"
    23  	"golang.org/x/sync/errgroup"
    24  	"google.golang.org/grpc"
    25  )
    26  
    27  type ClientStream[Req, Resp any] interface {
    28  	Send(Req) error
    29  	Recv() (Resp, error)
    30  	CloseSend() error
    31  }
    32  
    33  type ReqClientStream[Req, Resp any] interface {
    34  	ClientStream[Req, Resp]
    35  
    36  	// After a successful |Recv| call, calling |AssociatedReq| returns the
    37  	// request which was associated with the last returned response. This
    38  	// is only safe to do from the same goroutine which is calling |Recv|.
    39  	AssociatedReq() Req
    40  }
    41  
    42  type OpenStreamFunc[Req, Resp any] func(context.Context, ...grpc.CallOption) (ClientStream[Req, Resp], error)
    43  
    44  type CallOptions[Req, Resp any] struct {
    45  	Open     OpenStreamFunc[Req, Resp]
    46  	GrpcOpts []grpc.CallOption
    47  	ErrF     func(error) error
    48  	BackOffF func(context.Context) backoff.BackOff
    49  
    50  	ReadRequestTimeout time.Duration
    51  	DeliverRespTimeout time.Duration
    52  }
    53  
    54  type reqResp[Req, Resp any] struct {
    55  	Req  Req
    56  	Resp Resp
    57  }
    58  
    59  func MakeCall[Req, Resp any](ctx context.Context, opts CallOptions[Req, Resp]) (ReqClientStream[Req, Resp], error) {
    60  	eg, ctx := errgroup.WithContext(ctx)
    61  	ret := &reliableCall[Req, Resp]{
    62  		eg:     eg,
    63  		ctx:    ctx,
    64  		reqCh:  make(chan Req),
    65  		respCh: make(chan reqResp[Req, Resp]),
    66  		opts:   opts,
    67  	}
    68  	eg.Go(func() error {
    69  		return ret.thread()
    70  	})
    71  	return ret, nil
    72  }
    73  
    74  type reliableCall[Req, Resp any] struct {
    75  	eg  *errgroup.Group
    76  	ctx context.Context
    77  
    78  	reqCh  chan Req
    79  	respCh chan reqResp[Req, Resp]
    80  
    81  	opts CallOptions[Req, Resp]
    82  
    83  	associatedReq Req
    84  }
    85  
    86  func (c *reliableCall[Req, Resp]) thread() error {
    87  	bo := c.opts.BackOffF(c.ctx)
    88  
    89  	requests := NewChan(c.reqCh)
    90  	defer requests.Close()
    91  
    92  	sm := &reliableCallStateMachine[Req, Resp]{
    93  		call:     c,
    94  		requests: requests,
    95  		bo:       bo,
    96  	}
    97  
    98  	return sm.run(c.ctx)
    99  }
   100  
   101  func (c *reliableCall[Req, Resp]) Send(r Req) error {
   102  	select {
   103  	case c.reqCh <- r:
   104  		return nil
   105  	case <-c.ctx.Done():
   106  		return c.eg.Wait()
   107  	}
   108  }
   109  
   110  func (c *reliableCall[Req, Resp]) CloseSend() error {
   111  	close(c.reqCh)
   112  	return nil
   113  }
   114  
   115  func (c *reliableCall[Req, Resp]) Recv() (Resp, error) {
   116  	var r reqResp[Req, Resp]
   117  	var ok bool
   118  	select {
   119  	case r, ok = <-c.respCh:
   120  		if !ok {
   121  			return r.Resp, io.EOF
   122  		}
   123  		c.associatedReq = r.Req
   124  		return r.Resp, nil
   125  	case <-c.ctx.Done():
   126  		return r.Resp, c.eg.Wait()
   127  	}
   128  }
   129  
   130  func (c *reliableCall[Req, Resp]) AssociatedReq() Req {
   131  	return c.associatedReq
   132  }