github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/creds_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm_test
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"io/ioutil"
    14  	"net"
    15  	"path/filepath"
    16  	"sync"
    17  	"testing"
    18  
    19  	"github.com/hechain20/hechain/common/flogging/floggingtest"
    20  	"github.com/hechain20/hechain/internal/pkg/comm"
    21  	"github.com/stretchr/testify/require"
    22  )
    23  
    24  func TestCreds(t *testing.T) {
    25  	t.Parallel()
    26  
    27  	caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
    28  	if err != nil {
    29  		t.Fatalf("failed to read root certificate: %v", err)
    30  	}
    31  	certPool := x509.NewCertPool()
    32  	ok := certPool.AppendCertsFromPEM(caPEM)
    33  	if !ok {
    34  		t.Fatalf("failed to create certPool")
    35  	}
    36  	cert, err := tls.LoadX509KeyPair(
    37  		filepath.Join("testdata", "certs", "Org1-server1-cert.pem"),
    38  		filepath.Join("testdata", "certs", "Org1-server1-key.pem"),
    39  	)
    40  	if err != nil {
    41  		t.Fatalf("failed to load TLS certificate [%s]", err)
    42  	}
    43  
    44  	tlsConfig := &tls.Config{
    45  		Certificates: []tls.Certificate{cert},
    46  	}
    47  
    48  	config := comm.NewTLSConfig(tlsConfig)
    49  
    50  	logger, recorder := floggingtest.NewTestLogger(t)
    51  
    52  	creds := comm.NewServerTransportCredentials(config, logger)
    53  	_, _, err = creds.ClientHandshake(context.Background(), "", nil)
    54  	require.EqualError(t, err, comm.ErrClientHandshakeNotImplemented.Error())
    55  	err = creds.OverrideServerName("")
    56  	require.EqualError(t, err, comm.ErrOverrideHostnameNotSupported.Error())
    57  	require.Equal(t, "1.2", creds.Info().SecurityVersion)
    58  	require.Equal(t, "tls", creds.Info().SecurityProtocol)
    59  
    60  	lis, err := net.Listen("tcp", "localhost:0")
    61  	if err != nil {
    62  		t.Fatalf("failed to start listener [%s]", err)
    63  	}
    64  	defer lis.Close()
    65  
    66  	_, port, err := net.SplitHostPort(lis.Addr().String())
    67  	require.NoError(t, err)
    68  	addr := net.JoinHostPort("localhost", port)
    69  
    70  	handshake := func(wg *sync.WaitGroup) {
    71  		defer wg.Done()
    72  		conn, err := lis.Accept()
    73  		if err != nil {
    74  			t.Logf("failed to accept connection [%s]", err)
    75  		}
    76  		_, _, err = creds.ServerHandshake(conn)
    77  		if err != nil {
    78  			t.Logf("ServerHandshake error [%s]", err)
    79  		}
    80  	}
    81  
    82  	wg := &sync.WaitGroup{}
    83  	wg.Add(1)
    84  	go handshake(wg)
    85  	_, err = tls.Dial("tcp", addr, &tls.Config{RootCAs: certPool})
    86  	wg.Wait()
    87  	require.NoError(t, err)
    88  
    89  	wg = &sync.WaitGroup{}
    90  	wg.Add(1)
    91  	go handshake(wg)
    92  	_, err = tls.Dial("tcp", addr, &tls.Config{
    93  		RootCAs:    certPool,
    94  		MaxVersion: tls.VersionTLS10,
    95  	})
    96  	wg.Wait()
    97  	require.Contains(t, err.Error(), "protocol version not supported")
    98  	require.Contains(t, recorder.Messages()[1], "TLS handshake failed")
    99  }
   100  
   101  func TestNewTLSConfig(t *testing.T) {
   102  	t.Parallel()
   103  	tlsConfig := &tls.Config{}
   104  
   105  	config := comm.NewTLSConfig(tlsConfig)
   106  
   107  	require.NotEmpty(t, config, "TLSConfig is not empty")
   108  }
   109  
   110  func TestConfig(t *testing.T) {
   111  	t.Parallel()
   112  	config := comm.NewTLSConfig(&tls.Config{
   113  		ServerName: "bueno",
   114  	})
   115  
   116  	configCopy := config.Config()
   117  
   118  	certPool := x509.NewCertPool()
   119  	config.SetClientCAs(certPool)
   120  
   121  	require.NotEqual(t, config.Config(), &configCopy, "TLSConfig should have new certs")
   122  }
   123  
   124  func TestAddRootCA(t *testing.T) {
   125  	t.Parallel()
   126  
   127  	caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
   128  	require.NoError(t, err, "failed to read root certificate")
   129  
   130  	expectedCertPool := x509.NewCertPool()
   131  	ok := expectedCertPool.AppendCertsFromPEM(caPEM)
   132  	require.True(t, ok, "failed to create expected certPool")
   133  
   134  	cert := &x509.Certificate{EmailAddresses: []string{"test@foobar.com"}}
   135  	expectedCertPool.AddCert(cert)
   136  
   137  	certPool := x509.NewCertPool()
   138  	ok = certPool.AppendCertsFromPEM(caPEM)
   139  	require.True(t, ok, "failed to create certPool")
   140  
   141  	config := comm.NewTLSConfig(&tls.Config{ClientCAs: certPool})
   142  	require.Same(t, config.Config().ClientCAs, certPool)
   143  
   144  	// https://go-review.googlesource.com/c/go/+/229917
   145  	config.AddClientRootCA(cert)
   146  	require.Equal(t, certPool.Subjects(), expectedCertPool.Subjects(), "subjects in the pool should be equal")
   147  }
   148  
   149  func TestSetClientCAs(t *testing.T) {
   150  	t.Parallel()
   151  	tlsConfig := &tls.Config{
   152  		Certificates: []tls.Certificate{},
   153  	}
   154  	config := comm.NewTLSConfig(tlsConfig)
   155  
   156  	require.Empty(t, config.Config().ClientCAs, "No CertPool should be defined")
   157  
   158  	certPool := x509.NewCertPool()
   159  	config.SetClientCAs(certPool)
   160  
   161  	require.NotNil(t, config.Config().ClientCAs, "The CertPools' should not be the same")
   162  }