github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/cluster/interceptors_test.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cluster
    16  
    17  import (
    18  	"context"
    19  	"crypto/ed25519"
    20  	"crypto/rand"
    21  	"net"
    22  	"strconv"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/sirupsen/logrus"
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  	"google.golang.org/grpc"
    31  	"google.golang.org/grpc/codes"
    32  	"google.golang.org/grpc/health/grpc_health_v1"
    33  	"google.golang.org/grpc/metadata"
    34  	"google.golang.org/grpc/status"
    35  	"gopkg.in/square/go-jose.v2"
    36  	"gopkg.in/square/go-jose.v2/jwt"
    37  
    38  	"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
    39  )
    40  
    41  type server struct {
    42  	md metadata.MD
    43  }
    44  
    45  func (s *server) Check(ctx context.Context, req *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) {
    46  	s.md, _ = metadata.FromIncomingContext(ctx)
    47  	return nil, status.Errorf(codes.Unimplemented, "method Check not implemented")
    48  }
    49  
    50  func (s *server) Watch(req *grpc_health_v1.HealthCheckRequest, ss grpc_health_v1.Health_WatchServer) error {
    51  	s.md, _ = metadata.FromIncomingContext(ss.Context())
    52  	return status.Errorf(codes.Unimplemented, "method Watch not implemented")
    53  }
    54  
    55  func noopSetRole(string, int) {
    56  }
    57  
    58  var lgr = logrus.StandardLogger().WithFields(logrus.Fields{})
    59  
    60  var kp jwtauth.KeyProvider
    61  var pub ed25519.PublicKey
    62  var priv ed25519.PrivateKey
    63  
    64  func init() {
    65  	var err error
    66  	pub, priv, err = ed25519.GenerateKey(rand.Reader)
    67  	if err != nil {
    68  		panic(err)
    69  	}
    70  	kp = keyProvider{pub}
    71  }
    72  
    73  type keyProvider struct {
    74  	ed25519.PublicKey
    75  }
    76  
    77  func (p keyProvider) GetKey(string) ([]jose.JSONWebKey, error) {
    78  	return []jose.JSONWebKey{{
    79  		Key:   p.PublicKey,
    80  		KeyID: "1",
    81  	}}, nil
    82  }
    83  
    84  func newJWT() string {
    85  	key := jose.SigningKey{Algorithm: jose.EdDSA, Key: priv}
    86  	opts := &jose.SignerOptions{ExtraHeaders: map[jose.HeaderKey]interface{}{
    87  		"kid": "1",
    88  	}}
    89  	signer, err := jose.NewSigner(key, opts)
    90  	if err != nil {
    91  		panic(err)
    92  	}
    93  	jwtBuilder := jwt.Signed(signer)
    94  	jwtBuilder = jwtBuilder.Claims(jwt.Claims{
    95  		Audience: []string{"some_audience"},
    96  		Issuer:   "some_issuer",
    97  		Subject:  "some_subject",
    98  		Expiry:   jwt.NewNumericDate(time.Now().Add(30 * time.Second)),
    99  	})
   100  	res, err := jwtBuilder.CompactSerialize()
   101  	if err != nil {
   102  		panic(err)
   103  	}
   104  	return res
   105  }
   106  
   107  func withClient(t *testing.T, cb func(*testing.T, grpc_health_v1.HealthClient), serveropts []grpc.ServerOption, dialopts []grpc.DialOption) *server {
   108  	addr, err := net.ResolveUnixAddr("unix", "test_grpc.socket")
   109  	require.NoError(t, err)
   110  	lis, err := net.ListenUnix("unix", addr)
   111  	require.NoError(t, err)
   112  
   113  	var wg sync.WaitGroup
   114  	var srvErr error
   115  	wg.Add(1)
   116  
   117  	srv := grpc.NewServer(serveropts...)
   118  	hs := new(server)
   119  	grpc_health_v1.RegisterHealthServer(srv, hs)
   120  	defer func() {
   121  		if srv != nil {
   122  			srv.GracefulStop()
   123  			wg.Wait()
   124  		}
   125  	}()
   126  
   127  	go func() {
   128  		defer wg.Done()
   129  		srvErr = srv.Serve(lis)
   130  	}()
   131  
   132  	cc, err := grpc.Dial("unix:test_grpc.socket", append([]grpc.DialOption{grpc.WithInsecure()}, dialopts...)...)
   133  	require.NoError(t, err)
   134  	client := grpc_health_v1.NewHealthClient(cc)
   135  
   136  	cb(t, client)
   137  
   138  	srv.GracefulStop()
   139  	wg.Wait()
   140  	srv = nil
   141  	require.NoError(t, srvErr)
   142  
   143  	return hs
   144  }
   145  
   146  func outboundCtx(vals ...interface{}) context.Context {
   147  	ctx := context.Background()
   148  	if len(vals) == 0 {
   149  		return metadata.AppendToOutgoingContext(ctx,
   150  			"authorization", "Bearer "+newJWT())
   151  	}
   152  	if len(vals) == 2 {
   153  		return metadata.AppendToOutgoingContext(ctx,
   154  			clusterRoleHeader, string(vals[0].(Role)),
   155  			clusterRoleEpochHeader, strconv.Itoa(vals[1].(int)),
   156  			"authorization", "Bearer "+newJWT())
   157  	}
   158  	panic("bad test --- outboundCtx must take 0 or 2 values")
   159  }
   160  
   161  func TestServerInterceptorUnauthenticatedWithoutClientHeaders(t *testing.T) {
   162  	var si serverinterceptor
   163  	si.roleSetter = noopSetRole
   164  	si.lgr = lgr
   165  	si.setRole(RoleStandby, 10)
   166  	si.keyProvider = kp
   167  	t.Run("Standby", func(t *testing.T) {
   168  		withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
   169  			_, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
   170  			assert.Equal(t, codes.Unauthenticated, status.Code(err))
   171  			srv, err := client.Watch(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
   172  			assert.NoError(t, err)
   173  			_, err = srv.Recv()
   174  			assert.Equal(t, codes.Unauthenticated, status.Code(err))
   175  		}, si.Options(), nil)
   176  	})
   177  	si.setRole(RolePrimary, 10)
   178  	t.Run("Primary", func(t *testing.T) {
   179  		withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
   180  			_, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
   181  			assert.Equal(t, codes.Unauthenticated, status.Code(err))
   182  			srv, err := client.Watch(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
   183  			assert.NoError(t, err)
   184  			_, err = srv.Recv()
   185  			assert.Equal(t, codes.Unauthenticated, status.Code(err))
   186  		}, si.Options(), nil)
   187  	})
   188  }
   189  
   190  func TestServerInterceptorAddsUnaryResponseHeaders(t *testing.T) {
   191  	var si serverinterceptor
   192  	si.setRole(RoleStandby, 10)
   193  	si.roleSetter = noopSetRole
   194  	si.lgr = lgr
   195  	si.keyProvider = kp
   196  	withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
   197  		var md metadata.MD
   198  		_, err := client.Check(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md))
   199  		assert.Equal(t, codes.Unimplemented, status.Code(err))
   200  		if assert.Len(t, md.Get(clusterRoleHeader), 1) {
   201  			assert.Equal(t, "standby", md.Get(clusterRoleHeader)[0])
   202  		}
   203  		if assert.Len(t, md.Get(clusterRoleEpochHeader), 1) {
   204  			assert.Equal(t, "10", md.Get(clusterRoleEpochHeader)[0])
   205  		}
   206  	}, si.Options(), nil)
   207  }
   208  
   209  func TestServerInterceptorAddsStreamResponseHeaders(t *testing.T) {
   210  	var si serverinterceptor
   211  	si.setRole(RoleStandby, 10)
   212  	si.roleSetter = noopSetRole
   213  	si.lgr = lgr
   214  	si.keyProvider = kp
   215  	withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
   216  		var md metadata.MD
   217  		srv, err := client.Watch(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md))
   218  		require.NoError(t, err)
   219  		_, err = srv.Recv()
   220  		assert.Equal(t, codes.Unimplemented, status.Code(err))
   221  		if assert.Len(t, md.Get(clusterRoleHeader), 1) {
   222  			assert.Equal(t, "standby", md.Get(clusterRoleHeader)[0])
   223  		}
   224  		if assert.Len(t, md.Get(clusterRoleEpochHeader), 1) {
   225  			assert.Equal(t, "10", md.Get(clusterRoleEpochHeader)[0])
   226  		}
   227  	}, si.Options(), nil)
   228  }
   229  
   230  func TestServerInterceptorAsPrimaryDoesNotSendRequest(t *testing.T) {
   231  	var si serverinterceptor
   232  	si.setRole(RolePrimary, 10)
   233  	si.roleSetter = noopSetRole
   234  	si.lgr = lgr
   235  	si.keyProvider = kp
   236  	srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
   237  		ctx := metadata.AppendToOutgoingContext(outboundCtx(RoleStandby, 10), "test-header", "test-header-value")
   238  		_, err := client.Check(ctx, &grpc_health_v1.HealthCheckRequest{})
   239  		assert.Equal(t, codes.FailedPrecondition, status.Code(err))
   240  		ctx = metadata.AppendToOutgoingContext(outboundCtx(RoleStandby, 10), "test-header", "test-header-value")
   241  		ss, err := client.Watch(ctx, &grpc_health_v1.HealthCheckRequest{})
   242  		assert.NoError(t, err)
   243  		_, err = ss.Recv()
   244  		assert.Equal(t, codes.FailedPrecondition, status.Code(err))
   245  	}, si.Options(), nil)
   246  	assert.Nil(t, srv.md)
   247  }
   248  
   249  func TestClientInterceptorAddsUnaryRequestHeaders(t *testing.T) {
   250  	var ci clientinterceptor
   251  	ci.setRole(RolePrimary, 10)
   252  	ci.roleSetter = noopSetRole
   253  	ci.lgr = lgr
   254  	srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
   255  		_, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
   256  		assert.Equal(t, codes.Unimplemented, status.Code(err))
   257  	}, nil, ci.Options())
   258  	if assert.Len(t, srv.md.Get(clusterRoleHeader), 1) {
   259  		assert.Equal(t, "primary", srv.md.Get(clusterRoleHeader)[0])
   260  	}
   261  	if assert.Len(t, srv.md.Get(clusterRoleEpochHeader), 1) {
   262  		assert.Equal(t, "10", srv.md.Get(clusterRoleEpochHeader)[0])
   263  	}
   264  }
   265  
   266  func TestClientInterceptorAddsStreamRequestHeaders(t *testing.T) {
   267  	var ci clientinterceptor
   268  	ci.setRole(RolePrimary, 10)
   269  	ci.roleSetter = noopSetRole
   270  	ci.lgr = lgr
   271  	srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
   272  		srv, err := client.Watch(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
   273  		require.NoError(t, err)
   274  		_, err = srv.Recv()
   275  		assert.Equal(t, codes.Unimplemented, status.Code(err))
   276  	}, nil, ci.Options())
   277  	if assert.Len(t, srv.md.Get(clusterRoleHeader), 1) {
   278  		assert.Equal(t, "primary", srv.md.Get(clusterRoleHeader)[0])
   279  	}
   280  	if assert.Len(t, srv.md.Get(clusterRoleEpochHeader), 1) {
   281  		assert.Equal(t, "10", srv.md.Get(clusterRoleEpochHeader)[0])
   282  	}
   283  }
   284  
   285  func TestClientInterceptorAsStandbyDoesNotSendRequest(t *testing.T) {
   286  	var ci clientinterceptor
   287  	ci.setRole(RolePrimary, 10)
   288  	ci.roleSetter = noopSetRole
   289  	ci.lgr = lgr
   290  	srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
   291  		_, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
   292  		assert.Equal(t, codes.Unimplemented, status.Code(err))
   293  		ci.setRole(RoleStandby, 11)
   294  		_, err = client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
   295  		assert.Equal(t, codes.FailedPrecondition, status.Code(err))
   296  		_, err = client.Watch(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
   297  		assert.Equal(t, codes.FailedPrecondition, status.Code(err))
   298  	}, nil, ci.Options())
   299  	if assert.Len(t, srv.md.Get(clusterRoleHeader), 1) {
   300  		assert.Equal(t, "primary", srv.md.Get(clusterRoleHeader)[0])
   301  	}
   302  	if assert.Len(t, srv.md.Get(clusterRoleEpochHeader), 1) {
   303  		assert.Equal(t, "10", srv.md.Get(clusterRoleEpochHeader)[0])
   304  	}
   305  }