github.com/true-sqn/fabric@v2.1.1+incompatible/internal/pkg/comm/util.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/sha256"
    13  	"crypto/x509"
    14  	"encoding/pem"
    15  	"net"
    16  
    17  	"github.com/golang/protobuf/proto"
    18  	"github.com/pkg/errors"
    19  	"google.golang.org/grpc/credentials"
    20  	"google.golang.org/grpc/peer"
    21  )
    22  
    23  // AddPemToCertPool adds PEM-encoded certs to a cert pool
    24  func AddPemToCertPool(pemCerts []byte, pool *x509.CertPool) error {
    25  	certs, _, err := pemToX509Certs(pemCerts)
    26  	if err != nil {
    27  		return err
    28  	}
    29  	for _, cert := range certs {
    30  		pool.AddCert(cert)
    31  	}
    32  	return nil
    33  }
    34  
    35  // parse PEM-encoded certs
    36  func pemToX509Certs(pemCerts []byte) ([]*x509.Certificate, []string, error) {
    37  	var certs []*x509.Certificate
    38  	var subjects []string
    39  
    40  	// it's possible that multiple certs are encoded
    41  	for len(pemCerts) > 0 {
    42  		var block *pem.Block
    43  		block, pemCerts = pem.Decode(pemCerts)
    44  		if block == nil {
    45  			break
    46  		}
    47  
    48  		cert, err := x509.ParseCertificate(block.Bytes)
    49  		if err != nil {
    50  			return nil, []string{}, err
    51  		}
    52  
    53  		certs = append(certs, cert)
    54  		subjects = append(subjects, string(cert.RawSubject))
    55  	}
    56  
    57  	return certs, subjects, nil
    58  }
    59  
    60  // BindingInspector receives as parameters a gRPC context and an Envelope,
    61  // and verifies whether the message contains an appropriate binding to the context
    62  type BindingInspector func(context.Context, proto.Message) error
    63  
    64  // CertHashExtractor extracts a certificate from a proto.Message message
    65  type CertHashExtractor func(proto.Message) []byte
    66  
    67  // NewBindingInspector returns a BindingInspector according to whether
    68  // mutualTLS is configured or not, and according to a function that extracts
    69  // TLS certificate hashes from proto messages
    70  func NewBindingInspector(mutualTLS bool, extractTLSCertHash CertHashExtractor) BindingInspector {
    71  	if extractTLSCertHash == nil {
    72  		panic(errors.New("extractTLSCertHash parameter is nil"))
    73  	}
    74  	inspectMessage := mutualTLSBinding
    75  	if !mutualTLS {
    76  		inspectMessage = noopBinding
    77  	}
    78  	return func(ctx context.Context, msg proto.Message) error {
    79  		if msg == nil {
    80  			return errors.New("message is nil")
    81  		}
    82  		return inspectMessage(ctx, extractTLSCertHash(msg))
    83  	}
    84  }
    85  
    86  // mutualTLSBinding enforces the client to send its TLS cert hash in the message,
    87  // and then compares it to the computed hash that is derived
    88  // from the gRPC context.
    89  // In case they don't match, or the cert hash is missing from the request or
    90  // there is no TLS certificate to be excavated from the gRPC context,
    91  // an error is returned.
    92  func mutualTLSBinding(ctx context.Context, claimedTLScertHash []byte) error {
    93  	if len(claimedTLScertHash) == 0 {
    94  		return errors.Errorf("client didn't include its TLS cert hash")
    95  	}
    96  	actualTLScertHash := ExtractCertificateHashFromContext(ctx)
    97  	if len(actualTLScertHash) == 0 {
    98  		return errors.Errorf("client didn't send a TLS certificate")
    99  	}
   100  	if !bytes.Equal(actualTLScertHash, claimedTLScertHash) {
   101  		return errors.Errorf("claimed TLS cert hash is %v but actual TLS cert hash is %v", claimedTLScertHash, actualTLScertHash)
   102  	}
   103  	return nil
   104  }
   105  
   106  // noopBinding is a BindingInspector that always returns nil
   107  func noopBinding(_ context.Context, _ []byte) error {
   108  	return nil
   109  }
   110  
   111  // ExtractCertificateHashFromContext extracts the hash of the certificate from the given context.
   112  // If the certificate isn't present, nil is returned
   113  func ExtractCertificateHashFromContext(ctx context.Context) []byte {
   114  	rawCert := ExtractRawCertificateFromContext(ctx)
   115  	if len(rawCert) == 0 {
   116  		return nil
   117  	}
   118  	h := sha256.New()
   119  	h.Write(rawCert)
   120  	return h.Sum(nil)
   121  }
   122  
   123  // ExtractCertificateFromContext returns the TLS certificate (if applicable)
   124  // from the given context of a gRPC stream
   125  func ExtractCertificateFromContext(ctx context.Context) *x509.Certificate {
   126  	pr, extracted := peer.FromContext(ctx)
   127  	if !extracted {
   128  		return nil
   129  	}
   130  
   131  	authInfo := pr.AuthInfo
   132  	if authInfo == nil {
   133  		return nil
   134  	}
   135  
   136  	tlsInfo, isTLSConn := authInfo.(credentials.TLSInfo)
   137  	if !isTLSConn {
   138  		return nil
   139  	}
   140  	certs := tlsInfo.State.PeerCertificates
   141  	if len(certs) == 0 {
   142  		return nil
   143  	}
   144  	return certs[0]
   145  }
   146  
   147  // ExtractRawCertificateFromContext returns the raw TLS certificate (if applicable)
   148  // from the given context of a gRPC stream
   149  func ExtractRawCertificateFromContext(ctx context.Context) []byte {
   150  	cert := ExtractCertificateFromContext(ctx)
   151  	if cert == nil {
   152  		return nil
   153  	}
   154  	return cert.Raw
   155  }
   156  
   157  // GetLocalIP returns the non loopback local IP of the host
   158  func GetLocalIP() (string, error) {
   159  	addrs, err := net.InterfaceAddrs()
   160  	if err != nil {
   161  		return "", err
   162  	}
   163  	for _, address := range addrs {
   164  		// check the address type and if it is not a loopback then display it
   165  		if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
   166  			if ipnet.IP.To4() != nil {
   167  				return ipnet.IP.String(), nil
   168  			}
   169  		}
   170  	}
   171  	return "", errors.Errorf("no non-loopback, IPv4 interface detected")
   172  }