github.com/cilium/cilium@v1.16.2/pkg/clustermesh/common/interceptor_test.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package common
     5  
     6  import (
     7  	"context"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  	"go.etcd.io/etcd/api/v3/etcdserverpb"
    13  	"google.golang.org/grpc"
    14  )
    15  
    16  type responseType int
    17  
    18  const (
    19  	status responseType = iota
    20  	watch
    21  	leaseKeepAlive
    22  	leaseGrant
    23  	invalid
    24  )
    25  
    26  type mockClientStream struct {
    27  	grpc.ClientStream
    28  	toClient chan *etcdResponse
    29  }
    30  
    31  func newMockClientStream() mockClientStream {
    32  	return mockClientStream{
    33  		toClient: make(chan *etcdResponse),
    34  	}
    35  }
    36  
    37  func (c mockClientStream) RecvMsg(msg interface{}) error {
    38  	return nil
    39  }
    40  
    41  func (c mockClientStream) Send(resp *etcdResponse) error {
    42  	return nil
    43  }
    44  
    45  func newStreamerMock(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    46  	return newMockClientStream(), nil
    47  }
    48  
    49  func (u unaryResponder) recv() etcdResponse {
    50  	var resp unaryResponse
    51  	switch u.rt {
    52  	case status:
    53  		resp = unaryResponse{&etcdserverpb.StatusResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: u.cid}}}
    54  	case leaseGrant:
    55  		resp = unaryResponse{&etcdserverpb.LeaseGrantResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: u.cid}}}
    56  	case invalid:
    57  		resp = unaryResponse{&etcdserverpb.StatusResponse{}}
    58  	}
    59  
    60  	return resp
    61  }
    62  
    63  func (s streamResponder) recv() etcdResponse {
    64  	var resp streamResponse
    65  	switch s.rt {
    66  	case watch:
    67  		resp = streamResponse{&etcdserverpb.WatchResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: s.cid}}}
    68  	case leaseKeepAlive:
    69  		resp = streamResponse{&etcdserverpb.LeaseKeepAliveResponse{Header: &etcdserverpb.ResponseHeader{ClusterId: s.cid}}}
    70  	case invalid:
    71  		resp = streamResponse{&etcdserverpb.WatchResponse{}}
    72  	}
    73  	return resp
    74  
    75  }
    76  
    77  func noopInvoker(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
    78  	return nil
    79  }
    80  
    81  type unaryResponder struct {
    82  	rt       responseType
    83  	cid      uint64
    84  	expError error
    85  }
    86  
    87  func (u unaryResponder) expectedErr() error {
    88  	return u.expError
    89  }
    90  
    91  type unaryResponse struct {
    92  	etcdResponse
    93  }
    94  
    95  type streamResponder struct {
    96  	rt       responseType
    97  	cid      uint64
    98  	expError error
    99  }
   100  
   101  func (s streamResponder) expectedErr() error {
   102  	return s.expError
   103  }
   104  
   105  type streamResponse struct {
   106  	etcdResponse
   107  }
   108  
   109  type mockResponder interface {
   110  	recv() etcdResponse
   111  	expectedErr() error
   112  }
   113  
   114  var maxId uint64 = 0xFFFFFFFFFFFFFFFF
   115  
   116  func TestInterceptors(t *testing.T) {
   117  	tests := []struct {
   118  		name             string
   119  		initialClusterId uint64
   120  		r                []mockResponder
   121  	}{
   122  		{
   123  			name:             "healthy stream responses",
   124  			initialClusterId: 1,
   125  			r: []mockResponder{
   126  				streamResponder{rt: watch, cid: 1, expError: nil},
   127  				streamResponder{rt: watch, cid: 1, expError: nil},
   128  				streamResponder{rt: watch, cid: 1, expError: nil},
   129  			},
   130  		},
   131  		{
   132  			name:             "healthy unary responses",
   133  			initialClusterId: 1,
   134  			r: []mockResponder{
   135  				unaryResponder{rt: leaseGrant, cid: 1, expError: nil},
   136  				unaryResponder{rt: status, cid: 1, expError: nil},
   137  			},
   138  		},
   139  		{
   140  			name:             "healthy stream and unary responses",
   141  			initialClusterId: maxId,
   142  			r: []mockResponder{
   143  				unaryResponder{rt: leaseGrant, cid: maxId, expError: nil},
   144  				unaryResponder{rt: status, cid: maxId, expError: nil},
   145  				streamResponder{rt: watch, cid: maxId, expError: nil},
   146  				unaryResponder{rt: status, cid: maxId, expError: nil},
   147  				streamResponder{rt: watch, cid: maxId, expError: nil},
   148  			},
   149  		},
   150  		{
   151  			name:             "watch response from another cluster",
   152  			initialClusterId: 1,
   153  			r: []mockResponder{
   154  				streamResponder{rt: watch, cid: 1, expError: nil},
   155  				streamResponder{rt: watch, cid: 2, expError: ErrClusterIDChanged},
   156  				streamResponder{rt: watch, cid: 1, expError: nil},
   157  			},
   158  		},
   159  		{
   160  			name:             "status response from another cluster",
   161  			initialClusterId: 1,
   162  			r: []mockResponder{
   163  				streamResponder{rt: watch, cid: 1, expError: nil},
   164  				unaryResponder{rt: status, cid: maxId, expError: ErrClusterIDChanged},
   165  				streamResponder{rt: watch, cid: 1, expError: nil},
   166  			},
   167  		},
   168  		{
   169  			name:             "receive an invalid response with no header",
   170  			initialClusterId: 1,
   171  			r: []mockResponder{
   172  				streamResponder{rt: leaseKeepAlive, cid: 1, expError: nil},
   173  				streamResponder{rt: invalid, cid: 0, expError: ErrEtcdInvalidResponse},
   174  			},
   175  		},
   176  	}
   177  
   178  	for _, tt := range tests {
   179  		t.Run(tt.name, func(t *testing.T) {
   180  			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   181  			defer cancel()
   182  
   183  			cl := newClusterLock()
   184  			checkForError := func() error {
   185  				select {
   186  				case err := <-cl.errors:
   187  					return err
   188  				default:
   189  					return nil
   190  				}
   191  			}
   192  
   193  			si := newStreamInterceptor(cl)
   194  			desc := &grpc.StreamDesc{
   195  				StreamName:    "test",
   196  				Handler:       nil,
   197  				ServerStreams: true,
   198  				ClientStreams: true,
   199  			}
   200  
   201  			cc := &grpc.ClientConn{}
   202  
   203  			stream, err := si(ctx, desc, cc, "test", newStreamerMock)
   204  			require.NoError(t, err)
   205  
   206  			unaryRecvMsg := newUnaryInterceptor(cl)
   207  			for _, responder := range tt.r {
   208  
   209  				switch response := responder.recv().(type) {
   210  				case unaryResponse:
   211  					unaryRecvMsg(ctx, "test", nil, response, cc, noopInvoker)
   212  				case streamResponse:
   213  					stream.RecvMsg(responder.recv())
   214  				}
   215  				require.ErrorIs(t, checkForError(), responder.expectedErr())
   216  				require.Equal(t, tt.initialClusterId, cl.etcdClusterID.Load())
   217  			}
   218  		})
   219  	}
   220  
   221  }