github.com/defanghe/fabric@v2.1.1+incompatible/internal/pkg/comm/util_test.go (about) 1 /* 2 Copyright IBM Corp. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package comm_test 8 9 import ( 10 "context" 11 "crypto/sha256" 12 "crypto/tls" 13 "crypto/x509" 14 "net" 15 "sync/atomic" 16 "testing" 17 "time" 18 19 "github.com/golang/protobuf/proto" 20 "github.com/hyperledger/fabric-protos-go/common" 21 "github.com/hyperledger/fabric/internal/pkg/comm" 22 "github.com/hyperledger/fabric/internal/pkg/comm/testpb" 23 "github.com/hyperledger/fabric/protoutil" 24 "github.com/stretchr/testify/assert" 25 "google.golang.org/grpc" 26 "google.golang.org/grpc/credentials" 27 "google.golang.org/grpc/peer" 28 ) 29 30 func TestExtractCertificateHashFromContext(t *testing.T) { 31 t.Parallel() 32 assert.Nil(t, comm.ExtractCertificateHashFromContext(context.Background())) 33 34 p := &peer.Peer{} 35 ctx := peer.NewContext(context.Background(), p) 36 assert.Nil(t, comm.ExtractCertificateHashFromContext(ctx)) 37 38 p.AuthInfo = &nonTLSConnection{} 39 ctx = peer.NewContext(context.Background(), p) 40 assert.Nil(t, comm.ExtractCertificateHashFromContext(ctx)) 41 42 p.AuthInfo = credentials.TLSInfo{} 43 ctx = peer.NewContext(context.Background(), p) 44 assert.Nil(t, comm.ExtractCertificateHashFromContext(ctx)) 45 46 p.AuthInfo = credentials.TLSInfo{ 47 State: tls.ConnectionState{ 48 PeerCertificates: []*x509.Certificate{ 49 {Raw: []byte{1, 2, 3}}, 50 }, 51 }, 52 } 53 ctx = peer.NewContext(context.Background(), p) 54 h := sha256.New() 55 h.Write([]byte{1, 2, 3}) 56 assert.Equal(t, h.Sum(nil), comm.ExtractCertificateHashFromContext(ctx)) 57 } 58 59 type nonTLSConnection struct { 60 } 61 62 func (*nonTLSConnection) AuthType() string { 63 return "" 64 } 65 66 func TestBindingInspectorBadInit(t *testing.T) { 67 t.Parallel() 68 assert.Panics(t, func() { 69 comm.NewBindingInspector(false, nil) 70 }) 71 } 72 73 func TestNoopBindingInspector(t *testing.T) { 74 t.Parallel() 75 extract := func(msg proto.Message) []byte { 76 return nil 77 } 78 assert.Nil(t, comm.NewBindingInspector(false, extract)(context.Background(), &common.Envelope{})) 79 err := comm.NewBindingInspector(false, extract)(context.Background(), nil) 80 assert.Error(t, err) 81 assert.Equal(t, "message is nil", err.Error()) 82 } 83 84 func TestBindingInspector(t *testing.T) { 85 t.Parallel() 86 lis, err := net.Listen("tcp", "127.0.0.1:0") 87 if err != nil { 88 t.Fatalf("failed to create listener for test server: %v", err) 89 } 90 91 extract := func(msg proto.Message) []byte { 92 env, isEnvelope := msg.(*common.Envelope) 93 if !isEnvelope || env == nil { 94 return nil 95 } 96 ch, err := protoutil.ChannelHeader(env) 97 if err != nil { 98 return nil 99 } 100 return ch.TlsCertHash 101 } 102 srv := newInspectingServer(lis, comm.NewBindingInspector(true, extract)) 103 go srv.Start() 104 defer srv.Stop() 105 time.Sleep(time.Second) 106 107 // Scenario I: Invalid header sent 108 err = srv.newInspection(t).inspectBinding(nil) 109 assert.Error(t, err) 110 assert.Contains(t, err.Error(), "client didn't include its TLS cert hash") 111 112 // Scenario II: invalid channel header 113 ch, _ := proto.Marshal(protoutil.MakeChannelHeader(common.HeaderType_CONFIG, 0, "test", 0)) 114 // Corrupt channel header 115 ch = append(ch, 0) 116 err = srv.newInspection(t).inspectBinding(envelopeWithChannelHeader(ch)) 117 assert.Error(t, err) 118 assert.Contains(t, err.Error(), "client didn't include its TLS cert hash") 119 120 // Scenario III: No TLS cert hash in envelope 121 chanHdr := protoutil.MakeChannelHeader(common.HeaderType_CONFIG, 0, "test", 0) 122 ch, _ = proto.Marshal(chanHdr) 123 err = srv.newInspection(t).inspectBinding(envelopeWithChannelHeader(ch)) 124 assert.Error(t, err) 125 assert.Contains(t, err.Error(), "client didn't include its TLS cert hash") 126 127 // Scenario IV: Client sends its TLS cert hash as needed, but doesn't use mutual TLS 128 cert, _ := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM)) 129 h := sha256.New() 130 h.Write([]byte(cert.Certificate[0])) 131 chanHdr.TlsCertHash = h.Sum(nil) 132 ch, _ = proto.Marshal(chanHdr) 133 err = srv.newInspection(t).inspectBinding(envelopeWithChannelHeader(ch)) 134 assert.Error(t, err) 135 assert.Contains(t, err.Error(), "client didn't send a TLS certificate") 136 137 // Scenario V: Client uses mutual TLS but sends the wrong TLS cert hash 138 chanHdr.TlsCertHash = []byte{1, 2, 3} 139 chHdrWithWrongTLSCertHash, _ := proto.Marshal(chanHdr) 140 err = srv.newInspection(t).withMutualTLS().inspectBinding(envelopeWithChannelHeader(chHdrWithWrongTLSCertHash)) 141 assert.Error(t, err) 142 assert.Contains(t, err.Error(), "claimed TLS cert hash is [1 2 3] but actual TLS cert hash is") 143 144 // Scenario VI: Client uses mutual TLS and also sends the correct TLS cert hash 145 err = srv.newInspection(t).withMutualTLS().inspectBinding(envelopeWithChannelHeader(ch)) 146 assert.NoError(t, err) 147 } 148 149 func TestGetLocalIP(t *testing.T) { 150 ip, err := comm.GetLocalIP() 151 assert.NoError(t, err) 152 t.Log(ip) 153 } 154 155 type inspectingServer struct { 156 addr string 157 *comm.GRPCServer 158 lastContext atomic.Value 159 inspector comm.BindingInspector 160 } 161 162 func (is *inspectingServer) EmptyCall(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { 163 is.lastContext.Store(ctx) 164 return &testpb.Empty{}, nil 165 } 166 167 func (is *inspectingServer) inspect(envelope *common.Envelope) error { 168 return is.inspector(is.lastContext.Load().(context.Context), envelope) 169 } 170 171 func newInspectingServer(listener net.Listener, inspector comm.BindingInspector) *inspectingServer { 172 srv, err := comm.NewGRPCServerFromListener(listener, comm.ServerConfig{ 173 ConnectionTimeout: 250 * time.Millisecond, 174 SecOpts: comm.SecureOptions{ 175 UseTLS: true, 176 Certificate: []byte(selfSignedCertPEM), 177 Key: []byte(selfSignedKeyPEM), 178 }}) 179 if err != nil { 180 panic(err) 181 } 182 is := &inspectingServer{ 183 addr: listener.Addr().String(), 184 GRPCServer: srv, 185 inspector: inspector, 186 } 187 testpb.RegisterTestServiceServer(srv.Server(), is) 188 return is 189 } 190 191 type inspection struct { 192 tlsConfig *tls.Config 193 server *inspectingServer 194 creds credentials.TransportCredentials 195 t *testing.T 196 } 197 198 func (is *inspectingServer) newInspection(t *testing.T) *inspection { 199 tlsConfig := &tls.Config{ 200 RootCAs: x509.NewCertPool(), 201 } 202 tlsConfig.RootCAs.AppendCertsFromPEM([]byte(selfSignedCertPEM)) 203 return &inspection{ 204 server: is, 205 creds: credentials.NewTLS(tlsConfig), 206 t: t, 207 tlsConfig: tlsConfig, 208 } 209 } 210 211 func (ins *inspection) withMutualTLS() *inspection { 212 cert, err := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM)) 213 assert.NoError(ins.t, err) 214 ins.tlsConfig.Certificates = []tls.Certificate{cert} 215 ins.creds = credentials.NewTLS(ins.tlsConfig) 216 return ins 217 } 218 219 func (ins *inspection) inspectBinding(envelope *common.Envelope) error { 220 ctx := context.Background() 221 ctx, c := context.WithTimeout(ctx, time.Second*3) 222 defer c() 223 conn, err := grpc.DialContext(ctx, ins.server.addr, grpc.WithTransportCredentials(ins.creds), grpc.WithBlock()) 224 assert.NoError(ins.t, err) 225 defer conn.Close() 226 _, err = testpb.NewTestServiceClient(conn).EmptyCall(context.Background(), &testpb.Empty{}) 227 assert.NoError(ins.t, err) 228 return ins.server.inspect(envelope) 229 } 230 231 func envelopeWithChannelHeader(ch []byte) *common.Envelope { 232 pl := &common.Payload{ 233 Header: &common.Header{ 234 ChannelHeader: ch, 235 }, 236 } 237 payload, _ := proto.Marshal(pl) 238 return &common.Envelope{ 239 Payload: payload, 240 } 241 }