github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/common/deliver/binding_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package deliver
     8  
     9  import (
    10  	"context"
    11  	"crypto/sha256"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"net"
    15  	"sync/atomic"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/golang/protobuf/proto"
    20  	"github.com/hechain20/hechain/internal/pkg/comm"
    21  	"github.com/hechain20/hechain/internal/pkg/comm/testpb"
    22  	"github.com/hechain20/hechain/protoutil"
    23  	"github.com/hyperledger/fabric-protos-go/common"
    24  	"github.com/stretchr/testify/require"
    25  	"google.golang.org/grpc"
    26  	"google.golang.org/grpc/credentials"
    27  )
    28  
    29  func TestBindingInspectorBadInit(t *testing.T) {
    30  	require.Panics(t, func() { NewBindingInspector(false, nil) })
    31  }
    32  
    33  func TestNoopBindingInspector(t *testing.T) {
    34  	extract := func(msg proto.Message) []byte {
    35  		return nil
    36  	}
    37  	require.Nil(t, NewBindingInspector(false, extract)(context.Background(), &common.Envelope{}))
    38  	err := NewBindingInspector(false, extract)(context.Background(), nil)
    39  	require.Error(t, err)
    40  	require.Equal(t, "message is nil", err.Error())
    41  }
    42  
    43  func TestBindingInspector(t *testing.T) {
    44  	lis, err := net.Listen("tcp", "127.0.0.1:0")
    45  	if err != nil {
    46  		t.Fatalf("failed to create listener for test server: %v", err)
    47  	}
    48  
    49  	extract := func(msg proto.Message) []byte {
    50  		env, isEnvelope := msg.(*common.Envelope)
    51  		if !isEnvelope || env == nil {
    52  			return nil
    53  		}
    54  		ch, err := protoutil.ChannelHeader(env)
    55  		if err != nil {
    56  			return nil
    57  		}
    58  		return ch.TlsCertHash
    59  	}
    60  	srv := newInspectingServer(lis, NewBindingInspector(true, extract))
    61  	go srv.Start()
    62  	defer srv.Stop()
    63  	time.Sleep(time.Second)
    64  
    65  	// Scenario I: Invalid header sent
    66  	err = srv.newInspection(t).inspectBinding(nil)
    67  	require.Error(t, err)
    68  	require.Contains(t, err.Error(), "client didn't include its TLS cert hash")
    69  
    70  	// Scenario II: invalid channel header
    71  	ch, _ := proto.Marshal(protoutil.MakeChannelHeader(common.HeaderType_CONFIG, 0, "test", 0))
    72  	// Corrupt channel header
    73  	ch = append(ch, 0)
    74  	err = srv.newInspection(t).inspectBinding(envelopeWithChannelHeader(ch))
    75  	require.Error(t, err)
    76  	require.Contains(t, err.Error(), "client didn't include its TLS cert hash")
    77  
    78  	// Scenario III: No TLS cert hash in envelope
    79  	chanHdr := protoutil.MakeChannelHeader(common.HeaderType_CONFIG, 0, "test", 0)
    80  	ch, _ = proto.Marshal(chanHdr)
    81  	err = srv.newInspection(t).inspectBinding(envelopeWithChannelHeader(ch))
    82  	require.Error(t, err)
    83  	require.Contains(t, err.Error(), "client didn't include its TLS cert hash")
    84  
    85  	// Scenario IV: Client sends its TLS cert hash as needed, but doesn't use mutual TLS
    86  	cert, _ := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM))
    87  	h := sha256.New()
    88  	h.Write([]byte(cert.Certificate[0]))
    89  	chanHdr.TlsCertHash = h.Sum(nil)
    90  	ch, _ = proto.Marshal(chanHdr)
    91  	err = srv.newInspection(t).inspectBinding(envelopeWithChannelHeader(ch))
    92  	require.Error(t, err)
    93  	require.Contains(t, err.Error(), "client didn't send a TLS certificate")
    94  
    95  	// Scenario V: Client uses mutual TLS but sends the wrong TLS cert hash
    96  	chanHdr.TlsCertHash = []byte{1, 2, 3}
    97  	chHdrWithWrongTLSCertHash, _ := proto.Marshal(chanHdr)
    98  	err = srv.newInspection(t).withMutualTLS().inspectBinding(envelopeWithChannelHeader(chHdrWithWrongTLSCertHash))
    99  	require.Error(t, err)
   100  	require.Contains(t, err.Error(), "claimed TLS cert hash is [1 2 3] but actual TLS cert hash is")
   101  
   102  	// Scenario VI: Client uses mutual TLS and also sends the correct TLS cert hash
   103  	err = srv.newInspection(t).withMutualTLS().inspectBinding(envelopeWithChannelHeader(ch))
   104  	require.NoError(t, err)
   105  }
   106  
   107  type inspectingServer struct {
   108  	addr string
   109  	*comm.GRPCServer
   110  	lastContext atomic.Value
   111  	inspector   BindingInspector
   112  }
   113  
   114  func (is *inspectingServer) EmptyCall(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   115  	is.lastContext.Store(ctx)
   116  	return &testpb.Empty{}, nil
   117  }
   118  
   119  func (is *inspectingServer) inspect(envelope *common.Envelope) error {
   120  	return is.inspector(is.lastContext.Load().(context.Context), envelope)
   121  }
   122  
   123  func newInspectingServer(listener net.Listener, inspector BindingInspector) *inspectingServer {
   124  	srv, err := comm.NewGRPCServerFromListener(listener, comm.ServerConfig{
   125  		ConnectionTimeout: 250 * time.Millisecond,
   126  		SecOpts: comm.SecureOptions{
   127  			UseTLS:      true,
   128  			Certificate: []byte(selfSignedCertPEM),
   129  			Key:         []byte(selfSignedKeyPEM),
   130  		},
   131  	})
   132  	if err != nil {
   133  		panic(err)
   134  	}
   135  	is := &inspectingServer{
   136  		addr:       listener.Addr().String(),
   137  		GRPCServer: srv,
   138  		inspector:  inspector,
   139  	}
   140  	testpb.RegisterTestServiceServer(srv.Server(), is)
   141  	return is
   142  }
   143  
   144  type inspection struct {
   145  	tlsConfig *tls.Config
   146  	server    *inspectingServer
   147  	creds     credentials.TransportCredentials
   148  	t         *testing.T
   149  }
   150  
   151  func (is *inspectingServer) newInspection(t *testing.T) *inspection {
   152  	tlsConfig := &tls.Config{
   153  		RootCAs: x509.NewCertPool(),
   154  	}
   155  	tlsConfig.RootCAs.AppendCertsFromPEM([]byte(selfSignedCertPEM))
   156  	return &inspection{
   157  		server:    is,
   158  		creds:     credentials.NewTLS(tlsConfig),
   159  		t:         t,
   160  		tlsConfig: tlsConfig,
   161  	}
   162  }
   163  
   164  func (ins *inspection) withMutualTLS() *inspection {
   165  	cert, err := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM))
   166  	require.NoError(ins.t, err)
   167  	ins.tlsConfig.Certificates = []tls.Certificate{cert}
   168  	ins.creds = credentials.NewTLS(ins.tlsConfig)
   169  	return ins
   170  }
   171  
   172  func (ins *inspection) inspectBinding(envelope *common.Envelope) error {
   173  	ctx := context.Background()
   174  	ctx, c := context.WithTimeout(ctx, time.Second*3)
   175  	defer c()
   176  	conn, err := grpc.DialContext(ctx, ins.server.addr, grpc.WithTransportCredentials(ins.creds), grpc.WithBlock())
   177  	require.NoError(ins.t, err)
   178  	defer conn.Close()
   179  	_, err = testpb.NewTestServiceClient(conn).EmptyCall(context.Background(), &testpb.Empty{})
   180  	require.NoError(ins.t, err)
   181  	return ins.server.inspect(envelope)
   182  }
   183  
   184  func envelopeWithChannelHeader(ch []byte) *common.Envelope {
   185  	pl := &common.Payload{
   186  		Header: &common.Header{
   187  			ChannelHeader: ch,
   188  		},
   189  	}
   190  	payload, _ := proto.Marshal(pl)
   191  	return &common.Envelope{
   192  		Payload: payload,
   193  	}
   194  }
   195  
   196  // Embedded certificates for testing
   197  // The self-signed cert expires in 2028
   198  var selfSignedKeyPEM = `-----BEGIN EC PRIVATE KEY-----
   199  MHcCAQEEIMLemLh3+uDzww1pvqP6Xj2Z0Kc6yqf3RxyfTBNwRuuyoAoGCCqGSM49
   200  AwEHoUQDQgAEDB3l94vM7EqKr2L/vhqU5IsEub0rviqCAaWGiVAPp3orb/LJqFLS
   201  yo/k60rhUiir6iD4S4pb5TEb2ouWylQI3A==
   202  -----END EC PRIVATE KEY-----
   203  `
   204  
   205  var selfSignedCertPEM = `-----BEGIN CERTIFICATE-----
   206  MIIBdDCCARqgAwIBAgIRAKCiW5r6W32jGUn+l9BORMAwCgYIKoZIzj0EAwIwEjEQ
   207  MA4GA1UEChMHQWNtZSBDbzAeFw0xODA4MjExMDI1MzJaFw0yODA4MTgxMDI1MzJa
   208  MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQM
   209  HeX3i8zsSoqvYv++GpTkiwS5vSu+KoIBpYaJUA+neitv8smoUtLKj+TrSuFSKKvq
   210  IPhLilvlMRvai5bKVAjco1EwTzAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYI
   211  KwYBBQUHAwEwDAYDVR0TAQH/BAIwADAaBgNVHREEEzARgglsb2NhbGhvc3SHBH8A
   212  AAEwCgYIKoZIzj0EAwIDSAAwRQIgOaYc3pdGf2j0uXRyvdBJq2PlK9FkgvsUjXOT
   213  bQ9fWRkCIQCr1FiRRzapgtrnttDn3O2fhLlbrw67kClzY8pIIN42Qw==
   214  -----END CERTIFICATE-----
   215  `