github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/common/deliver/binding.go (about) 1 /* 2 Copyright hechain. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package deliver 8 9 import ( 10 "bytes" 11 "context" 12 13 "github.com/golang/protobuf/proto" 14 "github.com/hechain20/hechain/common/util" 15 "github.com/pkg/errors" 16 ) 17 18 // BindingInspector receives as parameters a gRPC context and an Envelope, 19 // and verifies whether the message contains an appropriate binding to the context 20 type BindingInspector func(context.Context, proto.Message) error 21 22 // CertHashExtractor extracts a certificate from a proto.Message message 23 type CertHashExtractor func(proto.Message) []byte 24 25 // NewBindingInspector returns a BindingInspector according to whether 26 // mutualTLS is configured or not, and according to a function that extracts 27 // TLS certificate hashes from proto messages 28 func NewBindingInspector(mutualTLS bool, extractTLSCertHash CertHashExtractor) BindingInspector { 29 if extractTLSCertHash == nil { 30 panic(errors.New("extractTLSCertHash parameter is nil")) 31 } 32 inspectMessage := mutualTLSBinding 33 if !mutualTLS { 34 inspectMessage = noopBinding 35 } 36 return func(ctx context.Context, msg proto.Message) error { 37 if msg == nil { 38 return errors.New("message is nil") 39 } 40 return inspectMessage(ctx, extractTLSCertHash(msg)) 41 } 42 } 43 44 // mutualTLSBinding enforces the client to send its TLS cert hash in the message, 45 // and then compares it to the computed hash that is derived 46 // from the gRPC context. 47 // In case they don't match, or the cert hash is missing from the request or 48 // there is no TLS certificate to be excavated from the gRPC context, 49 // an error is returned. 50 func mutualTLSBinding(ctx context.Context, claimedTLScertHash []byte) error { 51 if len(claimedTLScertHash) == 0 { 52 return errors.Errorf("client didn't include its TLS cert hash") 53 } 54 actualTLScertHash := util.ExtractCertificateHashFromContext(ctx) 55 if len(actualTLScertHash) == 0 { 56 return errors.Errorf("client didn't send a TLS certificate") 57 } 58 if !bytes.Equal(actualTLScertHash, claimedTLScertHash) { 59 return errors.Errorf("claimed TLS cert hash is %v but actual TLS cert hash is %v", claimedTLScertHash, actualTLScertHash) 60 } 61 return nil 62 } 63 64 // noopBinding is a BindingInspector that always returns nil 65 func noopBinding(_ context.Context, _ []byte) error { 66 return nil 67 }