github.com/letsencrypt/boulder@v0.20251208.0/grpc/interceptors_test.go (about) 1 package grpc 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "errors" 8 "fmt" 9 "log" 10 "net" 11 "strconv" 12 "strings" 13 "sync" 14 "testing" 15 "time" 16 17 "github.com/jmhodges/clock" 18 "github.com/prometheus/client_golang/prometheus" 19 "google.golang.org/grpc" 20 "google.golang.org/grpc/balancer/roundrobin" 21 "google.golang.org/grpc/credentials" 22 "google.golang.org/grpc/credentials/insecure" 23 "google.golang.org/grpc/metadata" 24 "google.golang.org/grpc/peer" 25 "google.golang.org/grpc/status" 26 "google.golang.org/protobuf/types/known/durationpb" 27 28 "github.com/letsencrypt/boulder/grpc/test_proto" 29 "github.com/letsencrypt/boulder/metrics" 30 "github.com/letsencrypt/boulder/test" 31 "github.com/letsencrypt/boulder/web" 32 ) 33 34 var fc = clock.NewFake() 35 36 func testHandler(_ context.Context, i any) (any, error) { 37 if i != nil { 38 return nil, errors.New("") 39 } 40 fc.Sleep(time.Second) 41 return nil, nil 42 } 43 44 func testInvoker(_ context.Context, method string, _, _ any, _ *grpc.ClientConn, opts ...grpc.CallOption) error { 45 switch method { 46 case "-service-brokeTest": 47 return errors.New("") 48 case "-service-requesterCanceledTest": 49 return status.Error(1, context.Canceled.Error()) 50 } 51 fc.Sleep(time.Second) 52 return nil 53 } 54 55 func TestServerInterceptor(t *testing.T) { 56 serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) 57 test.AssertNotError(t, err, "creating server metrics") 58 si := newServerMetadataInterceptor(serverMetrics, clock.NewFake()) 59 60 md := metadata.New(map[string]string{clientRequestTimeKey: "0"}) 61 ctxWithMetadata := metadata.NewIncomingContext(context.Background(), md) 62 63 _, err = si.Unary(context.Background(), nil, nil, testHandler) 64 test.AssertError(t, err, "si.intercept didn't fail with a context missing metadata") 65 66 _, err = si.Unary(ctxWithMetadata, nil, nil, testHandler) 67 test.AssertError(t, err, "si.intercept didn't fail with a nil grpc.UnaryServerInfo") 68 69 _, err = si.Unary(ctxWithMetadata, nil, &grpc.UnaryServerInfo{FullMethod: "-service-test"}, testHandler) 70 test.AssertNotError(t, err, "si.intercept failed with a non-nil grpc.UnaryServerInfo") 71 72 _, err = si.Unary(ctxWithMetadata, 0, &grpc.UnaryServerInfo{FullMethod: "brokeTest"}, testHandler) 73 test.AssertError(t, err, "si.intercept didn't fail when handler returned a error") 74 } 75 76 func TestClientInterceptor(t *testing.T) { 77 clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) 78 test.AssertNotError(t, err, "creating client metrics") 79 ci := clientMetadataInterceptor{ 80 timeout: time.Second, 81 metrics: clientMetrics, 82 clk: clock.NewFake(), 83 } 84 85 err = ci.Unary(context.Background(), "-service-test", nil, nil, nil, testInvoker) 86 test.AssertNotError(t, err, "ci.intercept failed with a non-nil grpc.UnaryServerInfo") 87 88 err = ci.Unary(context.Background(), "-service-brokeTest", nil, nil, nil, testInvoker) 89 test.AssertError(t, err, "ci.intercept didn't fail when handler returned a error") 90 } 91 92 // TestWaitForReadyTrue configures a gRPC client with waitForReady: true and 93 // sends a request to a backend that is unavailable. It ensures that the 94 // request doesn't error out until the timeout is reached, i.e. that 95 // FailFast is set to false. 96 // https://github.com/grpc/grpc/blob/main/doc/wait-for-ready.md 97 func TestWaitForReadyTrue(t *testing.T) { 98 clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) 99 test.AssertNotError(t, err, "creating client metrics") 100 ci := &clientMetadataInterceptor{ 101 timeout: 100 * time.Millisecond, 102 metrics: clientMetrics, 103 clk: clock.NewFake(), 104 waitForReady: true, 105 } 106 conn, err := grpc.NewClient("localhost:19876", // random, probably unused port 107 grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)), 108 grpc.WithTransportCredentials(insecure.NewCredentials()), 109 grpc.WithUnaryInterceptor(ci.Unary)) 110 if err != nil { 111 t.Fatalf("did not connect: %v", err) 112 } 113 defer conn.Close() 114 c := test_proto.NewChillerClient(conn) 115 116 start := time.Now() 117 _, err = c.Chill(context.Background(), &test_proto.Time{Duration: durationpb.New(time.Second)}) 118 if err == nil { 119 t.Errorf("Successful Chill when we expected failure.") 120 } 121 if time.Since(start) < 90*time.Millisecond { 122 t.Errorf("Chill failed fast, when WaitForReady should be enabled.") 123 } 124 } 125 126 // TestWaitForReadyFalse configures a gRPC client with waitForReady: false and 127 // sends a request to a backend that is unavailable, and ensures that the request 128 // errors out promptly. 129 func TestWaitForReadyFalse(t *testing.T) { 130 clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) 131 test.AssertNotError(t, err, "creating client metrics") 132 ci := &clientMetadataInterceptor{ 133 timeout: time.Second, 134 metrics: clientMetrics, 135 clk: clock.NewFake(), 136 waitForReady: false, 137 } 138 conn, err := grpc.NewClient("localhost:19876", // random, probably unused port 139 grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)), 140 grpc.WithTransportCredentials(insecure.NewCredentials()), 141 grpc.WithUnaryInterceptor(ci.Unary)) 142 if err != nil { 143 t.Fatalf("did not connect: %v", err) 144 } 145 defer conn.Close() 146 c := test_proto.NewChillerClient(conn) 147 148 start := time.Now() 149 _, err = c.Chill(context.Background(), &test_proto.Time{Duration: durationpb.New(time.Second)}) 150 if err == nil { 151 t.Errorf("Successful Chill when we expected failure.") 152 } 153 if time.Since(start) > 200*time.Millisecond { 154 t.Errorf("Chill failed slow, when WaitForReady should be disabled.") 155 } 156 } 157 158 // testTimeoutServer is used to implement TestTimeouts, and will attempt to sleep for 159 // the given amount of time (unless it hits a timeout or cancel). 160 type testTimeoutServer struct { 161 test_proto.UnimplementedChillerServer 162 } 163 164 // Chill implements ChillerServer.Chill 165 func (s *testTimeoutServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) { 166 start := time.Now() 167 // Sleep for either the requested amount of time, or the context times out or 168 // is canceled. 169 select { 170 case <-time.After(in.Duration.AsDuration() * time.Nanosecond): 171 spent := time.Since(start) / time.Nanosecond 172 return &test_proto.Time{Duration: durationpb.New(spent)}, nil 173 case <-ctx.Done(): 174 return nil, errors.New("unique error indicating that the server's shortened context timed itself out") 175 } 176 } 177 178 func TestTimeouts(t *testing.T) { 179 server := new(testTimeoutServer) 180 client, _, stop := setup(t, server, clock.NewFake()) 181 defer stop() 182 183 testCases := []struct { 184 timeout time.Duration 185 expectedErrorPrefix string 186 }{ 187 {250 * time.Millisecond, "rpc error: code = Unknown desc = unique error indicating that the server's shortened context timed itself out"}, 188 {100 * time.Millisecond, "Chiller.Chill timed out after 0 ms"}, 189 {10 * time.Millisecond, "Chiller.Chill timed out after 0 ms"}, 190 } 191 for _, tc := range testCases { 192 t.Run(tc.timeout.String(), func(t *testing.T) { 193 ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) 194 defer cancel() 195 _, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second)}) 196 if err == nil { 197 t.Fatal("Got no error, expected a timeout") 198 } 199 if !strings.HasPrefix(err.Error(), tc.expectedErrorPrefix) { 200 t.Errorf("Wrong error. Got %s, expected %s", err.Error(), tc.expectedErrorPrefix) 201 } 202 }) 203 } 204 } 205 206 func TestRequestTimeTagging(t *testing.T) { 207 server := new(testTimeoutServer) 208 serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) 209 test.AssertNotError(t, err, "creating server metrics") 210 client, _, stop := setup(t, server, serverMetrics) 211 defer stop() 212 213 // Make an RPC request with the ChillerClient with a timeout higher than the 214 // requested ChillerServer delay so that the RPC completes normally 215 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 216 defer cancel() 217 if _, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second * 5)}); err != nil { 218 t.Fatalf("Unexpected error calling Chill RPC: %s", err) 219 } 220 221 // There should be one histogram sample in the serverInterceptor rpcLag stat 222 test.AssertMetricWithLabelsEquals(t, serverMetrics.rpcLag, prometheus.Labels{}, 1) 223 } 224 225 func TestClockSkew(t *testing.T) { 226 // Create two separate clocks for the client and server 227 serverClk := clock.NewFake() 228 serverClk.Set(time.Now()) 229 clientClk := clock.NewFake() 230 clientClk.Set(time.Now()) 231 232 _, serverPort, stop := setup(t, &testTimeoutServer{}, serverClk) 233 defer stop() 234 235 clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) 236 test.AssertNotError(t, err, "creating client metrics") 237 ci := &clientMetadataInterceptor{ 238 timeout: 30 * time.Second, 239 metrics: clientMetrics, 240 clk: clientClk, 241 } 242 conn, err := grpc.NewClient(net.JoinHostPort("localhost", strconv.Itoa(serverPort)), 243 grpc.WithTransportCredentials(insecure.NewCredentials()), 244 grpc.WithUnaryInterceptor(ci.Unary)) 245 if err != nil { 246 t.Fatalf("did not connect: %v", err) 247 } 248 249 client := test_proto.NewChillerClient(conn) 250 251 // Create a context with plenty of timeout 252 ctx, cancel := context.WithDeadline(context.Background(), clientClk.Now().Add(10*time.Second)) 253 defer cancel() 254 255 // Attempt a gRPC request which should succeed 256 _, err = client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(100 * time.Millisecond)}) 257 test.AssertNotError(t, err, "should succeed with no skew") 258 259 // Skew the client clock forward and the request should fail due to skew 260 clientClk.Add(time.Hour) 261 _, err = client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(100 * time.Millisecond)}) 262 test.AssertError(t, err, "should fail with positive client skew") 263 test.AssertContains(t, err.Error(), "very different time") 264 265 // Skew the server clock forward and the request should fail due to skew 266 serverClk.Add(2 * time.Hour) 267 _, err = client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(100 * time.Millisecond)}) 268 test.AssertError(t, err, "should fail with negative client skew") 269 test.AssertContains(t, err.Error(), "very different time") 270 } 271 272 // blockedServer implements a ChillerServer with a Chill method that: 273 // 1. Calls Done() on the received waitgroup when receiving an RPC 274 // 2. Blocks the RPC on the roadblock waitgroup 275 // 276 // This is used by TestInFlightRPCStat to test that the gauge for in-flight RPCs 277 // is incremented and decremented as expected. 278 type blockedServer struct { 279 test_proto.UnimplementedChillerServer 280 roadblock, received sync.WaitGroup 281 } 282 283 // Chill implements ChillerServer.Chill 284 func (s *blockedServer) Chill(_ context.Context, _ *test_proto.Time) (*test_proto.Time, error) { 285 // Note that a client RPC arrived 286 s.received.Done() 287 // Wait for the roadblock to be cleared 288 s.roadblock.Wait() 289 // Return a dummy spent value to adhere to the chiller protocol 290 return &test_proto.Time{Duration: durationpb.New(time.Millisecond)}, nil 291 } 292 293 func TestInFlightRPCStat(t *testing.T) { 294 // Create a new blockedServer to act as a ChillerServer 295 server := &blockedServer{} 296 297 metrics, err := newClientMetrics(metrics.NoopRegisterer) 298 test.AssertNotError(t, err, "creating client metrics") 299 300 client, _, stop := setup(t, server, metrics) 301 defer stop() 302 303 // Increment the roadblock waitgroup - this will cause all chill RPCs to 304 // the server to block until we call Done()! 305 server.roadblock.Add(1) 306 307 // Increment the sentRPCs waitgroup - we use this to find out when all the 308 // RPCs we want to send have been received and we can count the in-flight 309 // gauge 310 numRPCs := 5 311 server.received.Add(numRPCs) 312 313 // Fire off a few RPCs. They will block on the blockedServer's roadblock wg 314 for range numRPCs { 315 go func() { 316 // Ignore errors, just chilllll. 317 _, _ = client.Chill(context.Background(), &test_proto.Time{}) 318 }() 319 } 320 321 // wait until all of the client RPCs have been sent and are blocking. We can 322 // now check the gauge. 323 server.received.Wait() 324 325 // Specify the labels for the RPCs we're interested in 326 labels := prometheus.Labels{ 327 "service": "Chiller", 328 "method": "Chill", 329 } 330 331 // We expect the inFlightRPCs gauge for the Chiller.Chill RPCs to be equal to numRPCs. 332 test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, float64(numRPCs)) 333 334 // Unblock the blockedServer to let all of the Chiller.Chill RPCs complete 335 server.roadblock.Done() 336 // Sleep for a little bit to let all the RPCs complete 337 time.Sleep(1 * time.Second) 338 339 // Check the gauge value again 340 test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, 0) 341 } 342 343 func TestServiceAuthChecker(t *testing.T) { 344 ac := authInterceptor{ 345 map[string]map[string]struct{}{ 346 "package.ServiceName": { 347 "allowed.client": {}, 348 "also.allowed": {}, 349 }, 350 }, 351 } 352 353 // No allowlist is a bad configuration. 354 ctx := context.Background() 355 err := ac.checkContextAuth(ctx, "/package.OtherService/Method/") 356 test.AssertError(t, err, "checking empty allowlist") 357 358 // Context with no peering information is disallowed. 359 err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") 360 test.AssertError(t, err, "checking un-peered context") 361 362 // Context with no auth info is disallowed. 363 ctx = peer.NewContext(ctx, &peer.Peer{}) 364 err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") 365 test.AssertError(t, err, "checking peer with no auth") 366 367 // Context with no verified chains is disallowed. 368 ctx = peer.NewContext(ctx, &peer.Peer{ 369 AuthInfo: credentials.TLSInfo{ 370 State: tls.ConnectionState{}, 371 }, 372 }) 373 err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") 374 test.AssertError(t, err, "checking TLS with no valid chains") 375 376 // Context with cert with wrong name is disallowed. 377 ctx = peer.NewContext(ctx, &peer.Peer{ 378 AuthInfo: credentials.TLSInfo{ 379 State: tls.ConnectionState{ 380 VerifiedChains: [][]*x509.Certificate{ 381 { 382 &x509.Certificate{ 383 DNSNames: []string{ 384 "disallowed.client", 385 }, 386 }, 387 }, 388 }, 389 }, 390 }, 391 }) 392 err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") 393 test.AssertError(t, err, "checking disallowed cert") 394 395 // Context with cert with good name is allowed. 396 ctx = peer.NewContext(ctx, &peer.Peer{ 397 AuthInfo: credentials.TLSInfo{ 398 State: tls.ConnectionState{ 399 VerifiedChains: [][]*x509.Certificate{ 400 { 401 &x509.Certificate{ 402 DNSNames: []string{ 403 "disallowed.client", 404 "also.allowed", 405 }, 406 }, 407 }, 408 }, 409 }, 410 }, 411 }) 412 err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") 413 test.AssertNotError(t, err, "checking allowed cert") 414 } 415 416 // testUserAgentServer stores the last value it saw in the user agent field of its context. 417 type testUserAgentServer struct { 418 test_proto.UnimplementedChillerServer 419 420 lastSeenUA string 421 } 422 423 // Chill implements ChillerServer.Chill 424 func (s *testUserAgentServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) { 425 s.lastSeenUA = web.UserAgent(ctx) 426 return nil, nil 427 } 428 429 func TestUserAgentMetadata(t *testing.T) { 430 server := new(testUserAgentServer) 431 client, _, stop := setup(t, server) 432 defer stop() 433 434 testUA := "test UA" 435 ctx := web.WithUserAgent(context.Background(), testUA) 436 437 _, err := client.Chill(ctx, &test_proto.Time{}) 438 if err != nil { 439 t.Fatalf("calling c.Chill: %s", err) 440 } 441 442 if server.lastSeenUA != testUA { 443 t.Errorf("last seen User-Agent on server side was %q, want %q", server.lastSeenUA, testUA) 444 } 445 } 446 447 // setup creates a server and client, returning the created client, the running server's port, and a stop function. 448 func setup(t *testing.T, server test_proto.ChillerServer, opts ...any) (test_proto.ChillerClient, int, func()) { 449 clk := clock.NewFake() 450 serverMetricsVal, err := newServerMetrics(metrics.NoopRegisterer) 451 test.AssertNotError(t, err, "creating server metrics") 452 clientMetricsVal, err := newClientMetrics(metrics.NoopRegisterer) 453 test.AssertNotError(t, err, "creating client metrics") 454 455 for _, opt := range opts { 456 switch optTyped := opt.(type) { 457 case clock.FakeClock: 458 clk = optTyped 459 case clientMetrics: 460 clientMetricsVal = optTyped 461 case serverMetrics: 462 serverMetricsVal = optTyped 463 default: 464 t.Fatalf("setup called with unrecognize option %#v", t) 465 } 466 } 467 lis, err := net.Listen("tcp", ":0") 468 if err != nil { 469 log.Fatalf("failed to listen: %v", err) 470 } 471 port := lis.Addr().(*net.TCPAddr).Port 472 473 si := newServerMetadataInterceptor(serverMetricsVal, clk) 474 s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) 475 test_proto.RegisterChillerServer(s, server) 476 477 go func() { 478 start := time.Now() 479 err := s.Serve(lis) 480 if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") { 481 t.Logf("s.Serve: %v after %s", err, time.Since(start)) 482 } 483 }() 484 485 ci := &clientMetadataInterceptor{ 486 timeout: 30 * time.Second, 487 metrics: clientMetricsVal, 488 clk: clock.NewFake(), 489 } 490 conn, err := grpc.NewClient(net.JoinHostPort("localhost", strconv.Itoa(port)), 491 grpc.WithTransportCredentials(insecure.NewCredentials()), 492 grpc.WithUnaryInterceptor(ci.Unary)) 493 if err != nil { 494 t.Fatalf("did not connect: %v", err) 495 } 496 return test_proto.NewChillerClient(conn), port, s.Stop 497 }