github.com/cilium/cilium@v1.16.2/pkg/clustermesh/common/interceptor.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package common
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"sync/atomic"
    11  
    12  	"go.etcd.io/etcd/api/v3/etcdserverpb"
    13  	"google.golang.org/grpc"
    14  )
    15  
    16  var (
    17  	ErrClusterIDChanged    = errors.New("etcd cluster ID has changed")
    18  	ErrEtcdInvalidResponse = errors.New("received an invalid etcd response")
    19  )
    20  
    21  // newUnaryInterceptor returns a new unary client interceptor that validates the
    22  // cluster ID of any received etcd responses.
    23  func newUnaryInterceptor(cl *clusterLock) grpc.UnaryClientInterceptor {
    24  	return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    25  		if err := invoker(ctx, method, req, reply, cc, opts...); err != nil {
    26  			return err
    27  		}
    28  		return validateReply(cl, reply)
    29  	}
    30  }
    31  
    32  // newStreamInterceptor returns a new stream client interceptor that validates
    33  // the cluster ID of any received etcd responses.
    34  func newStreamInterceptor(cl *clusterLock) grpc.StreamClientInterceptor {
    35  	return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    36  		s, err := streamer(ctx, desc, cc, method, opts...)
    37  		if err != nil {
    38  			return nil, err
    39  		}
    40  		return &wrappedClientStream{
    41  			ClientStream: s,
    42  			clusterLock:  cl,
    43  		}, nil
    44  	}
    45  }
    46  
    47  // wrappedClientStream is a wrapper around a grpc.ClientStream that adds
    48  // validation for the etcd cluster ID
    49  type wrappedClientStream struct {
    50  	grpc.ClientStream
    51  	clusterLock *clusterLock
    52  }
    53  
    54  // RecvMsg implements the grpc.ClientStream interface, adding validation for the etcd cluster ID
    55  func (w *wrappedClientStream) RecvMsg(m interface{}) error {
    56  	if err := w.ClientStream.RecvMsg(m); err != nil {
    57  		return err
    58  	}
    59  
    60  	return validateReply(w.clusterLock, m)
    61  }
    62  
    63  type etcdResponse interface {
    64  	GetHeader() *etcdserverpb.ResponseHeader
    65  }
    66  
    67  func validateReply(cl *clusterLock, reply any) error {
    68  	resp, ok := reply.(etcdResponse)
    69  	if !ok || resp.GetHeader() == nil {
    70  		select {
    71  		case cl.errors <- ErrEtcdInvalidResponse:
    72  		default:
    73  		}
    74  		return ErrEtcdInvalidResponse
    75  	}
    76  
    77  	if err := cl.validateClusterId(resp.GetHeader().ClusterId); err != nil {
    78  		select {
    79  		case cl.errors <- err:
    80  		default:
    81  		}
    82  		return err
    83  	}
    84  	return nil
    85  }
    86  
    87  // clusterLock is a wrapper around an atomic uint64 that can only be set once. It
    88  // provides validation for an etcd connection to ensure that it is only used
    89  // for the same etcd cluster it was initially connected to. This is to prevent
    90  // accidentally connecting to the wrong cluster in a high availability
    91  // configuration utilizing mutiple active clusters.
    92  type clusterLock struct {
    93  	etcdClusterID atomic.Uint64
    94  	errors        chan error
    95  }
    96  
    97  func newClusterLock() *clusterLock {
    98  	return &clusterLock{
    99  		etcdClusterID: atomic.Uint64{},
   100  		errors:        make(chan error, 1),
   101  	}
   102  }
   103  
   104  func (c *clusterLock) validateClusterId(clusterId uint64) error {
   105  	// If the cluster ID has not been set, set it to the received cluster ID
   106  	c.etcdClusterID.CompareAndSwap(0, clusterId)
   107  
   108  	if clusterId != c.etcdClusterID.Load() {
   109  		return fmt.Errorf("%w: expected %d, got %d", ErrClusterIDChanged, c.etcdClusterID.Load(), clusterId)
   110  	}
   111  	return nil
   112  }