github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/server_test.go (about) 1 /* 2 Copyright hechain. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package comm_test 8 9 import ( 10 "bytes" 11 "context" 12 "crypto/tls" 13 "crypto/x509" 14 "fmt" 15 "io" 16 "io/ioutil" 17 "log" 18 "net" 19 "path/filepath" 20 "sync/atomic" 21 "testing" 22 "time" 23 24 "github.com/hechain20/hechain/common/crypto/tlsgen" 25 "github.com/hechain20/hechain/internal/pkg/comm" 26 "github.com/hechain20/hechain/internal/pkg/comm/testpb" 27 "github.com/pkg/errors" 28 "github.com/stretchr/testify/require" 29 "google.golang.org/grpc" 30 "google.golang.org/grpc/codes" 31 "google.golang.org/grpc/credentials" 32 "google.golang.org/grpc/status" 33 ) 34 35 // Embedded certificates for testing 36 // The self-signed cert expires in 2028 37 var selfSignedKeyPEM = `-----BEGIN EC PRIVATE KEY----- 38 MHcCAQEEIMLemLh3+uDzww1pvqP6Xj2Z0Kc6yqf3RxyfTBNwRuuyoAoGCCqGSM49 39 AwEHoUQDQgAEDB3l94vM7EqKr2L/vhqU5IsEub0rviqCAaWGiVAPp3orb/LJqFLS 40 yo/k60rhUiir6iD4S4pb5TEb2ouWylQI3A== 41 -----END EC PRIVATE KEY----- 42 ` 43 44 var selfSignedCertPEM = `-----BEGIN CERTIFICATE----- 45 MIIBdDCCARqgAwIBAgIRAKCiW5r6W32jGUn+l9BORMAwCgYIKoZIzj0EAwIwEjEQ 46 MA4GA1UEChMHQWNtZSBDbzAeFw0xODA4MjExMDI1MzJaFw0yODA4MTgxMDI1MzJa 47 MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQM 48 HeX3i8zsSoqvYv++GpTkiwS5vSu+KoIBpYaJUA+neitv8smoUtLKj+TrSuFSKKvq 49 IPhLilvlMRvai5bKVAjco1EwTzAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYI 50 KwYBBQUHAwEwDAYDVR0TAQH/BAIwADAaBgNVHREEEzARgglsb2NhbGhvc3SHBH8A 51 AAEwCgYIKoZIzj0EAwIDSAAwRQIgOaYc3pdGf2j0uXRyvdBJq2PlK9FkgvsUjXOT 52 bQ9fWRkCIQCr1FiRRzapgtrnttDn3O2fhLlbrw67kClzY8pIIN42Qw== 53 -----END CERTIFICATE----- 54 ` 55 56 var badPEM = `-----BEGIN CERTIFICATE----- 57 MIICRDCCAemgAwIBAgIJALwW//dz2ZBvMAoGCCqGSM49BAMCMH4xCzAJBgNVBAYT 58 AlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1TYW4gRnJhbmNpc2Nv 59 MRgwFgYDVQQKDA9MaW51eEZvdW5kYXRpb24xFDASBgNVBAsMC0h5cGVybGVkZ2Vy 60 MRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMTYxMjA0MjIzMDE4WhcNMjYxMjAyMjIz 61 MDE4WjB+MQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UE 62 BwwNU2FuIEZyYW5jaXNjbzEYMBYGA1UECgwPTGludXhGb3VuZGF0aW9uMRQwEgYD 63 VQQLDAtIeXBlcmxlZGdlcjESMBAGA1UEAwwJbG9jYWxob3N0MFkwEwYHKoZIzj0C 64 -----END CERTIFICATE----- 65 ` 66 67 var testOrgs = []testOrg{} 68 69 func init() { 70 // load up crypto material for test orgs 71 for i := 1; i <= numOrgs; i++ { 72 testOrg, err := loadOrg(i) 73 if err != nil { 74 log.Fatalf("Failed to load test organizations due to error: %s", err.Error()) 75 } 76 testOrgs = append(testOrgs, testOrg) 77 } 78 } 79 80 // test servers to be registered with the GRPCServer 81 type emptyServiceServer struct{} 82 83 func (ess *emptyServiceServer) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, error) { 84 return new(testpb.Empty), nil 85 } 86 87 func (esss *emptyServiceServer) EmptyStream(stream testpb.EmptyService_EmptyStreamServer) error { 88 for { 89 _, err := stream.Recv() 90 if err == io.EOF { 91 return nil 92 } 93 if err != nil { 94 return err 95 } 96 if err := stream.Send(&testpb.Empty{}); err != nil { 97 return err 98 } 99 100 } 101 } 102 103 // invoke the EmptyCall RPC 104 func invokeEmptyCall(address string, dialOptions ...grpc.DialOption) (*testpb.Empty, error) { 105 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 106 defer cancel() 107 // create GRPC client conn 108 clientConn, err := grpc.DialContext(ctx, address, dialOptions...) 109 if err != nil { 110 return nil, err 111 } 112 defer clientConn.Close() 113 114 // create GRPC client 115 client := testpb.NewEmptyServiceClient(clientConn) 116 117 // invoke service 118 empty, err := client.EmptyCall(context.Background(), new(testpb.Empty)) 119 if err != nil { 120 return nil, err 121 } 122 123 return empty, nil 124 } 125 126 // invoke the EmptyStream RPC 127 func invokeEmptyStream(address string, dialOptions ...grpc.DialOption) (*testpb.Empty, error) { 128 ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 129 defer cancel() 130 // create GRPC client conn 131 clientConn, err := grpc.DialContext(ctx, address, dialOptions...) 132 if err != nil { 133 return nil, err 134 } 135 defer clientConn.Close() 136 137 stream, err := testpb.NewEmptyServiceClient(clientConn).EmptyStream(ctx) 138 if err != nil { 139 return nil, err 140 } 141 142 var msg *testpb.Empty 143 var streamErr error 144 145 waitc := make(chan struct{}) 146 go func() { 147 for { 148 in, err := stream.Recv() 149 if err == io.EOF { 150 close(waitc) 151 return 152 } 153 if err != nil { 154 streamErr = err 155 close(waitc) 156 return 157 } 158 msg = in 159 } 160 }() 161 162 // TestServerInterceptors adds an interceptor that does not call the target 163 // StreamHandler and returns an error so Send can return with an io.EOF since 164 // the server side has already terminated. Whether or not we get an error 165 // depends on timing. 166 err = stream.Send(&testpb.Empty{}) 167 if err != nil && err != io.EOF { 168 return nil, fmt.Errorf("stream send failed: %s", err) 169 } 170 171 stream.CloseSend() 172 <-waitc 173 return msg, streamErr 174 } 175 176 const ( 177 numOrgs = 2 178 numChildOrgs = 2 179 numServerCerts = 2 180 ) 181 182 // string for cert filenames 183 var ( 184 orgCACert = filepath.Join("testdata", "certs", "Org%d-cert.pem") 185 orgServerKey = filepath.Join("testdata", "certs", "Org%d-server%d-key.pem") 186 orgServerCert = filepath.Join("testdata", "certs", "Org%d-server%d-cert.pem") 187 orgClientKey = filepath.Join("testdata", "certs", "Org%d-client%d-key.pem") 188 orgClientCert = filepath.Join("testdata", "certs", "Org%d-client%d-cert.pem") 189 childCACert = filepath.Join("testdata", "certs", "Org%d-child%d-cert.pem") 190 childServerKey = filepath.Join("testdata", "certs", "Org%d-child%d-server%d-key.pem") 191 childServerCert = filepath.Join("testdata", "certs", "Org%d-child%d-server%d-cert.pem") 192 childClientKey = filepath.Join("testdata", "certs", "Org%d-child%d-client%d-key.pem") 193 childClientCert = filepath.Join("testdata", "certs", "Org%d-child%d-client%d-cert.pem") 194 ) 195 196 type testServer struct { 197 config comm.ServerConfig 198 } 199 200 type serverCert struct { 201 keyPEM []byte 202 certPEM []byte 203 } 204 205 type testOrg struct { 206 rootCA []byte 207 serverCerts []serverCert 208 clientCerts []tls.Certificate 209 childOrgs []testOrg 210 } 211 212 // return *X509.CertPool for the rootCA of the org 213 func (org *testOrg) rootCertPool() *x509.CertPool { 214 certPool := x509.NewCertPool() 215 certPool.AppendCertsFromPEM(org.rootCA) 216 return certPool 217 } 218 219 // return testServers for the org 220 func (org *testOrg) testServers(clientRootCAs [][]byte) []testServer { 221 clientRootCAs = append(clientRootCAs, org.rootCA) 222 223 // loop through the serverCerts and create testServers 224 testServers := []testServer{} 225 for _, serverCert := range org.serverCerts { 226 testServer := testServer{ 227 comm.ServerConfig{ 228 ConnectionTimeout: 250 * time.Millisecond, 229 SecOpts: comm.SecureOptions{ 230 UseTLS: true, 231 Certificate: serverCert.certPEM, 232 Key: serverCert.keyPEM, 233 RequireClientCert: true, 234 ClientRootCAs: clientRootCAs, 235 }, 236 }, 237 } 238 testServers = append(testServers, testServer) 239 } 240 return testServers 241 } 242 243 // return trusted clients for the org 244 func (org *testOrg) trustedClients(serverRootCAs [][]byte) []*tls.Config { 245 // if we have any additional server root CAs add them to the certPool 246 certPool := org.rootCertPool() 247 for _, serverRootCA := range serverRootCAs { 248 certPool.AppendCertsFromPEM(serverRootCA) 249 } 250 251 // loop through the clientCerts and create tls.Configs 252 trustedClients := []*tls.Config{} 253 for _, clientCert := range org.clientCerts { 254 trustedClient := &tls.Config{ 255 Certificates: []tls.Certificate{clientCert}, 256 RootCAs: certPool, 257 } 258 trustedClients = append(trustedClients, trustedClient) 259 } 260 return trustedClients 261 } 262 263 // createCertPool creates an x509.CertPool from an array of PEM-encoded certificates 264 func createCertPool(rootCAs [][]byte) (*x509.CertPool, error) { 265 certPool := x509.NewCertPool() 266 for _, rootCA := range rootCAs { 267 if !certPool.AppendCertsFromPEM(rootCA) { 268 return nil, errors.New("Failed to load root certificates") 269 } 270 } 271 return certPool, nil 272 } 273 274 // utility function to load crypto material for organizations 275 func loadOrg(parent int) (testOrg, error) { 276 org := testOrg{} 277 // load the CA 278 caPEM, err := ioutil.ReadFile(fmt.Sprintf(orgCACert, parent)) 279 if err != nil { 280 return org, err 281 } 282 283 // loop through and load servers 284 serverCerts := []serverCert{} 285 for i := 1; i <= numServerCerts; i++ { 286 keyPEM, err := ioutil.ReadFile(fmt.Sprintf(orgServerKey, parent, i)) 287 if err != nil { 288 return org, err 289 } 290 certPEM, err := ioutil.ReadFile(fmt.Sprintf(orgServerCert, parent, i)) 291 if err != nil { 292 return org, err 293 } 294 serverCerts = append(serverCerts, serverCert{keyPEM, certPEM}) 295 } 296 297 // loop through and load clients 298 clientCerts := []tls.Certificate{} 299 for j := 1; j <= numServerCerts; j++ { 300 clientCert, err := loadTLSKeyPairFromFile(fmt.Sprintf(orgClientKey, parent, j), 301 fmt.Sprintf(orgClientCert, parent, j)) 302 if err != nil { 303 return org, err 304 } 305 clientCerts = append(clientCerts, clientCert) 306 } 307 308 // loop through and load child orgs 309 childOrgs := []testOrg{} 310 for k := 1; k <= numChildOrgs; k++ { 311 childOrg, err := loadChildOrg(parent, k) 312 if err != nil { 313 return org, err 314 } 315 childOrgs = append(childOrgs, childOrg) 316 } 317 318 return testOrg{caPEM, serverCerts, clientCerts, childOrgs}, nil 319 } 320 321 // utility function to load crypto material for child organizations 322 func loadChildOrg(parent, child int) (testOrg, error) { 323 // load the CA 324 caPEM, err := ioutil.ReadFile(fmt.Sprintf(childCACert, parent, child)) 325 if err != nil { 326 return testOrg{}, err 327 } 328 329 // loop through and load servers 330 serverCerts := []serverCert{} 331 for i := 1; i <= numServerCerts; i++ { 332 keyPEM, err := ioutil.ReadFile(fmt.Sprintf(childServerKey, parent, child, i)) 333 if err != nil { 334 return testOrg{}, err 335 } 336 certPEM, err := ioutil.ReadFile(fmt.Sprintf(childServerCert, parent, child, i)) 337 if err != nil { 338 return testOrg{}, err 339 } 340 serverCerts = append(serverCerts, serverCert{keyPEM, certPEM}) 341 } 342 343 // loop through and load clients 344 clientCerts := []tls.Certificate{} 345 for j := 1; j <= numServerCerts; j++ { 346 clientCert, err := loadTLSKeyPairFromFile( 347 fmt.Sprintf(childClientKey, parent, child, j), 348 fmt.Sprintf(childClientCert, parent, child, j), 349 ) 350 if err != nil { 351 return testOrg{}, err 352 } 353 clientCerts = append(clientCerts, clientCert) 354 } 355 356 return testOrg{caPEM, serverCerts, clientCerts, []testOrg{}}, nil 357 } 358 359 // loadTLSKeyPairFromFile creates a tls.Certificate from PEM-encoded key and cert files 360 func loadTLSKeyPairFromFile(keyFile, certFile string) (tls.Certificate, error) { 361 certPEMBlock, err := ioutil.ReadFile(certFile) 362 if err != nil { 363 return tls.Certificate{}, err 364 } 365 366 keyPEMBlock, err := ioutil.ReadFile(keyFile) 367 if err != nil { 368 return tls.Certificate{}, err 369 } 370 371 cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) 372 if err != nil { 373 return tls.Certificate{}, err 374 } 375 376 return cert, nil 377 } 378 379 func TestNewGRPCServerInvalidParameters(t *testing.T) { 380 t.Parallel() 381 382 // missing address 383 _, err := comm.NewGRPCServer( 384 "", 385 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 386 ) 387 require.EqualError(t, err, "missing address parameter") 388 389 // missing port 390 _, err = comm.NewGRPCServer( 391 "abcdef", 392 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 393 ) 394 require.Error(t, err, "Expected error with missing port") 395 require.Contains(t, err.Error(), "missing port in address") 396 397 // bad port 398 _, err = comm.NewGRPCServer( 399 "127.0.0.1:1BBB", 400 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 401 ) 402 // check for possible errors based on platform and Go release 403 msgs := []string{ 404 "listen tcp: lookup tcp/1BBB: nodename nor servname provided, or not known", 405 "listen tcp: unknown port tcp/1BBB", 406 "listen tcp: address tcp/1BBB: unknown port", 407 "listen tcp: lookup tcp/1BBB: Servname not supported for ai_socktype", 408 } 409 require.Error(t, err, fmt.Sprintf("[%s], [%s] [%s] or [%s] expected", msgs[0], msgs[1], msgs[2], msgs[3])) 410 require.Contains(t, msgs, err.Error()) 411 412 // bad hostname 413 _, err = comm.NewGRPCServer( 414 "hostdoesnotexist.localdomain:9050", 415 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 416 ) 417 // We cannot check for a specific error message due to the fact that some 418 // systems will automatically resolve unknown host names to a "search" 419 // address so we just check to make sure that an error was returned 420 require.Error(t, err, "error expected") 421 422 // address in use 423 lis, err := net.Listen("tcp", "127.0.0.1:0") 424 require.NoError(t, err, "failed to create listener") 425 defer lis.Close() 426 427 _, err = comm.NewGRPCServerFromListener( 428 lis, 429 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 430 ) 431 require.NoError(t, err, "failed to create grpc server") 432 433 _, err = comm.NewGRPCServer( 434 lis.Addr().String(), 435 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 436 ) 437 require.Error(t, err) 438 require.Contains(t, err.Error(), "address already in use") 439 440 // missing server Certificate 441 _, err = comm.NewGRPCServerFromListener( 442 lis, 443 comm.ServerConfig{ 444 SecOpts: comm.SecureOptions{UseTLS: true, Key: []byte{}}, 445 }, 446 ) 447 require.EqualError(t, err, "serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true") 448 449 // missing server Key 450 _, err = comm.NewGRPCServerFromListener( 451 lis, 452 comm.ServerConfig{ 453 SecOpts: comm.SecureOptions{ 454 UseTLS: true, 455 Certificate: []byte{}, 456 }, 457 }, 458 ) 459 require.EqualError(t, err, "serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true") 460 461 // bad server Key 462 _, err = comm.NewGRPCServerFromListener( 463 lis, 464 comm.ServerConfig{ 465 SecOpts: comm.SecureOptions{ 466 UseTLS: true, 467 Certificate: []byte(selfSignedCertPEM), 468 Key: []byte{}, 469 }, 470 }, 471 ) 472 require.EqualError(t, err, "tls: failed to find any PEM data in key input") 473 474 // bad server Certificate 475 _, err = comm.NewGRPCServerFromListener( 476 lis, 477 comm.ServerConfig{ 478 SecOpts: comm.SecureOptions{ 479 UseTLS: true, 480 Certificate: []byte{}, 481 Key: []byte(selfSignedKeyPEM), 482 }, 483 }, 484 ) 485 require.EqualError(t, err, "tls: failed to find any PEM data in certificate input") 486 487 srv, err := comm.NewGRPCServerFromListener( 488 lis, 489 comm.ServerConfig{ 490 SecOpts: comm.SecureOptions{ 491 UseTLS: true, 492 Certificate: []byte(selfSignedCertPEM), 493 Key: []byte(selfSignedKeyPEM), 494 RequireClientCert: true, 495 }, 496 }, 497 ) 498 require.NoError(t, err) 499 500 badRootCAs := [][]byte{[]byte(badPEM)} 501 err = srv.SetClientRootCAs(badRootCAs) 502 require.EqualError(t, err, "failed to set client root certificate(s)") 503 } 504 505 func TestNewGRPCServer(t *testing.T) { 506 t.Parallel() 507 508 testAddress := "127.0.0.1:9053" 509 srv, err := comm.NewGRPCServer( 510 testAddress, 511 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 512 ) 513 require.NoError(t, err, "failed to create new GRPC server") 514 515 // resolve the address 516 addr, err := net.ResolveTCPAddr("tcp", testAddress) 517 require.NoError(t, err) 518 519 // make sure our properties are as expected 520 require.Equal(t, srv.Address(), addr.String()) 521 require.Equal(t, srv.Listener().Addr().String(), addr.String()) 522 require.Equal(t, srv.TLSEnabled(), false) 523 require.Equal(t, srv.MutualTLSRequired(), false) 524 525 // register the GRPC test server 526 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 527 528 // start the server 529 go srv.Start() 530 defer srv.Stop() 531 532 // should not be needed 533 time.Sleep(10 * time.Millisecond) 534 535 // invoke the EmptyCall service 536 _, err = invokeEmptyCall(testAddress, grpc.WithInsecure()) 537 require.NoError(t, err, "failed to invoke the EmptyCall service") 538 } 539 540 func TestNewGRPCServerFromListener(t *testing.T) { 541 t.Parallel() 542 543 // create our listener 544 lis, err := net.Listen("tcp", "127.0.0.1:0") 545 require.NoError(t, err, "failed to create listener") 546 testAddress := lis.Addr().String() 547 548 srv, err := comm.NewGRPCServerFromListener( 549 lis, 550 comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}}, 551 ) 552 require.NoError(t, err, "failed to create new GRPC server") 553 554 require.Equal(t, srv.Address(), testAddress) 555 require.Equal(t, srv.Listener().Addr().String(), testAddress) 556 require.Equal(t, srv.TLSEnabled(), false) 557 require.Equal(t, srv.MutualTLSRequired(), false) 558 559 // register the GRPC test server 560 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 561 562 // start the server 563 go srv.Start() 564 defer srv.Stop() 565 566 // should not be needed 567 time.Sleep(10 * time.Millisecond) 568 569 // invoke the EmptyCall service 570 _, err = invokeEmptyCall(testAddress, grpc.WithInsecure()) 571 require.NoError(t, err, "client failed to invoke the EmptyCall service") 572 } 573 574 func TestNewSecureGRPCServer(t *testing.T) { 575 t.Parallel() 576 577 // create our listener 578 lis, err := net.Listen("tcp", "127.0.0.1:0") 579 require.NoError(t, err, "failed to create listener") 580 testAddress := lis.Addr().String() 581 582 srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{ 583 ConnectionTimeout: 250 * time.Millisecond, 584 SecOpts: comm.SecureOptions{ 585 UseTLS: true, 586 Certificate: []byte(selfSignedCertPEM), 587 Key: []byte(selfSignedKeyPEM), 588 }, 589 }, 590 ) 591 require.NoError(t, err, "failed to create new grpc server") 592 593 // make sure our properties are as expected 594 require.NoError(t, err) 595 require.Equal(t, srv.Address(), testAddress) 596 require.Equal(t, srv.Listener().Addr().String(), testAddress) 597 598 cert, _ := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM)) 599 require.Equal(t, srv.ServerCertificate(), cert) 600 601 require.Equal(t, srv.TLSEnabled(), true) 602 require.Equal(t, srv.MutualTLSRequired(), false) 603 604 // register the GRPC test server 605 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 606 607 // start the server 608 go srv.Start() 609 defer srv.Stop() 610 611 // should not be needed 612 time.Sleep(10 * time.Millisecond) 613 614 // create the client credentials 615 certPool := x509.NewCertPool() 616 if !certPool.AppendCertsFromPEM([]byte(selfSignedCertPEM)) { 617 t.Fatal("Failed to append certificate to client credentials") 618 } 619 creds := credentials.NewClientTLSFromCert(certPool, "") 620 621 // invoke the EmptyCall service 622 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 623 require.NoError(t, err, "client failed to invoke the EmptyCall service") 624 625 // Test TLS versions which should be valid 626 tlsVersions := map[string]uint16{ 627 "TLS12": tls.VersionTLS12, 628 "TLS13": tls.VersionTLS13, 629 } 630 for name, tlsVersion := range tlsVersions { 631 tlsVersion := tlsVersion 632 633 t.Run(name, func(t *testing.T) { 634 creds := credentials.NewTLS(&tls.Config{RootCAs: certPool, MinVersion: tlsVersion, MaxVersion: tlsVersion}) 635 _, err := invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds), grpc.WithBlock()) 636 require.NoError(t, err) 637 }) 638 } 639 640 // Test TLS versions which should be invalid 641 tlsVersions = map[string]uint16{ 642 "SSL30": tls.VersionSSL30, 643 "TLS10": tls.VersionTLS10, 644 "TLS11": tls.VersionTLS11, 645 } 646 for name, tlsVersion := range tlsVersions { 647 tlsVersion := tlsVersion 648 t.Run(name, func(t *testing.T) { 649 t.Parallel() 650 651 creds := credentials.NewTLS(&tls.Config{RootCAs: certPool, MinVersion: tlsVersion, MaxVersion: tlsVersion}) 652 _, err := invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds), grpc.WithBlock()) 653 require.Error(t, err, "should not have been able to connect with TLS version < 1.2") 654 require.Contains(t, err.Error(), "context deadline exceeded") 655 }) 656 } 657 } 658 659 func TestVerifyCertificateCallback(t *testing.T) { 660 t.Parallel() 661 662 ca, err := tlsgen.NewCA() 663 require.NoError(t, err) 664 665 authorizedClientKeyPair, err := ca.NewClientCertKeyPair() 666 require.NoError(t, err) 667 668 notAuthorizedClientKeyPair, err := ca.NewClientCertKeyPair() 669 require.NoError(t, err) 670 671 serverKeyPair, err := ca.NewServerCertKeyPair("127.0.0.1") 672 require.NoError(t, err) 673 674 verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 675 if bytes.Equal(rawCerts[0], authorizedClientKeyPair.TLSCert.Raw) { 676 return nil 677 } 678 return errors.New("certificate mismatch") 679 } 680 681 probeTLS := func(endpoint string, clientKeyPair *tlsgen.CertKeyPair) error { 682 cert, err := tls.X509KeyPair(clientKeyPair.Cert, clientKeyPair.Key) 683 if err != nil { 684 return err 685 } 686 tlsCfg := &tls.Config{ 687 Certificates: []tls.Certificate{cert}, 688 RootCAs: x509.NewCertPool(), 689 MinVersion: tls.VersionTLS12, 690 MaxVersion: tls.VersionTLS12, 691 } 692 tlsCfg.RootCAs.AppendCertsFromPEM(ca.CertBytes()) 693 694 conn, err := tls.Dial("tcp", endpoint, tlsCfg) 695 if err != nil { 696 return err 697 } 698 conn.Close() 699 return nil 700 } 701 702 gRPCServer, err := comm.NewGRPCServer("127.0.0.1:", comm.ServerConfig{ 703 SecOpts: comm.SecureOptions{ 704 ClientRootCAs: [][]byte{ca.CertBytes()}, 705 Key: serverKeyPair.Key, 706 Certificate: serverKeyPair.Cert, 707 UseTLS: true, 708 VerifyCertificate: verifyFunc, 709 }, 710 }) 711 go gRPCServer.Start() 712 defer gRPCServer.Stop() 713 714 t.Run("Success path", func(t *testing.T) { 715 err = probeTLS(gRPCServer.Address(), authorizedClientKeyPair) 716 require.NoError(t, err) 717 }) 718 719 t.Run("Failure path", func(t *testing.T) { 720 err = probeTLS(gRPCServer.Address(), notAuthorizedClientKeyPair) 721 require.EqualError(t, err, "remote error: tls: bad certificate") 722 }) 723 } 724 725 // prior tests used self-signed certficates loaded by the GRPCServer and the test client 726 // here we'll use certificates signed by certificate authorities 727 func TestWithSignedRootCertificates(t *testing.T) { 728 t.Parallel() 729 730 // use Org1 testdata 731 fileBase := "Org1" 732 certPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-server1-cert.pem")) 733 require.NoError(t, err, "failed to load test certificates") 734 735 keyPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-server1-key.pem")) 736 require.NoError(t, err, "failed to load test certificates: %v") 737 738 caPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-cert.pem")) 739 require.NoError(t, err, "failed to load test certificates") 740 741 // create our listener 742 lis, err := net.Listen("tcp", "127.0.0.1:0") 743 require.NoError(t, err, "failed to create listener") 744 testAddress := lis.Addr().String() 745 746 srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{ 747 SecOpts: comm.SecureOptions{ 748 UseTLS: true, 749 Certificate: certPEMBlock, 750 Key: keyPEMBlock, 751 }, 752 }) 753 require.NoError(t, err, "failed to create new grpc server") 754 // register the GRPC test server 755 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 756 757 // start the server 758 go srv.Start() 759 defer srv.Stop() 760 761 // should not be needed 762 time.Sleep(10 * time.Millisecond) 763 764 // create a CertPool for use by the client with the server cert only 765 certPoolServer, err := createCertPool([][]byte{certPEMBlock}) 766 require.NoError(t, err, "failed to load root certificates into pool") 767 creds := credentials.NewClientTLSFromCert(certPoolServer, "") 768 769 // invoke the EmptyCall service 770 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 771 require.NoError(t, err, "Expected client to connect with server cert only") 772 773 // now use the CA certificate 774 certPoolCA := x509.NewCertPool() 775 if !certPoolCA.AppendCertsFromPEM(caPEMBlock) { 776 t.Fatal("Failed to append certificate to client credentials") 777 } 778 creds = credentials.NewClientTLSFromCert(certPoolCA, "") 779 780 // invoke the EmptyCall service 781 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 782 require.NoError(t, err, "client failed to invoke the EmptyCall") 783 } 784 785 // here we'll use certificates signed by intermediate certificate authorities 786 func TestWithSignedIntermediateCertificates(t *testing.T) { 787 t.Parallel() 788 789 // use Org1 testdata 790 fileBase := "Org1" 791 certPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-child1-server1-cert.pem")) 792 require.NoError(t, err) 793 794 keyPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-child1-server1-key.pem")) 795 require.NoError(t, err) 796 797 intermediatePEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-child1-cert.pem")) 798 if err != nil { 799 t.Fatalf("Failed to load test certificates: %v", err) 800 } 801 802 // create our listener 803 lis, err := net.Listen("tcp", "127.0.0.1:0") 804 if err != nil { 805 t.Fatalf("Failed to create listener: %v", err) 806 } 807 testAddress := lis.Addr().String() 808 809 srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{ 810 SecOpts: comm.SecureOptions{ 811 UseTLS: true, 812 Certificate: certPEMBlock, 813 Key: keyPEMBlock, 814 }, 815 }) 816 // check for error 817 if err != nil { 818 t.Fatalf("Failed to return new GRPC server: %v", err) 819 } 820 821 // register the GRPC test server 822 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 823 824 // start the server 825 go srv.Start() 826 defer srv.Stop() 827 828 // should not be needed 829 time.Sleep(10 * time.Millisecond) 830 831 // create a CertPool for use by the client with the server cert only 832 certPoolServer, err := createCertPool([][]byte{certPEMBlock}) 833 if err != nil { 834 t.Fatalf("Failed to load root certificates into pool: %v", err) 835 } 836 // create the client credentials 837 creds := credentials.NewClientTLSFromCert(certPoolServer, "") 838 839 // invoke the EmptyCall service 840 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 841 842 // client should be able to connect with Go 1.9 843 require.NoError(t, err, "Expected client to connect with server cert only") 844 845 // now use the CA certificate 846 // create a CertPool for use by the client with the intermediate root CA 847 certPoolCA, err := createCertPool([][]byte{intermediatePEMBlock}) 848 require.NoError(t, err, "failed to load root certificates into pool") 849 850 creds = credentials.NewClientTLSFromCert(certPoolCA, "") 851 852 // invoke the EmptyCall service 853 _, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds)) 854 require.NoError(t, err, "client failed to invoke the EmptyCall service") 855 } 856 857 // utility function for testing client / server communication using TLS 858 func runMutualAuth(t *testing.T, servers []testServer, trustedClients, unTrustedClients []*tls.Config) error { 859 // loop through all the test servers 860 for i := 0; i < len(servers); i++ { 861 // create listener 862 lis, err := net.Listen("tcp", "127.0.0.1:0") 863 if err != nil { 864 return err 865 } 866 srvAddr := lis.Addr().String() 867 868 // create GRPCServer 869 srv, err := comm.NewGRPCServerFromListener(lis, servers[i].config) 870 if err != nil { 871 return err 872 } 873 874 // MutualTLSRequired should be true 875 require.Equal(t, srv.MutualTLSRequired(), true) 876 877 // register the GRPC test server and start the GRPCServer 878 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 879 go srv.Start() 880 defer srv.Stop() 881 882 // should not be needed but just in case 883 time.Sleep(10 * time.Millisecond) 884 885 // loop through all the trusted clients 886 for j := 0; j < len(trustedClients); j++ { 887 // invoke the EmptyCall service 888 _, err = invokeEmptyCall(srvAddr, grpc.WithTransportCredentials(credentials.NewTLS(trustedClients[j]))) 889 // we expect success from trusted clients 890 if err != nil { 891 return err 892 } else { 893 t.Logf("Trusted client%d successfully connected to %s", j, srvAddr) 894 } 895 } 896 897 // loop through all the untrusted clients 898 for k := 0; k < len(unTrustedClients); k++ { 899 // invoke the EmptyCall service 900 _, err = invokeEmptyCall( 901 srvAddr, 902 grpc.WithTransportCredentials(credentials.NewTLS(unTrustedClients[k])), 903 ) 904 // we expect failure from untrusted clients 905 if err != nil { 906 t.Logf("Untrusted client%d was correctly rejected by %s", k, srvAddr) 907 } else { 908 return fmt.Errorf("Untrusted client %d should not have been able to connect to %s", k, srvAddr) 909 } 910 } 911 } 912 913 return nil 914 } 915 916 func TestMutualAuth(t *testing.T) { 917 t.Parallel() 918 919 tests := []struct { 920 name string 921 servers []testServer 922 trustedClients []*tls.Config 923 unTrustedClients []*tls.Config 924 }{ 925 { 926 name: "ClientAuthRequiredWithSingleOrg", 927 servers: testOrgs[0].testServers([][]byte{}), 928 trustedClients: testOrgs[0].trustedClients([][]byte{}), 929 unTrustedClients: testOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 930 }, 931 { 932 name: "ClientAuthRequiredWithChildClientOrg", 933 servers: testOrgs[0].testServers([][]byte{testOrgs[0].childOrgs[0].rootCA}), 934 trustedClients: testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA}), 935 unTrustedClients: testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 936 }, 937 { 938 name: "ClientAuthRequiredWithMultipleChildClientOrgs", 939 servers: testOrgs[0].testServers(append([][]byte{}, 940 testOrgs[0].childOrgs[0].rootCA, 941 testOrgs[0].childOrgs[1].rootCA, 942 )), 943 trustedClients: append(append([]*tls.Config{}, 944 testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA})...), 945 testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA})...), 946 unTrustedClients: testOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 947 }, 948 { 949 name: "ClientAuthRequiredWithDifferentServerAndClientOrgs", 950 servers: testOrgs[0].testServers([][]byte{testOrgs[1].rootCA}), 951 trustedClients: testOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 952 unTrustedClients: testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}), 953 }, 954 { 955 name: "ClientAuthRequiredWithDifferentServerAndChildClientOrgs", 956 servers: testOrgs[1].testServers([][]byte{testOrgs[0].childOrgs[0].rootCA}), 957 trustedClients: testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[1].rootCA}), 958 unTrustedClients: testOrgs[1].childOrgs[0].trustedClients([][]byte{testOrgs[1].rootCA}), 959 }, 960 } 961 962 for _, test := range tests { 963 test := test 964 t.Run(test.name, func(t *testing.T) { 965 t.Parallel() 966 t.Logf("Running test %s ...", test.name) 967 testErr := runMutualAuth(t, test.servers, test.trustedClients, test.unTrustedClients) 968 require.NoError(t, testErr) 969 }) 970 } 971 } 972 973 func TestSetClientRootCAs(t *testing.T) { 974 t.Parallel() 975 976 // get the config for one of our Org1 test servers 977 serverConfig := testOrgs[0].testServers([][]byte{})[0].config 978 lis, err := net.Listen("tcp", "127.0.0.1:0") 979 require.NoError(t, err, "listen failed") 980 defer lis.Close() 981 address := lis.Addr().String() 982 983 // create a GRPCServer 984 srv, err := comm.NewGRPCServerFromListener(lis, serverConfig) 985 require.NoError(t, err, "failed to create GRPCServer") 986 987 // register the GRPC test server and start the GRPCServer 988 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 989 go srv.Start() 990 defer srv.Stop() 991 992 // should not be needed but just in case 993 time.Sleep(10 * time.Millisecond) 994 995 // set up our test clients 996 // Org1 997 clientConfigOrg1Child1 := testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA})[0] 998 clientConfigOrg1Child2 := testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA})[0] 999 clientConfigsOrg1Children := []*tls.Config{clientConfigOrg1Child1, clientConfigOrg1Child2} 1000 org1ChildRootCAs := [][]byte{testOrgs[0].childOrgs[0].rootCA, testOrgs[0].childOrgs[1].rootCA} 1001 // Org2 1002 clientConfigOrg2Child1 := testOrgs[1].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA})[0] 1003 clientConfigOrg2Child2 := testOrgs[1].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA})[0] 1004 clientConfigsOrg2Children := []*tls.Config{clientConfigOrg2Child1, clientConfigOrg2Child2} 1005 org2ChildRootCAs := [][]byte{testOrgs[1].childOrgs[0].rootCA, testOrgs[1].childOrgs[1].rootCA} 1006 1007 // initially set client CAs to Org1 children 1008 err = srv.SetClientRootCAs(org1ChildRootCAs) 1009 require.NoError(t, err, "SetClientRootCAs failed") 1010 1011 // clientConfigsOrg1Children are currently trusted 1012 for _, clientConfig := range clientConfigsOrg1Children { 1013 // we expect success as these are trusted clients 1014 _, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig))) 1015 require.NoError(t, err, "trusted client should have connected") 1016 } 1017 1018 // clientConfigsOrg2Children are currently not trusted 1019 for _, clientConfig := range clientConfigsOrg2Children { 1020 // we expect failure as these are now untrusted clients 1021 _, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig))) 1022 require.Error(t, err, "untrusted client should not have been able to connect") 1023 } 1024 1025 // now set client CAs to Org2 children 1026 err = srv.SetClientRootCAs(org2ChildRootCAs) 1027 require.NoError(t, err, "SetClientRootCAs failed") 1028 1029 // now reverse trusted and not trusted 1030 // clientConfigsOrg1Children are currently trusted 1031 for _, clientConfig := range clientConfigsOrg2Children { 1032 // we expect success as these are trusted clients 1033 _, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig))) 1034 require.NoError(t, err, "trusted client should have connected") 1035 } 1036 1037 // clientConfigsOrg2Children are currently not trusted 1038 for _, clientConfig := range clientConfigsOrg1Children { 1039 // we expect failure as these are now untrusted clients 1040 _, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig))) 1041 require.Error(t, err, "untrusted client should not have connected") 1042 } 1043 } 1044 1045 func TestUpdateTLSCert(t *testing.T) { 1046 t.Parallel() 1047 1048 readFile := func(path string) []byte { 1049 fName := filepath.Join("testdata", "dynamic_cert_update", path) 1050 data, err := ioutil.ReadFile(fName) 1051 if err != nil { 1052 panic(fmt.Errorf("Failed reading %s: %v", fName, err)) 1053 } 1054 return data 1055 } 1056 loadBytes := func(prefix string) (key, cert, caCert []byte) { 1057 cert = readFile(filepath.Join(prefix, "server.crt")) 1058 key = readFile(filepath.Join(prefix, "server.key")) 1059 caCert = readFile(filepath.Join("ca.crt")) 1060 return 1061 } 1062 1063 key, cert, caCert := loadBytes("notlocalhost") 1064 1065 cfg := comm.ServerConfig{ 1066 SecOpts: comm.SecureOptions{ 1067 UseTLS: true, 1068 Key: key, 1069 Certificate: cert, 1070 }, 1071 } 1072 1073 // create our listener 1074 lis, err := net.Listen("tcp", "127.0.0.1:0") 1075 require.NoError(t, err, "listen failed") 1076 testAddress := lis.Addr().String() 1077 1078 srv, err := comm.NewGRPCServerFromListener(lis, cfg) 1079 require.NoError(t, err) 1080 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 1081 1082 go srv.Start() 1083 defer srv.Stop() 1084 1085 certPool := x509.NewCertPool() 1086 certPool.AppendCertsFromPEM(caCert) 1087 1088 probeServer := func() error { 1089 _, err = invokeEmptyCall( 1090 testAddress, 1091 grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{RootCAs: certPool})), 1092 grpc.WithBlock(), 1093 ) 1094 return err 1095 } 1096 1097 // bootstrap TLS certificate has a SAN of "notlocalhost" so it should fail 1098 err = probeServer() 1099 require.Error(t, err) 1100 require.Contains(t, err.Error(), "context deadline exceeded") 1101 1102 // new TLS certificate has a SAN of "127.0.0.1" so it should succeed 1103 certPath := filepath.Join("testdata", "dynamic_cert_update", "localhost", "server.crt") 1104 keyPath := filepath.Join("testdata", "dynamic_cert_update", "localhost", "server.key") 1105 tlsCert, err := tls.LoadX509KeyPair(certPath, keyPath) 1106 require.NoError(t, err) 1107 srv.SetServerCertificate(tlsCert) 1108 err = probeServer() 1109 require.NoError(t, err) 1110 1111 // revert back to the old certificate, should fail. 1112 certPath = filepath.Join("testdata", "dynamic_cert_update", "notlocalhost", "server.crt") 1113 keyPath = filepath.Join("testdata", "dynamic_cert_update", "notlocalhost", "server.key") 1114 tlsCert, err = tls.LoadX509KeyPair(certPath, keyPath) 1115 require.NoError(t, err) 1116 srv.SetServerCertificate(tlsCert) 1117 1118 err = probeServer() 1119 require.Error(t, err) 1120 require.Contains(t, err.Error(), "context deadline exceeded") 1121 } 1122 1123 func TestCipherSuites(t *testing.T) { 1124 t.Parallel() 1125 1126 certPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-server1-cert.pem")) 1127 require.NoError(t, err) 1128 keyPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-server1-key.pem")) 1129 require.NoError(t, err) 1130 caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem")) 1131 require.NoError(t, err) 1132 certPool, err := createCertPool([][]byte{caPEM}) 1133 require.NoError(t, err) 1134 1135 serverConfig := comm.ServerConfig{ 1136 SecOpts: comm.SecureOptions{ 1137 Certificate: certPEM, 1138 Key: keyPEM, 1139 UseTLS: true, 1140 }, 1141 } 1142 1143 fabricDefaultCipherSuite := func(cipher uint16) bool { 1144 for _, defaultCipher := range comm.DefaultTLSCipherSuites { 1145 if cipher == defaultCipher { 1146 return true 1147 } 1148 } 1149 return false 1150 } 1151 1152 var otherCipherSuites []uint16 1153 for _, cipher := range append(tls.CipherSuites(), tls.InsecureCipherSuites()...) { 1154 if !fabricDefaultCipherSuite(cipher.ID) { 1155 otherCipherSuites = append(otherCipherSuites, cipher.ID) 1156 } 1157 } 1158 1159 tests := []struct { 1160 name string 1161 clientCiphers []uint16 1162 success bool 1163 versions []uint16 1164 }{ 1165 { 1166 name: "server default / client all", 1167 success: true, 1168 versions: []uint16{tls.VersionTLS12, tls.VersionTLS13}, 1169 }, 1170 { 1171 name: "server default / client match", 1172 clientCiphers: comm.DefaultTLSCipherSuites, 1173 success: true, 1174 // Skip TLS1.3 as it ignores the Fabric DefaultCipherSuites 1175 // https://github.com/golang/go/issues/29349 1176 versions: []uint16{tls.VersionTLS12}, 1177 }, 1178 { 1179 name: "server default / client no match", 1180 clientCiphers: otherCipherSuites, 1181 success: false, 1182 // Skip TLS1.3 as it ignores the Fabric DefaultCipherSuites 1183 // https://github.com/golang/go/issues/29349 1184 versions: []uint16{tls.VersionTLS12}, 1185 }, 1186 } 1187 1188 // create our listener 1189 lis, err := net.Listen("tcp", "127.0.0.1:0") 1190 require.NoError(t, err, "listen failed") 1191 testAddress := lis.Addr().String() 1192 srv, err := comm.NewGRPCServerFromListener(lis, serverConfig) 1193 require.NoError(t, err) 1194 go srv.Start() 1195 1196 for _, test := range tests { 1197 test := test 1198 t.Run(test.name, func(t *testing.T) { 1199 t.Parallel() 1200 1201 for _, tlsVersion := range test.versions { 1202 tlsConfig := &tls.Config{ 1203 RootCAs: certPool, 1204 CipherSuites: test.clientCiphers, 1205 MinVersion: tlsVersion, 1206 MaxVersion: tlsVersion, 1207 } 1208 _, err := tls.Dial("tcp", testAddress, tlsConfig) 1209 if test.success { 1210 require.NoError(t, err) 1211 } else { 1212 require.Error(t, err, "expected handshake failure") 1213 require.Contains(t, err.Error(), "handshake failure") 1214 } 1215 } 1216 }) 1217 } 1218 } 1219 1220 func TestServerInterceptors(t *testing.T) { 1221 lis, err := net.Listen("tcp", "127.0.0.1:0") 1222 require.NoError(t, err, "listen failed") 1223 msg := "error from interceptor" 1224 1225 // set up interceptors 1226 usiCount := uint32(0) 1227 ssiCount := uint32(0) 1228 usi1 := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 1229 atomic.AddUint32(&usiCount, 1) 1230 return handler(ctx, req) 1231 } 1232 usi2 := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 1233 atomic.AddUint32(&usiCount, 1) 1234 return nil, status.Error(codes.Aborted, msg) 1235 } 1236 ssi1 := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 1237 atomic.AddUint32(&ssiCount, 1) 1238 return handler(srv, ss) 1239 } 1240 ssi2 := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 1241 atomic.AddUint32(&ssiCount, 1) 1242 return status.Error(codes.Aborted, msg) 1243 } 1244 1245 srvConfig := comm.ServerConfig{} 1246 srvConfig.UnaryInterceptors = append(srvConfig.UnaryInterceptors, usi1) 1247 srvConfig.UnaryInterceptors = append(srvConfig.UnaryInterceptors, usi2) 1248 srvConfig.StreamInterceptors = append(srvConfig.StreamInterceptors, ssi1) 1249 srvConfig.StreamInterceptors = append(srvConfig.StreamInterceptors, ssi2) 1250 1251 srv, err := comm.NewGRPCServerFromListener(lis, srvConfig) 1252 require.NoError(t, err, "failed to create gRPC server") 1253 testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{}) 1254 defer srv.Stop() 1255 go srv.Start() 1256 1257 _, err = invokeEmptyCall( 1258 lis.Addr().String(), 1259 grpc.WithBlock(), 1260 grpc.WithInsecure(), 1261 ) 1262 require.Error(t, err) 1263 require.Equal(t, status.Convert(err).Message(), msg, "Expected error from second usi") 1264 require.Equal(t, uint32(2), atomic.LoadUint32(&usiCount), "Expected both usi handlers to be invoked") 1265 1266 _, err = invokeEmptyStream( 1267 lis.Addr().String(), 1268 grpc.WithBlock(), 1269 grpc.WithInsecure(), 1270 ) 1271 require.Error(t, err) 1272 require.Equal(t, status.Convert(err).Message(), msg, "Expected error from second ssi") 1273 require.Equal(t, uint32(2), atomic.LoadUint32(&ssiCount), "Expected both ssi handlers to be invoked") 1274 }