github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/common/deliver/binding.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package deliver
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  
    13  	"github.com/golang/protobuf/proto"
    14  	"github.com/hechain20/hechain/common/util"
    15  	"github.com/pkg/errors"
    16  )
    17  
    18  // BindingInspector receives as parameters a gRPC context and an Envelope,
    19  // and verifies whether the message contains an appropriate binding to the context
    20  type BindingInspector func(context.Context, proto.Message) error
    21  
    22  // CertHashExtractor extracts a certificate from a proto.Message message
    23  type CertHashExtractor func(proto.Message) []byte
    24  
    25  // NewBindingInspector returns a BindingInspector according to whether
    26  // mutualTLS is configured or not, and according to a function that extracts
    27  // TLS certificate hashes from proto messages
    28  func NewBindingInspector(mutualTLS bool, extractTLSCertHash CertHashExtractor) BindingInspector {
    29  	if extractTLSCertHash == nil {
    30  		panic(errors.New("extractTLSCertHash parameter is nil"))
    31  	}
    32  	inspectMessage := mutualTLSBinding
    33  	if !mutualTLS {
    34  		inspectMessage = noopBinding
    35  	}
    36  	return func(ctx context.Context, msg proto.Message) error {
    37  		if msg == nil {
    38  			return errors.New("message is nil")
    39  		}
    40  		return inspectMessage(ctx, extractTLSCertHash(msg))
    41  	}
    42  }
    43  
    44  // mutualTLSBinding enforces the client to send its TLS cert hash in the message,
    45  // and then compares it to the computed hash that is derived
    46  // from the gRPC context.
    47  // In case they don't match, or the cert hash is missing from the request or
    48  // there is no TLS certificate to be excavated from the gRPC context,
    49  // an error is returned.
    50  func mutualTLSBinding(ctx context.Context, claimedTLScertHash []byte) error {
    51  	if len(claimedTLScertHash) == 0 {
    52  		return errors.Errorf("client didn't include its TLS cert hash")
    53  	}
    54  	actualTLScertHash := util.ExtractCertificateHashFromContext(ctx)
    55  	if len(actualTLScertHash) == 0 {
    56  		return errors.Errorf("client didn't send a TLS certificate")
    57  	}
    58  	if !bytes.Equal(actualTLScertHash, claimedTLScertHash) {
    59  		return errors.Errorf("claimed TLS cert hash is %v but actual TLS cert hash is %v", claimedTLScertHash, actualTLScertHash)
    60  	}
    61  	return nil
    62  }
    63  
    64  // noopBinding is a BindingInspector that always returns nil
    65  func noopBinding(_ context.Context, _ []byte) error {
    66  	return nil
    67  }