github.com/adnan-c/fabric_e2e_couchdb@v0.6.1-preview.0.20170228180935-21ce6b23cf91/gossip/comm/crypto_test.go (about)

     1  /*
     2  Copyright IBM Corp. 2016 All Rights Reserved.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8  		 http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package comm
    18  
    19  import (
    20  	"crypto/tls"
    21  	"fmt"
    22  	"net"
    23  	"os"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	proto "github.com/hyperledger/fabric/protos/gossip"
    29  	"github.com/stretchr/testify/assert"
    30  	"golang.org/x/net/context"
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/credentials"
    33  )
    34  
    35  type gossipTestServer struct {
    36  	lock           sync.Mutex
    37  	remoteCertHash []byte
    38  	selfCertHash   []byte
    39  	ll             net.Listener
    40  	s              *grpc.Server
    41  }
    42  
    43  func createTestServer(t *testing.T, cert *tls.Certificate) *gossipTestServer {
    44  	tlsConf := &tls.Config{
    45  		Certificates:       []tls.Certificate{*cert},
    46  		ClientAuth:         tls.RequestClientCert,
    47  		InsecureSkipVerify: true,
    48  	}
    49  	s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConf)))
    50  	ll, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "", 5611))
    51  	assert.NoError(t, err, "%v", err)
    52  
    53  	srv := &gossipTestServer{s: s, ll: ll, selfCertHash: certHashFromRawCert(cert.Certificate[0])}
    54  	proto.RegisterGossipServer(s, srv)
    55  	go s.Serve(ll)
    56  	return srv
    57  }
    58  
    59  func (s *gossipTestServer) stop() {
    60  	s.s.Stop()
    61  	s.ll.Close()
    62  }
    63  
    64  func (s *gossipTestServer) GossipStream(stream proto.Gossip_GossipStreamServer) error {
    65  	s.lock.Lock()
    66  	defer s.lock.Unlock()
    67  	s.remoteCertHash = extractCertificateHashFromContext(stream.Context())
    68  	return nil
    69  }
    70  
    71  func (s *gossipTestServer) getClientCertHash() []byte {
    72  	s.lock.Lock()
    73  	defer s.lock.Unlock()
    74  	return s.remoteCertHash
    75  }
    76  
    77  func (s *gossipTestServer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) {
    78  	return &proto.Empty{}, nil
    79  }
    80  
    81  func TestCertificateExtraction(t *testing.T) {
    82  	err := generateCertificates("key.pem", "cert.pem")
    83  	defer os.Remove("cert.pem")
    84  	defer os.Remove("key.pem")
    85  	assert.NoError(t, err, "%v", err)
    86  	serverCert, err := tls.LoadX509KeyPair("cert.pem", "key.pem")
    87  	assert.NoError(t, err, "%v", err)
    88  
    89  	srv := createTestServer(t, &serverCert)
    90  	defer srv.stop()
    91  
    92  	generateCertificates("key2.pem", "cert2.pem")
    93  	defer os.Remove("cert2.pem")
    94  	defer os.Remove("key2.pem")
    95  	clientCert, err := tls.LoadX509KeyPair("cert2.pem", "key2.pem")
    96  	clientCertHash := certHashFromRawCert(clientCert.Certificate[0])
    97  	assert.NoError(t, err)
    98  	ta := credentials.NewTLS(&tls.Config{
    99  		Certificates:       []tls.Certificate{clientCert},
   100  		InsecureSkipVerify: true,
   101  	})
   102  	assert.NoError(t, err, "%v", err)
   103  	conn, err := grpc.Dial("localhost:5611", grpc.WithTransportCredentials(&authCreds{tlsCreds: ta}), grpc.WithBlock(), grpc.WithTimeout(time.Second))
   104  	assert.NoError(t, err, "%v", err)
   105  
   106  	cl := proto.NewGossipClient(conn)
   107  	stream, err := cl.GossipStream(context.Background())
   108  	assert.NoError(t, err, "%v", err)
   109  	if err != nil {
   110  		return
   111  	}
   112  
   113  	time.Sleep(time.Second)
   114  	clientSideCertHash := extractCertificateHashFromContext(stream.Context())
   115  	serverSideCertHash := srv.getClientCertHash()
   116  
   117  	assert.NotNil(t, clientSideCertHash)
   118  	assert.NotNil(t, serverSideCertHash)
   119  
   120  	assert.Equal(t, 32, len(clientSideCertHash), "client side cert hash is %v", clientSideCertHash)
   121  	assert.Equal(t, 32, len(serverSideCertHash), "server side cert hash is %v", serverSideCertHash)
   122  
   123  	assert.Equal(t, clientSideCertHash, srv.selfCertHash, "Server self hash isn't equal to client side hash")
   124  	assert.Equal(t, clientCertHash, srv.remoteCertHash, "Server side and client hash aren't equal")
   125  }