google.golang.org/grpc@v1.72.2/credentials/xds/xds_server_test.go (about) 1 /* 2 * 3 * Copyright 2020 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package xds 20 21 import ( 22 "context" 23 "crypto/tls" 24 "crypto/x509" 25 "errors" 26 "fmt" 27 "net" 28 "os" 29 "strings" 30 "testing" 31 "time" 32 33 "google.golang.org/grpc/credentials" 34 "google.golang.org/grpc/credentials/tls/certprovider" 35 xdsinternal "google.golang.org/grpc/internal/credentials/xds" 36 "google.golang.org/grpc/testdata" 37 ) 38 39 func makeClientTLSConfig(t *testing.T, mTLS bool) *tls.Config { 40 t.Helper() 41 42 pemData, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) 43 if err != nil { 44 t.Fatal(err) 45 } 46 roots := x509.NewCertPool() 47 roots.AppendCertsFromPEM(pemData) 48 49 var certs []tls.Certificate 50 if mTLS { 51 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) 52 if err != nil { 53 t.Fatal(err) 54 } 55 certs = append(certs, cert) 56 } 57 58 return &tls.Config{ 59 Certificates: certs, 60 RootCAs: roots, 61 ServerName: "*.test.example.com", 62 // Setting this to true completely turns off the certificate validation 63 // on the client side. So, the client side handshake always seems to 64 // succeed. But if we want to turn this ON, we will need to generate 65 // certificates which work with localhost, or supply a custom 66 // verification function. So, the server credentials tests will rely 67 // solely on the success/failure of the server-side handshake. 68 InsecureSkipVerify: true, 69 NextProtos: []string{"h2"}, 70 } 71 } 72 73 // Helper function to create a real TLS server credentials which is used as 74 // fallback credentials from multiple tests. 75 func makeFallbackServerCreds(t *testing.T) credentials.TransportCredentials { 76 t.Helper() 77 78 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 79 if err != nil { 80 t.Fatal(err) 81 } 82 return creds 83 } 84 85 type errorCreds struct { 86 credentials.TransportCredentials 87 } 88 89 // TestServerCredsWithoutFallback verifies that the call to 90 // NewServerCredentials() fails when no fallback is specified. 91 func (s) TestServerCredsWithoutFallback(t *testing.T) { 92 if _, err := NewServerCredentials(ServerOptions{}); err == nil { 93 t.Fatal("NewServerCredentials() succeeded without specifying fallback") 94 } 95 } 96 97 type wrapperConn struct { 98 net.Conn 99 xdsHI *xdsinternal.HandshakeInfo 100 deadline time.Time 101 handshakeInfoErr error 102 } 103 104 func (wc *wrapperConn) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) { 105 return wc.xdsHI, wc.handshakeInfoErr 106 } 107 108 func (wc *wrapperConn) GetDeadline() time.Time { 109 return wc.deadline 110 } 111 112 func newWrappedConn(conn net.Conn, xdsHI *xdsinternal.HandshakeInfo, deadline time.Time) *wrapperConn { 113 return &wrapperConn{Conn: conn, xdsHI: xdsHI, deadline: deadline} 114 } 115 116 // TestServerCredsInvalidHandshakeInfo verifies scenarios where the passed in 117 // HandshakeInfo is invalid because it does not contain the expected certificate 118 // providers. 119 func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) { 120 opts := ServerOptions{FallbackCreds: &errorCreds{}} 121 creds, err := NewServerCredentials(opts) 122 if err != nil { 123 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 124 } 125 126 info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil, nil, false) 127 conn := newWrappedConn(nil, info, time.Time{}) 128 if _, _, err := creds.ServerHandshake(conn); err == nil { 129 t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo") 130 } 131 } 132 133 // TestServerCredsProviderFailure verifies the cases where an expected 134 // certificate provider is missing in the HandshakeInfo value in the context. 135 func (s) TestServerCredsProviderFailure(t *testing.T) { 136 opts := ServerOptions{FallbackCreds: &errorCreds{}} 137 creds, err := NewServerCredentials(opts) 138 if err != nil { 139 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 140 } 141 142 tests := []struct { 143 desc string 144 rootProvider certprovider.Provider 145 identityProvider certprovider.Provider 146 wantErr string 147 }{ 148 { 149 desc: "erroring identity provider", 150 identityProvider: &fakeProvider{err: errors.New("identity provider error")}, 151 wantErr: "identity provider error", 152 }, 153 { 154 desc: "erroring root provider", 155 identityProvider: &fakeProvider{km: &certprovider.KeyMaterial{}}, 156 rootProvider: &fakeProvider{err: errors.New("root provider error")}, 157 wantErr: "root provider error", 158 }, 159 } 160 for _, test := range tests { 161 t.Run(test.desc, func(t *testing.T) { 162 info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, false) 163 conn := newWrappedConn(nil, info, time.Time{}) 164 if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) { 165 t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr) 166 } 167 }) 168 } 169 } 170 171 // TestServerCredsHandshake_XDSHandshakeInfoError verifies the case where the 172 // call to XDSHandshakeInfo() from the ServerHandshake() method returns an 173 // error, and the test verifies that the ServerHandshake() fails with the 174 // expected error. 175 func (s) TestServerCredsHandshake_XDSHandshakeInfoError(t *testing.T) { 176 opts := ServerOptions{FallbackCreds: &errorCreds{}} 177 creds, err := NewServerCredentials(opts) 178 if err != nil { 179 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 180 } 181 182 // Create a test server which uses the xDS server credentials created above 183 // to perform TLS handshake on incoming connections. 184 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 185 // Create a wrapped conn which returns a nil HandshakeInfo and a non-nil error. 186 conn := newWrappedConn(rawConn, nil, time.Now().Add(defaultTestTimeout)) 187 hiErr := errors.New("xdsHandshakeInfo error") 188 conn.handshakeInfoErr = hiErr 189 190 // Invoke the ServerHandshake() method on the xDS credentials and verify 191 // that the error returned by the XDSHandshakeInfo() method on the 192 // wrapped conn is returned here. 193 _, _, err := creds.ServerHandshake(conn) 194 if !errors.Is(err, hiErr) { 195 return handshakeResult{err: fmt.Errorf("ServerHandshake() returned err: %v, wantErr: %v", err, hiErr)} 196 } 197 return handshakeResult{} 198 }) 199 defer ts.stop() 200 201 // Dial the test server, but don't trigger the TLS handshake. This will 202 // cause ServerHandshake() to fail. 203 rawConn, err := net.Dial("tcp", ts.address) 204 if err != nil { 205 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 206 } 207 defer rawConn.Close() 208 209 // Read handshake result from the testServer which will return an error if 210 // the handshake succeeded. 211 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 212 defer cancel() 213 val, err := ts.hsResult.Receive(ctx) 214 if err != nil { 215 t.Fatalf("testServer failed to return handshake result: %v", err) 216 } 217 hsr := val.(handshakeResult) 218 if hsr.err != nil { 219 t.Fatalf("testServer handshake failure: %v", hsr.err) 220 } 221 } 222 223 // TestServerCredsHandshakeTimeout verifies the case where the client does not 224 // send required handshake data before the deadline set on the net.Conn passed 225 // to ServerHandshake(). 226 func (s) TestServerCredsHandshakeTimeout(t *testing.T) { 227 opts := ServerOptions{FallbackCreds: &errorCreds{}} 228 creds, err := NewServerCredentials(opts) 229 if err != nil { 230 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 231 } 232 233 // Create a test server which uses the xDS server credentials created above 234 // to perform TLS handshake on incoming connections. 235 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 236 hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true) 237 238 // Create a wrapped conn which can return the HandshakeInfo created 239 // above with a very small deadline. 240 d := time.Now().Add(defaultTestShortTimeout) 241 rawConn.SetDeadline(d) 242 conn := newWrappedConn(rawConn, hi, d) 243 244 // ServerHandshake() on the xDS credentials is expected to fail. 245 if _, _, err := creds.ServerHandshake(conn); err == nil { 246 return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to timeout")} 247 } 248 return handshakeResult{} 249 }) 250 defer ts.stop() 251 252 // Dial the test server, but don't trigger the TLS handshake. This will 253 // cause ServerHandshake() to fail. 254 rawConn, err := net.Dial("tcp", ts.address) 255 if err != nil { 256 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 257 } 258 defer rawConn.Close() 259 260 // Read handshake result from the testServer and expect a failure result. 261 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 262 defer cancel() 263 val, err := ts.hsResult.Receive(ctx) 264 if err != nil { 265 t.Fatalf("testServer failed to return handshake result: %v", err) 266 } 267 hsr := val.(handshakeResult) 268 if hsr.err != nil { 269 t.Fatalf("testServer handshake failure: %v", hsr.err) 270 } 271 } 272 273 // TestServerCredsHandshakeFailure verifies the case where the server-side 274 // credentials uses a root certificate which does not match the certificate 275 // presented by the client, and hence the handshake must fail. 276 func (s) TestServerCredsHandshakeFailure(t *testing.T) { 277 opts := ServerOptions{FallbackCreds: &errorCreds{}} 278 creds, err := NewServerCredentials(opts) 279 if err != nil { 280 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 281 } 282 283 // Create a test server which uses the xDS server credentials created above 284 // to perform TLS handshake on incoming connections. 285 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 286 // Create a HandshakeInfo which has a root provider which does not match 287 // the certificate sent by the client. 288 hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true) 289 290 // Create a wrapped conn which can return the HandshakeInfo and 291 // configured deadline to the xDS credentials' ServerHandshake() 292 // method. 293 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) 294 295 // ServerHandshake() on the xDS credentials is expected to fail. 296 if _, _, err := creds.ServerHandshake(conn); err == nil { 297 return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")} 298 } 299 return handshakeResult{} 300 }) 301 defer ts.stop() 302 303 // Dial the test server, and trigger the TLS handshake. 304 rawConn, err := net.Dial("tcp", ts.address) 305 if err != nil { 306 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 307 } 308 defer rawConn.Close() 309 tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true)) 310 tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout)) 311 if err := tlsConn.Handshake(); err != nil { 312 t.Fatal(err) 313 } 314 315 // Read handshake result from the testServer which will return an error if 316 // the handshake succeeded. 317 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 318 defer cancel() 319 val, err := ts.hsResult.Receive(ctx) 320 if err != nil { 321 t.Fatalf("testServer failed to return handshake result: %v", err) 322 } 323 hsr := val.(handshakeResult) 324 if hsr.err != nil { 325 t.Fatalf("testServer handshake failure: %v", hsr.err) 326 } 327 } 328 329 // TestServerCredsHandshakeSuccess verifies success handshake cases. 330 func (s) TestServerCredsHandshakeSuccess(t *testing.T) { 331 tests := []struct { 332 desc string 333 fallbackCreds credentials.TransportCredentials 334 rootProvider certprovider.Provider 335 identityProvider certprovider.Provider 336 requireClientCert bool 337 }{ 338 { 339 desc: "fallback", 340 fallbackCreds: makeFallbackServerCreds(t), 341 }, 342 { 343 desc: "TLS", 344 fallbackCreds: &errorCreds{}, 345 identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), 346 }, 347 { 348 desc: "mTLS", 349 fallbackCreds: &errorCreds{}, 350 identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), 351 rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"), 352 requireClientCert: true, 353 }, 354 } 355 356 for _, test := range tests { 357 t.Run(test.desc, func(t *testing.T) { 358 // Create an xDS server credentials. 359 opts := ServerOptions{FallbackCreds: test.fallbackCreds} 360 creds, err := NewServerCredentials(opts) 361 if err != nil { 362 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 363 } 364 365 // Create a test server which uses the xDS server credentials 366 // created above to perform TLS handshake on incoming connections. 367 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 368 // Create a HandshakeInfo with information from the test table. 369 hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert) 370 371 // Create a wrapped conn which can return the HandshakeInfo and 372 // configured deadline to the xDS credentials' ServerHandshake() 373 // method. 374 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) 375 376 // Invoke the ServerHandshake() method on the xDS credentials 377 // and make some sanity checks before pushing the result for 378 // inspection by the main test body. 379 _, ai, err := creds.ServerHandshake(conn) 380 if err != nil { 381 return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)} 382 } 383 if ai.AuthType() != "tls" { 384 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")} 385 } 386 info, ok := ai.(credentials.TLSInfo) 387 if !ok { 388 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})} 389 } 390 return handshakeResult{connState: info.State} 391 }) 392 defer ts.stop() 393 394 // Dial the test server, and trigger the TLS handshake. 395 rawConn, err := net.Dial("tcp", ts.address) 396 if err != nil { 397 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 398 } 399 defer rawConn.Close() 400 tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, test.requireClientCert)) 401 tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout)) 402 if err := tlsConn.Handshake(); err != nil { 403 t.Fatal(err) 404 } 405 406 // Read the handshake result from the testServer which contains the 407 // TLS connection state on the server-side and compare it with the 408 // one received on the client-side. 409 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 410 defer cancel() 411 val, err := ts.hsResult.Receive(ctx) 412 if err != nil { 413 t.Fatalf("testServer failed to return handshake result: %v", err) 414 } 415 hsr := val.(handshakeResult) 416 if hsr.err != nil { 417 t.Fatalf("testServer handshake failure: %v", hsr.err) 418 } 419 420 // AuthInfo contains a variety of information. We only verify a 421 // subset here. This is the same subset which is verified in TLS 422 // credentials tests. 423 if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil { 424 t.Fatal(err) 425 } 426 }) 427 } 428 } 429 430 func (s) TestServerCredsProviderSwitch(t *testing.T) { 431 opts := ServerOptions{FallbackCreds: &errorCreds{}} 432 creds, err := NewServerCredentials(opts) 433 if err != nil { 434 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 435 } 436 437 // The first time the handshake function is invoked, it returns a 438 // HandshakeInfo which is expected to fail. Further invocations return a 439 // HandshakeInfo which is expected to succeed. 440 cnt := 0 441 // Create a test server which uses the xDS server credentials created above 442 // to perform TLS handshake on incoming connections. 443 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 444 cnt++ 445 var hi *xdsinternal.HandshakeInfo 446 if cnt == 1 { 447 // Create a HandshakeInfo which has a root provider which does not match 448 // the certificate sent by the client. 449 hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true) 450 451 // Create a wrapped conn which can return the HandshakeInfo and 452 // configured deadline to the xDS credentials' ServerHandshake() 453 // method. 454 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) 455 456 // ServerHandshake() on the xDS credentials is expected to fail. 457 if _, _, err := creds.ServerHandshake(conn); err == nil { 458 return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")} 459 } 460 return handshakeResult{} 461 } 462 463 hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), nil, true) 464 465 // Create a wrapped conn which can return the HandshakeInfo and 466 // configured deadline to the xDS credentials' ServerHandshake() 467 // method. 468 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) 469 470 // Invoke the ServerHandshake() method on the xDS credentials 471 // and make some sanity checks before pushing the result for 472 // inspection by the main test body. 473 _, ai, err := creds.ServerHandshake(conn) 474 if err != nil { 475 return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)} 476 } 477 if ai.AuthType() != "tls" { 478 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")} 479 } 480 info, ok := ai.(credentials.TLSInfo) 481 if !ok { 482 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})} 483 } 484 return handshakeResult{connState: info.State} 485 }) 486 defer ts.stop() 487 488 for i := 0; i < 5; i++ { 489 // Dial the test server, and trigger the TLS handshake. 490 rawConn, err := net.Dial("tcp", ts.address) 491 if err != nil { 492 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 493 } 494 defer rawConn.Close() 495 tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true)) 496 tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout)) 497 if err := tlsConn.Handshake(); err != nil { 498 t.Fatal(err) 499 } 500 501 // Read the handshake result from the testServer which contains the 502 // TLS connection state on the server-side and compare it with the 503 // one received on the client-side. 504 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 505 defer cancel() 506 val, err := ts.hsResult.Receive(ctx) 507 if err != nil { 508 t.Fatalf("testServer failed to return handshake result: %v", err) 509 } 510 hsr := val.(handshakeResult) 511 if hsr.err != nil { 512 t.Fatalf("testServer handshake failure: %v", hsr.err) 513 } 514 if i == 0 { 515 // We expect the first handshake to fail. So, we skip checks which 516 // compare connection state. 517 continue 518 } 519 // AuthInfo contains a variety of information. We only verify a 520 // subset here. This is the same subset which is verified in TLS 521 // credentials tests. 522 if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil { 523 t.Fatal(err) 524 } 525 } 526 } 527 528 // TestServerClone verifies the Clone() method on client credentials. 529 func (s) TestServerClone(t *testing.T) { 530 opts := ServerOptions{FallbackCreds: makeFallbackServerCreds(t)} 531 orig, err := NewServerCredentials(opts) 532 if err != nil { 533 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 534 } 535 536 // The credsImpl does not have any exported fields, and it does not make 537 // sense to use any cmp options to look deep into. So, all we make sure here 538 // is that the cloned object points to a different location in memory. 539 if clone := orig.Clone(); clone == orig { 540 t.Fatal("return value from Clone() doesn't point to new credentials instance") 541 } 542 }