github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/connection.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"sync"
    13  
    14  	"github.com/hechain20/hechain/common/channelconfig"
    15  	"github.com/hechain20/hechain/common/flogging"
    16  	"github.com/hechain20/hechain/msp"
    17  	"google.golang.org/grpc/credentials"
    18  )
    19  
    20  var commLogger = flogging.MustGetLogger("comm")
    21  
    22  // CredentialSupport type manages credentials used for gRPC client connections
    23  type CredentialSupport struct {
    24  	mutex             sync.RWMutex
    25  	appRootCAsByChain map[string][][]byte
    26  	serverRootCAs     [][]byte
    27  	clientCert        tls.Certificate
    28  }
    29  
    30  // NewCredentialSupport creates a CredentialSupport instance.
    31  func NewCredentialSupport(rootCAs ...[]byte) *CredentialSupport {
    32  	return &CredentialSupport{
    33  		appRootCAsByChain: make(map[string][][]byte),
    34  		serverRootCAs:     rootCAs,
    35  	}
    36  }
    37  
    38  // SetClientCertificate sets the tls.Certificate to use for gRPC client
    39  // connections
    40  func (cs *CredentialSupport) SetClientCertificate(cert tls.Certificate) {
    41  	cs.mutex.Lock()
    42  	cs.clientCert = cert
    43  	cs.mutex.Unlock()
    44  }
    45  
    46  // GetClientCertificate returns the client certificate of the CredentialSupport
    47  func (cs *CredentialSupport) GetClientCertificate() tls.Certificate {
    48  	cs.mutex.RLock()
    49  	defer cs.mutex.RUnlock()
    50  	return cs.clientCert
    51  }
    52  
    53  // GetPeerCredentials returns gRPC transport credentials for use by gRPC
    54  // clients which communicate with remote peer endpoints.
    55  func (cs *CredentialSupport) GetPeerCredentials() credentials.TransportCredentials {
    56  	cs.mutex.RLock()
    57  	defer cs.mutex.RUnlock()
    58  
    59  	var appRootCAs [][]byte
    60  	appRootCAs = append(appRootCAs, cs.serverRootCAs...)
    61  	for _, appRootCA := range cs.appRootCAsByChain {
    62  		appRootCAs = append(appRootCAs, appRootCA...)
    63  	}
    64  
    65  	certPool := x509.NewCertPool()
    66  	for _, appRootCA := range appRootCAs {
    67  		if !certPool.AppendCertsFromPEM(appRootCA) {
    68  			commLogger.Warningf("Failed adding certificates to peer's client TLS trust pool")
    69  		}
    70  	}
    71  
    72  	return credentials.NewTLS(&tls.Config{
    73  		Certificates: []tls.Certificate{cs.clientCert},
    74  		RootCAs:      certPool,
    75  	})
    76  }
    77  
    78  func (cs *CredentialSupport) AppRootCAsByChain() map[string][][]byte {
    79  	cs.mutex.RLock()
    80  	defer cs.mutex.RUnlock()
    81  	return cs.appRootCAsByChain
    82  }
    83  
    84  // BuildTrustedRootsForChain populates the appRootCAs and orderRootCAs maps by
    85  // getting the root and intermediate certs for all msps associated with the
    86  // MSPManager.
    87  func (cs *CredentialSupport) BuildTrustedRootsForChain(cm channelconfig.Resources) {
    88  	appOrgMSPs := make(map[string]struct{})
    89  	if ac, ok := cm.ApplicationConfig(); ok {
    90  		for _, appOrg := range ac.Organizations() {
    91  			appOrgMSPs[appOrg.MSPID()] = struct{}{}
    92  		}
    93  	}
    94  
    95  	ordOrgMSPs := make(map[string]struct{})
    96  	if ac, ok := cm.OrdererConfig(); ok {
    97  		for _, ordOrg := range ac.Organizations() {
    98  			ordOrgMSPs[ordOrg.MSPID()] = struct{}{}
    99  		}
   100  	}
   101  
   102  	cid := cm.ConfigtxValidator().ChannelID()
   103  	msps, err := cm.MSPManager().GetMSPs()
   104  	if err != nil {
   105  		commLogger.Errorf("Error getting root CAs for channel %s (%s)", cid, err)
   106  		return
   107  	}
   108  
   109  	var appRootCAs [][]byte
   110  	for k, v := range msps {
   111  		// we only support the fabric MSP
   112  		if v.GetType() != msp.FABRIC {
   113  			continue
   114  		}
   115  
   116  		for _, root := range v.GetTLSRootCerts() {
   117  			// check to see of this is an app org MSP
   118  			if _, ok := appOrgMSPs[k]; ok {
   119  				commLogger.Debugf("adding app root CAs for MSP [%s]", k)
   120  				appRootCAs = append(appRootCAs, root)
   121  			}
   122  		}
   123  		for _, intermediate := range v.GetTLSIntermediateCerts() {
   124  			// check to see of this is an app org MSP
   125  			if _, ok := appOrgMSPs[k]; ok {
   126  				commLogger.Debugf("adding app root CAs for MSP [%s]", k)
   127  				appRootCAs = append(appRootCAs, intermediate)
   128  			}
   129  		}
   130  	}
   131  
   132  	cs.mutex.Lock()
   133  	cs.appRootCAsByChain[cid] = appRootCAs
   134  	cs.mutex.Unlock()
   135  }