github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/server_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm_test
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"fmt"
    15  	"io"
    16  	"io/ioutil"
    17  	"log"
    18  	"net"
    19  	"path/filepath"
    20  	"sync/atomic"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/hechain20/hechain/common/crypto/tlsgen"
    25  	"github.com/hechain20/hechain/internal/pkg/comm"
    26  	"github.com/hechain20/hechain/internal/pkg/comm/testpb"
    27  	"github.com/pkg/errors"
    28  	"github.com/stretchr/testify/require"
    29  	"google.golang.org/grpc"
    30  	"google.golang.org/grpc/codes"
    31  	"google.golang.org/grpc/credentials"
    32  	"google.golang.org/grpc/status"
    33  )
    34  
    35  // Embedded certificates for testing
    36  // The self-signed cert expires in 2028
    37  var selfSignedKeyPEM = `-----BEGIN EC PRIVATE KEY-----
    38  MHcCAQEEIMLemLh3+uDzww1pvqP6Xj2Z0Kc6yqf3RxyfTBNwRuuyoAoGCCqGSM49
    39  AwEHoUQDQgAEDB3l94vM7EqKr2L/vhqU5IsEub0rviqCAaWGiVAPp3orb/LJqFLS
    40  yo/k60rhUiir6iD4S4pb5TEb2ouWylQI3A==
    41  -----END EC PRIVATE KEY-----
    42  `
    43  
    44  var selfSignedCertPEM = `-----BEGIN CERTIFICATE-----
    45  MIIBdDCCARqgAwIBAgIRAKCiW5r6W32jGUn+l9BORMAwCgYIKoZIzj0EAwIwEjEQ
    46  MA4GA1UEChMHQWNtZSBDbzAeFw0xODA4MjExMDI1MzJaFw0yODA4MTgxMDI1MzJa
    47  MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQM
    48  HeX3i8zsSoqvYv++GpTkiwS5vSu+KoIBpYaJUA+neitv8smoUtLKj+TrSuFSKKvq
    49  IPhLilvlMRvai5bKVAjco1EwTzAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYI
    50  KwYBBQUHAwEwDAYDVR0TAQH/BAIwADAaBgNVHREEEzARgglsb2NhbGhvc3SHBH8A
    51  AAEwCgYIKoZIzj0EAwIDSAAwRQIgOaYc3pdGf2j0uXRyvdBJq2PlK9FkgvsUjXOT
    52  bQ9fWRkCIQCr1FiRRzapgtrnttDn3O2fhLlbrw67kClzY8pIIN42Qw==
    53  -----END CERTIFICATE-----
    54  `
    55  
    56  var badPEM = `-----BEGIN CERTIFICATE-----
    57  MIICRDCCAemgAwIBAgIJALwW//dz2ZBvMAoGCCqGSM49BAMCMH4xCzAJBgNVBAYT
    58  AlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1TYW4gRnJhbmNpc2Nv
    59  MRgwFgYDVQQKDA9MaW51eEZvdW5kYXRpb24xFDASBgNVBAsMC0h5cGVybGVkZ2Vy
    60  MRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMTYxMjA0MjIzMDE4WhcNMjYxMjAyMjIz
    61  MDE4WjB+MQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UE
    62  BwwNU2FuIEZyYW5jaXNjbzEYMBYGA1UECgwPTGludXhGb3VuZGF0aW9uMRQwEgYD
    63  VQQLDAtIeXBlcmxlZGdlcjESMBAGA1UEAwwJbG9jYWxob3N0MFkwEwYHKoZIzj0C
    64  -----END CERTIFICATE-----
    65  `
    66  
    67  var testOrgs = []testOrg{}
    68  
    69  func init() {
    70  	// load up crypto material for test orgs
    71  	for i := 1; i <= numOrgs; i++ {
    72  		testOrg, err := loadOrg(i)
    73  		if err != nil {
    74  			log.Fatalf("Failed to load test organizations due to error: %s", err.Error())
    75  		}
    76  		testOrgs = append(testOrgs, testOrg)
    77  	}
    78  }
    79  
    80  // test servers to be registered with the GRPCServer
    81  type emptyServiceServer struct{}
    82  
    83  func (ess *emptyServiceServer) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, error) {
    84  	return new(testpb.Empty), nil
    85  }
    86  
    87  func (esss *emptyServiceServer) EmptyStream(stream testpb.EmptyService_EmptyStreamServer) error {
    88  	for {
    89  		_, err := stream.Recv()
    90  		if err == io.EOF {
    91  			return nil
    92  		}
    93  		if err != nil {
    94  			return err
    95  		}
    96  		if err := stream.Send(&testpb.Empty{}); err != nil {
    97  			return err
    98  		}
    99  
   100  	}
   101  }
   102  
   103  // invoke the EmptyCall RPC
   104  func invokeEmptyCall(address string, dialOptions ...grpc.DialOption) (*testpb.Empty, error) {
   105  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   106  	defer cancel()
   107  	// create GRPC client conn
   108  	clientConn, err := grpc.DialContext(ctx, address, dialOptions...)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	defer clientConn.Close()
   113  
   114  	// create GRPC client
   115  	client := testpb.NewEmptyServiceClient(clientConn)
   116  
   117  	// invoke service
   118  	empty, err := client.EmptyCall(context.Background(), new(testpb.Empty))
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	return empty, nil
   124  }
   125  
   126  // invoke the EmptyStream RPC
   127  func invokeEmptyStream(address string, dialOptions ...grpc.DialOption) (*testpb.Empty, error) {
   128  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   129  	defer cancel()
   130  	// create GRPC client conn
   131  	clientConn, err := grpc.DialContext(ctx, address, dialOptions...)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	defer clientConn.Close()
   136  
   137  	stream, err := testpb.NewEmptyServiceClient(clientConn).EmptyStream(ctx)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	var msg *testpb.Empty
   143  	var streamErr error
   144  
   145  	waitc := make(chan struct{})
   146  	go func() {
   147  		for {
   148  			in, err := stream.Recv()
   149  			if err == io.EOF {
   150  				close(waitc)
   151  				return
   152  			}
   153  			if err != nil {
   154  				streamErr = err
   155  				close(waitc)
   156  				return
   157  			}
   158  			msg = in
   159  		}
   160  	}()
   161  
   162  	// TestServerInterceptors adds an interceptor that does not call the target
   163  	// StreamHandler and returns an error so Send can return with an io.EOF since
   164  	// the server side has already terminated. Whether or not we get an error
   165  	// depends on timing.
   166  	err = stream.Send(&testpb.Empty{})
   167  	if err != nil && err != io.EOF {
   168  		return nil, fmt.Errorf("stream send failed: %s", err)
   169  	}
   170  
   171  	stream.CloseSend()
   172  	<-waitc
   173  	return msg, streamErr
   174  }
   175  
   176  const (
   177  	numOrgs        = 2
   178  	numChildOrgs   = 2
   179  	numServerCerts = 2
   180  )
   181  
   182  // string for cert filenames
   183  var (
   184  	orgCACert       = filepath.Join("testdata", "certs", "Org%d-cert.pem")
   185  	orgServerKey    = filepath.Join("testdata", "certs", "Org%d-server%d-key.pem")
   186  	orgServerCert   = filepath.Join("testdata", "certs", "Org%d-server%d-cert.pem")
   187  	orgClientKey    = filepath.Join("testdata", "certs", "Org%d-client%d-key.pem")
   188  	orgClientCert   = filepath.Join("testdata", "certs", "Org%d-client%d-cert.pem")
   189  	childCACert     = filepath.Join("testdata", "certs", "Org%d-child%d-cert.pem")
   190  	childServerKey  = filepath.Join("testdata", "certs", "Org%d-child%d-server%d-key.pem")
   191  	childServerCert = filepath.Join("testdata", "certs", "Org%d-child%d-server%d-cert.pem")
   192  	childClientKey  = filepath.Join("testdata", "certs", "Org%d-child%d-client%d-key.pem")
   193  	childClientCert = filepath.Join("testdata", "certs", "Org%d-child%d-client%d-cert.pem")
   194  )
   195  
   196  type testServer struct {
   197  	config comm.ServerConfig
   198  }
   199  
   200  type serverCert struct {
   201  	keyPEM  []byte
   202  	certPEM []byte
   203  }
   204  
   205  type testOrg struct {
   206  	rootCA      []byte
   207  	serverCerts []serverCert
   208  	clientCerts []tls.Certificate
   209  	childOrgs   []testOrg
   210  }
   211  
   212  // return *X509.CertPool for the rootCA of the org
   213  func (org *testOrg) rootCertPool() *x509.CertPool {
   214  	certPool := x509.NewCertPool()
   215  	certPool.AppendCertsFromPEM(org.rootCA)
   216  	return certPool
   217  }
   218  
   219  // return testServers for the org
   220  func (org *testOrg) testServers(clientRootCAs [][]byte) []testServer {
   221  	clientRootCAs = append(clientRootCAs, org.rootCA)
   222  
   223  	// loop through the serverCerts and create testServers
   224  	testServers := []testServer{}
   225  	for _, serverCert := range org.serverCerts {
   226  		testServer := testServer{
   227  			comm.ServerConfig{
   228  				ConnectionTimeout: 250 * time.Millisecond,
   229  				SecOpts: comm.SecureOptions{
   230  					UseTLS:            true,
   231  					Certificate:       serverCert.certPEM,
   232  					Key:               serverCert.keyPEM,
   233  					RequireClientCert: true,
   234  					ClientRootCAs:     clientRootCAs,
   235  				},
   236  			},
   237  		}
   238  		testServers = append(testServers, testServer)
   239  	}
   240  	return testServers
   241  }
   242  
   243  // return trusted clients for the org
   244  func (org *testOrg) trustedClients(serverRootCAs [][]byte) []*tls.Config {
   245  	// if we have any additional server root CAs add them to the certPool
   246  	certPool := org.rootCertPool()
   247  	for _, serverRootCA := range serverRootCAs {
   248  		certPool.AppendCertsFromPEM(serverRootCA)
   249  	}
   250  
   251  	// loop through the clientCerts and create tls.Configs
   252  	trustedClients := []*tls.Config{}
   253  	for _, clientCert := range org.clientCerts {
   254  		trustedClient := &tls.Config{
   255  			Certificates: []tls.Certificate{clientCert},
   256  			RootCAs:      certPool,
   257  		}
   258  		trustedClients = append(trustedClients, trustedClient)
   259  	}
   260  	return trustedClients
   261  }
   262  
   263  // createCertPool creates an x509.CertPool from an array of PEM-encoded certificates
   264  func createCertPool(rootCAs [][]byte) (*x509.CertPool, error) {
   265  	certPool := x509.NewCertPool()
   266  	for _, rootCA := range rootCAs {
   267  		if !certPool.AppendCertsFromPEM(rootCA) {
   268  			return nil, errors.New("Failed to load root certificates")
   269  		}
   270  	}
   271  	return certPool, nil
   272  }
   273  
   274  // utility function to load crypto material for organizations
   275  func loadOrg(parent int) (testOrg, error) {
   276  	org := testOrg{}
   277  	// load the CA
   278  	caPEM, err := ioutil.ReadFile(fmt.Sprintf(orgCACert, parent))
   279  	if err != nil {
   280  		return org, err
   281  	}
   282  
   283  	// loop through and load servers
   284  	serverCerts := []serverCert{}
   285  	for i := 1; i <= numServerCerts; i++ {
   286  		keyPEM, err := ioutil.ReadFile(fmt.Sprintf(orgServerKey, parent, i))
   287  		if err != nil {
   288  			return org, err
   289  		}
   290  		certPEM, err := ioutil.ReadFile(fmt.Sprintf(orgServerCert, parent, i))
   291  		if err != nil {
   292  			return org, err
   293  		}
   294  		serverCerts = append(serverCerts, serverCert{keyPEM, certPEM})
   295  	}
   296  
   297  	// loop through and load clients
   298  	clientCerts := []tls.Certificate{}
   299  	for j := 1; j <= numServerCerts; j++ {
   300  		clientCert, err := loadTLSKeyPairFromFile(fmt.Sprintf(orgClientKey, parent, j),
   301  			fmt.Sprintf(orgClientCert, parent, j))
   302  		if err != nil {
   303  			return org, err
   304  		}
   305  		clientCerts = append(clientCerts, clientCert)
   306  	}
   307  
   308  	// loop through and load child orgs
   309  	childOrgs := []testOrg{}
   310  	for k := 1; k <= numChildOrgs; k++ {
   311  		childOrg, err := loadChildOrg(parent, k)
   312  		if err != nil {
   313  			return org, err
   314  		}
   315  		childOrgs = append(childOrgs, childOrg)
   316  	}
   317  
   318  	return testOrg{caPEM, serverCerts, clientCerts, childOrgs}, nil
   319  }
   320  
   321  // utility function to load crypto material for child organizations
   322  func loadChildOrg(parent, child int) (testOrg, error) {
   323  	// load the CA
   324  	caPEM, err := ioutil.ReadFile(fmt.Sprintf(childCACert, parent, child))
   325  	if err != nil {
   326  		return testOrg{}, err
   327  	}
   328  
   329  	// loop through and load servers
   330  	serverCerts := []serverCert{}
   331  	for i := 1; i <= numServerCerts; i++ {
   332  		keyPEM, err := ioutil.ReadFile(fmt.Sprintf(childServerKey, parent, child, i))
   333  		if err != nil {
   334  			return testOrg{}, err
   335  		}
   336  		certPEM, err := ioutil.ReadFile(fmt.Sprintf(childServerCert, parent, child, i))
   337  		if err != nil {
   338  			return testOrg{}, err
   339  		}
   340  		serverCerts = append(serverCerts, serverCert{keyPEM, certPEM})
   341  	}
   342  
   343  	// loop through and load clients
   344  	clientCerts := []tls.Certificate{}
   345  	for j := 1; j <= numServerCerts; j++ {
   346  		clientCert, err := loadTLSKeyPairFromFile(
   347  			fmt.Sprintf(childClientKey, parent, child, j),
   348  			fmt.Sprintf(childClientCert, parent, child, j),
   349  		)
   350  		if err != nil {
   351  			return testOrg{}, err
   352  		}
   353  		clientCerts = append(clientCerts, clientCert)
   354  	}
   355  
   356  	return testOrg{caPEM, serverCerts, clientCerts, []testOrg{}}, nil
   357  }
   358  
   359  // loadTLSKeyPairFromFile creates a tls.Certificate from PEM-encoded key and cert files
   360  func loadTLSKeyPairFromFile(keyFile, certFile string) (tls.Certificate, error) {
   361  	certPEMBlock, err := ioutil.ReadFile(certFile)
   362  	if err != nil {
   363  		return tls.Certificate{}, err
   364  	}
   365  
   366  	keyPEMBlock, err := ioutil.ReadFile(keyFile)
   367  	if err != nil {
   368  		return tls.Certificate{}, err
   369  	}
   370  
   371  	cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
   372  	if err != nil {
   373  		return tls.Certificate{}, err
   374  	}
   375  
   376  	return cert, nil
   377  }
   378  
   379  func TestNewGRPCServerInvalidParameters(t *testing.T) {
   380  	t.Parallel()
   381  
   382  	// missing address
   383  	_, err := comm.NewGRPCServer(
   384  		"",
   385  		comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}},
   386  	)
   387  	require.EqualError(t, err, "missing address parameter")
   388  
   389  	// missing port
   390  	_, err = comm.NewGRPCServer(
   391  		"abcdef",
   392  		comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}},
   393  	)
   394  	require.Error(t, err, "Expected error with missing port")
   395  	require.Contains(t, err.Error(), "missing port in address")
   396  
   397  	// bad port
   398  	_, err = comm.NewGRPCServer(
   399  		"127.0.0.1:1BBB",
   400  		comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}},
   401  	)
   402  	// check for possible errors based on platform and Go release
   403  	msgs := []string{
   404  		"listen tcp: lookup tcp/1BBB: nodename nor servname provided, or not known",
   405  		"listen tcp: unknown port tcp/1BBB",
   406  		"listen tcp: address tcp/1BBB: unknown port",
   407  		"listen tcp: lookup tcp/1BBB: Servname not supported for ai_socktype",
   408  	}
   409  	require.Error(t, err, fmt.Sprintf("[%s], [%s] [%s] or [%s] expected", msgs[0], msgs[1], msgs[2], msgs[3]))
   410  	require.Contains(t, msgs, err.Error())
   411  
   412  	// bad hostname
   413  	_, err = comm.NewGRPCServer(
   414  		"hostdoesnotexist.localdomain:9050",
   415  		comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}},
   416  	)
   417  	// We cannot check for a specific error message due to the fact that some
   418  	// systems will automatically resolve unknown host names to a "search"
   419  	// address so we just check to make sure that an error was returned
   420  	require.Error(t, err, "error expected")
   421  
   422  	// address in use
   423  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   424  	require.NoError(t, err, "failed to create listener")
   425  	defer lis.Close()
   426  
   427  	_, err = comm.NewGRPCServerFromListener(
   428  		lis,
   429  		comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}},
   430  	)
   431  	require.NoError(t, err, "failed to create grpc server")
   432  
   433  	_, err = comm.NewGRPCServer(
   434  		lis.Addr().String(),
   435  		comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}},
   436  	)
   437  	require.Error(t, err)
   438  	require.Contains(t, err.Error(), "address already in use")
   439  
   440  	// missing server Certificate
   441  	_, err = comm.NewGRPCServerFromListener(
   442  		lis,
   443  		comm.ServerConfig{
   444  			SecOpts: comm.SecureOptions{UseTLS: true, Key: []byte{}},
   445  		},
   446  	)
   447  	require.EqualError(t, err, "serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true")
   448  
   449  	// missing server Key
   450  	_, err = comm.NewGRPCServerFromListener(
   451  		lis,
   452  		comm.ServerConfig{
   453  			SecOpts: comm.SecureOptions{
   454  				UseTLS:      true,
   455  				Certificate: []byte{},
   456  			},
   457  		},
   458  	)
   459  	require.EqualError(t, err, "serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true")
   460  
   461  	// bad server Key
   462  	_, err = comm.NewGRPCServerFromListener(
   463  		lis,
   464  		comm.ServerConfig{
   465  			SecOpts: comm.SecureOptions{
   466  				UseTLS:      true,
   467  				Certificate: []byte(selfSignedCertPEM),
   468  				Key:         []byte{},
   469  			},
   470  		},
   471  	)
   472  	require.EqualError(t, err, "tls: failed to find any PEM data in key input")
   473  
   474  	// bad server Certificate
   475  	_, err = comm.NewGRPCServerFromListener(
   476  		lis,
   477  		comm.ServerConfig{
   478  			SecOpts: comm.SecureOptions{
   479  				UseTLS:      true,
   480  				Certificate: []byte{},
   481  				Key:         []byte(selfSignedKeyPEM),
   482  			},
   483  		},
   484  	)
   485  	require.EqualError(t, err, "tls: failed to find any PEM data in certificate input")
   486  
   487  	srv, err := comm.NewGRPCServerFromListener(
   488  		lis,
   489  		comm.ServerConfig{
   490  			SecOpts: comm.SecureOptions{
   491  				UseTLS:            true,
   492  				Certificate:       []byte(selfSignedCertPEM),
   493  				Key:               []byte(selfSignedKeyPEM),
   494  				RequireClientCert: true,
   495  			},
   496  		},
   497  	)
   498  	require.NoError(t, err)
   499  
   500  	badRootCAs := [][]byte{[]byte(badPEM)}
   501  	err = srv.SetClientRootCAs(badRootCAs)
   502  	require.EqualError(t, err, "failed to set client root certificate(s)")
   503  }
   504  
   505  func TestNewGRPCServer(t *testing.T) {
   506  	t.Parallel()
   507  
   508  	testAddress := "127.0.0.1:9053"
   509  	srv, err := comm.NewGRPCServer(
   510  		testAddress,
   511  		comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}},
   512  	)
   513  	require.NoError(t, err, "failed to create new GRPC server")
   514  
   515  	// resolve the address
   516  	addr, err := net.ResolveTCPAddr("tcp", testAddress)
   517  	require.NoError(t, err)
   518  
   519  	// make sure our properties are as expected
   520  	require.Equal(t, srv.Address(), addr.String())
   521  	require.Equal(t, srv.Listener().Addr().String(), addr.String())
   522  	require.Equal(t, srv.TLSEnabled(), false)
   523  	require.Equal(t, srv.MutualTLSRequired(), false)
   524  
   525  	// register the GRPC test server
   526  	testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{})
   527  
   528  	// start the server
   529  	go srv.Start()
   530  	defer srv.Stop()
   531  
   532  	// should not be needed
   533  	time.Sleep(10 * time.Millisecond)
   534  
   535  	// invoke the EmptyCall service
   536  	_, err = invokeEmptyCall(testAddress, grpc.WithInsecure())
   537  	require.NoError(t, err, "failed to invoke the EmptyCall service")
   538  }
   539  
   540  func TestNewGRPCServerFromListener(t *testing.T) {
   541  	t.Parallel()
   542  
   543  	// create our listener
   544  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   545  	require.NoError(t, err, "failed to create listener")
   546  	testAddress := lis.Addr().String()
   547  
   548  	srv, err := comm.NewGRPCServerFromListener(
   549  		lis,
   550  		comm.ServerConfig{SecOpts: comm.SecureOptions{UseTLS: false}},
   551  	)
   552  	require.NoError(t, err, "failed to create new GRPC server")
   553  
   554  	require.Equal(t, srv.Address(), testAddress)
   555  	require.Equal(t, srv.Listener().Addr().String(), testAddress)
   556  	require.Equal(t, srv.TLSEnabled(), false)
   557  	require.Equal(t, srv.MutualTLSRequired(), false)
   558  
   559  	// register the GRPC test server
   560  	testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{})
   561  
   562  	// start the server
   563  	go srv.Start()
   564  	defer srv.Stop()
   565  
   566  	// should not be needed
   567  	time.Sleep(10 * time.Millisecond)
   568  
   569  	// invoke the EmptyCall service
   570  	_, err = invokeEmptyCall(testAddress, grpc.WithInsecure())
   571  	require.NoError(t, err, "client failed to invoke the EmptyCall service")
   572  }
   573  
   574  func TestNewSecureGRPCServer(t *testing.T) {
   575  	t.Parallel()
   576  
   577  	// create our listener
   578  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   579  	require.NoError(t, err, "failed to create listener")
   580  	testAddress := lis.Addr().String()
   581  
   582  	srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{
   583  		ConnectionTimeout: 250 * time.Millisecond,
   584  		SecOpts: comm.SecureOptions{
   585  			UseTLS:      true,
   586  			Certificate: []byte(selfSignedCertPEM),
   587  			Key:         []byte(selfSignedKeyPEM),
   588  		},
   589  	},
   590  	)
   591  	require.NoError(t, err, "failed to create new grpc server")
   592  
   593  	// make sure our properties are as expected
   594  	require.NoError(t, err)
   595  	require.Equal(t, srv.Address(), testAddress)
   596  	require.Equal(t, srv.Listener().Addr().String(), testAddress)
   597  
   598  	cert, _ := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM))
   599  	require.Equal(t, srv.ServerCertificate(), cert)
   600  
   601  	require.Equal(t, srv.TLSEnabled(), true)
   602  	require.Equal(t, srv.MutualTLSRequired(), false)
   603  
   604  	// register the GRPC test server
   605  	testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{})
   606  
   607  	// start the server
   608  	go srv.Start()
   609  	defer srv.Stop()
   610  
   611  	// should not be needed
   612  	time.Sleep(10 * time.Millisecond)
   613  
   614  	// create the client credentials
   615  	certPool := x509.NewCertPool()
   616  	if !certPool.AppendCertsFromPEM([]byte(selfSignedCertPEM)) {
   617  		t.Fatal("Failed to append certificate to client credentials")
   618  	}
   619  	creds := credentials.NewClientTLSFromCert(certPool, "")
   620  
   621  	// invoke the EmptyCall service
   622  	_, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds))
   623  	require.NoError(t, err, "client failed to invoke the EmptyCall service")
   624  
   625  	// Test TLS versions which should be valid
   626  	tlsVersions := map[string]uint16{
   627  		"TLS12": tls.VersionTLS12,
   628  		"TLS13": tls.VersionTLS13,
   629  	}
   630  	for name, tlsVersion := range tlsVersions {
   631  		tlsVersion := tlsVersion
   632  
   633  		t.Run(name, func(t *testing.T) {
   634  			creds := credentials.NewTLS(&tls.Config{RootCAs: certPool, MinVersion: tlsVersion, MaxVersion: tlsVersion})
   635  			_, err := invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds), grpc.WithBlock())
   636  			require.NoError(t, err)
   637  		})
   638  	}
   639  
   640  	// Test TLS versions which should be invalid
   641  	tlsVersions = map[string]uint16{
   642  		"SSL30": tls.VersionSSL30,
   643  		"TLS10": tls.VersionTLS10,
   644  		"TLS11": tls.VersionTLS11,
   645  	}
   646  	for name, tlsVersion := range tlsVersions {
   647  		tlsVersion := tlsVersion
   648  		t.Run(name, func(t *testing.T) {
   649  			t.Parallel()
   650  
   651  			creds := credentials.NewTLS(&tls.Config{RootCAs: certPool, MinVersion: tlsVersion, MaxVersion: tlsVersion})
   652  			_, err := invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds), grpc.WithBlock())
   653  			require.Error(t, err, "should not have been able to connect with TLS version < 1.2")
   654  			require.Contains(t, err.Error(), "context deadline exceeded")
   655  		})
   656  	}
   657  }
   658  
   659  func TestVerifyCertificateCallback(t *testing.T) {
   660  	t.Parallel()
   661  
   662  	ca, err := tlsgen.NewCA()
   663  	require.NoError(t, err)
   664  
   665  	authorizedClientKeyPair, err := ca.NewClientCertKeyPair()
   666  	require.NoError(t, err)
   667  
   668  	notAuthorizedClientKeyPair, err := ca.NewClientCertKeyPair()
   669  	require.NoError(t, err)
   670  
   671  	serverKeyPair, err := ca.NewServerCertKeyPair("127.0.0.1")
   672  	require.NoError(t, err)
   673  
   674  	verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   675  		if bytes.Equal(rawCerts[0], authorizedClientKeyPair.TLSCert.Raw) {
   676  			return nil
   677  		}
   678  		return errors.New("certificate mismatch")
   679  	}
   680  
   681  	probeTLS := func(endpoint string, clientKeyPair *tlsgen.CertKeyPair) error {
   682  		cert, err := tls.X509KeyPair(clientKeyPair.Cert, clientKeyPair.Key)
   683  		if err != nil {
   684  			return err
   685  		}
   686  		tlsCfg := &tls.Config{
   687  			Certificates: []tls.Certificate{cert},
   688  			RootCAs:      x509.NewCertPool(),
   689  			MinVersion:   tls.VersionTLS12,
   690  			MaxVersion:   tls.VersionTLS12,
   691  		}
   692  		tlsCfg.RootCAs.AppendCertsFromPEM(ca.CertBytes())
   693  
   694  		conn, err := tls.Dial("tcp", endpoint, tlsCfg)
   695  		if err != nil {
   696  			return err
   697  		}
   698  		conn.Close()
   699  		return nil
   700  	}
   701  
   702  	gRPCServer, err := comm.NewGRPCServer("127.0.0.1:", comm.ServerConfig{
   703  		SecOpts: comm.SecureOptions{
   704  			ClientRootCAs:     [][]byte{ca.CertBytes()},
   705  			Key:               serverKeyPair.Key,
   706  			Certificate:       serverKeyPair.Cert,
   707  			UseTLS:            true,
   708  			VerifyCertificate: verifyFunc,
   709  		},
   710  	})
   711  	go gRPCServer.Start()
   712  	defer gRPCServer.Stop()
   713  
   714  	t.Run("Success path", func(t *testing.T) {
   715  		err = probeTLS(gRPCServer.Address(), authorizedClientKeyPair)
   716  		require.NoError(t, err)
   717  	})
   718  
   719  	t.Run("Failure path", func(t *testing.T) {
   720  		err = probeTLS(gRPCServer.Address(), notAuthorizedClientKeyPair)
   721  		require.EqualError(t, err, "remote error: tls: bad certificate")
   722  	})
   723  }
   724  
   725  // prior tests used self-signed certficates loaded by the GRPCServer and the test client
   726  // here we'll use certificates signed by certificate authorities
   727  func TestWithSignedRootCertificates(t *testing.T) {
   728  	t.Parallel()
   729  
   730  	// use Org1 testdata
   731  	fileBase := "Org1"
   732  	certPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-server1-cert.pem"))
   733  	require.NoError(t, err, "failed to load test certificates")
   734  
   735  	keyPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-server1-key.pem"))
   736  	require.NoError(t, err, "failed to load test certificates: %v")
   737  
   738  	caPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-cert.pem"))
   739  	require.NoError(t, err, "failed to load test certificates")
   740  
   741  	// create our listener
   742  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   743  	require.NoError(t, err, "failed to create listener")
   744  	testAddress := lis.Addr().String()
   745  
   746  	srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{
   747  		SecOpts: comm.SecureOptions{
   748  			UseTLS:      true,
   749  			Certificate: certPEMBlock,
   750  			Key:         keyPEMBlock,
   751  		},
   752  	})
   753  	require.NoError(t, err, "failed to create new grpc server")
   754  	// register the GRPC test server
   755  	testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{})
   756  
   757  	// start the server
   758  	go srv.Start()
   759  	defer srv.Stop()
   760  
   761  	// should not be needed
   762  	time.Sleep(10 * time.Millisecond)
   763  
   764  	// create a CertPool for use by the client with the server cert only
   765  	certPoolServer, err := createCertPool([][]byte{certPEMBlock})
   766  	require.NoError(t, err, "failed to load root certificates into pool")
   767  	creds := credentials.NewClientTLSFromCert(certPoolServer, "")
   768  
   769  	// invoke the EmptyCall service
   770  	_, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds))
   771  	require.NoError(t, err, "Expected client to connect with server cert only")
   772  
   773  	// now use the CA certificate
   774  	certPoolCA := x509.NewCertPool()
   775  	if !certPoolCA.AppendCertsFromPEM(caPEMBlock) {
   776  		t.Fatal("Failed to append certificate to client credentials")
   777  	}
   778  	creds = credentials.NewClientTLSFromCert(certPoolCA, "")
   779  
   780  	// invoke the EmptyCall service
   781  	_, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds))
   782  	require.NoError(t, err, "client failed to invoke the EmptyCall")
   783  }
   784  
   785  // here we'll use certificates signed by intermediate certificate authorities
   786  func TestWithSignedIntermediateCertificates(t *testing.T) {
   787  	t.Parallel()
   788  
   789  	// use Org1 testdata
   790  	fileBase := "Org1"
   791  	certPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-child1-server1-cert.pem"))
   792  	require.NoError(t, err)
   793  
   794  	keyPEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-child1-server1-key.pem"))
   795  	require.NoError(t, err)
   796  
   797  	intermediatePEMBlock, err := ioutil.ReadFile(filepath.Join("testdata", "certs", fileBase+"-child1-cert.pem"))
   798  	if err != nil {
   799  		t.Fatalf("Failed to load test certificates: %v", err)
   800  	}
   801  
   802  	// create our listener
   803  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   804  	if err != nil {
   805  		t.Fatalf("Failed to create listener: %v", err)
   806  	}
   807  	testAddress := lis.Addr().String()
   808  
   809  	srv, err := comm.NewGRPCServerFromListener(lis, comm.ServerConfig{
   810  		SecOpts: comm.SecureOptions{
   811  			UseTLS:      true,
   812  			Certificate: certPEMBlock,
   813  			Key:         keyPEMBlock,
   814  		},
   815  	})
   816  	// check for error
   817  	if err != nil {
   818  		t.Fatalf("Failed to return new GRPC server: %v", err)
   819  	}
   820  
   821  	// register the GRPC test server
   822  	testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{})
   823  
   824  	// start the server
   825  	go srv.Start()
   826  	defer srv.Stop()
   827  
   828  	// should not be needed
   829  	time.Sleep(10 * time.Millisecond)
   830  
   831  	// create a CertPool for use by the client with the server cert only
   832  	certPoolServer, err := createCertPool([][]byte{certPEMBlock})
   833  	if err != nil {
   834  		t.Fatalf("Failed to load root certificates into pool: %v", err)
   835  	}
   836  	// create the client credentials
   837  	creds := credentials.NewClientTLSFromCert(certPoolServer, "")
   838  
   839  	// invoke the EmptyCall service
   840  	_, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds))
   841  
   842  	// client should be able to connect with Go 1.9
   843  	require.NoError(t, err, "Expected client to connect with server cert only")
   844  
   845  	// now use the CA certificate
   846  	// create a CertPool for use by the client with the intermediate root CA
   847  	certPoolCA, err := createCertPool([][]byte{intermediatePEMBlock})
   848  	require.NoError(t, err, "failed to load root certificates into pool")
   849  
   850  	creds = credentials.NewClientTLSFromCert(certPoolCA, "")
   851  
   852  	// invoke the EmptyCall service
   853  	_, err = invokeEmptyCall(testAddress, grpc.WithTransportCredentials(creds))
   854  	require.NoError(t, err, "client failed to invoke the EmptyCall service")
   855  }
   856  
   857  // utility function for testing client / server communication using TLS
   858  func runMutualAuth(t *testing.T, servers []testServer, trustedClients, unTrustedClients []*tls.Config) error {
   859  	// loop through all the test servers
   860  	for i := 0; i < len(servers); i++ {
   861  		// create listener
   862  		lis, err := net.Listen("tcp", "127.0.0.1:0")
   863  		if err != nil {
   864  			return err
   865  		}
   866  		srvAddr := lis.Addr().String()
   867  
   868  		// create GRPCServer
   869  		srv, err := comm.NewGRPCServerFromListener(lis, servers[i].config)
   870  		if err != nil {
   871  			return err
   872  		}
   873  
   874  		// MutualTLSRequired should be true
   875  		require.Equal(t, srv.MutualTLSRequired(), true)
   876  
   877  		// register the GRPC test server and start the GRPCServer
   878  		testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{})
   879  		go srv.Start()
   880  		defer srv.Stop()
   881  
   882  		// should not be needed but just in case
   883  		time.Sleep(10 * time.Millisecond)
   884  
   885  		// loop through all the trusted clients
   886  		for j := 0; j < len(trustedClients); j++ {
   887  			// invoke the EmptyCall service
   888  			_, err = invokeEmptyCall(srvAddr, grpc.WithTransportCredentials(credentials.NewTLS(trustedClients[j])))
   889  			// we expect success from trusted clients
   890  			if err != nil {
   891  				return err
   892  			} else {
   893  				t.Logf("Trusted client%d successfully connected to %s", j, srvAddr)
   894  			}
   895  		}
   896  
   897  		// loop through all the untrusted clients
   898  		for k := 0; k < len(unTrustedClients); k++ {
   899  			// invoke the EmptyCall service
   900  			_, err = invokeEmptyCall(
   901  				srvAddr,
   902  				grpc.WithTransportCredentials(credentials.NewTLS(unTrustedClients[k])),
   903  			)
   904  			// we expect failure from untrusted clients
   905  			if err != nil {
   906  				t.Logf("Untrusted client%d was correctly rejected by %s", k, srvAddr)
   907  			} else {
   908  				return fmt.Errorf("Untrusted client %d should not have been able to connect to %s", k, srvAddr)
   909  			}
   910  		}
   911  	}
   912  
   913  	return nil
   914  }
   915  
   916  func TestMutualAuth(t *testing.T) {
   917  	t.Parallel()
   918  
   919  	tests := []struct {
   920  		name             string
   921  		servers          []testServer
   922  		trustedClients   []*tls.Config
   923  		unTrustedClients []*tls.Config
   924  	}{
   925  		{
   926  			name:             "ClientAuthRequiredWithSingleOrg",
   927  			servers:          testOrgs[0].testServers([][]byte{}),
   928  			trustedClients:   testOrgs[0].trustedClients([][]byte{}),
   929  			unTrustedClients: testOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}),
   930  		},
   931  		{
   932  			name:             "ClientAuthRequiredWithChildClientOrg",
   933  			servers:          testOrgs[0].testServers([][]byte{testOrgs[0].childOrgs[0].rootCA}),
   934  			trustedClients:   testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA}),
   935  			unTrustedClients: testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}),
   936  		},
   937  		{
   938  			name: "ClientAuthRequiredWithMultipleChildClientOrgs",
   939  			servers: testOrgs[0].testServers(append([][]byte{},
   940  				testOrgs[0].childOrgs[0].rootCA,
   941  				testOrgs[0].childOrgs[1].rootCA,
   942  			)),
   943  			trustedClients: append(append([]*tls.Config{},
   944  				testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA})...),
   945  				testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA})...),
   946  			unTrustedClients: testOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}),
   947  		},
   948  		{
   949  			name:             "ClientAuthRequiredWithDifferentServerAndClientOrgs",
   950  			servers:          testOrgs[0].testServers([][]byte{testOrgs[1].rootCA}),
   951  			trustedClients:   testOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}),
   952  			unTrustedClients: testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA}),
   953  		},
   954  		{
   955  			name:             "ClientAuthRequiredWithDifferentServerAndChildClientOrgs",
   956  			servers:          testOrgs[1].testServers([][]byte{testOrgs[0].childOrgs[0].rootCA}),
   957  			trustedClients:   testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[1].rootCA}),
   958  			unTrustedClients: testOrgs[1].childOrgs[0].trustedClients([][]byte{testOrgs[1].rootCA}),
   959  		},
   960  	}
   961  
   962  	for _, test := range tests {
   963  		test := test
   964  		t.Run(test.name, func(t *testing.T) {
   965  			t.Parallel()
   966  			t.Logf("Running test %s ...", test.name)
   967  			testErr := runMutualAuth(t, test.servers, test.trustedClients, test.unTrustedClients)
   968  			require.NoError(t, testErr)
   969  		})
   970  	}
   971  }
   972  
   973  func TestSetClientRootCAs(t *testing.T) {
   974  	t.Parallel()
   975  
   976  	// get the config for one of our Org1 test servers
   977  	serverConfig := testOrgs[0].testServers([][]byte{})[0].config
   978  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   979  	require.NoError(t, err, "listen failed")
   980  	defer lis.Close()
   981  	address := lis.Addr().String()
   982  
   983  	// create a GRPCServer
   984  	srv, err := comm.NewGRPCServerFromListener(lis, serverConfig)
   985  	require.NoError(t, err, "failed to create GRPCServer")
   986  
   987  	// register the GRPC test server and start the GRPCServer
   988  	testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{})
   989  	go srv.Start()
   990  	defer srv.Stop()
   991  
   992  	// should not be needed but just in case
   993  	time.Sleep(10 * time.Millisecond)
   994  
   995  	// set up our test clients
   996  	// Org1
   997  	clientConfigOrg1Child1 := testOrgs[0].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA})[0]
   998  	clientConfigOrg1Child2 := testOrgs[0].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA})[0]
   999  	clientConfigsOrg1Children := []*tls.Config{clientConfigOrg1Child1, clientConfigOrg1Child2}
  1000  	org1ChildRootCAs := [][]byte{testOrgs[0].childOrgs[0].rootCA, testOrgs[0].childOrgs[1].rootCA}
  1001  	// Org2
  1002  	clientConfigOrg2Child1 := testOrgs[1].childOrgs[0].trustedClients([][]byte{testOrgs[0].rootCA})[0]
  1003  	clientConfigOrg2Child2 := testOrgs[1].childOrgs[1].trustedClients([][]byte{testOrgs[0].rootCA})[0]
  1004  	clientConfigsOrg2Children := []*tls.Config{clientConfigOrg2Child1, clientConfigOrg2Child2}
  1005  	org2ChildRootCAs := [][]byte{testOrgs[1].childOrgs[0].rootCA, testOrgs[1].childOrgs[1].rootCA}
  1006  
  1007  	// initially set client CAs to Org1 children
  1008  	err = srv.SetClientRootCAs(org1ChildRootCAs)
  1009  	require.NoError(t, err, "SetClientRootCAs failed")
  1010  
  1011  	// clientConfigsOrg1Children are currently trusted
  1012  	for _, clientConfig := range clientConfigsOrg1Children {
  1013  		// we expect success as these are trusted clients
  1014  		_, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig)))
  1015  		require.NoError(t, err, "trusted client should have connected")
  1016  	}
  1017  
  1018  	// clientConfigsOrg2Children are currently not trusted
  1019  	for _, clientConfig := range clientConfigsOrg2Children {
  1020  		// we expect failure as these are now untrusted clients
  1021  		_, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig)))
  1022  		require.Error(t, err, "untrusted client should not have been able to connect")
  1023  	}
  1024  
  1025  	// now set client CAs to Org2 children
  1026  	err = srv.SetClientRootCAs(org2ChildRootCAs)
  1027  	require.NoError(t, err, "SetClientRootCAs failed")
  1028  
  1029  	// now reverse trusted and not trusted
  1030  	// clientConfigsOrg1Children are currently trusted
  1031  	for _, clientConfig := range clientConfigsOrg2Children {
  1032  		// we expect success as these are trusted clients
  1033  		_, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig)))
  1034  		require.NoError(t, err, "trusted client should have connected")
  1035  	}
  1036  
  1037  	// clientConfigsOrg2Children are currently not trusted
  1038  	for _, clientConfig := range clientConfigsOrg1Children {
  1039  		// we expect failure as these are now untrusted clients
  1040  		_, err = invokeEmptyCall(address, grpc.WithTransportCredentials(credentials.NewTLS(clientConfig)))
  1041  		require.Error(t, err, "untrusted client should not have connected")
  1042  	}
  1043  }
  1044  
  1045  func TestUpdateTLSCert(t *testing.T) {
  1046  	t.Parallel()
  1047  
  1048  	readFile := func(path string) []byte {
  1049  		fName := filepath.Join("testdata", "dynamic_cert_update", path)
  1050  		data, err := ioutil.ReadFile(fName)
  1051  		if err != nil {
  1052  			panic(fmt.Errorf("Failed reading %s: %v", fName, err))
  1053  		}
  1054  		return data
  1055  	}
  1056  	loadBytes := func(prefix string) (key, cert, caCert []byte) {
  1057  		cert = readFile(filepath.Join(prefix, "server.crt"))
  1058  		key = readFile(filepath.Join(prefix, "server.key"))
  1059  		caCert = readFile(filepath.Join("ca.crt"))
  1060  		return
  1061  	}
  1062  
  1063  	key, cert, caCert := loadBytes("notlocalhost")
  1064  
  1065  	cfg := comm.ServerConfig{
  1066  		SecOpts: comm.SecureOptions{
  1067  			UseTLS:      true,
  1068  			Key:         key,
  1069  			Certificate: cert,
  1070  		},
  1071  	}
  1072  
  1073  	// create our listener
  1074  	lis, err := net.Listen("tcp", "127.0.0.1:0")
  1075  	require.NoError(t, err, "listen failed")
  1076  	testAddress := lis.Addr().String()
  1077  
  1078  	srv, err := comm.NewGRPCServerFromListener(lis, cfg)
  1079  	require.NoError(t, err)
  1080  	testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{})
  1081  
  1082  	go srv.Start()
  1083  	defer srv.Stop()
  1084  
  1085  	certPool := x509.NewCertPool()
  1086  	certPool.AppendCertsFromPEM(caCert)
  1087  
  1088  	probeServer := func() error {
  1089  		_, err = invokeEmptyCall(
  1090  			testAddress,
  1091  			grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{RootCAs: certPool})),
  1092  			grpc.WithBlock(),
  1093  		)
  1094  		return err
  1095  	}
  1096  
  1097  	// bootstrap TLS certificate has a SAN of "notlocalhost" so it should fail
  1098  	err = probeServer()
  1099  	require.Error(t, err)
  1100  	require.Contains(t, err.Error(), "context deadline exceeded")
  1101  
  1102  	// new TLS certificate has a SAN of "127.0.0.1" so it should succeed
  1103  	certPath := filepath.Join("testdata", "dynamic_cert_update", "localhost", "server.crt")
  1104  	keyPath := filepath.Join("testdata", "dynamic_cert_update", "localhost", "server.key")
  1105  	tlsCert, err := tls.LoadX509KeyPair(certPath, keyPath)
  1106  	require.NoError(t, err)
  1107  	srv.SetServerCertificate(tlsCert)
  1108  	err = probeServer()
  1109  	require.NoError(t, err)
  1110  
  1111  	// revert back to the old certificate, should fail.
  1112  	certPath = filepath.Join("testdata", "dynamic_cert_update", "notlocalhost", "server.crt")
  1113  	keyPath = filepath.Join("testdata", "dynamic_cert_update", "notlocalhost", "server.key")
  1114  	tlsCert, err = tls.LoadX509KeyPair(certPath, keyPath)
  1115  	require.NoError(t, err)
  1116  	srv.SetServerCertificate(tlsCert)
  1117  
  1118  	err = probeServer()
  1119  	require.Error(t, err)
  1120  	require.Contains(t, err.Error(), "context deadline exceeded")
  1121  }
  1122  
  1123  func TestCipherSuites(t *testing.T) {
  1124  	t.Parallel()
  1125  
  1126  	certPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-server1-cert.pem"))
  1127  	require.NoError(t, err)
  1128  	keyPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-server1-key.pem"))
  1129  	require.NoError(t, err)
  1130  	caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
  1131  	require.NoError(t, err)
  1132  	certPool, err := createCertPool([][]byte{caPEM})
  1133  	require.NoError(t, err)
  1134  
  1135  	serverConfig := comm.ServerConfig{
  1136  		SecOpts: comm.SecureOptions{
  1137  			Certificate: certPEM,
  1138  			Key:         keyPEM,
  1139  			UseTLS:      true,
  1140  		},
  1141  	}
  1142  
  1143  	fabricDefaultCipherSuite := func(cipher uint16) bool {
  1144  		for _, defaultCipher := range comm.DefaultTLSCipherSuites {
  1145  			if cipher == defaultCipher {
  1146  				return true
  1147  			}
  1148  		}
  1149  		return false
  1150  	}
  1151  
  1152  	var otherCipherSuites []uint16
  1153  	for _, cipher := range append(tls.CipherSuites(), tls.InsecureCipherSuites()...) {
  1154  		if !fabricDefaultCipherSuite(cipher.ID) {
  1155  			otherCipherSuites = append(otherCipherSuites, cipher.ID)
  1156  		}
  1157  	}
  1158  
  1159  	tests := []struct {
  1160  		name          string
  1161  		clientCiphers []uint16
  1162  		success       bool
  1163  		versions      []uint16
  1164  	}{
  1165  		{
  1166  			name:     "server default / client all",
  1167  			success:  true,
  1168  			versions: []uint16{tls.VersionTLS12, tls.VersionTLS13},
  1169  		},
  1170  		{
  1171  			name:          "server default / client match",
  1172  			clientCiphers: comm.DefaultTLSCipherSuites,
  1173  			success:       true,
  1174  			// Skip TLS1.3 as it ignores the Fabric DefaultCipherSuites
  1175  			// https://github.com/golang/go/issues/29349
  1176  			versions: []uint16{tls.VersionTLS12},
  1177  		},
  1178  		{
  1179  			name:          "server default / client no match",
  1180  			clientCiphers: otherCipherSuites,
  1181  			success:       false,
  1182  			// Skip TLS1.3 as it ignores the Fabric DefaultCipherSuites
  1183  			// https://github.com/golang/go/issues/29349
  1184  			versions: []uint16{tls.VersionTLS12},
  1185  		},
  1186  	}
  1187  
  1188  	// create our listener
  1189  	lis, err := net.Listen("tcp", "127.0.0.1:0")
  1190  	require.NoError(t, err, "listen failed")
  1191  	testAddress := lis.Addr().String()
  1192  	srv, err := comm.NewGRPCServerFromListener(lis, serverConfig)
  1193  	require.NoError(t, err)
  1194  	go srv.Start()
  1195  
  1196  	for _, test := range tests {
  1197  		test := test
  1198  		t.Run(test.name, func(t *testing.T) {
  1199  			t.Parallel()
  1200  
  1201  			for _, tlsVersion := range test.versions {
  1202  				tlsConfig := &tls.Config{
  1203  					RootCAs:      certPool,
  1204  					CipherSuites: test.clientCiphers,
  1205  					MinVersion:   tlsVersion,
  1206  					MaxVersion:   tlsVersion,
  1207  				}
  1208  				_, err := tls.Dial("tcp", testAddress, tlsConfig)
  1209  				if test.success {
  1210  					require.NoError(t, err)
  1211  				} else {
  1212  					require.Error(t, err, "expected handshake failure")
  1213  					require.Contains(t, err.Error(), "handshake failure")
  1214  				}
  1215  			}
  1216  		})
  1217  	}
  1218  }
  1219  
  1220  func TestServerInterceptors(t *testing.T) {
  1221  	lis, err := net.Listen("tcp", "127.0.0.1:0")
  1222  	require.NoError(t, err, "listen failed")
  1223  	msg := "error from interceptor"
  1224  
  1225  	// set up interceptors
  1226  	usiCount := uint32(0)
  1227  	ssiCount := uint32(0)
  1228  	usi1 := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
  1229  		atomic.AddUint32(&usiCount, 1)
  1230  		return handler(ctx, req)
  1231  	}
  1232  	usi2 := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
  1233  		atomic.AddUint32(&usiCount, 1)
  1234  		return nil, status.Error(codes.Aborted, msg)
  1235  	}
  1236  	ssi1 := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
  1237  		atomic.AddUint32(&ssiCount, 1)
  1238  		return handler(srv, ss)
  1239  	}
  1240  	ssi2 := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
  1241  		atomic.AddUint32(&ssiCount, 1)
  1242  		return status.Error(codes.Aborted, msg)
  1243  	}
  1244  
  1245  	srvConfig := comm.ServerConfig{}
  1246  	srvConfig.UnaryInterceptors = append(srvConfig.UnaryInterceptors, usi1)
  1247  	srvConfig.UnaryInterceptors = append(srvConfig.UnaryInterceptors, usi2)
  1248  	srvConfig.StreamInterceptors = append(srvConfig.StreamInterceptors, ssi1)
  1249  	srvConfig.StreamInterceptors = append(srvConfig.StreamInterceptors, ssi2)
  1250  
  1251  	srv, err := comm.NewGRPCServerFromListener(lis, srvConfig)
  1252  	require.NoError(t, err, "failed to create gRPC server")
  1253  	testpb.RegisterEmptyServiceServer(srv.Server(), &emptyServiceServer{})
  1254  	defer srv.Stop()
  1255  	go srv.Start()
  1256  
  1257  	_, err = invokeEmptyCall(
  1258  		lis.Addr().String(),
  1259  		grpc.WithBlock(),
  1260  		grpc.WithInsecure(),
  1261  	)
  1262  	require.Error(t, err)
  1263  	require.Equal(t, status.Convert(err).Message(), msg, "Expected error from second usi")
  1264  	require.Equal(t, uint32(2), atomic.LoadUint32(&usiCount), "Expected both usi handlers to be invoked")
  1265  
  1266  	_, err = invokeEmptyStream(
  1267  		lis.Addr().String(),
  1268  		grpc.WithBlock(),
  1269  		grpc.WithInsecure(),
  1270  	)
  1271  	require.Error(t, err)
  1272  	require.Equal(t, status.Convert(err).Message(), msg, "Expected error from second ssi")
  1273  	require.Equal(t, uint32(2), atomic.LoadUint32(&ssiCount), "Expected both ssi handlers to be invoked")
  1274  }