github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/balancer/gcp_interceptor.go (about)

     1  package balancer
     2  
     3  import (
     4  	"context"
     5  	"os"
     6  	"sync"
     7  
     8  	pb "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer/proto"
     9  	"google.golang.org/grpc"
    10  	"google.golang.org/protobuf/encoding/protojson"
    11  )
    12  
    13  const (
    14  	// Default max number of connections is 0, meaning "no limit"
    15  	defaultMaxConn = 0
    16  
    17  	// Default max stream watermark is 100, which is the current stream limit for GFE.
    18  	// Any value >100 will be rounded down to 100.
    19  	defaultMaxStream = 100
    20  )
    21  
    22  type key int
    23  
    24  var gcpKey key
    25  
    26  type poolConfig struct {
    27  	maxConn   uint32
    28  	maxStream uint32
    29  }
    30  
    31  type gcpContext struct {
    32  	affinityCfg *pb.AffinityConfig
    33  	poolCfg     *poolConfig
    34  	// request message used for pre-process of an affinity call
    35  	reqMsg interface{}
    36  	// response message used for post-process of an affinity call
    37  	replyMsg interface{}
    38  }
    39  
    40  // GCPInterceptor provides functions for intercepting client requests
    41  // in order to support GCP specific features
    42  type GCPInterceptor struct {
    43  	poolCfg *poolConfig
    44  
    45  	// Maps method path to AffinityConfig
    46  	methodToAffinity map[string]*pb.AffinityConfig
    47  }
    48  
    49  // NewGCPInterceptor creates a new GCPInterceptor with a given ApiConfig
    50  func NewGCPInterceptor(config *pb.ApiConfig) *GCPInterceptor {
    51  	mp := make(map[string]*pb.AffinityConfig)
    52  	methodCfgs := config.GetMethod()
    53  	for _, methodCfg := range methodCfgs {
    54  		methodNames := methodCfg.GetName()
    55  		affinityCfg := methodCfg.GetAffinity()
    56  		if methodNames != nil && affinityCfg != nil {
    57  			for _, method := range methodNames {
    58  				mp[method] = affinityCfg
    59  			}
    60  		}
    61  	}
    62  
    63  	poolCfg := &poolConfig{
    64  		maxConn:   defaultMaxConn,
    65  		maxStream: defaultMaxStream,
    66  	}
    67  
    68  	userPoolCfg := config.GetChannelPool()
    69  
    70  	// Set user defined MaxSize.
    71  	poolCfg.maxConn = userPoolCfg.GetMaxSize()
    72  
    73  	// Set user defined MaxConcurrentStreamsLowWatermark if ranged in [1, defaultMaxStream],
    74  	// otherwise use the defaultMaxStream.
    75  	watermarkValue := userPoolCfg.GetMaxConcurrentStreamsLowWatermark()
    76  	if watermarkValue >= 1 && watermarkValue <= defaultMaxStream {
    77  		poolCfg.maxStream = watermarkValue
    78  	}
    79  	return &GCPInterceptor{
    80  		poolCfg:          poolCfg,
    81  		methodToAffinity: mp,
    82  	}
    83  }
    84  
    85  // GCPUnaryClientInterceptor intercepts the execution of a unary RPC
    86  // and injects necessary information to be used by the picker.
    87  func (gcpInt *GCPInterceptor) GCPUnaryClientInterceptor(
    88  	ctx context.Context,
    89  	method string,
    90  	req interface{},
    91  	reply interface{},
    92  	cc *grpc.ClientConn,
    93  	invoker grpc.UnaryInvoker,
    94  	opts ...grpc.CallOption,
    95  ) error {
    96  	affinityCfg, _ := gcpInt.methodToAffinity[method]
    97  	gcpCtx := &gcpContext{
    98  		affinityCfg: affinityCfg,
    99  		reqMsg:      req,
   100  		replyMsg:    reply,
   101  		poolCfg:     gcpInt.poolCfg,
   102  	}
   103  	ctx = context.WithValue(ctx, gcpKey, gcpCtx)
   104  
   105  	return invoker(ctx, method, req, reply, cc, opts...)
   106  }
   107  
   108  // GCPStreamClientInterceptor intercepts the execution of a client streaming RPC
   109  // and injects necessary information to be used by the picker.
   110  func (gcpInt *GCPInterceptor) GCPStreamClientInterceptor(
   111  	ctx context.Context,
   112  	desc *grpc.StreamDesc,
   113  	cc *grpc.ClientConn,
   114  	method string,
   115  	streamer grpc.Streamer,
   116  	opts ...grpc.CallOption,
   117  ) (grpc.ClientStream, error) {
   118  	// This constructor does not create a real ClientStream,
   119  	// it only stores all parameters and let SendMsg() to create ClientStream.
   120  	affinityCfg, _ := gcpInt.methodToAffinity[method]
   121  	gcpCtx := &gcpContext{
   122  		affinityCfg: affinityCfg,
   123  		poolCfg:     gcpInt.poolCfg,
   124  	}
   125  	ctx = context.WithValue(ctx, gcpKey, gcpCtx)
   126  	cs := &gcpClientStream{
   127  		gcpInt:   gcpInt,
   128  		ctx:      ctx,
   129  		desc:     desc,
   130  		cc:       cc,
   131  		method:   method,
   132  		streamer: streamer,
   133  		opts:     opts,
   134  	}
   135  	cs.cond = sync.NewCond(cs)
   136  	return cs, nil
   137  }
   138  
   139  type gcpClientStream struct {
   140  	sync.Mutex
   141  	grpc.ClientStream
   142  
   143  	cond          *sync.Cond
   144  	initStreamErr error
   145  	gcpInt        *GCPInterceptor
   146  	ctx           context.Context
   147  	desc          *grpc.StreamDesc
   148  	cc            *grpc.ClientConn
   149  	method        string
   150  	streamer      grpc.Streamer
   151  	opts          []grpc.CallOption
   152  }
   153  
   154  func (cs *gcpClientStream) SendMsg(m interface{}) error {
   155  	cs.Lock()
   156  	// Initialize underlying ClientStream when getting the first request.
   157  	if cs.ClientStream == nil {
   158  		affinityCfg, ok := cs.gcpInt.methodToAffinity[cs.method]
   159  		ctx := cs.ctx
   160  		if ok {
   161  			gcpCtx := &gcpContext{
   162  				affinityCfg: affinityCfg,
   163  				reqMsg:      m,
   164  				poolCfg:     cs.gcpInt.poolCfg,
   165  			}
   166  			ctx = context.WithValue(cs.ctx, gcpKey, gcpCtx)
   167  		}
   168  		realCS, err := cs.streamer(ctx, cs.desc, cs.cc, cs.method, cs.opts...)
   169  		if err != nil {
   170  			cs.initStreamErr = err
   171  			cs.Unlock()
   172  			cs.cond.Broadcast()
   173  			return err
   174  		}
   175  		cs.ClientStream = realCS
   176  	}
   177  	cs.Unlock()
   178  	cs.cond.Broadcast()
   179  	return cs.ClientStream.SendMsg(m)
   180  }
   181  
   182  func (cs *gcpClientStream) RecvMsg(m interface{}) error {
   183  	// If RecvMsg is called before SendMsg, it should wait until cs.ClientStream
   184  	// is initialized or the initialization failed.
   185  	cs.Lock()
   186  	for cs.initStreamErr == nil && cs.ClientStream == nil {
   187  		cs.cond.Wait()
   188  	}
   189  	if cs.initStreamErr != nil {
   190  		cs.Unlock()
   191  		return cs.initStreamErr
   192  	}
   193  	cs.Unlock()
   194  	return cs.ClientStream.RecvMsg(m)
   195  }
   196  
   197  func (cs *gcpClientStream) CloseSend() error {
   198  	cs.Lock()
   199  	defer cs.Unlock()
   200  	if cs.ClientStream != nil {
   201  		return cs.ClientStream.CloseSend()
   202  	}
   203  	return nil
   204  }
   205  
   206  // ParseAPIConfig parses a json config file into ApiConfig proto message.
   207  func ParseAPIConfig(path string) (*pb.ApiConfig, error) {
   208  	jsonFile, err := os.ReadFile(path)
   209  	if err != nil {
   210  		return nil, err
   211  	}
   212  	result := &pb.ApiConfig{}
   213  	protojson.Unmarshal(jsonFile, result)
   214  	return result, nil
   215  }