google.golang.org/grpc@v1.74.2/credentials/credentials_ext_test.go (about)

     1  /*
     2   *
     3   * Copyright 2025 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package credentials_test
    20  
    21  import (
    22  	"context"
    23  	"crypto/tls"
    24  	"fmt"
    25  	"net"
    26  	"testing"
    27  	"time"
    28  
    29  	"google.golang.org/grpc"
    30  	"google.golang.org/grpc/codes"
    31  	"google.golang.org/grpc/credentials"
    32  	"google.golang.org/grpc/credentials/insecure"
    33  	"google.golang.org/grpc/credentials/local"
    34  	"google.golang.org/grpc/internal/stubserver"
    35  	"google.golang.org/grpc/metadata"
    36  	"google.golang.org/grpc/status"
    37  	"google.golang.org/grpc/testdata"
    38  
    39  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    40  	testpb "google.golang.org/grpc/interop/grpc_testing"
    41  )
    42  
    43  func authorityChecker(ctx context.Context, wantAuthority string) error {
    44  	md, ok := metadata.FromIncomingContext(ctx)
    45  	if !ok {
    46  		return status.Error(codes.InvalidArgument, "failed to parse metadata")
    47  	}
    48  	auths, ok := md[":authority"]
    49  	if !ok {
    50  		return status.Error(codes.InvalidArgument, "no authority header")
    51  	}
    52  	if len(auths) != 1 {
    53  		return status.Errorf(codes.InvalidArgument, "expected exactly one authority header, got %v", auths)
    54  	}
    55  	if auths[0] != wantAuthority {
    56  		return status.Errorf(codes.InvalidArgument, "invalid authority header %q, want %q", auths[0], wantAuthority)
    57  	}
    58  	return nil
    59  }
    60  
    61  func loadTLSCreds(t *testing.T) (grpc.ServerOption, grpc.DialOption) {
    62  	t.Helper()
    63  	cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
    64  	if err != nil {
    65  		t.Fatalf("Failed to load key pair: %v", err)
    66  		return nil, nil
    67  	}
    68  	serverCreds := grpc.Creds(credentials.NewServerTLSFromCert(&cert))
    69  
    70  	clientCreds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
    71  	if err != nil {
    72  		t.Fatalf("Failed to create client credentials: %v", err)
    73  	}
    74  	return serverCreds, grpc.WithTransportCredentials(clientCreds)
    75  }
    76  
    77  // Tests the scenario where the `grpc.CallAuthority` call option is used with
    78  // different transport credentials. The test verifies that the specified
    79  // authority is correctly propagated to the serve when a correct authority is
    80  // used.
    81  func (s) TestCorrectAuthorityWithCreds(t *testing.T) {
    82  	const authority = "auth.test.example.com"
    83  
    84  	tests := []struct {
    85  		name         string
    86  		creds        func(t *testing.T) (grpc.ServerOption, grpc.DialOption)
    87  		expectedAuth string
    88  	}{
    89  		{
    90  			name: "Insecure",
    91  			creds: func(*testing.T) (grpc.ServerOption, grpc.DialOption) {
    92  				c := insecure.NewCredentials()
    93  				return grpc.Creds(c), grpc.WithTransportCredentials(c)
    94  			},
    95  			expectedAuth: authority,
    96  		},
    97  		{
    98  			name: "Local",
    99  			creds: func(*testing.T) (grpc.ServerOption, grpc.DialOption) {
   100  				c := local.NewCredentials()
   101  				return grpc.Creds(c), grpc.WithTransportCredentials(c)
   102  			},
   103  			expectedAuth: authority,
   104  		},
   105  		{
   106  			name: "TLS",
   107  			creds: func(t *testing.T) (grpc.ServerOption, grpc.DialOption) {
   108  				return loadTLSCreds(t)
   109  			},
   110  			expectedAuth: authority,
   111  		},
   112  	}
   113  
   114  	for _, tt := range tests {
   115  		t.Run(tt.name, func(t *testing.T) {
   116  			ss := &stubserver.StubServer{
   117  				EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   118  					if err := authorityChecker(ctx, tt.expectedAuth); err != nil {
   119  						return nil, err
   120  					}
   121  					return &testpb.Empty{}, nil
   122  				},
   123  			}
   124  			serverOpt, dialOpt := tt.creds(t)
   125  			if err := ss.StartServer(serverOpt); err != nil {
   126  				t.Fatalf("Error starting endpoint server: %v", err)
   127  			}
   128  			defer ss.Stop()
   129  
   130  			cc, err := grpc.NewClient(ss.Address, dialOpt)
   131  			if err != nil {
   132  				t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
   133  			}
   134  			defer cc.Close()
   135  
   136  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   137  			defer cancel()
   138  			if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.expectedAuth)); err != nil {
   139  				t.Fatalf("EmptyCall() rpc failed: %v", err)
   140  			}
   141  		})
   142  	}
   143  }
   144  
   145  // Tests the `grpc.CallAuthority` option with TLS credentials. This test verifies
   146  // that the RPC fails with `UNAVAILABLE` status code and doesn't reach the server
   147  // when an incorrect authority is used.
   148  func (s) TestIncorrectAuthorityWithTLS(t *testing.T) {
   149  	cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   150  	if err != nil {
   151  		t.Fatalf("Failed to load key pair: %s", err)
   152  	}
   153  	creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
   154  	if err != nil {
   155  		t.Fatalf("Failed to create credentials %v", err)
   156  	}
   157  
   158  	serverCalled := make(chan struct{})
   159  	ss := &stubserver.StubServer{
   160  		EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   161  			close(serverCalled)
   162  			return nil, nil
   163  		},
   164  	}
   165  	if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil {
   166  		t.Fatalf("Error starting endpoint server: %v", err)
   167  	}
   168  	defer ss.Stop()
   169  
   170  	cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
   171  	if err != nil {
   172  		t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
   173  	}
   174  	defer cc.Close()
   175  
   176  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   177  	defer cancel()
   178  
   179  	const authority = "auth.example.com"
   180  	if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.Unavailable {
   181  		t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.Unavailable)
   182  	}
   183  	select {
   184  	case <-serverCalled:
   185  		t.Fatalf("Server handler should not have been called")
   186  	case <-time.After(defaultTestShortTimeout):
   187  	}
   188  }
   189  
   190  // testAuthInfoNoValidator implements only credentials.AuthInfo and not
   191  // credentials.AuthorityValidator.
   192  type testAuthInfoNoValidator struct{}
   193  
   194  // AuthType returns the authentication type.
   195  func (testAuthInfoNoValidator) AuthType() string {
   196  	return "test"
   197  }
   198  
   199  // testAuthInfoWithValidator implements both credentials.AuthInfo and
   200  // credentials.AuthorityValidator.
   201  type testAuthInfoWithValidator struct {
   202  	validAuthority string
   203  }
   204  
   205  // AuthType returns the authentication type.
   206  func (testAuthInfoWithValidator) AuthType() string {
   207  	return "test"
   208  }
   209  
   210  // ValidateAuthority implements credentials.AuthorityValidator.
   211  func (v testAuthInfoWithValidator) ValidateAuthority(authority string) error {
   212  	if authority == v.validAuthority {
   213  		return nil
   214  	}
   215  	return fmt.Errorf("invalid authority %q, want %q", authority, v.validAuthority)
   216  }
   217  
   218  // testCreds is a test TransportCredentials that can optionally support
   219  // authority validation.
   220  type testCreds struct {
   221  	authority string
   222  }
   223  
   224  // ClientHandshake performs the client-side handshake.
   225  func (c *testCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   226  	if c.authority != "" {
   227  		return rawConn, testAuthInfoWithValidator{validAuthority: c.authority}, nil
   228  	}
   229  	return rawConn, testAuthInfoNoValidator{}, nil
   230  }
   231  
   232  // ServerHandshake performs the server-side handshake.
   233  func (c *testCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   234  	if c.authority != "" {
   235  		return rawConn, testAuthInfoWithValidator{validAuthority: c.authority}, nil
   236  	}
   237  	return rawConn, testAuthInfoNoValidator{}, nil
   238  }
   239  
   240  // Clone creates a copy of testCreds.
   241  func (c *testCreds) Clone() credentials.TransportCredentials {
   242  	return &testCreds{authority: c.authority}
   243  }
   244  
   245  // Info provides protocol information.
   246  func (c *testCreds) Info() credentials.ProtocolInfo {
   247  	return credentials.ProtocolInfo{}
   248  }
   249  
   250  // OverrideServerName overrides the server name used for verification.
   251  func (c *testCreds) OverrideServerName(string) error {
   252  	return nil
   253  }
   254  
   255  // TestAuthorityValidationFailureWithCustomCreds tests the `grpc.CallAuthority`
   256  // call option using custom credentials. It covers two failure scenarios:
   257  // - The credentials implement AuthorityValidator but authority used to override
   258  // is not valid.
   259  // - The credentials do not implement AuthorityValidator, but an authority
   260  // override is specified.
   261  // In both cases, the RPC is expected to fail with an `UNAVAILABLE` status code.
   262  func (s) TestAuthorityValidationFailureWithCustomCreds(t *testing.T) {
   263  	tests := []struct {
   264  		name      string
   265  		creds     credentials.TransportCredentials
   266  		authority string
   267  	}{
   268  		{
   269  			name:      "IncorrectAuthorityWithFakeCreds",
   270  			authority: "auth.example.com",
   271  			creds:     &testCreds{authority: "auth.test.example.com"},
   272  		},
   273  		{
   274  			name:      "FakeCredsWithNoAuthValidator",
   275  			creds:     &testCreds{},
   276  			authority: "auth.test.example.com",
   277  		},
   278  	}
   279  	for _, tt := range tests {
   280  		t.Run(tt.name, func(t *testing.T) {
   281  			serverCalled := make(chan struct{})
   282  			ss := stubserver.StubServer{
   283  				EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
   284  					close(serverCalled)
   285  					return nil, nil
   286  				},
   287  			}
   288  			if err := ss.StartServer(); err != nil {
   289  				t.Fatalf("Failed to start stub server: %v", err)
   290  			}
   291  			defer ss.Stop()
   292  
   293  			cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(tt.creds))
   294  			if err != nil {
   295  				t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
   296  			}
   297  			defer cc.Close()
   298  
   299  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   300  			defer cancel()
   301  			if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.authority)); status.Code(err) != codes.Unavailable {
   302  				t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.Unavailable)
   303  			}
   304  			select {
   305  			case <-serverCalled:
   306  				t.Fatalf("Server should not have been called")
   307  			case <-time.After(defaultTestShortTimeout):
   308  			}
   309  		})
   310  	}
   311  
   312  }
   313  
   314  // TestCorrectAuthorityWithCustomCreds tests the `grpc.CallAuthority` call
   315  // option using custom credentials. It verifies that the provided authority is
   316  // correctly propagated to the server when a correct authority is used.
   317  func (s) TestCorrectAuthorityWithCustomCreds(t *testing.T) {
   318  	const authority = "auth.test.example.com"
   319  	creds := &testCreds{authority: "auth.test.example.com"}
   320  	ss := stubserver.StubServer{
   321  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   322  			if err := authorityChecker(ctx, authority); err != nil {
   323  				return nil, err
   324  			}
   325  			return &testpb.Empty{}, nil
   326  		},
   327  	}
   328  	if err := ss.StartServer(); err != nil {
   329  		t.Fatalf("Failed to start stub server: %v", err)
   330  	}
   331  	defer ss.Stop()
   332  
   333  	cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
   334  	if err != nil {
   335  		t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
   336  	}
   337  	defer cc.Close()
   338  
   339  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   340  	defer cancel()
   341  	if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.OK {
   342  		t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.OK)
   343  	}
   344  }