github.com/true-sqn/fabric@v2.1.1+incompatible/internal/pkg/comm/creds_test.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm_test
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"io/ioutil"
    14  	"net"
    15  	"path/filepath"
    16  	"sync"
    17  	"testing"
    18  
    19  	"github.com/hyperledger/fabric/common/flogging/floggingtest"
    20  	"github.com/hyperledger/fabric/internal/pkg/comm"
    21  	"github.com/stretchr/testify/assert"
    22  )
    23  
    24  func TestCreds(t *testing.T) {
    25  	t.Parallel()
    26  
    27  	caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
    28  	if err != nil {
    29  		t.Fatalf("failed to read root certificate: %v", err)
    30  	}
    31  	certPool := x509.NewCertPool()
    32  	ok := certPool.AppendCertsFromPEM(caPEM)
    33  	if !ok {
    34  		t.Fatalf("failed to create certPool")
    35  	}
    36  	cert, err := tls.LoadX509KeyPair(
    37  		filepath.Join("testdata", "certs", "Org1-server1-cert.pem"),
    38  		filepath.Join("testdata", "certs", "Org1-server1-key.pem"),
    39  	)
    40  	if err != nil {
    41  		t.Fatalf("failed to load TLS certificate [%s]", err)
    42  	}
    43  
    44  	tlsConfig := &tls.Config{
    45  		Certificates: []tls.Certificate{cert},
    46  	}
    47  
    48  	config := comm.NewTLSConfig(tlsConfig)
    49  
    50  	logger, recorder := floggingtest.NewTestLogger(t)
    51  
    52  	creds := comm.NewServerTransportCredentials(config, logger)
    53  	_, _, err = creds.ClientHandshake(context.Background(), "", nil)
    54  	assert.EqualError(t, err, comm.ErrClientHandshakeNotImplemented.Error())
    55  	err = creds.OverrideServerName("")
    56  	assert.EqualError(t, err, comm.ErrOverrideHostnameNotSupported.Error())
    57  	assert.Equal(t, "1.2", creds.Info().SecurityVersion)
    58  	assert.Equal(t, "tls", creds.Info().SecurityProtocol)
    59  
    60  	lis, err := net.Listen("tcp", "localhost:0")
    61  	if err != nil {
    62  		t.Fatalf("failed to start listener [%s]", err)
    63  	}
    64  	defer lis.Close()
    65  
    66  	_, port, err := net.SplitHostPort(lis.Addr().String())
    67  	assert.NoError(t, err)
    68  	addr := net.JoinHostPort("localhost", port)
    69  
    70  	handshake := func(wg *sync.WaitGroup) {
    71  		defer wg.Done()
    72  		conn, err := lis.Accept()
    73  		if err != nil {
    74  			t.Logf("failed to accept connection [%s]", err)
    75  		}
    76  		_, _, err = creds.ServerHandshake(conn)
    77  		if err != nil {
    78  			t.Logf("ServerHandshake error [%s]", err)
    79  		}
    80  	}
    81  
    82  	wg := &sync.WaitGroup{}
    83  	wg.Add(1)
    84  	go handshake(wg)
    85  	_, err = tls.Dial("tcp", addr, &tls.Config{RootCAs: certPool})
    86  	wg.Wait()
    87  	assert.NoError(t, err)
    88  
    89  	wg = &sync.WaitGroup{}
    90  	wg.Add(1)
    91  	go handshake(wg)
    92  	_, err = tls.Dial("tcp", addr, &tls.Config{
    93  		RootCAs:    certPool,
    94  		MaxVersion: tls.VersionTLS10,
    95  	})
    96  	wg.Wait()
    97  	assert.Contains(t, err.Error(), "protocol version not supported")
    98  	assert.Contains(t, recorder.Messages()[0], "TLS handshake failed with error")
    99  }
   100  
   101  func TestNewTLSConfig(t *testing.T) {
   102  	t.Parallel()
   103  	tlsConfig := &tls.Config{}
   104  
   105  	config := comm.NewTLSConfig(tlsConfig)
   106  
   107  	assert.NotEmpty(t, config, "TLSConfig is not empty")
   108  }
   109  
   110  func TestConfig(t *testing.T) {
   111  	t.Parallel()
   112  	config := comm.NewTLSConfig(&tls.Config{
   113  		ServerName: "bueno",
   114  	})
   115  
   116  	configCopy := config.Config()
   117  
   118  	certPool := x509.NewCertPool()
   119  	config.SetClientCAs(certPool)
   120  
   121  	assert.NotEqual(t, config.Config(), &configCopy, "TLSConfig should have new certs")
   122  }
   123  
   124  func TestAddRootCA(t *testing.T) {
   125  	t.Parallel()
   126  
   127  	caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
   128  	if err != nil {
   129  		t.Fatalf("failed to read root certificate: %v", err)
   130  	}
   131  
   132  	cert := &x509.Certificate{
   133  		EmailAddresses: []string{"test@foobar.com"},
   134  	}
   135  
   136  	expectedCertPool := x509.NewCertPool()
   137  	ok := expectedCertPool.AppendCertsFromPEM(caPEM)
   138  	if !ok {
   139  		t.Fatalf("failed to create expected certPool")
   140  	}
   141  
   142  	expectedCertPool.AddCert(cert)
   143  
   144  	certPool := x509.NewCertPool()
   145  	ok = certPool.AppendCertsFromPEM(caPEM)
   146  	if !ok {
   147  		t.Fatalf("failed to create certPool")
   148  	}
   149  
   150  	tlsConfig := &tls.Config{
   151  		ClientCAs: certPool,
   152  	}
   153  	config := comm.NewTLSConfig(tlsConfig)
   154  
   155  	assert.Equal(t, config.Config().ClientCAs, certPool)
   156  
   157  	config.AddClientRootCA(cert)
   158  
   159  	assert.Equal(t, config.Config().ClientCAs, expectedCertPool, "The CertPools should be equal")
   160  }
   161  
   162  func TestSetClientCAs(t *testing.T) {
   163  	t.Parallel()
   164  	tlsConfig := &tls.Config{
   165  		Certificates: []tls.Certificate{},
   166  	}
   167  	config := comm.NewTLSConfig(tlsConfig)
   168  
   169  	assert.Empty(t, config.Config().ClientCAs, "No CertPool should be defined")
   170  
   171  	certPool := x509.NewCertPool()
   172  	config.SetClientCAs(certPool)
   173  
   174  	assert.NotNil(t, config.Config().ClientCAs, "The CertPools' should not be the same")
   175  }