github.com/Hnampk/my-fabric@v0.0.0-20201028083322-75069da399c0/gossip/comm/crypto_test.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"context"
    11  	"crypto/ecdsa"
    12  	"crypto/elliptic"
    13  	"crypto/rand"
    14  	"crypto/tls"
    15  	"crypto/x509"
    16  	"encoding/pem"
    17  	"fmt"
    18  	"math/big"
    19  	"net"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  
    24  	proto "github.com/hyperledger/fabric-protos-go/gossip"
    25  	"github.com/hyperledger/fabric/gossip/util"
    26  	"github.com/stretchr/testify/assert"
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/credentials"
    29  )
    30  
    31  type gossipTestServer struct {
    32  	lock           sync.Mutex
    33  	remoteCertHash []byte
    34  	selfCertHash   []byte
    35  	ll             net.Listener
    36  	s              *grpc.Server
    37  }
    38  
    39  func init() {
    40  	util.SetupTestLogging()
    41  }
    42  
    43  func createTestServer(t *testing.T, cert *tls.Certificate) (srv *gossipTestServer, ll net.Listener) {
    44  	tlsConf := &tls.Config{
    45  		Certificates:       []tls.Certificate{*cert},
    46  		ClientAuth:         tls.RequestClientCert,
    47  		InsecureSkipVerify: true,
    48  	}
    49  	s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConf)))
    50  	ll, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:0"))
    51  	assert.NoError(t, err, "%v", err)
    52  
    53  	srv = &gossipTestServer{s: s, ll: ll, selfCertHash: certHashFromRawCert(cert.Certificate[0])}
    54  	proto.RegisterGossipServer(s, srv)
    55  	go s.Serve(ll)
    56  	return srv, ll
    57  }
    58  
    59  func (s *gossipTestServer) stop() {
    60  	s.s.Stop()
    61  	s.ll.Close()
    62  }
    63  
    64  func (s *gossipTestServer) GossipStream(stream proto.Gossip_GossipStreamServer) error {
    65  	s.lock.Lock()
    66  	defer s.lock.Unlock()
    67  	s.remoteCertHash = extractCertificateHashFromContext(stream.Context())
    68  	return nil
    69  }
    70  
    71  func (s *gossipTestServer) getClientCertHash() []byte {
    72  	s.lock.Lock()
    73  	defer s.lock.Unlock()
    74  	return s.remoteCertHash
    75  }
    76  
    77  func (s *gossipTestServer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) {
    78  	return &proto.Empty{}, nil
    79  }
    80  
    81  func TestCertificateExtraction(t *testing.T) {
    82  	cert := GenerateCertificatesOrPanic()
    83  	srv, ll := createTestServer(t, &cert)
    84  	defer srv.stop()
    85  
    86  	clientCert := GenerateCertificatesOrPanic()
    87  	clientCertHash := certHashFromRawCert(clientCert.Certificate[0])
    88  	ta := credentials.NewTLS(&tls.Config{
    89  		Certificates:       []tls.Certificate{clientCert},
    90  		InsecureSkipVerify: true,
    91  	})
    92  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    93  	defer cancel()
    94  	conn, err := grpc.DialContext(ctx, ll.Addr().String(), grpc.WithTransportCredentials(ta), grpc.WithBlock())
    95  	assert.NoError(t, err, "%v", err)
    96  
    97  	cl := proto.NewGossipClient(conn)
    98  	stream, err := cl.GossipStream(context.Background())
    99  	assert.NoError(t, err, "%v", err)
   100  	if err != nil {
   101  		return
   102  	}
   103  
   104  	time.Sleep(time.Second)
   105  	clientSideCertHash := extractCertificateHashFromContext(stream.Context())
   106  	serverSideCertHash := srv.getClientCertHash()
   107  
   108  	assert.NotNil(t, clientSideCertHash)
   109  	assert.NotNil(t, serverSideCertHash)
   110  
   111  	assert.Equal(t, 32, len(clientSideCertHash), "client side cert hash is %v", clientSideCertHash)
   112  	assert.Equal(t, 32, len(serverSideCertHash), "server side cert hash is %v", serverSideCertHash)
   113  
   114  	assert.Equal(t, clientSideCertHash, srv.selfCertHash, "Server self hash isn't equal to client side hash")
   115  	assert.Equal(t, clientCertHash, srv.remoteCertHash, "Server side and client hash aren't equal")
   116  }
   117  
   118  // GenerateCertificatesOrPanic generates a a random pair of public and private keys
   119  // and return TLS certificate.
   120  func GenerateCertificatesOrPanic() tls.Certificate {
   121  	privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   122  	if err != nil {
   123  		panic(err)
   124  	}
   125  	sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
   126  	if err != nil {
   127  		panic(err)
   128  	}
   129  	template := x509.Certificate{
   130  		KeyUsage:     x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
   131  		SerialNumber: sn,
   132  		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   133  	}
   134  	rawBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
   135  	if err != nil {
   136  		panic(err)
   137  	}
   138  	privBytes, err := x509.MarshalECPrivateKey(privateKey)
   139  	if err != nil {
   140  		panic(err)
   141  	}
   142  	encodedCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: rawBytes})
   143  	encodedKey := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
   144  	cert, err := tls.X509KeyPair(encodedCert, encodedKey)
   145  	if err != nil {
   146  		panic(err)
   147  	}
   148  	return cert
   149  }