github.com/kaituanwang/hyperledger@v2.0.1+incompatible/core/comm/util_test.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm_test
     8  
     9  import (
    10  	"context"
    11  	"crypto/sha256"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"net"
    15  	"sync/atomic"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/golang/protobuf/proto"
    20  	"github.com/hyperledger/fabric-protos-go/common"
    21  	"github.com/hyperledger/fabric/core/comm"
    22  	"github.com/hyperledger/fabric/core/comm/testpb"
    23  	"github.com/hyperledger/fabric/protoutil"
    24  	"github.com/stretchr/testify/assert"
    25  	"google.golang.org/grpc"
    26  	"google.golang.org/grpc/credentials"
    27  	"google.golang.org/grpc/peer"
    28  )
    29  
    30  func TestExtractCertificateHashFromContext(t *testing.T) {
    31  	t.Parallel()
    32  	assert.Nil(t, comm.ExtractCertificateHashFromContext(context.Background()))
    33  
    34  	p := &peer.Peer{}
    35  	ctx := peer.NewContext(context.Background(), p)
    36  	assert.Nil(t, comm.ExtractCertificateHashFromContext(ctx))
    37  
    38  	p.AuthInfo = &nonTLSConnection{}
    39  	ctx = peer.NewContext(context.Background(), p)
    40  	assert.Nil(t, comm.ExtractCertificateHashFromContext(ctx))
    41  
    42  	p.AuthInfo = credentials.TLSInfo{}
    43  	ctx = peer.NewContext(context.Background(), p)
    44  	assert.Nil(t, comm.ExtractCertificateHashFromContext(ctx))
    45  
    46  	p.AuthInfo = credentials.TLSInfo{
    47  		State: tls.ConnectionState{
    48  			PeerCertificates: []*x509.Certificate{
    49  				{Raw: []byte{1, 2, 3}},
    50  			},
    51  		},
    52  	}
    53  	ctx = peer.NewContext(context.Background(), p)
    54  	h := sha256.New()
    55  	h.Write([]byte{1, 2, 3})
    56  	assert.Equal(t, h.Sum(nil), comm.ExtractCertificateHashFromContext(ctx))
    57  }
    58  
    59  type nonTLSConnection struct {
    60  }
    61  
    62  func (*nonTLSConnection) AuthType() string {
    63  	return ""
    64  }
    65  
    66  func TestBindingInspectorBadInit(t *testing.T) {
    67  	t.Parallel()
    68  	assert.Panics(t, func() {
    69  		comm.NewBindingInspector(false, nil)
    70  	})
    71  }
    72  
    73  func TestNoopBindingInspector(t *testing.T) {
    74  	t.Parallel()
    75  	extract := func(msg proto.Message) []byte {
    76  		return nil
    77  	}
    78  	assert.Nil(t, comm.NewBindingInspector(false, extract)(context.Background(), &common.Envelope{}))
    79  	err := comm.NewBindingInspector(false, extract)(context.Background(), nil)
    80  	assert.Error(t, err)
    81  	assert.Equal(t, "message is nil", err.Error())
    82  }
    83  
    84  func TestBindingInspector(t *testing.T) {
    85  	t.Parallel()
    86  	lis, err := net.Listen("tcp", "127.0.0.1:0")
    87  	if err != nil {
    88  		t.Fatalf("failed to create listener for test server: %v", err)
    89  	}
    90  
    91  	extract := func(msg proto.Message) []byte {
    92  		env, isEnvelope := msg.(*common.Envelope)
    93  		if !isEnvelope || env == nil {
    94  			return nil
    95  		}
    96  		ch, err := protoutil.ChannelHeader(env)
    97  		if err != nil {
    98  			return nil
    99  		}
   100  		return ch.TlsCertHash
   101  	}
   102  	srv := newInspectingServer(lis, comm.NewBindingInspector(true, extract))
   103  	go srv.Start()
   104  	defer srv.Stop()
   105  	time.Sleep(time.Second)
   106  
   107  	// Scenario I: Invalid header sent
   108  	err = srv.newInspection(t).inspectBinding(nil)
   109  	assert.Error(t, err)
   110  	assert.Contains(t, err.Error(), "client didn't include its TLS cert hash")
   111  
   112  	// Scenario II: invalid channel header
   113  	ch, _ := proto.Marshal(protoutil.MakeChannelHeader(common.HeaderType_CONFIG, 0, "test", 0))
   114  	// Corrupt channel header
   115  	ch = append(ch, 0)
   116  	err = srv.newInspection(t).inspectBinding(envelopeWithChannelHeader(ch))
   117  	assert.Error(t, err)
   118  	assert.Contains(t, err.Error(), "client didn't include its TLS cert hash")
   119  
   120  	// Scenario III: No TLS cert hash in envelope
   121  	chanHdr := protoutil.MakeChannelHeader(common.HeaderType_CONFIG, 0, "test", 0)
   122  	ch, _ = proto.Marshal(chanHdr)
   123  	err = srv.newInspection(t).inspectBinding(envelopeWithChannelHeader(ch))
   124  	assert.Error(t, err)
   125  	assert.Contains(t, err.Error(), "client didn't include its TLS cert hash")
   126  
   127  	// Scenario IV: Client sends its TLS cert hash as needed, but doesn't use mutual TLS
   128  	cert, _ := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM))
   129  	h := sha256.New()
   130  	h.Write([]byte(cert.Certificate[0]))
   131  	chanHdr.TlsCertHash = h.Sum(nil)
   132  	ch, _ = proto.Marshal(chanHdr)
   133  	err = srv.newInspection(t).inspectBinding(envelopeWithChannelHeader(ch))
   134  	assert.Error(t, err)
   135  	assert.Contains(t, err.Error(), "client didn't send a TLS certificate")
   136  
   137  	// Scenario V: Client uses mutual TLS but sends the wrong TLS cert hash
   138  	chanHdr.TlsCertHash = []byte{1, 2, 3}
   139  	chHdrWithWrongTLSCertHash, _ := proto.Marshal(chanHdr)
   140  	err = srv.newInspection(t).withMutualTLS().inspectBinding(envelopeWithChannelHeader(chHdrWithWrongTLSCertHash))
   141  	assert.Error(t, err)
   142  	assert.Contains(t, err.Error(), "claimed TLS cert hash is [1 2 3] but actual TLS cert hash is")
   143  
   144  	// Scenario VI: Client uses mutual TLS and also sends the correct TLS cert hash
   145  	err = srv.newInspection(t).withMutualTLS().inspectBinding(envelopeWithChannelHeader(ch))
   146  	assert.NoError(t, err)
   147  }
   148  
   149  func TestGetLocalIP(t *testing.T) {
   150  	ip, err := comm.GetLocalIP()
   151  	assert.NoError(t, err)
   152  	t.Log(ip)
   153  }
   154  
   155  type inspectingServer struct {
   156  	addr string
   157  	*comm.GRPCServer
   158  	lastContext atomic.Value
   159  	inspector   comm.BindingInspector
   160  }
   161  
   162  func (is *inspectingServer) EmptyCall(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   163  	is.lastContext.Store(ctx)
   164  	return &testpb.Empty{}, nil
   165  }
   166  
   167  func (is *inspectingServer) inspect(envelope *common.Envelope) error {
   168  	return is.inspector(is.lastContext.Load().(context.Context), envelope)
   169  }
   170  
   171  func newInspectingServer(listener net.Listener, inspector comm.BindingInspector) *inspectingServer {
   172  	srv, err := comm.NewGRPCServerFromListener(listener, comm.ServerConfig{
   173  		ConnectionTimeout: 250 * time.Millisecond,
   174  		SecOpts: comm.SecureOptions{
   175  			UseTLS:      true,
   176  			Certificate: []byte(selfSignedCertPEM),
   177  			Key:         []byte(selfSignedKeyPEM),
   178  		}})
   179  	if err != nil {
   180  		panic(err)
   181  	}
   182  	is := &inspectingServer{
   183  		addr:       listener.Addr().String(),
   184  		GRPCServer: srv,
   185  		inspector:  inspector,
   186  	}
   187  	testpb.RegisterTestServiceServer(srv.Server(), is)
   188  	return is
   189  }
   190  
   191  type inspection struct {
   192  	tlsConfig *tls.Config
   193  	server    *inspectingServer
   194  	creds     credentials.TransportCredentials
   195  	t         *testing.T
   196  }
   197  
   198  func (is *inspectingServer) newInspection(t *testing.T) *inspection {
   199  	tlsConfig := &tls.Config{
   200  		RootCAs: x509.NewCertPool(),
   201  	}
   202  	tlsConfig.RootCAs.AppendCertsFromPEM([]byte(selfSignedCertPEM))
   203  	return &inspection{
   204  		server:    is,
   205  		creds:     credentials.NewTLS(tlsConfig),
   206  		t:         t,
   207  		tlsConfig: tlsConfig,
   208  	}
   209  }
   210  
   211  func (ins *inspection) withMutualTLS() *inspection {
   212  	cert, err := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM))
   213  	assert.NoError(ins.t, err)
   214  	ins.tlsConfig.Certificates = []tls.Certificate{cert}
   215  	ins.creds = credentials.NewTLS(ins.tlsConfig)
   216  	return ins
   217  }
   218  
   219  func (ins *inspection) inspectBinding(envelope *common.Envelope) error {
   220  	ctx := context.Background()
   221  	ctx, c := context.WithTimeout(ctx, time.Second*3)
   222  	defer c()
   223  	conn, err := grpc.DialContext(ctx, ins.server.addr, grpc.WithTransportCredentials(ins.creds), grpc.WithBlock())
   224  	assert.NoError(ins.t, err)
   225  	defer conn.Close()
   226  	_, err = testpb.NewTestServiceClient(conn).EmptyCall(context.Background(), &testpb.Empty{})
   227  	assert.NoError(ins.t, err)
   228  	return ins.server.inspect(envelope)
   229  }
   230  
   231  func envelopeWithChannelHeader(ch []byte) *common.Envelope {
   232  	pl := &common.Payload{
   233  		Header: &common.Header{
   234  			ChannelHeader: ch,
   235  		},
   236  	}
   237  	payload, _ := proto.Marshal(pl)
   238  	return &common.Envelope{
   239  		Payload: payload,
   240  	}
   241  }