github.com/osdi23p228/fabric@v0.0.0-20221218062954-77808885f5db/gossip/comm/crypto_test.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"context"
    11  	"crypto/ecdsa"
    12  	"crypto/elliptic"
    13  	"crypto/rand"
    14  	"crypto/tls"
    15  	"crypto/x509"
    16  	"encoding/pem"
    17  	"math/big"
    18  	"net"
    19  	"sync"
    20  	"testing"
    21  	"time"
    22  
    23  	proto "github.com/hyperledger/fabric-protos-go/gossip"
    24  	"github.com/osdi23p228/fabric/gossip/util"
    25  	"github.com/stretchr/testify/assert"
    26  	"google.golang.org/grpc"
    27  	"google.golang.org/grpc/credentials"
    28  )
    29  
    30  type gossipTestServer struct {
    31  	lock           sync.Mutex
    32  	remoteCertHash []byte
    33  	selfCertHash   []byte
    34  	ll             net.Listener
    35  	s              *grpc.Server
    36  }
    37  
    38  func init() {
    39  	util.SetupTestLogging()
    40  }
    41  
    42  func createTestServer(t *testing.T, cert *tls.Certificate) (srv *gossipTestServer, ll net.Listener) {
    43  	tlsConf := &tls.Config{
    44  		Certificates:       []tls.Certificate{*cert},
    45  		ClientAuth:         tls.RequestClientCert,
    46  		InsecureSkipVerify: true,
    47  	}
    48  	s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConf)))
    49  	ll, err := net.Listen("tcp", "127.0.0.1:0")
    50  	assert.NoError(t, err, "%v", err)
    51  
    52  	srv = &gossipTestServer{s: s, ll: ll, selfCertHash: certHashFromRawCert(cert.Certificate[0])}
    53  	proto.RegisterGossipServer(s, srv)
    54  	go s.Serve(ll)
    55  	return srv, ll
    56  }
    57  
    58  func (s *gossipTestServer) stop() {
    59  	s.s.Stop()
    60  	s.ll.Close()
    61  }
    62  
    63  func (s *gossipTestServer) GossipStream(stream proto.Gossip_GossipStreamServer) error {
    64  	s.lock.Lock()
    65  	defer s.lock.Unlock()
    66  	s.remoteCertHash = extractCertificateHashFromContext(stream.Context())
    67  	return nil
    68  }
    69  
    70  func (s *gossipTestServer) getClientCertHash() []byte {
    71  	s.lock.Lock()
    72  	defer s.lock.Unlock()
    73  	return s.remoteCertHash
    74  }
    75  
    76  func (s *gossipTestServer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) {
    77  	return &proto.Empty{}, nil
    78  }
    79  
    80  func TestCertificateExtraction(t *testing.T) {
    81  	cert := GenerateCertificatesOrPanic()
    82  	srv, ll := createTestServer(t, &cert)
    83  	defer srv.stop()
    84  
    85  	clientCert := GenerateCertificatesOrPanic()
    86  	clientCertHash := certHashFromRawCert(clientCert.Certificate[0])
    87  	ta := credentials.NewTLS(&tls.Config{
    88  		Certificates:       []tls.Certificate{clientCert},
    89  		InsecureSkipVerify: true,
    90  	})
    91  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    92  	defer cancel()
    93  	conn, err := grpc.DialContext(ctx, ll.Addr().String(), grpc.WithTransportCredentials(ta), grpc.WithBlock())
    94  	assert.NoError(t, err, "%v", err)
    95  
    96  	cl := proto.NewGossipClient(conn)
    97  	stream, err := cl.GossipStream(context.Background())
    98  	assert.NoError(t, err, "%v", err)
    99  	if err != nil {
   100  		return
   101  	}
   102  
   103  	time.Sleep(time.Second)
   104  	clientSideCertHash := extractCertificateHashFromContext(stream.Context())
   105  	serverSideCertHash := srv.getClientCertHash()
   106  
   107  	assert.NotNil(t, clientSideCertHash)
   108  	assert.NotNil(t, serverSideCertHash)
   109  
   110  	assert.Equal(t, 32, len(clientSideCertHash), "client side cert hash is %v", clientSideCertHash)
   111  	assert.Equal(t, 32, len(serverSideCertHash), "server side cert hash is %v", serverSideCertHash)
   112  
   113  	assert.Equal(t, clientSideCertHash, srv.selfCertHash, "Server self hash isn't equal to client side hash")
   114  	assert.Equal(t, clientCertHash, srv.remoteCertHash, "Server side and client hash aren't equal")
   115  }
   116  
   117  // GenerateCertificatesOrPanic generates a a random pair of public and private keys
   118  // and return TLS certificate.
   119  func GenerateCertificatesOrPanic() tls.Certificate {
   120  	privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   121  	if err != nil {
   122  		panic(err)
   123  	}
   124  	sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
   125  	if err != nil {
   126  		panic(err)
   127  	}
   128  	template := x509.Certificate{
   129  		KeyUsage:     x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
   130  		SerialNumber: sn,
   131  		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   132  	}
   133  	rawBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
   134  	if err != nil {
   135  		panic(err)
   136  	}
   137  	privBytes, err := x509.MarshalECPrivateKey(privateKey)
   138  	if err != nil {
   139  		panic(err)
   140  	}
   141  	encodedCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: rawBytes})
   142  	encodedKey := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
   143  	cert, err := tls.X509KeyPair(encodedCert, encodedKey)
   144  	if err != nil {
   145  		panic(err)
   146  	}
   147  	return cert
   148  }