github.com/osdi23p228/fabric@v0.0.0-20221218062954-77808885f5db/internal/pkg/comm/creds_test.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm_test
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"io/ioutil"
    14  	"net"
    15  	"path/filepath"
    16  	"sync"
    17  	"testing"
    18  
    19  	"github.com/stretchr/testify/require"
    20  
    21  	"github.com/osdi23p228/fabric/common/flogging/floggingtest"
    22  	"github.com/osdi23p228/fabric/internal/pkg/comm"
    23  	"github.com/stretchr/testify/assert"
    24  )
    25  
    26  func TestCreds(t *testing.T) {
    27  	t.Parallel()
    28  
    29  	caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
    30  	if err != nil {
    31  		t.Fatalf("failed to read root certificate: %v", err)
    32  	}
    33  	certPool := x509.NewCertPool()
    34  	ok := certPool.AppendCertsFromPEM(caPEM)
    35  	if !ok {
    36  		t.Fatalf("failed to create certPool")
    37  	}
    38  	cert, err := tls.LoadX509KeyPair(
    39  		filepath.Join("testdata", "certs", "Org1-server1-cert.pem"),
    40  		filepath.Join("testdata", "certs", "Org1-server1-key.pem"),
    41  	)
    42  	if err != nil {
    43  		t.Fatalf("failed to load TLS certificate [%s]", err)
    44  	}
    45  
    46  	tlsConfig := &tls.Config{
    47  		Certificates: []tls.Certificate{cert},
    48  	}
    49  
    50  	config := comm.NewTLSConfig(tlsConfig)
    51  
    52  	logger, recorder := floggingtest.NewTestLogger(t)
    53  
    54  	creds := comm.NewServerTransportCredentials(config, logger)
    55  	_, _, err = creds.ClientHandshake(context.Background(), "", nil)
    56  	assert.EqualError(t, err, comm.ErrClientHandshakeNotImplemented.Error())
    57  	err = creds.OverrideServerName("")
    58  	assert.EqualError(t, err, comm.ErrOverrideHostnameNotSupported.Error())
    59  	assert.Equal(t, "1.2", creds.Info().SecurityVersion)
    60  	assert.Equal(t, "tls", creds.Info().SecurityProtocol)
    61  
    62  	lis, err := net.Listen("tcp", "localhost:0")
    63  	if err != nil {
    64  		t.Fatalf("failed to start listener [%s]", err)
    65  	}
    66  	defer lis.Close()
    67  
    68  	_, port, err := net.SplitHostPort(lis.Addr().String())
    69  	assert.NoError(t, err)
    70  	addr := net.JoinHostPort("localhost", port)
    71  
    72  	handshake := func(wg *sync.WaitGroup) {
    73  		defer wg.Done()
    74  		conn, err := lis.Accept()
    75  		if err != nil {
    76  			t.Logf("failed to accept connection [%s]", err)
    77  		}
    78  		_, _, err = creds.ServerHandshake(conn)
    79  		if err != nil {
    80  			t.Logf("ServerHandshake error [%s]", err)
    81  		}
    82  	}
    83  
    84  	wg := &sync.WaitGroup{}
    85  	wg.Add(1)
    86  	go handshake(wg)
    87  	_, err = tls.Dial("tcp", addr, &tls.Config{RootCAs: certPool})
    88  	wg.Wait()
    89  	assert.NoError(t, err)
    90  
    91  	wg = &sync.WaitGroup{}
    92  	wg.Add(1)
    93  	go handshake(wg)
    94  	_, err = tls.Dial("tcp", addr, &tls.Config{
    95  		RootCAs:    certPool,
    96  		MaxVersion: tls.VersionTLS10,
    97  	})
    98  	wg.Wait()
    99  	require.Contains(t, err.Error(), "protocol version not supported")
   100  	require.Contains(t, recorder.Messages()[1], "TLS handshake failed")
   101  }
   102  
   103  func TestNewTLSConfig(t *testing.T) {
   104  	t.Parallel()
   105  	tlsConfig := &tls.Config{}
   106  
   107  	config := comm.NewTLSConfig(tlsConfig)
   108  
   109  	assert.NotEmpty(t, config, "TLSConfig is not empty")
   110  }
   111  
   112  func TestConfig(t *testing.T) {
   113  	t.Parallel()
   114  	config := comm.NewTLSConfig(&tls.Config{
   115  		ServerName: "bueno",
   116  	})
   117  
   118  	configCopy := config.Config()
   119  
   120  	certPool := x509.NewCertPool()
   121  	config.SetClientCAs(certPool)
   122  
   123  	assert.NotEqual(t, config.Config(), &configCopy, "TLSConfig should have new certs")
   124  }
   125  
   126  func TestAddRootCA(t *testing.T) {
   127  	t.Parallel()
   128  
   129  	caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
   130  	if err != nil {
   131  		t.Fatalf("failed to read root certificate: %v", err)
   132  	}
   133  
   134  	cert := &x509.Certificate{
   135  		EmailAddresses: []string{"test@foobar.com"},
   136  	}
   137  
   138  	expectedCertPool := x509.NewCertPool()
   139  	ok := expectedCertPool.AppendCertsFromPEM(caPEM)
   140  	if !ok {
   141  		t.Fatalf("failed to create expected certPool")
   142  	}
   143  
   144  	expectedCertPool.AddCert(cert)
   145  
   146  	certPool := x509.NewCertPool()
   147  	ok = certPool.AppendCertsFromPEM(caPEM)
   148  	if !ok {
   149  		t.Fatalf("failed to create certPool")
   150  	}
   151  
   152  	tlsConfig := &tls.Config{
   153  		ClientCAs: certPool,
   154  	}
   155  	config := comm.NewTLSConfig(tlsConfig)
   156  
   157  	assert.Equal(t, config.Config().ClientCAs, certPool)
   158  
   159  	config.AddClientRootCA(cert)
   160  
   161  	assert.Equal(t, config.Config().ClientCAs, expectedCertPool, "The CertPools should be equal")
   162  }
   163  
   164  func TestSetClientCAs(t *testing.T) {
   165  	t.Parallel()
   166  	tlsConfig := &tls.Config{
   167  		Certificates: []tls.Certificate{},
   168  	}
   169  	config := comm.NewTLSConfig(tlsConfig)
   170  
   171  	assert.Empty(t, config.Config().ClientCAs, "No CertPool should be defined")
   172  
   173  	certPool := x509.NewCertPool()
   174  	config.SetClientCAs(certPool)
   175  
   176  	assert.NotNil(t, config.Config().ClientCAs, "The CertPools' should not be the same")
   177  }