google.golang.org/grpc@v1.62.1/credentials/xds/xds_server_test.go (about) 1 /* 2 * 3 * Copyright 2020 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package xds 20 21 import ( 22 "context" 23 "crypto/tls" 24 "crypto/x509" 25 "errors" 26 "fmt" 27 "net" 28 "os" 29 "strings" 30 "testing" 31 "time" 32 33 "google.golang.org/grpc/credentials" 34 "google.golang.org/grpc/credentials/tls/certprovider" 35 xdsinternal "google.golang.org/grpc/internal/credentials/xds" 36 "google.golang.org/grpc/testdata" 37 ) 38 39 func makeClientTLSConfig(t *testing.T, mTLS bool) *tls.Config { 40 t.Helper() 41 42 pemData, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) 43 if err != nil { 44 t.Fatal(err) 45 } 46 roots := x509.NewCertPool() 47 roots.AppendCertsFromPEM(pemData) 48 49 var certs []tls.Certificate 50 if mTLS { 51 cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem")) 52 if err != nil { 53 t.Fatal(err) 54 } 55 certs = append(certs, cert) 56 } 57 58 return &tls.Config{ 59 Certificates: certs, 60 RootCAs: roots, 61 ServerName: "*.test.example.com", 62 // Setting this to true completely turns off the certificate validation 63 // on the client side. So, the client side handshake always seems to 64 // succeed. But if we want to turn this ON, we will need to generate 65 // certificates which work with localhost, or supply a custom 66 // verification function. So, the server credentials tests will rely 67 // solely on the success/failure of the server-side handshake. 68 InsecureSkipVerify: true, 69 } 70 } 71 72 // Helper function to create a real TLS server credentials which is used as 73 // fallback credentials from multiple tests. 74 func makeFallbackServerCreds(t *testing.T) credentials.TransportCredentials { 75 t.Helper() 76 77 creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) 78 if err != nil { 79 t.Fatal(err) 80 } 81 return creds 82 } 83 84 type errorCreds struct { 85 credentials.TransportCredentials 86 } 87 88 // TestServerCredsWithoutFallback verifies that the call to 89 // NewServerCredentials() fails when no fallback is specified. 90 func (s) TestServerCredsWithoutFallback(t *testing.T) { 91 if _, err := NewServerCredentials(ServerOptions{}); err == nil { 92 t.Fatal("NewServerCredentials() succeeded without specifying fallback") 93 } 94 } 95 96 type wrapperConn struct { 97 net.Conn 98 xdsHI *xdsinternal.HandshakeInfo 99 deadline time.Time 100 handshakeInfoErr error 101 } 102 103 func (wc *wrapperConn) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) { 104 return wc.xdsHI, wc.handshakeInfoErr 105 } 106 107 func (wc *wrapperConn) GetDeadline() time.Time { 108 return wc.deadline 109 } 110 111 func newWrappedConn(conn net.Conn, xdsHI *xdsinternal.HandshakeInfo, deadline time.Time) *wrapperConn { 112 return &wrapperConn{Conn: conn, xdsHI: xdsHI, deadline: deadline} 113 } 114 115 // TestServerCredsInvalidHandshakeInfo verifies scenarios where the passed in 116 // HandshakeInfo is invalid because it does not contain the expected certificate 117 // providers. 118 func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) { 119 opts := ServerOptions{FallbackCreds: &errorCreds{}} 120 creds, err := NewServerCredentials(opts) 121 if err != nil { 122 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 123 } 124 125 info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil, nil, false) 126 conn := newWrappedConn(nil, info, time.Time{}) 127 if _, _, err := creds.ServerHandshake(conn); err == nil { 128 t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo") 129 } 130 } 131 132 // TestServerCredsProviderFailure verifies the cases where an expected 133 // certificate provider is missing in the HandshakeInfo value in the context. 134 func (s) TestServerCredsProviderFailure(t *testing.T) { 135 opts := ServerOptions{FallbackCreds: &errorCreds{}} 136 creds, err := NewServerCredentials(opts) 137 if err != nil { 138 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 139 } 140 141 tests := []struct { 142 desc string 143 rootProvider certprovider.Provider 144 identityProvider certprovider.Provider 145 wantErr string 146 }{ 147 { 148 desc: "erroring identity provider", 149 identityProvider: &fakeProvider{err: errors.New("identity provider error")}, 150 wantErr: "identity provider error", 151 }, 152 { 153 desc: "erroring root provider", 154 identityProvider: &fakeProvider{km: &certprovider.KeyMaterial{}}, 155 rootProvider: &fakeProvider{err: errors.New("root provider error")}, 156 wantErr: "root provider error", 157 }, 158 } 159 for _, test := range tests { 160 t.Run(test.desc, func(t *testing.T) { 161 info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, false) 162 conn := newWrappedConn(nil, info, time.Time{}) 163 if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) { 164 t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr) 165 } 166 }) 167 } 168 } 169 170 // TestServerCredsHandshake_XDSHandshakeInfoError verifies the case where the 171 // call to XDSHandshakeInfo() from the ServerHandshake() method returns an 172 // error, and the test verifies that the ServerHandshake() fails with the 173 // expected error. 174 func (s) TestServerCredsHandshake_XDSHandshakeInfoError(t *testing.T) { 175 opts := ServerOptions{FallbackCreds: &errorCreds{}} 176 creds, err := NewServerCredentials(opts) 177 if err != nil { 178 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 179 } 180 181 // Create a test server which uses the xDS server credentials created above 182 // to perform TLS handshake on incoming connections. 183 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 184 // Create a wrapped conn which returns a nil HandshakeInfo and a non-nil error. 185 conn := newWrappedConn(rawConn, nil, time.Now().Add(defaultTestTimeout)) 186 hiErr := errors.New("xdsHandshakeInfo error") 187 conn.handshakeInfoErr = hiErr 188 189 // Invoke the ServerHandshake() method on the xDS credentials and verify 190 // that the error returned by the XDSHandshakeInfo() method on the 191 // wrapped conn is returned here. 192 _, _, err := creds.ServerHandshake(conn) 193 if !errors.Is(err, hiErr) { 194 return handshakeResult{err: fmt.Errorf("ServerHandshake() returned err: %v, wantErr: %v", err, hiErr)} 195 } 196 return handshakeResult{} 197 }) 198 defer ts.stop() 199 200 // Dial the test server, but don't trigger the TLS handshake. This will 201 // cause ServerHandshake() to fail. 202 rawConn, err := net.Dial("tcp", ts.address) 203 if err != nil { 204 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 205 } 206 defer rawConn.Close() 207 208 // Read handshake result from the testServer which will return an error if 209 // the handshake succeeded. 210 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 211 defer cancel() 212 val, err := ts.hsResult.Receive(ctx) 213 if err != nil { 214 t.Fatalf("testServer failed to return handshake result: %v", err) 215 } 216 hsr := val.(handshakeResult) 217 if hsr.err != nil { 218 t.Fatalf("testServer handshake failure: %v", hsr.err) 219 } 220 } 221 222 // TestServerCredsHandshakeTimeout verifies the case where the client does not 223 // send required handshake data before the deadline set on the net.Conn passed 224 // to ServerHandshake(). 225 func (s) TestServerCredsHandshakeTimeout(t *testing.T) { 226 opts := ServerOptions{FallbackCreds: &errorCreds{}} 227 creds, err := NewServerCredentials(opts) 228 if err != nil { 229 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 230 } 231 232 // Create a test server which uses the xDS server credentials created above 233 // to perform TLS handshake on incoming connections. 234 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 235 hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true) 236 237 // Create a wrapped conn which can return the HandshakeInfo created 238 // above with a very small deadline. 239 d := time.Now().Add(defaultTestShortTimeout) 240 rawConn.SetDeadline(d) 241 conn := newWrappedConn(rawConn, hi, d) 242 243 // ServerHandshake() on the xDS credentials is expected to fail. 244 if _, _, err := creds.ServerHandshake(conn); err == nil { 245 return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to timeout")} 246 } 247 return handshakeResult{} 248 }) 249 defer ts.stop() 250 251 // Dial the test server, but don't trigger the TLS handshake. This will 252 // cause ServerHandshake() to fail. 253 rawConn, err := net.Dial("tcp", ts.address) 254 if err != nil { 255 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 256 } 257 defer rawConn.Close() 258 259 // Read handshake result from the testServer and expect a failure result. 260 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 261 defer cancel() 262 val, err := ts.hsResult.Receive(ctx) 263 if err != nil { 264 t.Fatalf("testServer failed to return handshake result: %v", err) 265 } 266 hsr := val.(handshakeResult) 267 if hsr.err != nil { 268 t.Fatalf("testServer handshake failure: %v", hsr.err) 269 } 270 } 271 272 // TestServerCredsHandshakeFailure verifies the case where the server-side 273 // credentials uses a root certificate which does not match the certificate 274 // presented by the client, and hence the handshake must fail. 275 func (s) TestServerCredsHandshakeFailure(t *testing.T) { 276 opts := ServerOptions{FallbackCreds: &errorCreds{}} 277 creds, err := NewServerCredentials(opts) 278 if err != nil { 279 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 280 } 281 282 // Create a test server which uses the xDS server credentials created above 283 // to perform TLS handshake on incoming connections. 284 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 285 // Create a HandshakeInfo which has a root provider which does not match 286 // the certificate sent by the client. 287 hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true) 288 289 // Create a wrapped conn which can return the HandshakeInfo and 290 // configured deadline to the xDS credentials' ServerHandshake() 291 // method. 292 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) 293 294 // ServerHandshake() on the xDS credentials is expected to fail. 295 if _, _, err := creds.ServerHandshake(conn); err == nil { 296 return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")} 297 } 298 return handshakeResult{} 299 }) 300 defer ts.stop() 301 302 // Dial the test server, and trigger the TLS handshake. 303 rawConn, err := net.Dial("tcp", ts.address) 304 if err != nil { 305 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 306 } 307 defer rawConn.Close() 308 tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true)) 309 tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout)) 310 if err := tlsConn.Handshake(); err != nil { 311 t.Fatal(err) 312 } 313 314 // Read handshake result from the testServer which will return an error if 315 // the handshake succeeded. 316 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 317 defer cancel() 318 val, err := ts.hsResult.Receive(ctx) 319 if err != nil { 320 t.Fatalf("testServer failed to return handshake result: %v", err) 321 } 322 hsr := val.(handshakeResult) 323 if hsr.err != nil { 324 t.Fatalf("testServer handshake failure: %v", hsr.err) 325 } 326 } 327 328 // TestServerCredsHandshakeSuccess verifies success handshake cases. 329 func (s) TestServerCredsHandshakeSuccess(t *testing.T) { 330 tests := []struct { 331 desc string 332 fallbackCreds credentials.TransportCredentials 333 rootProvider certprovider.Provider 334 identityProvider certprovider.Provider 335 requireClientCert bool 336 }{ 337 { 338 desc: "fallback", 339 fallbackCreds: makeFallbackServerCreds(t), 340 }, 341 { 342 desc: "TLS", 343 fallbackCreds: &errorCreds{}, 344 identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), 345 }, 346 { 347 desc: "mTLS", 348 fallbackCreds: &errorCreds{}, 349 identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), 350 rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"), 351 requireClientCert: true, 352 }, 353 } 354 355 for _, test := range tests { 356 t.Run(test.desc, func(t *testing.T) { 357 // Create an xDS server credentials. 358 opts := ServerOptions{FallbackCreds: test.fallbackCreds} 359 creds, err := NewServerCredentials(opts) 360 if err != nil { 361 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 362 } 363 364 // Create a test server which uses the xDS server credentials 365 // created above to perform TLS handshake on incoming connections. 366 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 367 // Create a HandshakeInfo with information from the test table. 368 hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert) 369 370 // Create a wrapped conn which can return the HandshakeInfo and 371 // configured deadline to the xDS credentials' ServerHandshake() 372 // method. 373 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) 374 375 // Invoke the ServerHandshake() method on the xDS credentials 376 // and make some sanity checks before pushing the result for 377 // inspection by the main test body. 378 _, ai, err := creds.ServerHandshake(conn) 379 if err != nil { 380 return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)} 381 } 382 if ai.AuthType() != "tls" { 383 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")} 384 } 385 info, ok := ai.(credentials.TLSInfo) 386 if !ok { 387 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})} 388 } 389 return handshakeResult{connState: info.State} 390 }) 391 defer ts.stop() 392 393 // Dial the test server, and trigger the TLS handshake. 394 rawConn, err := net.Dial("tcp", ts.address) 395 if err != nil { 396 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 397 } 398 defer rawConn.Close() 399 tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, test.requireClientCert)) 400 tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout)) 401 if err := tlsConn.Handshake(); err != nil { 402 t.Fatal(err) 403 } 404 405 // Read the handshake result from the testServer which contains the 406 // TLS connection state on the server-side and compare it with the 407 // one received on the client-side. 408 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 409 defer cancel() 410 val, err := ts.hsResult.Receive(ctx) 411 if err != nil { 412 t.Fatalf("testServer failed to return handshake result: %v", err) 413 } 414 hsr := val.(handshakeResult) 415 if hsr.err != nil { 416 t.Fatalf("testServer handshake failure: %v", hsr.err) 417 } 418 419 // AuthInfo contains a variety of information. We only verify a 420 // subset here. This is the same subset which is verified in TLS 421 // credentials tests. 422 if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil { 423 t.Fatal(err) 424 } 425 }) 426 } 427 } 428 429 func (s) TestServerCredsProviderSwitch(t *testing.T) { 430 opts := ServerOptions{FallbackCreds: &errorCreds{}} 431 creds, err := NewServerCredentials(opts) 432 if err != nil { 433 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 434 } 435 436 // The first time the handshake function is invoked, it returns a 437 // HandshakeInfo which is expected to fail. Further invocations return a 438 // HandshakeInfo which is expected to succeed. 439 cnt := 0 440 // Create a test server which uses the xDS server credentials created above 441 // to perform TLS handshake on incoming connections. 442 ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult { 443 cnt++ 444 var hi *xdsinternal.HandshakeInfo 445 if cnt == 1 { 446 // Create a HandshakeInfo which has a root provider which does not match 447 // the certificate sent by the client. 448 hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true) 449 450 // Create a wrapped conn which can return the HandshakeInfo and 451 // configured deadline to the xDS credentials' ServerHandshake() 452 // method. 453 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) 454 455 // ServerHandshake() on the xDS credentials is expected to fail. 456 if _, _, err := creds.ServerHandshake(conn); err == nil { 457 return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")} 458 } 459 return handshakeResult{} 460 } 461 462 hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), nil, true) 463 464 // Create a wrapped conn which can return the HandshakeInfo and 465 // configured deadline to the xDS credentials' ServerHandshake() 466 // method. 467 conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout)) 468 469 // Invoke the ServerHandshake() method on the xDS credentials 470 // and make some sanity checks before pushing the result for 471 // inspection by the main test body. 472 _, ai, err := creds.ServerHandshake(conn) 473 if err != nil { 474 return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)} 475 } 476 if ai.AuthType() != "tls" { 477 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")} 478 } 479 info, ok := ai.(credentials.TLSInfo) 480 if !ok { 481 return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})} 482 } 483 return handshakeResult{connState: info.State} 484 }) 485 defer ts.stop() 486 487 for i := 0; i < 5; i++ { 488 // Dial the test server, and trigger the TLS handshake. 489 rawConn, err := net.Dial("tcp", ts.address) 490 if err != nil { 491 t.Fatalf("net.Dial(%s) failed: %v", ts.address, err) 492 } 493 defer rawConn.Close() 494 tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true)) 495 tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout)) 496 if err := tlsConn.Handshake(); err != nil { 497 t.Fatal(err) 498 } 499 500 // Read the handshake result from the testServer which contains the 501 // TLS connection state on the server-side and compare it with the 502 // one received on the client-side. 503 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 504 defer cancel() 505 val, err := ts.hsResult.Receive(ctx) 506 if err != nil { 507 t.Fatalf("testServer failed to return handshake result: %v", err) 508 } 509 hsr := val.(handshakeResult) 510 if hsr.err != nil { 511 t.Fatalf("testServer handshake failure: %v", hsr.err) 512 } 513 if i == 0 { 514 // We expect the first handshake to fail. So, we skip checks which 515 // compare connection state. 516 continue 517 } 518 // AuthInfo contains a variety of information. We only verify a 519 // subset here. This is the same subset which is verified in TLS 520 // credentials tests. 521 if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil { 522 t.Fatal(err) 523 } 524 } 525 } 526 527 // TestServerClone verifies the Clone() method on client credentials. 528 func (s) TestServerClone(t *testing.T) { 529 opts := ServerOptions{FallbackCreds: makeFallbackServerCreds(t)} 530 orig, err := NewServerCredentials(opts) 531 if err != nil { 532 t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err) 533 } 534 535 // The credsImpl does not have any exported fields, and it does not make 536 // sense to use any cmp options to look deep into. So, all we make sure here 537 // is that the cloned object points to a different location in memory. 538 if clone := orig.Clone(); clone == orig { 539 t.Fatal("return value from Clone() doesn't point to new credentials instance") 540 } 541 }