github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/cluster/interceptors_test.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package cluster 16 17 import ( 18 "context" 19 "crypto/ed25519" 20 "crypto/rand" 21 "net" 22 "strconv" 23 "sync" 24 "testing" 25 "time" 26 27 "github.com/sirupsen/logrus" 28 "github.com/stretchr/testify/assert" 29 "github.com/stretchr/testify/require" 30 "google.golang.org/grpc" 31 "google.golang.org/grpc/codes" 32 "google.golang.org/grpc/health/grpc_health_v1" 33 "google.golang.org/grpc/metadata" 34 "google.golang.org/grpc/status" 35 "gopkg.in/square/go-jose.v2" 36 "gopkg.in/square/go-jose.v2/jwt" 37 38 "github.com/dolthub/dolt/go/libraries/utils/jwtauth" 39 ) 40 41 type server struct { 42 md metadata.MD 43 } 44 45 func (s *server) Check(ctx context.Context, req *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) { 46 s.md, _ = metadata.FromIncomingContext(ctx) 47 return nil, status.Errorf(codes.Unimplemented, "method Check not implemented") 48 } 49 50 func (s *server) Watch(req *grpc_health_v1.HealthCheckRequest, ss grpc_health_v1.Health_WatchServer) error { 51 s.md, _ = metadata.FromIncomingContext(ss.Context()) 52 return status.Errorf(codes.Unimplemented, "method Watch not implemented") 53 } 54 55 func noopSetRole(string, int) { 56 } 57 58 var lgr = logrus.StandardLogger().WithFields(logrus.Fields{}) 59 60 var kp jwtauth.KeyProvider 61 var pub ed25519.PublicKey 62 var priv ed25519.PrivateKey 63 64 func init() { 65 var err error 66 pub, priv, err = ed25519.GenerateKey(rand.Reader) 67 if err != nil { 68 panic(err) 69 } 70 kp = keyProvider{pub} 71 } 72 73 type keyProvider struct { 74 ed25519.PublicKey 75 } 76 77 func (p keyProvider) GetKey(string) ([]jose.JSONWebKey, error) { 78 return []jose.JSONWebKey{{ 79 Key: p.PublicKey, 80 KeyID: "1", 81 }}, nil 82 } 83 84 func newJWT() string { 85 key := jose.SigningKey{Algorithm: jose.EdDSA, Key: priv} 86 opts := &jose.SignerOptions{ExtraHeaders: map[jose.HeaderKey]interface{}{ 87 "kid": "1", 88 }} 89 signer, err := jose.NewSigner(key, opts) 90 if err != nil { 91 panic(err) 92 } 93 jwtBuilder := jwt.Signed(signer) 94 jwtBuilder = jwtBuilder.Claims(jwt.Claims{ 95 Audience: []string{"some_audience"}, 96 Issuer: "some_issuer", 97 Subject: "some_subject", 98 Expiry: jwt.NewNumericDate(time.Now().Add(30 * time.Second)), 99 }) 100 res, err := jwtBuilder.CompactSerialize() 101 if err != nil { 102 panic(err) 103 } 104 return res 105 } 106 107 func withClient(t *testing.T, cb func(*testing.T, grpc_health_v1.HealthClient), serveropts []grpc.ServerOption, dialopts []grpc.DialOption) *server { 108 addr, err := net.ResolveUnixAddr("unix", "test_grpc.socket") 109 require.NoError(t, err) 110 lis, err := net.ListenUnix("unix", addr) 111 require.NoError(t, err) 112 113 var wg sync.WaitGroup 114 var srvErr error 115 wg.Add(1) 116 117 srv := grpc.NewServer(serveropts...) 118 hs := new(server) 119 grpc_health_v1.RegisterHealthServer(srv, hs) 120 defer func() { 121 if srv != nil { 122 srv.GracefulStop() 123 wg.Wait() 124 } 125 }() 126 127 go func() { 128 defer wg.Done() 129 srvErr = srv.Serve(lis) 130 }() 131 132 cc, err := grpc.Dial("unix:test_grpc.socket", append([]grpc.DialOption{grpc.WithInsecure()}, dialopts...)...) 133 require.NoError(t, err) 134 client := grpc_health_v1.NewHealthClient(cc) 135 136 cb(t, client) 137 138 srv.GracefulStop() 139 wg.Wait() 140 srv = nil 141 require.NoError(t, srvErr) 142 143 return hs 144 } 145 146 func outboundCtx(vals ...interface{}) context.Context { 147 ctx := context.Background() 148 if len(vals) == 0 { 149 return metadata.AppendToOutgoingContext(ctx, 150 "authorization", "Bearer "+newJWT()) 151 } 152 if len(vals) == 2 { 153 return metadata.AppendToOutgoingContext(ctx, 154 clusterRoleHeader, string(vals[0].(Role)), 155 clusterRoleEpochHeader, strconv.Itoa(vals[1].(int)), 156 "authorization", "Bearer "+newJWT()) 157 } 158 panic("bad test --- outboundCtx must take 0 or 2 values") 159 } 160 161 func TestServerInterceptorUnauthenticatedWithoutClientHeaders(t *testing.T) { 162 var si serverinterceptor 163 si.roleSetter = noopSetRole 164 si.lgr = lgr 165 si.setRole(RoleStandby, 10) 166 si.keyProvider = kp 167 t.Run("Standby", func(t *testing.T) { 168 withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) { 169 _, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{}) 170 assert.Equal(t, codes.Unauthenticated, status.Code(err)) 171 srv, err := client.Watch(outboundCtx(), &grpc_health_v1.HealthCheckRequest{}) 172 assert.NoError(t, err) 173 _, err = srv.Recv() 174 assert.Equal(t, codes.Unauthenticated, status.Code(err)) 175 }, si.Options(), nil) 176 }) 177 si.setRole(RolePrimary, 10) 178 t.Run("Primary", func(t *testing.T) { 179 withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) { 180 _, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{}) 181 assert.Equal(t, codes.Unauthenticated, status.Code(err)) 182 srv, err := client.Watch(outboundCtx(), &grpc_health_v1.HealthCheckRequest{}) 183 assert.NoError(t, err) 184 _, err = srv.Recv() 185 assert.Equal(t, codes.Unauthenticated, status.Code(err)) 186 }, si.Options(), nil) 187 }) 188 } 189 190 func TestServerInterceptorAddsUnaryResponseHeaders(t *testing.T) { 191 var si serverinterceptor 192 si.setRole(RoleStandby, 10) 193 si.roleSetter = noopSetRole 194 si.lgr = lgr 195 si.keyProvider = kp 196 withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) { 197 var md metadata.MD 198 _, err := client.Check(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md)) 199 assert.Equal(t, codes.Unimplemented, status.Code(err)) 200 if assert.Len(t, md.Get(clusterRoleHeader), 1) { 201 assert.Equal(t, "standby", md.Get(clusterRoleHeader)[0]) 202 } 203 if assert.Len(t, md.Get(clusterRoleEpochHeader), 1) { 204 assert.Equal(t, "10", md.Get(clusterRoleEpochHeader)[0]) 205 } 206 }, si.Options(), nil) 207 } 208 209 func TestServerInterceptorAddsStreamResponseHeaders(t *testing.T) { 210 var si serverinterceptor 211 si.setRole(RoleStandby, 10) 212 si.roleSetter = noopSetRole 213 si.lgr = lgr 214 si.keyProvider = kp 215 withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) { 216 var md metadata.MD 217 srv, err := client.Watch(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md)) 218 require.NoError(t, err) 219 _, err = srv.Recv() 220 assert.Equal(t, codes.Unimplemented, status.Code(err)) 221 if assert.Len(t, md.Get(clusterRoleHeader), 1) { 222 assert.Equal(t, "standby", md.Get(clusterRoleHeader)[0]) 223 } 224 if assert.Len(t, md.Get(clusterRoleEpochHeader), 1) { 225 assert.Equal(t, "10", md.Get(clusterRoleEpochHeader)[0]) 226 } 227 }, si.Options(), nil) 228 } 229 230 func TestServerInterceptorAsPrimaryDoesNotSendRequest(t *testing.T) { 231 var si serverinterceptor 232 si.setRole(RolePrimary, 10) 233 si.roleSetter = noopSetRole 234 si.lgr = lgr 235 si.keyProvider = kp 236 srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) { 237 ctx := metadata.AppendToOutgoingContext(outboundCtx(RoleStandby, 10), "test-header", "test-header-value") 238 _, err := client.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) 239 assert.Equal(t, codes.FailedPrecondition, status.Code(err)) 240 ctx = metadata.AppendToOutgoingContext(outboundCtx(RoleStandby, 10), "test-header", "test-header-value") 241 ss, err := client.Watch(ctx, &grpc_health_v1.HealthCheckRequest{}) 242 assert.NoError(t, err) 243 _, err = ss.Recv() 244 assert.Equal(t, codes.FailedPrecondition, status.Code(err)) 245 }, si.Options(), nil) 246 assert.Nil(t, srv.md) 247 } 248 249 func TestClientInterceptorAddsUnaryRequestHeaders(t *testing.T) { 250 var ci clientinterceptor 251 ci.setRole(RolePrimary, 10) 252 ci.roleSetter = noopSetRole 253 ci.lgr = lgr 254 srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) { 255 _, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{}) 256 assert.Equal(t, codes.Unimplemented, status.Code(err)) 257 }, nil, ci.Options()) 258 if assert.Len(t, srv.md.Get(clusterRoleHeader), 1) { 259 assert.Equal(t, "primary", srv.md.Get(clusterRoleHeader)[0]) 260 } 261 if assert.Len(t, srv.md.Get(clusterRoleEpochHeader), 1) { 262 assert.Equal(t, "10", srv.md.Get(clusterRoleEpochHeader)[0]) 263 } 264 } 265 266 func TestClientInterceptorAddsStreamRequestHeaders(t *testing.T) { 267 var ci clientinterceptor 268 ci.setRole(RolePrimary, 10) 269 ci.roleSetter = noopSetRole 270 ci.lgr = lgr 271 srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) { 272 srv, err := client.Watch(outboundCtx(), &grpc_health_v1.HealthCheckRequest{}) 273 require.NoError(t, err) 274 _, err = srv.Recv() 275 assert.Equal(t, codes.Unimplemented, status.Code(err)) 276 }, nil, ci.Options()) 277 if assert.Len(t, srv.md.Get(clusterRoleHeader), 1) { 278 assert.Equal(t, "primary", srv.md.Get(clusterRoleHeader)[0]) 279 } 280 if assert.Len(t, srv.md.Get(clusterRoleEpochHeader), 1) { 281 assert.Equal(t, "10", srv.md.Get(clusterRoleEpochHeader)[0]) 282 } 283 } 284 285 func TestClientInterceptorAsStandbyDoesNotSendRequest(t *testing.T) { 286 var ci clientinterceptor 287 ci.setRole(RolePrimary, 10) 288 ci.roleSetter = noopSetRole 289 ci.lgr = lgr 290 srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) { 291 _, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{}) 292 assert.Equal(t, codes.Unimplemented, status.Code(err)) 293 ci.setRole(RoleStandby, 11) 294 _, err = client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{}) 295 assert.Equal(t, codes.FailedPrecondition, status.Code(err)) 296 _, err = client.Watch(outboundCtx(), &grpc_health_v1.HealthCheckRequest{}) 297 assert.Equal(t, codes.FailedPrecondition, status.Code(err)) 298 }, nil, ci.Options()) 299 if assert.Len(t, srv.md.Get(clusterRoleHeader), 1) { 300 assert.Equal(t, "primary", srv.md.Get(clusterRoleHeader)[0]) 301 } 302 if assert.Len(t, srv.md.Get(clusterRoleEpochHeader), 1) { 303 assert.Equal(t, "10", srv.md.Get(clusterRoleEpochHeader)[0]) 304 } 305 }