google.golang.org/grpc@v1.74.2/credentials/credentials_ext_test.go (about) 1 /* 2 * 3 * Copyright 2025 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package credentials_test 20 21 import ( 22 "context" 23 "crypto/tls" 24 "fmt" 25 "net" 26 "testing" 27 "time" 28 29 "google.golang.org/grpc" 30 "google.golang.org/grpc/codes" 31 "google.golang.org/grpc/credentials" 32 "google.golang.org/grpc/credentials/insecure" 33 "google.golang.org/grpc/credentials/local" 34 "google.golang.org/grpc/internal/stubserver" 35 "google.golang.org/grpc/metadata" 36 "google.golang.org/grpc/status" 37 "google.golang.org/grpc/testdata" 38 39 testgrpc "google.golang.org/grpc/interop/grpc_testing" 40 testpb "google.golang.org/grpc/interop/grpc_testing" 41 ) 42 43 func authorityChecker(ctx context.Context, wantAuthority string) error { 44 md, ok := metadata.FromIncomingContext(ctx) 45 if !ok { 46 return status.Error(codes.InvalidArgument, "failed to parse metadata") 47 } 48 auths, ok := md[":authority"] 49 if !ok { 50 return status.Error(codes.InvalidArgument, "no authority header") 51 } 52 if len(auths) != 1 { 53 return status.Errorf(codes.InvalidArgument, "expected exactly one authority header, got %v", auths) 54 } 55 if auths[0] != wantAuthority { 56 return status.Errorf(codes.InvalidArgument, "invalid authority header %q, want %q", auths[0], wantAuthority) 57 } 58 return nil 59 } 60 61 func loadTLSCreds(t *testing.T) (grpc.ServerOption, grpc.DialOption) { 62 t.Helper() 63 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 64 if err != nil { 65 t.Fatalf("Failed to load key pair: %v", err) 66 return nil, nil 67 } 68 serverCreds := grpc.Creds(credentials.NewServerTLSFromCert(&cert)) 69 70 clientCreds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") 71 if err != nil { 72 t.Fatalf("Failed to create client credentials: %v", err) 73 } 74 return serverCreds, grpc.WithTransportCredentials(clientCreds) 75 } 76 77 // Tests the scenario where the `grpc.CallAuthority` call option is used with 78 // different transport credentials. The test verifies that the specified 79 // authority is correctly propagated to the serve when a correct authority is 80 // used. 81 func (s) TestCorrectAuthorityWithCreds(t *testing.T) { 82 const authority = "auth.test.example.com" 83 84 tests := []struct { 85 name string 86 creds func(t *testing.T) (grpc.ServerOption, grpc.DialOption) 87 expectedAuth string 88 }{ 89 { 90 name: "Insecure", 91 creds: func(*testing.T) (grpc.ServerOption, grpc.DialOption) { 92 c := insecure.NewCredentials() 93 return grpc.Creds(c), grpc.WithTransportCredentials(c) 94 }, 95 expectedAuth: authority, 96 }, 97 { 98 name: "Local", 99 creds: func(*testing.T) (grpc.ServerOption, grpc.DialOption) { 100 c := local.NewCredentials() 101 return grpc.Creds(c), grpc.WithTransportCredentials(c) 102 }, 103 expectedAuth: authority, 104 }, 105 { 106 name: "TLS", 107 creds: func(t *testing.T) (grpc.ServerOption, grpc.DialOption) { 108 return loadTLSCreds(t) 109 }, 110 expectedAuth: authority, 111 }, 112 } 113 114 for _, tt := range tests { 115 t.Run(tt.name, func(t *testing.T) { 116 ss := &stubserver.StubServer{ 117 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 118 if err := authorityChecker(ctx, tt.expectedAuth); err != nil { 119 return nil, err 120 } 121 return &testpb.Empty{}, nil 122 }, 123 } 124 serverOpt, dialOpt := tt.creds(t) 125 if err := ss.StartServer(serverOpt); err != nil { 126 t.Fatalf("Error starting endpoint server: %v", err) 127 } 128 defer ss.Stop() 129 130 cc, err := grpc.NewClient(ss.Address, dialOpt) 131 if err != nil { 132 t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) 133 } 134 defer cc.Close() 135 136 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 137 defer cancel() 138 if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.expectedAuth)); err != nil { 139 t.Fatalf("EmptyCall() rpc failed: %v", err) 140 } 141 }) 142 } 143 } 144 145 // Tests the `grpc.CallAuthority` option with TLS credentials. This test verifies 146 // that the RPC fails with `UNAVAILABLE` status code and doesn't reach the server 147 // when an incorrect authority is used. 148 func (s) TestIncorrectAuthorityWithTLS(t *testing.T) { 149 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 150 if err != nil { 151 t.Fatalf("Failed to load key pair: %s", err) 152 } 153 creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") 154 if err != nil { 155 t.Fatalf("Failed to create credentials %v", err) 156 } 157 158 serverCalled := make(chan struct{}) 159 ss := &stubserver.StubServer{ 160 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 161 close(serverCalled) 162 return nil, nil 163 }, 164 } 165 if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil { 166 t.Fatalf("Error starting endpoint server: %v", err) 167 } 168 defer ss.Stop() 169 170 cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds)) 171 if err != nil { 172 t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) 173 } 174 defer cc.Close() 175 176 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 177 defer cancel() 178 179 const authority = "auth.example.com" 180 if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.Unavailable { 181 t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.Unavailable) 182 } 183 select { 184 case <-serverCalled: 185 t.Fatalf("Server handler should not have been called") 186 case <-time.After(defaultTestShortTimeout): 187 } 188 } 189 190 // testAuthInfoNoValidator implements only credentials.AuthInfo and not 191 // credentials.AuthorityValidator. 192 type testAuthInfoNoValidator struct{} 193 194 // AuthType returns the authentication type. 195 func (testAuthInfoNoValidator) AuthType() string { 196 return "test" 197 } 198 199 // testAuthInfoWithValidator implements both credentials.AuthInfo and 200 // credentials.AuthorityValidator. 201 type testAuthInfoWithValidator struct { 202 validAuthority string 203 } 204 205 // AuthType returns the authentication type. 206 func (testAuthInfoWithValidator) AuthType() string { 207 return "test" 208 } 209 210 // ValidateAuthority implements credentials.AuthorityValidator. 211 func (v testAuthInfoWithValidator) ValidateAuthority(authority string) error { 212 if authority == v.validAuthority { 213 return nil 214 } 215 return fmt.Errorf("invalid authority %q, want %q", authority, v.validAuthority) 216 } 217 218 // testCreds is a test TransportCredentials that can optionally support 219 // authority validation. 220 type testCreds struct { 221 authority string 222 } 223 224 // ClientHandshake performs the client-side handshake. 225 func (c *testCreds) ClientHandshake(_ context.Context, _ string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 226 if c.authority != "" { 227 return rawConn, testAuthInfoWithValidator{validAuthority: c.authority}, nil 228 } 229 return rawConn, testAuthInfoNoValidator{}, nil 230 } 231 232 // ServerHandshake performs the server-side handshake. 233 func (c *testCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 234 if c.authority != "" { 235 return rawConn, testAuthInfoWithValidator{validAuthority: c.authority}, nil 236 } 237 return rawConn, testAuthInfoNoValidator{}, nil 238 } 239 240 // Clone creates a copy of testCreds. 241 func (c *testCreds) Clone() credentials.TransportCredentials { 242 return &testCreds{authority: c.authority} 243 } 244 245 // Info provides protocol information. 246 func (c *testCreds) Info() credentials.ProtocolInfo { 247 return credentials.ProtocolInfo{} 248 } 249 250 // OverrideServerName overrides the server name used for verification. 251 func (c *testCreds) OverrideServerName(string) error { 252 return nil 253 } 254 255 // TestAuthorityValidationFailureWithCustomCreds tests the `grpc.CallAuthority` 256 // call option using custom credentials. It covers two failure scenarios: 257 // - The credentials implement AuthorityValidator but authority used to override 258 // is not valid. 259 // - The credentials do not implement AuthorityValidator, but an authority 260 // override is specified. 261 // In both cases, the RPC is expected to fail with an `UNAVAILABLE` status code. 262 func (s) TestAuthorityValidationFailureWithCustomCreds(t *testing.T) { 263 tests := []struct { 264 name string 265 creds credentials.TransportCredentials 266 authority string 267 }{ 268 { 269 name: "IncorrectAuthorityWithFakeCreds", 270 authority: "auth.example.com", 271 creds: &testCreds{authority: "auth.test.example.com"}, 272 }, 273 { 274 name: "FakeCredsWithNoAuthValidator", 275 creds: &testCreds{}, 276 authority: "auth.test.example.com", 277 }, 278 } 279 for _, tt := range tests { 280 t.Run(tt.name, func(t *testing.T) { 281 serverCalled := make(chan struct{}) 282 ss := stubserver.StubServer{ 283 EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { 284 close(serverCalled) 285 return nil, nil 286 }, 287 } 288 if err := ss.StartServer(); err != nil { 289 t.Fatalf("Failed to start stub server: %v", err) 290 } 291 defer ss.Stop() 292 293 cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(tt.creds)) 294 if err != nil { 295 t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) 296 } 297 defer cc.Close() 298 299 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 300 defer cancel() 301 if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.authority)); status.Code(err) != codes.Unavailable { 302 t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.Unavailable) 303 } 304 select { 305 case <-serverCalled: 306 t.Fatalf("Server should not have been called") 307 case <-time.After(defaultTestShortTimeout): 308 } 309 }) 310 } 311 312 } 313 314 // TestCorrectAuthorityWithCustomCreds tests the `grpc.CallAuthority` call 315 // option using custom credentials. It verifies that the provided authority is 316 // correctly propagated to the server when a correct authority is used. 317 func (s) TestCorrectAuthorityWithCustomCreds(t *testing.T) { 318 const authority = "auth.test.example.com" 319 creds := &testCreds{authority: "auth.test.example.com"} 320 ss := stubserver.StubServer{ 321 EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 322 if err := authorityChecker(ctx, authority); err != nil { 323 return nil, err 324 } 325 return &testpb.Empty{}, nil 326 }, 327 } 328 if err := ss.StartServer(); err != nil { 329 t.Fatalf("Failed to start stub server: %v", err) 330 } 331 defer ss.Stop() 332 333 cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds)) 334 if err != nil { 335 t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) 336 } 337 defer cc.Close() 338 339 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 340 defer cancel() 341 if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.OK { 342 t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.OK) 343 } 344 }