github.com/osdi23p228/fabric@v0.0.0-20221218062954-77808885f5db/internal/pkg/comm/server_test.go (about) 1 /* 2 Copyright IBM Corp. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package comm_test 8 9 import ( 10 "bytes" 11 "context" 12 "crypto/tls" 13 "crypto/x509" 14 "fmt" 15 "io" 16 "io/ioutil" 17 "log" 18 "net" 19 "path/filepath" 20 "sync/atomic" 21 "testing" 22 "time" 23 24 "github.com/osdi23p228/fabric/common/crypto/tlsgen" 25 "github.com/osdi23p228/fabric/internal/pkg/comm" 26 "github.com/osdi23p228/fabric/internal/pkg/comm/testpb" 27 "github.com/pkg/errors" 28 "github.com/stretchr/testify/assert" 29 "google.golang.org/grpc" 30 "google.golang.org/grpc/codes" 31 "google.golang.org/grpc/credentials" 32 "google.golang.org/grpc/status" 33 ) 34 35 // Embedded certificates for testing 36 // The self-signed cert expires in 2028 37 var selfSignedKeyPEM = `-----BEGIN EC PRIVATE KEY----- 38 MHcCAQEEIMLemLh3+uDzww1pvqP6Xj2Z0Kc6yqf3RxyfTBNwRuuyoAoGCCqGSM49 39 AwEHoUQDQgAEDB3l94vM7EqKr2L/vhqU5IsEub0rviqCAaWGiVAPp3orb/LJqFLS 40 yo/k60rhUiir6iD4S4pb5TEb2ouWylQI3A== 41 -----END EC PRIVATE KEY----- 42 ` 43 var selfSignedCertPEM = `-----BEGIN CERTIFICATE----- 44 MIIBdDCCARqgAwIBAgIRAKCiW5r6W32jGUn+l9BORMAwCgYIKoZIzj0EAwIwEjEQ 45 MA4GA1UEChMHQWNtZSBDbzAeFw0xODA4MjExMDI1MzJaFw0yODA4MTgxMDI1MzJa 46 MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQM 47 HeX3i8zsSoqvYv++GpTkiwS5vSu+KoIBpYaJUA+neitv8smoUtLKj+TrSuFSKKvq 48 IPhLilvlMRvai5bKVAjco1EwTzAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYI 49 KwYBBQUHAwEwDAYDVR0TAQH/BAIwADAaBgNVHREEEzARgglsb2NhbGhvc3SHBH8A 50 AAEwCgYIKoZIzj0EAwIDSAAwRQIgOaYc3pdGf2j0uXRyvdBJq2PlK9FkgvsUjXOT 51 bQ9fWRkCIQCr1FiRRzapgtrnttDn3O2fhLlbrw67kClzY8pIIN42Qw== 52 -----END CERTIFICATE----- 53 ` 54 55 var badPEM = `-----BEGIN CERTIFICATE----- 56 MIICRDCCAemgAwIBAgIJALwW//dz2ZBvMAoGCCqGSM49BAMCMH4xCzAJBgNVBAYT 57 AlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1TYW4gRnJhbmNpc2Nv 58 MRgwFgYDVQQKDA9MaW51eEZvdW5kYXRpb24xFDASBgNVBAsMC0h5cGVybGVkZ2Vy 59 MRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMTYxMjA0MjIzMDE4WhcNMjYxMjAyMjIz 60 MDE4WjB+MQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UE 61 BwwNU2FuIEZyYW5jaXNjbzEYMBYGA1UECgwPTGludXhGb3VuZGF0aW9uMRQwEgYD 62 VQQLDAtIeXBlcmxlZGdlcjESMBAGA1UEAwwJbG9jYWxob3N0MFkwEwYHKoZIzj0C 63 -----END CERTIFICATE----- 64 ` 65 66 var testOrgs = []testOrg{} 67 68 func init() { 69 //load up crypto material for test orgs 70 for i := 1; i <= numOrgs; i++ { 71 testOrg, err := loadOrg(i) 72 if err != nil { 73 log.Fatalf("Failed to load test organizations due to error: %s", err.Error()) 74 } 75 testOrgs = append(testOrgs, testOrg) 76 } 77 } 78 79 // test servers to be registered with the GRPCServer 80 type emptyServiceServer struct{} 81 82 func (ess *emptyServiceServer) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, error) { 83 return new(testpb.Empty), nil 84 } 85 86 func (esss *emptyServiceServer) EmptyStream(stream testpb.EmptyService_EmptyStreamServer) error { 87 for { 88 _, err := stream.Recv() 89 if err == io.EOF { 90 return nil 91 } 92 if err != nil { 93 return err 94 } 95 if err := stream.Send(&testpb.Empty{}); err != nil { 96 return err 97 } 98 99 } 100 } 101 102 // invoke the EmptyCall RPC 103 func invokeEmptyCall(address string, dialOptions ...grpc.DialOption) (*testpb.Empty, error) { 104 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 105 defer cancel() 106 //create GRPC client conn 107 clientConn, err := grpc.DialContext(ctx, address, dialOptions...) 108 if err != nil { 109 return nil, err 110 } 111 defer clientConn.Close() 112 113 //create GRPC client 114 client := testpb.NewEmptyServiceClient(clientConn) 115 116 //invoke service 117 empty, err := client.EmptyCall(context.Background(), new(testpb.Empty)) 118 if err != nil { 119 return nil, err 120 } 121 122 return empty, nil 123 } 124 125 // invoke the EmptyStream RPC 126 func invokeEmptyStream(address string, dialOptions ...grpc.DialOption) (*testpb.Empty, error) { 127 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 128 defer cancel() 129 //create GRPC client conn 130 clientConn, err := grpc.DialContext(ctx, address, dialOptions...) 131 if err != nil { 132 return nil, err 133 } 134 defer clientConn.Close() 135 136 stream, err := testpb.NewEmptyServiceClient(clientConn).EmptyStream(ctx) 137 if err != nil { 138 return nil, err 139 } 140 141 var msg *testpb.Empty 142 var streamErr error 143 144 waitc := make(chan struct{}) 145 go func() { 146 for { 147 in, err := stream.Recv() 148 if err == io.EOF { 149 close(waitc) 150 return 151 } 152 if err != nil { 153 streamErr = err 154 close(waitc) 155 return 156 } 157 msg = in 158 } 159 }() 160 161 // TestServerInterceptors adds an interceptor that does not call the target 162 // StreamHandler and returns an error so Send can return with an io.EOF since 163 // the server side has already terminated. Whether or not we get an error 164 // depends on timing. 165 err = stream.Send(&testpb.Empty{}) 166 if err != nil && err != io.EOF { 167 return nil, fmt.Errorf("stream send failed: %s", err) 168 } 169 170 stream.CloseSend() 171 <-waitc 172 return msg, streamErr 173 } 174 175 const ( 176 numOrgs = 2 177 numChildOrgs = 2 178 numServerCerts = 2 179 ) 180 181 // string for cert filenames 182 var ( 183 orgCACert = filepath.Join("testdata", "certs", "Org%d-cert.pem") 184 orgServerKey = filepath.Join("testdata", "certs", "Org%d-server%d-key.pem") 185 orgServerCert = filepath.Join("testdata", "certs", "Org%d-server%d-cert.pem") 186 orgClientKey = filepath.Join("testdata", "certs", "Org%d-client%d-key.pem") 187 orgClientCert = filepath.Join("testdata", "certs", "Org%d-client%d-cert.pem") 188 childCACert = filepath.Join("testdata", "certs", "Org%d-child%d-cert.pem") 189 childServerKey = filepath.Join("testdata", "certs", "Org%d-child%d-server%d-key.pem") 190 childServerCert = filepath.Join("testdata", "certs", "Org%d-child%d-server%d-cert.pem") 191 childClientKey = filepath.Join("testdata", "certs", "Org%d-child%d-client%d-key.pem") 192 childClientCert = filepath.Join("testdata", "certs", "Org%d-child%d-client%d-cert.pem") 193 ) 194 195 type testServer struct { 196 config comm.ServerConfig 197 } 198 199 type serverCert struct { 200 keyPEM []byte 201 certPEM []byte 202 } 203 204 type testOrg struct { 205 rootCA []byte 206 serverCerts []serverCert 207 clientCerts []tls.Certificate 208 childOrgs []testOrg 209 } 210 211 // return *X509.CertPool for the rootCA of the org 212 func (org *testOrg) rootCertPool() *x509.CertPool { 213 certPool := x509.NewCertPool() 214 certPool.AppendCertsFromPEM(org.rootCA) 215 return certPool 216 } 217 218 // return testServers for the org 219 func (org *testOrg) testServers(clientRootCAs [][]byte) []testServer { 220 clientRootCAs = append(clientRootCAs, org.rootCA) 221 222 // loop through the serverCerts and create testServers 223 var testServers = []testServer{} 224 for _, serverCert := range org.serverCerts { 225 testServer := testServer{ 226 comm.ServerConfig{ 227 ConnectionTimeout: 250 * time.Millisecond, 228 SecOpts: comm.SecureOptions{ 229 UseTLS: true, 230 Certificate: serverCert.certPEM, 231 Key: serverCert.keyPEM, 232 RequireClientCert: true, 233 ClientRootCAs: clientRootCAs, 234 }, 235 }, 236 } 237 testServers = append(testServers, testServer) 238 } 239 return testServers 240 } 241 242 // return trusted clients for the org 243 func (org *testOrg) trustedClients(serverRootCAs [][]byte) []*tls.Config { 244 // if we have any additional server root CAs add them to the certPool 245 certPool := org.rootCertPool() 246 for _, serverRootCA := range serverRootCAs { 247 certPool.AppendCertsFromPEM(serverRootCA) 248 } 249 250 // loop through the clientCerts and create tls.Configs 251 var trustedClients = []*tls.Config{} 252 for _, clientCert := range org.clientCerts { 253 trustedClient := &tls.Config{ 254 Certificates: []tls.Certificate{clientCert}, 255 RootCAs: certPool, 256 } 257 trustedClients = append(trustedClients, trustedClient) 258 } 259 return trustedClients 260 } 261 262 // createCertPool creates an x509.CertPool from an array of PEM-encoded certificates 263 func createCertPool(rootCAs [][]byte) (*x509.CertPool, error) { 264 certPool := x509.NewCertPool() 265 for _, rootCA := range rootCAs { 266 if !certPool.AppendCertsFromPEM(rootCA) { 267 return nil, errors.New("Failed to load root certificates") 268 } 269 } 270 return certPool, nil 271 } 272 273 // utility function to load crypto material for organizations 274 func loadOrg(parent int) (testOrg, error) { 275 var org = testOrg{} 276 // load the CA 277 caPEM, err := ioutil.ReadFile(fmt.Sprintf(orgCACert, parent)) 278 if err != nil { 279 return org, err 280 } 281 282 // loop through and load servers 283 var serverCerts = []serverCert{} 284 for i := 1; i <= numServerCerts; i++ { 285 keyPEM, err := ioutil.ReadFile(fmt.Sprintf(orgServerKey, parent, i)) 286 if err != nil { 287 return org, err 288 } 289 certPEM, err := ioutil.ReadFile(fmt.Sprintf(orgServerCert, parent, i)) 290 if err != nil { 291 return org, err 292 } 293 serverCerts = append(serverCerts, serverCert{keyPEM, certPEM}) 294 } 295 296 // loop through and load clients 297 var clientCerts = []tls.Certificate{} 298 for j := 1; j <= numServerCerts; j++ { 299 clientCert, err := loadTLSKeyPairFromFile(fmt.Sprintf(orgClientKey, parent, j), 300 fmt.Sprintf(orgClientCert, parent, j)) 301 if err != nil { 302 return org, err 303 } 304 clientCerts = append(clientCerts, clientCert) 305 } 306 307 // loop through and load child orgs 308 var childOrgs = []testOrg{} 309 for k := 1; k <= numChildOrgs; k++ { 310 childOrg, err := loadChildOrg(parent, k) 311 if err != nil { 312 return org, err 313 } 314 childOrgs = append(childOrgs, childOrg) 315 } 316 317 return testOrg{caPEM, serverCerts, clientCerts, childOrgs}, nil 318 } 319 320 // utility function to load crypto material for child organizations 321 func loadChildOrg(parent, child int) (testOrg, error) { 322 // load the CA 323 caPEM, err := ioutil.ReadFile(fmt.Sprintf(childCACert, parent, child)) 324 if err != nil { 325 return testOrg{}, err 326 } 327 328 // loop through and load servers 329 var serverCerts = []serverCert{} 330 for i := 1; i <= numServerCerts; i++ { 331 keyPEM, err := ioutil.ReadFile(fmt.Sprintf(childServerKey, parent, child, i)) 332 if err != nil { 333 return testOrg{}, err 334 } 335 certPEM, err := ioutil.ReadFile(fmt.Sprintf(childServerCert, parent, child, i)) 336 if err != nil { 337 return testOrg{}, err 338 } 339 serverCerts = append(serverCerts, serverCert{keyPEM, certPEM}) 340 } 341 342 // loop through and load clients 343 var clientCerts = []tls.Certificate{} 344 for j := 1; j <= numServerCerts; j++ { 345 clientCert, err := loadTLSKeyPairFromFile( 346 fmt.Sprintf(childClientKey, parent, child, j), 347 fmt.Sprintf(childClientCert, parent, child, j), 348 ) 349 if err != nil { 350 return testOrg{}, err 351 } 352 clientCerts = append(clientCerts, clientCert) 353 } 354 355 return testOrg{caPEM, serverCerts, clientCerts, []testOrg{}}, nil 356 } 357 358 // loadTLSKeyPairFromFile creates a tls.Certificate from PEM-encoded key and cert files 359 func loadTLSKeyPairFromFile(keyFile, certFile string) (tls.Certificate, error) { 360 certPEMBlock, err := ioutil.ReadFile(certFile) 361 if err != nil { 362 return tls.Certificate{}, err 363 } 364 365 keyPEMBlock, err := ioutil.ReadFile(keyFile) 366 if err != nil { 367 return tls.Certificate{}, err 368 } 369 370 cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) 371 if err != nil { 372 return tls.Certificate{}, err 373 } 374 375 return cert, nil 376 } 377 378 func TestNewGRPCServerInvalidParameters(t *testing.T) { 379 t.Parallel() 380 381 // missing address 382 _, err := comm.NewGRPCServer( 383 "", 384 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 385 ) 386 assert.EqualError(t, err, "missing address parameter") 387 388 // missing port 389 _, err = comm.NewGRPCServer( 390 "abcdef", 391 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 392 ) 393 assert.Error(t, err, "Expected error with missing port") 394 assert.Contains(t, err.Error(), "missing port in address") 395 396 // bad port 397 _, err = comm.NewGRPCServer( 398 "127.0.0.1:1BBB", 399 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 400 ) 401 //check for possible errors based on platform and Go release 402 msgs := []string{ 403 "listen tcp: lookup tcp/1BBB: nodename nor servname provided, or not known", 404 "listen tcp: unknown port tcp/1BBB", 405 "listen tcp: address tcp/1BBB: unknown port", 406 "listen tcp: lookup tcp/1BBB: Servname not supported for ai_socktype", 407 } 408 if assert.Error(t, err, fmt.Sprintf("[%s], [%s] [%s] or [%s] expected", msgs[0], msgs[1], msgs[2], msgs[3])) { 409 assert.Contains(t, msgs, err.Error()) 410 } 411 412 // bad hostname 413 _, err = comm.NewGRPCServer( 414 "hostdoesnotexist.localdomain:9050", 415 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 416 ) 417 // We cannot check for a specific error message due to the fact that some 418 // systems will automatically resolve unknown host names to a "search" 419 // address so we just check to make sure that an error was returned 420 assert.Error(t, err, "error expected") 421 422 // address in use 423 lis, err := net.Listen("tcp", "127.0.0.1:0") 424 assert.NoError(t, err, "failed to create listener") 425 defer lis.Close() 426 427 _, err = comm.NewGRPCServerFromListener( 428 lis, 429 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 430 ) 431 assert.NoError(t, err, "failed to create grpc server") 432 433 _, err = comm.NewGRPCServer( 434 lis.Addr().String(), 435 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 436 ) 437 assert.Error(t, err) 438 assert.Contains(t, err.Error(), "address already in use") 439 440 // missing server Certificate 441 _, err = comm.NewGRPCServerFromListener( 442 lis, 443 comm.ServerConfig{ 444 SecOpts: comm.SecureOptions{UseTLS: true, Key: []byte{}}, 445 }, 446 ) 447 assert.EqualError(t, err, "serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true") 448 449 // missing server Key 450 _, err = comm.NewGRPCServerFromListener( 451 lis, 452 comm.ServerConfig{ 453 SecOpts: comm.SecureOptions{ 454 UseTLS: true, 455 Certificate: []byte{}}, 456 }, 457 ) 458 assert.EqualError(t, err, "serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true") 459 460 // bad server Key 461 _, err = comm.NewGRPCServerFromListener( 462 lis, 463 comm.ServerConfig{ 464 SecOpts: comm.SecureOptions{ 465 UseTLS: true, 466 Certificate: []byte(selfSignedCertPEM), 467 Key: []byte{}, 468 }, 469 }, 470 ) 471 assert.EqualError(t, err, "tls: failed to find any PEM data in key input") 472 473 // bad server Certificate 474 _, err = comm.NewGRPCServerFromListener( 475 lis, 476 comm.ServerConfig{ 477 SecOpts: comm.SecureOptions{ 478 UseTLS: true, 479 Certificate: []byte{}, 480 Key: []byte(selfSignedKeyPEM)}, 481 }, 482 ) 483 assert.EqualError(t, err, "tls: failed to find any PEM data in certificate input") 484 485 srv, err := comm.NewGRPCServerFromListener( 486 lis, 487 comm.ServerConfig{ 488 SecOpts: comm.SecureOptions{ 489 UseTLS: true, 490 Certificate: []byte(selfSignedCertPEM), 491 Key: []byte(selfSignedKeyPEM), 492 RequireClientCert: true}, 493 }, 494 ) 495 assert.NoError(t, err) 496 497 badRootCAs := [][]byte{[]byte(badPEM)} 498 err = srv.SetClientRootCAs(badRootCAs) 499 assert.EqualError(t, err, "failed to set client root certificate(s): asn1: syntax error: data truncated") 500 } 501 502 func TestNewGRPCServer(t *testing.T) { 503 t.Parallel() 504 505 testAddress := "127.0.0.1:9053" 506 srv, err := comm.NewGRPCServer( 507 testAddress, 508 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 509 ) 510 assert.NoError(t, err, "failed to create new GRPC server") 511 512 // resolve the address 513 addr, err := net.ResolveTCPAddr("tcp", testAddress) 514 assert.NoError(t, err) 515 516 // make sure our properties are as expected 517 assert.Equal(t, srv.Address(), addr.String()) 518 assert.Equal(t, srv.Listener().Addr().String(), addr.String()) 519 assert.Equal(t, srv.TLSEnabled(), false) 520 assert.Equal(t, srv.MutualTLSRequired(), false) 521 522 // register the GRPC test server 523 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 524 525 // start the server 526 go srv.Start() 527 defer srv.Stop() 528 529 // should not be needed 530 time.Sleep(10 * time.Millisecond) 531 532 // invoke the EmptyCall service 533 _, err = invokeEmptyCall(testAddress, grpc.WithInsecure()) 534 assert.NoError(t, err, "failed to invoke the EmptyCall service") 535 } 536 537 func TestNewGRPCServerFromListener(t *testing.T) { 538 t.Parallel() 539 540 // create our listener 541 lis, err := net.Listen("tcp", "127.0.0.1:0") 542 assert.NoError(t, err, "failed to create listener") 543 testAddress := lis.Addr().String() 544 545 srv, err := comm.NewGRPCServerFromListener( 546 lis, 547 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 548 ) 549 assert.NoError(t, err, "failed to create new GRPC server") 550 551 assert.Equal(t, srv.Address(), testAddress) 552 assert.Equal(t, srv.Listener().Addr().String(), testAddress) 553 assert.Equal(t, srv.TLSEnabled(), false) 554 assert.Equal(t, srv.MutualTLSRequired(), false) 555 556 // register the GRPC test server 557 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 558 559 // start the server 560 go srv.Start() 561 defer srv.Stop() 562 563 // should not be needed 564 time.Sleep(10 * time.Millisecond) 565 566 // invoke the EmptyCall service 567 _, err = invokeEmptyCall(testAddress, grpc.WithInsecure()) 568 assert.NoError(t, err, "client failed to invoke the EmptyCall service") 569 } 570 571 func TestNewSecureGRPCServer(t *testing.T) { 572 t.Parallel() 573 574 // create our listener 575 lis, err := net.Listen("tcp", "127.0.0.1:0") 576 assert.NoError(t, err, "failed to create listener") 577 testAddress := lis.Addr().String() 578 579 srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{ 580 ConnectionTimeout: 250 * time.Millisecond, 581 SecOpts: comm.SecureOptions{ 582 UseTLS: true, 583 Certificate: []byte(selfSignedCertPEM), 584 Key: []byte(selfSignedKeyPEM)}, 585 }, 586 ) 587 assert.NoError(t, err, "failed to create new grpc server") 588 589 // make sure our properties are as expected 590 assert.NoError(t, err) 591 assert.Equal(t, srv.Address(), testAddress) 592 assert.Equal(t, srv.Listener().Addr().String(), testAddress) 593 594 cert, _ := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM)) 595 assert.Equal(t, srv.ServerCertificate(), cert) 596 597 assert.Equal(t, srv.TLSEnabled(), true) 598 assert.Equal(t, srv.MutualTLSRequired(), false) 599 600 // register the GRPC test server 601 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 602 603 //start the server 604 go srv.Start() 605 defer srv.Stop() 606 607 // should not be needed 608 time.Sleep(10 * time.Millisecond) 609 610 // create the client credentials 611 certPool := x509.NewCertPool() 612 if !certPool.AppendCertsFromPEM([]byte(selfSignedCertPEM)) { 613 t.Fatal("Failed to append certificate to client credentials") 614 } 615 creds := credentials.NewClientTLSFromCert(certPool, "") 616 617 // invoke the EmptyCall service 618 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 619 assert.NoError(t, err, "client failed to invoke the EmptyCall service") 620 621 // Test TLS versions which should be valid 622 tlsVersions := map[string]uint16{ 623 "TLS12": tls.VersionTLS12, 624 "TLS13": tls.VersionTLS13, 625 } 626 for name, tlsVersion := range tlsVersions { 627 tlsVersion := tlsVersion 628 629 t.Run(name, func(t *testing.T) { 630 creds := credentials.NewTLS(&tls.Config{RootCAs: certPool, MinVersion: tlsVersion, MaxVersion: tlsVersion}) 631 _, err := invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds), grpc.WithBlock()) 632 assert.NoError(t, err) 633 }) 634 } 635 636 // Test TLS versions which should be invalid 637 tlsVersions = map[string]uint16{ 638 "SSL30": tls.VersionSSL30, 639 "TLS10": tls.VersionTLS10, 640 "TLS11": tls.VersionTLS11, 641 } 642 for name, tlsVersion := range tlsVersions { 643 tlsVersion := tlsVersion 644 t.Run(name, func(t *testing.T) { 645 t.Parallel() 646 647 creds := credentials.NewTLS(&tls.Config{RootCAs: certPool, MinVersion: tlsVersion, MaxVersion: tlsVersion}) 648 _, err := invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds), grpc.WithBlock()) 649 assert.Error(t, err, "should not have been able to connect with TLS version < 1.2") 650 assert.Contains(t, err.Error(), "context deadline exceeded") 651 }) 652 } 653 } 654 655 func TestVerifyCertificateCallback(t *testing.T) { 656 t.Parallel() 657 658 ca, err := tlsgen.NewCA() 659 assert.NoError(t, err) 660 661 authorizedClientKeyPair, err := ca.NewClientCertKeyPair() 662 assert.NoError(t, err) 663 664 notAuthorizedClientKeyPair, err := ca.NewClientCertKeyPair() 665 assert.NoError(t, err) 666 667 serverKeyPair, err := ca.NewServerCertKeyPair("127.0.0.1") 668 assert.NoError(t, err) 669 670 verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 671 if bytes.Equal(rawCerts[0], authorizedClientKeyPair.TLSCert.Raw) { 672 return nil 673 } 674 return errors.New("certificate mismatch") 675 } 676 677 probeTLS := func(endpoint string, clientKeyPair *tlsgen.CertKeyPair) error { 678 cert, err := tls.X509KeyPair(clientKeyPair.Cert, clientKeyPair.Key) 679 if err != nil { 680 return err 681 } 682 tlsCfg := &tls.Config{ 683 Certificates: []tls.Certificate{cert}, 684 RootCAs: x509.NewCertPool(), 685 MinVersion: tls.VersionTLS12, 686 MaxVersion: tls.VersionTLS12, 687 } 688 tlsCfg.RootCAs.AppendCertsFromPEM(ca.CertBytes()) 689 690 conn, err := tls.Dial("tcp", endpoint, tlsCfg) 691 if err != nil { 692 return err 693 } 694 conn.Close() 695 return nil 696 } 697 698 gRPCServer, err := comm.NewGRPCServer("127.0.0.1:", comm.ServerConfig{ 699 SecOpts: comm.SecureOptions{ 700 ClientRootCAs: [][]byte{ca.CertBytes()}, 701 Key: serverKeyPair.Key, 702 Certificate: serverKeyPair.Cert, 703 UseTLS: true, 704 VerifyCertificate: verifyFunc, 705 }, 706 }) 707 go gRPCServer.Start() 708 defer gRPCServer.Stop() 709 710 t.Run("Success path", func(t *testing.T) { 711 err = probeTLS(gRPCServer.Address(), authorizedClientKeyPair) 712 assert.NoError(t, err) 713 }) 714 715 t.Run("Failure path", func(t *testing.T) { 716 err = probeTLS(gRPCServer.Address(), notAuthorizedClientKeyPair) 717 assert.EqualError(t, err, "remote error: tls: bad certificate") 718 }) 719 } 720 721 // prior tests used self-signed certficates loaded by the GRPCServer and the test client 722 // here we'll use certificates signed by certificate authorities 723 func TestWithSignedRootCertificates(t *testing.T) { 724 t.Parallel() 725 726 // use Org1 testdata 727 fileBase := "Org1" 728 certPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-server1-cert.pem")) 729 assert.NoError(t, err, "failed to load test certificates") 730 731 keyPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-server1-key.pem")) 732 assert.NoError(t, err, "failed to load test certificates: %v") 733 734 caPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-cert.pem")) 735 assert.NoError(t, err, "failed to load test certificates") 736 737 // create our listener 738 lis, err := net.Listen("tcp", "127.0.0.1:0") 739 assert.NoError(t, err, "failed to create listener") 740 testAddress := lis.Addr().String() 741 742 srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{ 743 SecOpts: comm.SecureOptions{ 744 UseTLS: true, 745 Certificate: certPEMBlock, 746 Key: keyPEMBlock, 747 }, 748 }) 749 assert.NoError(t, err, "failed to create new grpc server") 750 // register the GRPC test server 751 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 752 753 //start the server 754 go srv.Start() 755 defer srv.Stop() 756 757 // should not be needed 758 time.Sleep(10 * time.Millisecond) 759 760 // create a CertPool for use by the client with the server cert only 761 certPoolServer, err := createCertPool([][]byte{certPEMBlock}) 762 assert.NoError(t, err, "failed to load root certificates into pool") 763 creds := credentials.NewClientTLSFromCert(certPoolServer, "") 764 765 // invoke the EmptyCall service 766 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 767 assert.NoError(t, err, "Expected client to connect with server cert only") 768 769 // now use the CA certificate 770 certPoolCA := x509.NewCertPool() 771 if !certPoolCA.AppendCertsFromPEM(caPEMBlock) { 772 t.Fatal("Failed to append certificate to client credentials") 773 } 774 creds = credentials.NewClientTLSFromCert(certPoolCA, "") 775 776 // invoke the EmptyCall service 777 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 778 assert.NoError(t, err, "client failed to invoke the EmptyCall") 779 } 780 781 // here we'll use certificates signed by intermediate certificate authorities 782 func TestWithSignedIntermediateCertificates(t *testing.T) { 783 t.Parallel() 784 785 // use Org1 testdata 786 fileBase := "Org1" 787 certPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-child1-server1-cert.pem")) 788 assert.NoError(t, err) 789 790 keyPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-child1-server1-key.pem")) 791 assert.NoError(t, err) 792 793 intermediatePEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-child1-cert.pem")) 794 if err != nil { 795 t.Fatalf("Failed to load test certificates: %v", err) 796 } 797 798 // create our listener 799 lis, err := net.Listen("tcp", "127.0.0.1:0") 800 if err != nil { 801 t.Fatalf("Failed to create listener: %v", err) 802 } 803 testAddress := lis.Addr().String() 804 805 srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{ 806 SecOpts: comm.SecureOptions{ 807 UseTLS: true, 808 Certificate: certPEMBlock, 809 Key: keyPEMBlock}}) 810 // check for error 811 if err != nil { 812 t.Fatalf("Failed to return new GRPC server: %v", err) 813 } 814 815 // register the GRPC test server 816 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 817 818 // start the server 819 go srv.Start() 820 defer srv.Stop() 821 822 // should not be needed 823 time.Sleep(10 * time.Millisecond) 824 825 // create a CertPool for use by the client with the server cert only 826 certPoolServer, err := createCertPool([][]byte{certPEMBlock}) 827 if err != nil { 828 t.Fatalf("Failed to load root certificates into pool: %v", err) 829 } 830 // create the client credentials 831 creds := credentials.NewClientTLSFromCert(certPoolServer, "") 832 833 // invoke the EmptyCall service 834 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 835 836 // client should be able to connect with Go 1.9 837 assert.NoError(t, err, "Expected client to connect with server cert only") 838 839 // now use the CA certificate 840 // create a CertPool for use by the client with the intermediate root CA 841 certPoolCA, err := createCertPool([][]byte{intermediatePEMBlock}) 842 assert.NoError(t, err, "failed to load root certificates into pool") 843 844 creds = credentials.NewClientTLSFromCert(certPoolCA, "") 845 846 // invoke the EmptyCall service 847 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 848 assert.NoError(t, err, "client failed to invoke the EmptyCall service") 849 } 850 851 // utility function for testing client / server communication using TLS 852 func runMutualAuth(t *testing.T, servers []testServer, trustedClients, unTrustedClients []*tls.Config) error { 853 // loop through all the test servers 854 for i := 0; i < len(servers); i++ { 855 //create listener 856 lis, err := net.Listen("tcp", "127.0.0.1:0") 857 if err != nil { 858 return err 859 } 860 srvAddr := lis.Addr().String() 861 862 // create GRPCServer 863 srv, err := comm.NewGRPCServerFromListener(lis, servers[i].config) 864 if err != nil { 865 return err 866 } 867 868 // MutualTLSRequired should be true 869 assert.Equal(t, srv.MutualTLSRequired(), true) 870 871 //register the GRPC test server and start the GRPCServer 872 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 873 go srv.Start() 874 defer srv.Stop() 875 876 // should not be needed but just in case 877 time.Sleep(10 * time.Millisecond) 878 879 // loop through all the trusted clients 880 for j := 0; j < len(trustedClients); j++ { 881 // invoke the EmptyCall service 882 _, err = invokeEmptyCall(srvAddr, grpc.WithTransportCredentials(credentials.NewTLS(trustedClients[j]))) 883 // we expect success from trusted clients 884 if err != nil { 885 return err 886 } else { 887 t.Logf("Trusted client%d successfully connected to %s", j, srvAddr) 888 } 889 } 890 891 // loop through all the untrusted clients 892 for k := 0; k < len(unTrustedClients); k++ { 893 // invoke the EmptyCall service 894 _, err = invokeEmptyCall( 895 srvAddr, 896 grpc.WithTransportCredentials(credentials.NewTLS(unTrustedClients[k])), 897 ) 898 // we expect failure from untrusted clients 899 if err != nil { 900 t.Logf("Untrusted client%d was correctly rejected by %s", k, srvAddr) 901 } else { 902 return fmt.Errorf("Untrusted client %d should not have been able to connect to %s", k, srvAddr) 903 } 904 } 905 } 906 907 return nil 908 } 909 910 func TestMutualAuth(t *testing.T) { 911 t.Parallel() 912 913 var tests = []struct { 914 name string 915 servers []testServer 916 trustedClients []*tls.Config 917 unTrustedClients []*tls.Config 918 }{ 919 { 920 name: "ClientAuthRequiredWithSingleOrg", 921 servers: testOrgs[0].testServers([][]byte{}), 922 trustedClients: testOrgs[0].trustedClients([][]byte{}), 923 unTrustedClients: testOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 924 }, 925 { 926 name: "ClientAuthRequiredWithChildClientOrg", 927 servers: testOrgs[0].testServers([][]byte{testOrgs[0].childOrgs[0].rootCA}), 928 trustedClients: testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA}), 929 unTrustedClients: testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 930 }, 931 { 932 name: "ClientAuthRequiredWithMultipleChildClientOrgs", 933 servers: testOrgs[0].testServers(append([][]byte{}, 934 testOrgs[0].childOrgs[0].rootCA, 935 testOrgs[0].childOrgs[1].rootCA, 936 )), 937 trustedClients: append(append([]*tls.Config{}, 938 testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA})...), 939 testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA})...), 940 unTrustedClients: testOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 941 }, 942 { 943 name: "ClientAuthRequiredWithDifferentServerAndClientOrgs", 944 servers: testOrgs[0].testServers([][]byte{testOrgs[1].rootCA}), 945 trustedClients: testOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 946 unTrustedClients: testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 947 }, 948 { 949 name: "ClientAuthRequiredWithDifferentServerAndChildClientOrgs", 950 servers: testOrgs[1].testServers([][]byte{testOrgs[0].childOrgs[0].rootCA}), 951 trustedClients: testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[1].rootCA}), 952 unTrustedClients: testOrgs[1].childOrgs[0].trustedClients([][]byte{testOrgs[1].rootCA}), 953 }, 954 } 955 956 for _, test := range tests { 957 test := test 958 t.Run(test.name, func(t *testing.T) { 959 t.Parallel() 960 t.Logf("Running test %s ...", test.name) 961 testErr := runMutualAuth(t, test.servers, test.trustedClients, test.unTrustedClients) 962 assert.NoError(t, testErr) 963 }) 964 } 965 } 966 967 func TestSetClientRootCAs(t *testing.T) { 968 t.Parallel() 969 970 // get the config for one of our Org1 test servers 971 serverConfig := testOrgs[0].testServers([][]byte{})[0].config 972 lis, err := net.Listen("tcp", "127.0.0.1:0") 973 assert.NoError(t, err, "listen failed") 974 defer lis.Close() 975 address := lis.Addr().String() 976 977 // create a GRPCServer 978 srv, err := comm.NewGRPCServerFromListener(lis, serverConfig) 979 assert.NoError(t, err, "failed to create GRPCServer") 980 981 // register the GRPC test server and start the GRPCServer 982 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 983 go srv.Start() 984 defer srv.Stop() 985 986 // should not be needed but just in case 987 time.Sleep(10 * time.Millisecond) 988 989 // set up our test clients 990 // Org1 991 clientConfigOrg1Child1 := testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA})[0] 992 clientConfigOrg1Child2 := testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA})[0] 993 clientConfigsOrg1Children := []*tls.Config{clientConfigOrg1Child1, clientConfigOrg1Child2} 994 org1ChildRootCAs := [][]byte{testOrgs[0].childOrgs[0].rootCA, testOrgs[0].childOrgs[1].rootCA} 995 // Org2 996 clientConfigOrg2Child1 := testOrgs[1].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA})[0] 997 clientConfigOrg2Child2 := testOrgs[1].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA})[0] 998 clientConfigsOrg2Children := []*tls.Config{clientConfigOrg2Child1, clientConfigOrg2Child2} 999 org2ChildRootCAs := [][]byte{testOrgs[1].childOrgs[0].rootCA, testOrgs[1].childOrgs[1].rootCA} 1000 1001 // initially set client CAs to Org1 children 1002 err = srv.SetClientRootCAs(org1ChildRootCAs) 1003 assert.NoError(t, err, "SetClientRootCAs failed") 1004 1005 // clientConfigsOrg1Children are currently trusted 1006 for _, clientConfig := range clientConfigsOrg1Children { 1007 // we expect success as these are trusted clients 1008 _, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig))) 1009 assert.NoError(t, err, "trusted client should have connected") 1010 } 1011 1012 // clientConfigsOrg2Children are currently not trusted 1013 for _, clientConfig := range clientConfigsOrg2Children { 1014 // we expect failure as these are now untrusted clients 1015 _, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig))) 1016 assert.Error(t, err, "untrusted client should not have been able to connect") 1017 } 1018 1019 // now set client CAs to Org2 children 1020 err = srv.SetClientRootCAs(org2ChildRootCAs) 1021 assert.NoError(t, err, "SetClientRootCAs failed") 1022 1023 // now reverse trusted and not trusted 1024 // clientConfigsOrg1Children are currently trusted 1025 for _, clientConfig := range clientConfigsOrg2Children { 1026 // we expect success as these are trusted clients 1027 _, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig))) 1028 assert.NoError(t, err, "trusted client should have connected") 1029 } 1030 1031 // clientConfigsOrg2Children are currently not trusted 1032 for _, clientConfig := range clientConfigsOrg1Children { 1033 // we expect failure as these are now untrusted clients 1034 _, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig))) 1035 assert.Error(t, err, "untrusted client should not have connected") 1036 } 1037 } 1038 1039 func TestUpdateTLSCert(t *testing.T) { 1040 t.Parallel() 1041 1042 readFile := func(path string) []byte { 1043 fName := filepath.Join("testdata", "dynamic_cert_update", path) 1044 data, err := ioutil.ReadFile(fName) 1045 if err != nil { 1046 panic(fmt.Errorf("Failed reading %s: %v", fName, err)) 1047 } 1048 return data 1049 } 1050 loadBytes := func(prefix string) (key, cert, caCert []byte) { 1051 cert = readFile(filepath.Join(prefix, "server.crt")) 1052 key = readFile(filepath.Join(prefix, "server.key")) 1053 caCert = readFile(filepath.Join("ca.crt")) 1054 return 1055 } 1056 1057 key, cert, caCert := loadBytes("notlocalhost") 1058 1059 cfg := comm.ServerConfig{ 1060 SecOpts: comm.SecureOptions{ 1061 UseTLS: true, 1062 Key: key, 1063 Certificate: cert, 1064 }, 1065 } 1066 1067 // create our listener 1068 lis, err := net.Listen("tcp", "127.0.0.1:0") 1069 assert.NoError(t, err, "listen failed") 1070 testAddress := lis.Addr().String() 1071 1072 srv, err := comm.NewGRPCServerFromListener(lis, cfg) 1073 assert.NoError(t, err) 1074 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 1075 1076 go srv.Start() 1077 defer srv.Stop() 1078 1079 certPool := x509.NewCertPool() 1080 certPool.AppendCertsFromPEM(caCert) 1081 1082 probeServer := func() error { 1083 _, err = invokeEmptyCall( 1084 testAddress, 1085 grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{RootCAs: certPool})), 1086 grpc.WithBlock(), 1087 ) 1088 return err 1089 } 1090 1091 // bootstrap TLS certificate has a SAN of "notlocalhost" so it should fail 1092 err = probeServer() 1093 assert.Error(t, err) 1094 assert.Contains(t, err.Error(), "context deadline exceeded") 1095 1096 // new TLS certificate has a SAN of "127.0.0.1" so it should succeed 1097 certPath := filepath.Join("testdata", "dynamic_cert_update", "localhost", "server.crt") 1098 keyPath := filepath.Join("testdata", "dynamic_cert_update", "localhost", "server.key") 1099 tlsCert, err := tls.LoadX509KeyPair(certPath, keyPath) 1100 assert.NoError(t, err) 1101 srv.SetServerCertificate(tlsCert) 1102 err = probeServer() 1103 assert.NoError(t, err) 1104 1105 // revert back to the old certificate, should fail. 1106 certPath = filepath.Join("testdata", "dynamic_cert_update", "notlocalhost", "server.crt") 1107 keyPath = filepath.Join("testdata", "dynamic_cert_update", "notlocalhost", "server.key") 1108 tlsCert, err = tls.LoadX509KeyPair(certPath, keyPath) 1109 assert.NoError(t, err) 1110 srv.SetServerCertificate(tlsCert) 1111 1112 err = probeServer() 1113 assert.Error(t, err) 1114 assert.Contains(t, err.Error(), "context deadline exceeded") 1115 } 1116 1117 func TestCipherSuites(t *testing.T) { 1118 t.Parallel() 1119 1120 certPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-server1-cert.pem")) 1121 assert.NoError(t, err) 1122 keyPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-server1-key.pem")) 1123 assert.NoError(t, err) 1124 caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem")) 1125 assert.NoError(t, err) 1126 certPool, err := createCertPool([][]byte{caPEM}) 1127 assert.NoError(t, err) 1128 1129 serverConfig := comm.ServerConfig{ 1130 SecOpts: comm.SecureOptions{ 1131 Certificate: certPEM, 1132 Key: keyPEM, 1133 UseTLS: true, 1134 }} 1135 1136 fabricDefaultCipherSuite := func(cipher uint16) bool { 1137 for _, defaultCipher := range comm.DefaultTLSCipherSuites { 1138 if cipher == defaultCipher { 1139 return true 1140 } 1141 } 1142 return false 1143 } 1144 1145 var otherCipherSuites []uint16 1146 for _, cipher := range append(tls.CipherSuites(), tls.InsecureCipherSuites()...) { 1147 if !fabricDefaultCipherSuite(cipher.ID) { 1148 otherCipherSuites = append(otherCipherSuites, cipher.ID) 1149 } 1150 } 1151 1152 var tests = []struct { 1153 name string 1154 clientCiphers []uint16 1155 success bool 1156 versions []uint16 1157 }{ 1158 { 1159 name: "server default / client all", 1160 success: true, 1161 versions: []uint16{tls.VersionTLS12, tls.VersionTLS13}, 1162 }, 1163 { 1164 name: "server default / client match", 1165 clientCiphers: comm.DefaultTLSCipherSuites, 1166 success: true, 1167 // Skip TLS1.3 as it ignores the Fabric DefaultCipherSuites 1168 // https://github.com/golang/go/issues/29349 1169 versions: []uint16{tls.VersionTLS12}, 1170 }, 1171 { 1172 name: "server default / client no match", 1173 clientCiphers: otherCipherSuites, 1174 success: false, 1175 // Skip TLS1.3 as it ignores the Fabric DefaultCipherSuites 1176 // https://github.com/golang/go/issues/29349 1177 versions: []uint16{tls.VersionTLS12}, 1178 }, 1179 } 1180 1181 // create our listener 1182 lis, err := net.Listen("tcp", "127.0.0.1:0") 1183 assert.NoError(t, err, "listen failed") 1184 testAddress := lis.Addr().String() 1185 srv, err := comm.NewGRPCServerFromListener(lis, serverConfig) 1186 assert.NoError(t, err) 1187 go srv.Start() 1188 1189 for _, test := range tests { 1190 test := test 1191 t.Run(test.name, func(t *testing.T) { 1192 t.Parallel() 1193 1194 for _, tlsVersion := range test.versions { 1195 tlsConfig := &tls.Config{ 1196 RootCAs: certPool, 1197 CipherSuites: test.clientCiphers, 1198 MinVersion: tlsVersion, 1199 MaxVersion: tlsVersion, 1200 } 1201 _, err := tls.Dial("tcp", testAddress, tlsConfig) 1202 if test.success { 1203 assert.NoError(t, err) 1204 } else { 1205 assert.Error(t, err, "expected handshake failure") 1206 assert.Contains(t, err.Error(), "handshake failure") 1207 } 1208 } 1209 }) 1210 } 1211 } 1212 1213 func TestServerInterceptors(t *testing.T) { 1214 lis, err := net.Listen("tcp", "127.0.0.1:0") 1215 assert.NoError(t, err, "listen failed") 1216 msg := "error from interceptor" 1217 1218 // set up interceptors 1219 usiCount := uint32(0) 1220 ssiCount := uint32(0) 1221 usi1 := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 1222 atomic.AddUint32(&usiCount, 1) 1223 return handler(ctx, req) 1224 } 1225 usi2 := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 1226 atomic.AddUint32(&usiCount, 1) 1227 return nil, status.Error(codes.Aborted, msg) 1228 } 1229 ssi1 := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 1230 atomic.AddUint32(&ssiCount, 1) 1231 return handler(srv, ss) 1232 } 1233 ssi2 := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 1234 atomic.AddUint32(&ssiCount, 1) 1235 return status.Error(codes.Aborted, msg) 1236 } 1237 1238 srvConfig := comm.ServerConfig{} 1239 srvConfig.UnaryInterceptors = append(srvConfig.UnaryInterceptors, usi1) 1240 srvConfig.UnaryInterceptors = append(srvConfig.UnaryInterceptors, usi2) 1241 srvConfig.StreamInterceptors = append(srvConfig.StreamInterceptors, ssi1) 1242 srvConfig.StreamInterceptors = append(srvConfig.StreamInterceptors, ssi2) 1243 1244 srv, err := comm.NewGRPCServerFromListener(lis, srvConfig) 1245 assert.NoError(t, err, "failed to create gRPC server") 1246 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 1247 defer srv.Stop() 1248 go srv.Start() 1249 1250 _, err = invokeEmptyCall( 1251 lis.Addr().String(), 1252 grpc.WithBlock(), 1253 grpc.WithInsecure(), 1254 ) 1255 assert.Error(t, err) 1256 assert.Equal(t, status.Convert(err).Message(), msg, "Expected error from second usi") 1257 assert.Equal(t, uint32(2), atomic.LoadUint32(&usiCount), "Expected both usi handlers to be invoked") 1258 1259 _, err = invokeEmptyStream( 1260 lis.Addr().String(), 1261 grpc.WithBlock(), 1262 grpc.WithInsecure(), 1263 ) 1264 assert.Error(t, err) 1265 assert.Equal(t, status.Convert(err).Message(), msg, "Expected error from second ssi") 1266 assert.Equal(t, uint32(2), atomic.LoadUint32(&ssiCount), "Expected both ssi handlers to be invoked") 1267 }