github.com/letsencrypt/boulder@v0.20251208.0/grpc/interceptors_test.go (about)

     1  package grpc
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"errors"
     8  	"fmt"
     9  	"log"
    10  	"net"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/jmhodges/clock"
    18  	"github.com/prometheus/client_golang/prometheus"
    19  	"google.golang.org/grpc"
    20  	"google.golang.org/grpc/balancer/roundrobin"
    21  	"google.golang.org/grpc/credentials"
    22  	"google.golang.org/grpc/credentials/insecure"
    23  	"google.golang.org/grpc/metadata"
    24  	"google.golang.org/grpc/peer"
    25  	"google.golang.org/grpc/status"
    26  	"google.golang.org/protobuf/types/known/durationpb"
    27  
    28  	"github.com/letsencrypt/boulder/grpc/test_proto"
    29  	"github.com/letsencrypt/boulder/metrics"
    30  	"github.com/letsencrypt/boulder/test"
    31  	"github.com/letsencrypt/boulder/web"
    32  )
    33  
    34  var fc = clock.NewFake()
    35  
    36  func testHandler(_ context.Context, i any) (any, error) {
    37  	if i != nil {
    38  		return nil, errors.New("")
    39  	}
    40  	fc.Sleep(time.Second)
    41  	return nil, nil
    42  }
    43  
    44  func testInvoker(_ context.Context, method string, _, _ any, _ *grpc.ClientConn, opts ...grpc.CallOption) error {
    45  	switch method {
    46  	case "-service-brokeTest":
    47  		return errors.New("")
    48  	case "-service-requesterCanceledTest":
    49  		return status.Error(1, context.Canceled.Error())
    50  	}
    51  	fc.Sleep(time.Second)
    52  	return nil
    53  }
    54  
    55  func TestServerInterceptor(t *testing.T) {
    56  	serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
    57  	test.AssertNotError(t, err, "creating server metrics")
    58  	si := newServerMetadataInterceptor(serverMetrics, clock.NewFake())
    59  
    60  	md := metadata.New(map[string]string{clientRequestTimeKey: "0"})
    61  	ctxWithMetadata := metadata.NewIncomingContext(context.Background(), md)
    62  
    63  	_, err = si.Unary(context.Background(), nil, nil, testHandler)
    64  	test.AssertError(t, err, "si.intercept didn't fail with a context missing metadata")
    65  
    66  	_, err = si.Unary(ctxWithMetadata, nil, nil, testHandler)
    67  	test.AssertError(t, err, "si.intercept didn't fail with a nil grpc.UnaryServerInfo")
    68  
    69  	_, err = si.Unary(ctxWithMetadata, nil, &grpc.UnaryServerInfo{FullMethod: "-service-test"}, testHandler)
    70  	test.AssertNotError(t, err, "si.intercept failed with a non-nil grpc.UnaryServerInfo")
    71  
    72  	_, err = si.Unary(ctxWithMetadata, 0, &grpc.UnaryServerInfo{FullMethod: "brokeTest"}, testHandler)
    73  	test.AssertError(t, err, "si.intercept didn't fail when handler returned a error")
    74  }
    75  
    76  func TestClientInterceptor(t *testing.T) {
    77  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
    78  	test.AssertNotError(t, err, "creating client metrics")
    79  	ci := clientMetadataInterceptor{
    80  		timeout: time.Second,
    81  		metrics: clientMetrics,
    82  		clk:     clock.NewFake(),
    83  	}
    84  
    85  	err = ci.Unary(context.Background(), "-service-test", nil, nil, nil, testInvoker)
    86  	test.AssertNotError(t, err, "ci.intercept failed with a non-nil grpc.UnaryServerInfo")
    87  
    88  	err = ci.Unary(context.Background(), "-service-brokeTest", nil, nil, nil, testInvoker)
    89  	test.AssertError(t, err, "ci.intercept didn't fail when handler returned a error")
    90  }
    91  
    92  // TestWaitForReadyTrue configures a gRPC client with waitForReady: true and
    93  // sends a request to a backend that is unavailable. It ensures that the
    94  // request doesn't error out until the timeout is reached, i.e. that
    95  // FailFast is set to false.
    96  // https://github.com/grpc/grpc/blob/main/doc/wait-for-ready.md
    97  func TestWaitForReadyTrue(t *testing.T) {
    98  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
    99  	test.AssertNotError(t, err, "creating client metrics")
   100  	ci := &clientMetadataInterceptor{
   101  		timeout:      100 * time.Millisecond,
   102  		metrics:      clientMetrics,
   103  		clk:          clock.NewFake(),
   104  		waitForReady: true,
   105  	}
   106  	conn, err := grpc.NewClient("localhost:19876", // random, probably unused port
   107  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)),
   108  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   109  		grpc.WithUnaryInterceptor(ci.Unary))
   110  	if err != nil {
   111  		t.Fatalf("did not connect: %v", err)
   112  	}
   113  	defer conn.Close()
   114  	c := test_proto.NewChillerClient(conn)
   115  
   116  	start := time.Now()
   117  	_, err = c.Chill(context.Background(), &test_proto.Time{Duration: durationpb.New(time.Second)})
   118  	if err == nil {
   119  		t.Errorf("Successful Chill when we expected failure.")
   120  	}
   121  	if time.Since(start) < 90*time.Millisecond {
   122  		t.Errorf("Chill failed fast, when WaitForReady should be enabled.")
   123  	}
   124  }
   125  
   126  // TestWaitForReadyFalse configures a gRPC client with waitForReady: false and
   127  // sends a request to a backend that is unavailable, and ensures that the request
   128  // errors out promptly.
   129  func TestWaitForReadyFalse(t *testing.T) {
   130  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
   131  	test.AssertNotError(t, err, "creating client metrics")
   132  	ci := &clientMetadataInterceptor{
   133  		timeout:      time.Second,
   134  		metrics:      clientMetrics,
   135  		clk:          clock.NewFake(),
   136  		waitForReady: false,
   137  	}
   138  	conn, err := grpc.NewClient("localhost:19876", // random, probably unused port
   139  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)),
   140  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   141  		grpc.WithUnaryInterceptor(ci.Unary))
   142  	if err != nil {
   143  		t.Fatalf("did not connect: %v", err)
   144  	}
   145  	defer conn.Close()
   146  	c := test_proto.NewChillerClient(conn)
   147  
   148  	start := time.Now()
   149  	_, err = c.Chill(context.Background(), &test_proto.Time{Duration: durationpb.New(time.Second)})
   150  	if err == nil {
   151  		t.Errorf("Successful Chill when we expected failure.")
   152  	}
   153  	if time.Since(start) > 200*time.Millisecond {
   154  		t.Errorf("Chill failed slow, when WaitForReady should be disabled.")
   155  	}
   156  }
   157  
   158  // testTimeoutServer is used to implement TestTimeouts, and will attempt to sleep for
   159  // the given amount of time (unless it hits a timeout or cancel).
   160  type testTimeoutServer struct {
   161  	test_proto.UnimplementedChillerServer
   162  }
   163  
   164  // Chill implements ChillerServer.Chill
   165  func (s *testTimeoutServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) {
   166  	start := time.Now()
   167  	// Sleep for either the requested amount of time, or the context times out or
   168  	// is canceled.
   169  	select {
   170  	case <-time.After(in.Duration.AsDuration() * time.Nanosecond):
   171  		spent := time.Since(start) / time.Nanosecond
   172  		return &test_proto.Time{Duration: durationpb.New(spent)}, nil
   173  	case <-ctx.Done():
   174  		return nil, errors.New("unique error indicating that the server's shortened context timed itself out")
   175  	}
   176  }
   177  
   178  func TestTimeouts(t *testing.T) {
   179  	server := new(testTimeoutServer)
   180  	client, _, stop := setup(t, server, clock.NewFake())
   181  	defer stop()
   182  
   183  	testCases := []struct {
   184  		timeout             time.Duration
   185  		expectedErrorPrefix string
   186  	}{
   187  		{250 * time.Millisecond, "rpc error: code = Unknown desc = unique error indicating that the server's shortened context timed itself out"},
   188  		{100 * time.Millisecond, "Chiller.Chill timed out after 0 ms"},
   189  		{10 * time.Millisecond, "Chiller.Chill timed out after 0 ms"},
   190  	}
   191  	for _, tc := range testCases {
   192  		t.Run(tc.timeout.String(), func(t *testing.T) {
   193  			ctx, cancel := context.WithTimeout(context.Background(), tc.timeout)
   194  			defer cancel()
   195  			_, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second)})
   196  			if err == nil {
   197  				t.Fatal("Got no error, expected a timeout")
   198  			}
   199  			if !strings.HasPrefix(err.Error(), tc.expectedErrorPrefix) {
   200  				t.Errorf("Wrong error. Got %s, expected %s", err.Error(), tc.expectedErrorPrefix)
   201  			}
   202  		})
   203  	}
   204  }
   205  
   206  func TestRequestTimeTagging(t *testing.T) {
   207  	server := new(testTimeoutServer)
   208  	serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
   209  	test.AssertNotError(t, err, "creating server metrics")
   210  	client, _, stop := setup(t, server, serverMetrics)
   211  	defer stop()
   212  
   213  	// Make an RPC request with the ChillerClient with a timeout higher than the
   214  	// requested ChillerServer delay so that the RPC completes normally
   215  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   216  	defer cancel()
   217  	if _, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second * 5)}); err != nil {
   218  		t.Fatalf("Unexpected error calling Chill RPC: %s", err)
   219  	}
   220  
   221  	// There should be one histogram sample in the serverInterceptor rpcLag stat
   222  	test.AssertMetricWithLabelsEquals(t, serverMetrics.rpcLag, prometheus.Labels{}, 1)
   223  }
   224  
   225  func TestClockSkew(t *testing.T) {
   226  	// Create two separate clocks for the client and server
   227  	serverClk := clock.NewFake()
   228  	serverClk.Set(time.Now())
   229  	clientClk := clock.NewFake()
   230  	clientClk.Set(time.Now())
   231  
   232  	_, serverPort, stop := setup(t, &testTimeoutServer{}, serverClk)
   233  	defer stop()
   234  
   235  	clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
   236  	test.AssertNotError(t, err, "creating client metrics")
   237  	ci := &clientMetadataInterceptor{
   238  		timeout: 30 * time.Second,
   239  		metrics: clientMetrics,
   240  		clk:     clientClk,
   241  	}
   242  	conn, err := grpc.NewClient(net.JoinHostPort("localhost", strconv.Itoa(serverPort)),
   243  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   244  		grpc.WithUnaryInterceptor(ci.Unary))
   245  	if err != nil {
   246  		t.Fatalf("did not connect: %v", err)
   247  	}
   248  
   249  	client := test_proto.NewChillerClient(conn)
   250  
   251  	// Create a context with plenty of timeout
   252  	ctx, cancel := context.WithDeadline(context.Background(), clientClk.Now().Add(10*time.Second))
   253  	defer cancel()
   254  
   255  	// Attempt a gRPC request which should succeed
   256  	_, err = client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(100 * time.Millisecond)})
   257  	test.AssertNotError(t, err, "should succeed with no skew")
   258  
   259  	// Skew the client clock forward and the request should fail due to skew
   260  	clientClk.Add(time.Hour)
   261  	_, err = client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(100 * time.Millisecond)})
   262  	test.AssertError(t, err, "should fail with positive client skew")
   263  	test.AssertContains(t, err.Error(), "very different time")
   264  
   265  	// Skew the server clock forward and the request should fail due to skew
   266  	serverClk.Add(2 * time.Hour)
   267  	_, err = client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(100 * time.Millisecond)})
   268  	test.AssertError(t, err, "should fail with negative client skew")
   269  	test.AssertContains(t, err.Error(), "very different time")
   270  }
   271  
   272  // blockedServer implements a ChillerServer with a Chill method that:
   273  //  1. Calls Done() on the received waitgroup when receiving an RPC
   274  //  2. Blocks the RPC on the roadblock waitgroup
   275  //
   276  // This is used by TestInFlightRPCStat to test that the gauge for in-flight RPCs
   277  // is incremented and decremented as expected.
   278  type blockedServer struct {
   279  	test_proto.UnimplementedChillerServer
   280  	roadblock, received sync.WaitGroup
   281  }
   282  
   283  // Chill implements ChillerServer.Chill
   284  func (s *blockedServer) Chill(_ context.Context, _ *test_proto.Time) (*test_proto.Time, error) {
   285  	// Note that a client RPC arrived
   286  	s.received.Done()
   287  	// Wait for the roadblock to be cleared
   288  	s.roadblock.Wait()
   289  	// Return a dummy spent value to adhere to the chiller protocol
   290  	return &test_proto.Time{Duration: durationpb.New(time.Millisecond)}, nil
   291  }
   292  
   293  func TestInFlightRPCStat(t *testing.T) {
   294  	// Create a new blockedServer to act as a ChillerServer
   295  	server := &blockedServer{}
   296  
   297  	metrics, err := newClientMetrics(metrics.NoopRegisterer)
   298  	test.AssertNotError(t, err, "creating client metrics")
   299  
   300  	client, _, stop := setup(t, server, metrics)
   301  	defer stop()
   302  
   303  	// Increment the roadblock waitgroup - this will cause all chill RPCs to
   304  	// the server to block until we call Done()!
   305  	server.roadblock.Add(1)
   306  
   307  	// Increment the sentRPCs waitgroup - we use this to find out when all the
   308  	// RPCs we want to send have been received and we can count the in-flight
   309  	// gauge
   310  	numRPCs := 5
   311  	server.received.Add(numRPCs)
   312  
   313  	// Fire off a few RPCs. They will block on the blockedServer's roadblock wg
   314  	for range numRPCs {
   315  		go func() {
   316  			// Ignore errors, just chilllll.
   317  			_, _ = client.Chill(context.Background(), &test_proto.Time{})
   318  		}()
   319  	}
   320  
   321  	// wait until all of the client RPCs have been sent and are blocking. We can
   322  	// now check the gauge.
   323  	server.received.Wait()
   324  
   325  	// Specify the labels for the RPCs we're interested in
   326  	labels := prometheus.Labels{
   327  		"service": "Chiller",
   328  		"method":  "Chill",
   329  	}
   330  
   331  	// We expect the inFlightRPCs gauge for the Chiller.Chill RPCs to be equal to numRPCs.
   332  	test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, float64(numRPCs))
   333  
   334  	// Unblock the blockedServer to let all of the Chiller.Chill RPCs complete
   335  	server.roadblock.Done()
   336  	// Sleep for a little bit to let all the RPCs complete
   337  	time.Sleep(1 * time.Second)
   338  
   339  	// Check the gauge value again
   340  	test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, 0)
   341  }
   342  
   343  func TestServiceAuthChecker(t *testing.T) {
   344  	ac := authInterceptor{
   345  		map[string]map[string]struct{}{
   346  			"package.ServiceName": {
   347  				"allowed.client": {},
   348  				"also.allowed":   {},
   349  			},
   350  		},
   351  	}
   352  
   353  	// No allowlist is a bad configuration.
   354  	ctx := context.Background()
   355  	err := ac.checkContextAuth(ctx, "/package.OtherService/Method/")
   356  	test.AssertError(t, err, "checking empty allowlist")
   357  
   358  	// Context with no peering information is disallowed.
   359  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   360  	test.AssertError(t, err, "checking un-peered context")
   361  
   362  	// Context with no auth info is disallowed.
   363  	ctx = peer.NewContext(ctx, &peer.Peer{})
   364  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   365  	test.AssertError(t, err, "checking peer with no auth")
   366  
   367  	// Context with no verified chains is disallowed.
   368  	ctx = peer.NewContext(ctx, &peer.Peer{
   369  		AuthInfo: credentials.TLSInfo{
   370  			State: tls.ConnectionState{},
   371  		},
   372  	})
   373  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   374  	test.AssertError(t, err, "checking TLS with no valid chains")
   375  
   376  	// Context with cert with wrong name is disallowed.
   377  	ctx = peer.NewContext(ctx, &peer.Peer{
   378  		AuthInfo: credentials.TLSInfo{
   379  			State: tls.ConnectionState{
   380  				VerifiedChains: [][]*x509.Certificate{
   381  					{
   382  						&x509.Certificate{
   383  							DNSNames: []string{
   384  								"disallowed.client",
   385  							},
   386  						},
   387  					},
   388  				},
   389  			},
   390  		},
   391  	})
   392  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   393  	test.AssertError(t, err, "checking disallowed cert")
   394  
   395  	// Context with cert with good name is allowed.
   396  	ctx = peer.NewContext(ctx, &peer.Peer{
   397  		AuthInfo: credentials.TLSInfo{
   398  			State: tls.ConnectionState{
   399  				VerifiedChains: [][]*x509.Certificate{
   400  					{
   401  						&x509.Certificate{
   402  							DNSNames: []string{
   403  								"disallowed.client",
   404  								"also.allowed",
   405  							},
   406  						},
   407  					},
   408  				},
   409  			},
   410  		},
   411  	})
   412  	err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
   413  	test.AssertNotError(t, err, "checking allowed cert")
   414  }
   415  
   416  // testUserAgentServer stores the last value it saw in the user agent field of its context.
   417  type testUserAgentServer struct {
   418  	test_proto.UnimplementedChillerServer
   419  
   420  	lastSeenUA string
   421  }
   422  
   423  // Chill implements ChillerServer.Chill
   424  func (s *testUserAgentServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) {
   425  	s.lastSeenUA = web.UserAgent(ctx)
   426  	return nil, nil
   427  }
   428  
   429  func TestUserAgentMetadata(t *testing.T) {
   430  	server := new(testUserAgentServer)
   431  	client, _, stop := setup(t, server)
   432  	defer stop()
   433  
   434  	testUA := "test UA"
   435  	ctx := web.WithUserAgent(context.Background(), testUA)
   436  
   437  	_, err := client.Chill(ctx, &test_proto.Time{})
   438  	if err != nil {
   439  		t.Fatalf("calling c.Chill: %s", err)
   440  	}
   441  
   442  	if server.lastSeenUA != testUA {
   443  		t.Errorf("last seen User-Agent on server side was %q, want %q", server.lastSeenUA, testUA)
   444  	}
   445  }
   446  
   447  // setup creates a server and client, returning the created client, the running server's port, and a stop function.
   448  func setup(t *testing.T, server test_proto.ChillerServer, opts ...any) (test_proto.ChillerClient, int, func()) {
   449  	clk := clock.NewFake()
   450  	serverMetricsVal, err := newServerMetrics(metrics.NoopRegisterer)
   451  	test.AssertNotError(t, err, "creating server metrics")
   452  	clientMetricsVal, err := newClientMetrics(metrics.NoopRegisterer)
   453  	test.AssertNotError(t, err, "creating client metrics")
   454  
   455  	for _, opt := range opts {
   456  		switch optTyped := opt.(type) {
   457  		case clock.FakeClock:
   458  			clk = optTyped
   459  		case clientMetrics:
   460  			clientMetricsVal = optTyped
   461  		case serverMetrics:
   462  			serverMetricsVal = optTyped
   463  		default:
   464  			t.Fatalf("setup called with unrecognize option %#v", t)
   465  		}
   466  	}
   467  	lis, err := net.Listen("tcp", ":0")
   468  	if err != nil {
   469  		log.Fatalf("failed to listen: %v", err)
   470  	}
   471  	port := lis.Addr().(*net.TCPAddr).Port
   472  
   473  	si := newServerMetadataInterceptor(serverMetricsVal, clk)
   474  	s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
   475  	test_proto.RegisterChillerServer(s, server)
   476  
   477  	go func() {
   478  		start := time.Now()
   479  		err := s.Serve(lis)
   480  		if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") {
   481  			t.Logf("s.Serve: %v after %s", err, time.Since(start))
   482  		}
   483  	}()
   484  
   485  	ci := &clientMetadataInterceptor{
   486  		timeout: 30 * time.Second,
   487  		metrics: clientMetricsVal,
   488  		clk:     clock.NewFake(),
   489  	}
   490  	conn, err := grpc.NewClient(net.JoinHostPort("localhost", strconv.Itoa(port)),
   491  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   492  		grpc.WithUnaryInterceptor(ci.Unary))
   493  	if err != nil {
   494  		t.Fatalf("did not connect: %v", err)
   495  	}
   496  	return test_proto.NewChillerClient(conn), port, s.Stop
   497  }