github.com/cilium/cilium@v1.16.2/pkg/clustermesh/common/interceptor.go (about) 1 // SPDX-License-Identifier: Apache-2.0 2 // Copyright Authors of Cilium 3 4 package common 5 6 import ( 7 "context" 8 "errors" 9 "fmt" 10 "sync/atomic" 11 12 "go.etcd.io/etcd/api/v3/etcdserverpb" 13 "google.golang.org/grpc" 14 ) 15 16 var ( 17 ErrClusterIDChanged = errors.New("etcd cluster ID has changed") 18 ErrEtcdInvalidResponse = errors.New("received an invalid etcd response") 19 ) 20 21 // newUnaryInterceptor returns a new unary client interceptor that validates the 22 // cluster ID of any received etcd responses. 23 func newUnaryInterceptor(cl *clusterLock) grpc.UnaryClientInterceptor { 24 return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 25 if err := invoker(ctx, method, req, reply, cc, opts...); err != nil { 26 return err 27 } 28 return validateReply(cl, reply) 29 } 30 } 31 32 // newStreamInterceptor returns a new stream client interceptor that validates 33 // the cluster ID of any received etcd responses. 34 func newStreamInterceptor(cl *clusterLock) grpc.StreamClientInterceptor { 35 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 36 s, err := streamer(ctx, desc, cc, method, opts...) 37 if err != nil { 38 return nil, err 39 } 40 return &wrappedClientStream{ 41 ClientStream: s, 42 clusterLock: cl, 43 }, nil 44 } 45 } 46 47 // wrappedClientStream is a wrapper around a grpc.ClientStream that adds 48 // validation for the etcd cluster ID 49 type wrappedClientStream struct { 50 grpc.ClientStream 51 clusterLock *clusterLock 52 } 53 54 // RecvMsg implements the grpc.ClientStream interface, adding validation for the etcd cluster ID 55 func (w *wrappedClientStream) RecvMsg(m interface{}) error { 56 if err := w.ClientStream.RecvMsg(m); err != nil { 57 return err 58 } 59 60 return validateReply(w.clusterLock, m) 61 } 62 63 type etcdResponse interface { 64 GetHeader() *etcdserverpb.ResponseHeader 65 } 66 67 func validateReply(cl *clusterLock, reply any) error { 68 resp, ok := reply.(etcdResponse) 69 if !ok || resp.GetHeader() == nil { 70 select { 71 case cl.errors <- ErrEtcdInvalidResponse: 72 default: 73 } 74 return ErrEtcdInvalidResponse 75 } 76 77 if err := cl.validateClusterId(resp.GetHeader().ClusterId); err != nil { 78 select { 79 case cl.errors <- err: 80 default: 81 } 82 return err 83 } 84 return nil 85 } 86 87 // clusterLock is a wrapper around an atomic uint64 that can only be set once. It 88 // provides validation for an etcd connection to ensure that it is only used 89 // for the same etcd cluster it was initially connected to. This is to prevent 90 // accidentally connecting to the wrong cluster in a high availability 91 // configuration utilizing mutiple active clusters. 92 type clusterLock struct { 93 etcdClusterID atomic.Uint64 94 errors chan error 95 } 96 97 func newClusterLock() *clusterLock { 98 return &clusterLock{ 99 etcdClusterID: atomic.Uint64{}, 100 errors: make(chan error, 1), 101 } 102 } 103 104 func (c *clusterLock) validateClusterId(clusterId uint64) error { 105 // If the cluster ID has not been set, set it to the received cluster ID 106 c.etcdClusterID.CompareAndSwap(0, clusterId) 107 108 if clusterId != c.etcdClusterID.Load() { 109 return fmt.Errorf("%w: expected %d, got %d", ErrClusterIDChanged, c.etcdClusterID.Load(), clusterId) 110 } 111 return nil 112 }