github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/common/util/net_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package util
     8  
     9  import (
    10  	"context"
    11  	"crypto/sha256"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"testing"
    15  
    16  	"github.com/stretchr/testify/require"
    17  	"google.golang.org/grpc/credentials"
    18  	"google.golang.org/grpc/peer"
    19  )
    20  
    21  type addr struct{}
    22  
    23  func (*addr) Network() string {
    24  	return ""
    25  }
    26  
    27  func (*addr) String() string {
    28  	return "1.2.3.4:5000"
    29  }
    30  
    31  func TestExtractAddress(t *testing.T) {
    32  	ctx := context.Background()
    33  	require.Zero(t, ExtractRemoteAddress(ctx))
    34  
    35  	ctx = peer.NewContext(ctx, &peer.Peer{
    36  		Addr: &addr{},
    37  	})
    38  	require.Equal(t, "1.2.3.4:5000", ExtractRemoteAddress(ctx))
    39  }
    40  
    41  func TestExtractCertificateHashFromContext(t *testing.T) {
    42  	require.Nil(t, ExtractCertificateHashFromContext(context.Background()))
    43  
    44  	p := &peer.Peer{}
    45  	ctx := peer.NewContext(context.Background(), p)
    46  	require.Nil(t, ExtractCertificateHashFromContext(ctx))
    47  
    48  	p.AuthInfo = &nonTLSConnection{}
    49  	ctx = peer.NewContext(context.Background(), p)
    50  	require.Nil(t, ExtractCertificateHashFromContext(ctx))
    51  
    52  	p.AuthInfo = credentials.TLSInfo{}
    53  	ctx = peer.NewContext(context.Background(), p)
    54  	require.Nil(t, ExtractCertificateHashFromContext(ctx))
    55  
    56  	p.AuthInfo = credentials.TLSInfo{
    57  		State: tls.ConnectionState{
    58  			PeerCertificates: []*x509.Certificate{
    59  				{Raw: []byte{1, 2, 3}},
    60  			},
    61  		},
    62  	}
    63  	ctx = peer.NewContext(context.Background(), p)
    64  	h := sha256.New()
    65  	h.Write([]byte{1, 2, 3})
    66  	require.Equal(t, h.Sum(nil), ExtractCertificateHashFromContext(ctx))
    67  }
    68  
    69  type nonTLSConnection struct{}
    70  
    71  func (*nonTLSConnection) AuthType() string {
    72  	return ""
    73  }