github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/common/policies/policy_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package policies
     8  
     9  import (
    10  	"fmt"
    11  	"reflect"
    12  	"strconv"
    13  	"testing"
    14  
    15  	"github.com/golang/protobuf/proto"
    16  	"github.com/hechain20/hechain/common/crypto/tlsgen"
    17  	"github.com/hechain20/hechain/common/flogging/floggingtest"
    18  	"github.com/hechain20/hechain/common/policies/mocks"
    19  	mspi "github.com/hechain20/hechain/msp"
    20  	"github.com/hechain20/hechain/protoutil"
    21  	cb "github.com/hyperledger/fabric-protos-go/common"
    22  	"github.com/hyperledger/fabric-protos-go/msp"
    23  	"github.com/pkg/errors"
    24  	"github.com/stretchr/testify/require"
    25  	"go.uber.org/zap/zapcore"
    26  )
    27  
    28  //go:generate counterfeiter -o mocks/identity_deserializer.go --fake-name IdentityDeserializer . identityDeserializer
    29  type identityDeserializer interface {
    30  	mspi.IdentityDeserializer
    31  }
    32  
    33  //go:generate counterfeiter -o mocks/identity.go --fake-name Identity . identity
    34  type identity interface {
    35  	mspi.Identity
    36  }
    37  
    38  type mockProvider struct{}
    39  
    40  func (mpp mockProvider) NewPolicy(data []byte) (Policy, proto.Message, error) {
    41  	return nil, nil, nil
    42  }
    43  
    44  const mockType = int32(0)
    45  
    46  func defaultProviders() map[int32]Provider {
    47  	providers := make(map[int32]Provider)
    48  	providers[mockType] = &mockProvider{}
    49  	return providers
    50  }
    51  
    52  func TestUnnestedManager(t *testing.T) {
    53  	config := &cb.ConfigGroup{
    54  		Policies: map[string]*cb.ConfigPolicy{
    55  			"1": {Policy: &cb.Policy{Type: mockType}},
    56  			"2": {Policy: &cb.Policy{Type: mockType}},
    57  			"3": {Policy: &cb.Policy{Type: mockType}},
    58  		},
    59  	}
    60  
    61  	m, err := NewManagerImpl("test", defaultProviders(), config)
    62  	require.NoError(t, err)
    63  	require.NotNil(t, m)
    64  
    65  	_, ok := m.Manager([]string{"subGroup"})
    66  	require.False(t, ok, "Should not have found a subgroup manager")
    67  
    68  	r, ok := m.Manager([]string{})
    69  	require.True(t, ok, "Should have found the root manager")
    70  	require.Equal(t, m, r)
    71  
    72  	require.Len(t, m.Policies, len(config.Policies))
    73  
    74  	for policyName := range config.Policies {
    75  		_, ok := m.GetPolicy(policyName)
    76  		require.True(t, ok, "Should have found policy %s", policyName)
    77  	}
    78  }
    79  
    80  func TestNestedManager(t *testing.T) {
    81  	config := &cb.ConfigGroup{
    82  		Policies: map[string]*cb.ConfigPolicy{
    83  			"n0a": {Policy: &cb.Policy{Type: mockType}},
    84  			"n0b": {Policy: &cb.Policy{Type: mockType}},
    85  			"n0c": {Policy: &cb.Policy{Type: mockType}},
    86  		},
    87  		Groups: map[string]*cb.ConfigGroup{
    88  			"nest1": {
    89  				Policies: map[string]*cb.ConfigPolicy{
    90  					"n1a": {Policy: &cb.Policy{Type: mockType}},
    91  					"n1b": {Policy: &cb.Policy{Type: mockType}},
    92  					"n1c": {Policy: &cb.Policy{Type: mockType}},
    93  				},
    94  				Groups: map[string]*cb.ConfigGroup{
    95  					"nest2a": {
    96  						Policies: map[string]*cb.ConfigPolicy{
    97  							"n2a_1": {Policy: &cb.Policy{Type: mockType}},
    98  							"n2a_2": {Policy: &cb.Policy{Type: mockType}},
    99  							"n2a_3": {Policy: &cb.Policy{Type: mockType}},
   100  						},
   101  					},
   102  					"nest2b": {
   103  						Policies: map[string]*cb.ConfigPolicy{
   104  							"n2b_1": {Policy: &cb.Policy{Type: mockType}},
   105  							"n2b_2": {Policy: &cb.Policy{Type: mockType}},
   106  							"n2b_3": {Policy: &cb.Policy{Type: mockType}},
   107  						},
   108  					},
   109  				},
   110  			},
   111  		},
   112  	}
   113  
   114  	m, err := NewManagerImpl("nest0", defaultProviders(), config)
   115  	require.NoError(t, err)
   116  	require.NotNil(t, m)
   117  
   118  	r, ok := m.Manager([]string{})
   119  	require.True(t, ok, "Should have found the root manager")
   120  	require.Equal(t, m, r)
   121  
   122  	n1, ok := m.Manager([]string{"nest1"})
   123  	require.True(t, ok)
   124  	n2a, ok := m.Manager([]string{"nest1", "nest2a"})
   125  	require.True(t, ok)
   126  	n2b, ok := m.Manager([]string{"nest1", "nest2b"})
   127  	require.True(t, ok)
   128  
   129  	n2as, ok := n1.Manager([]string{"nest2a"})
   130  	require.True(t, ok)
   131  	require.Equal(t, n2a, n2as)
   132  	n2bs, ok := n1.Manager([]string{"nest2b"})
   133  	require.True(t, ok)
   134  	require.Equal(t, n2b, n2bs)
   135  
   136  	absPrefix := PathSeparator + "nest0" + PathSeparator
   137  	for policyName := range config.Policies {
   138  		_, ok := m.GetPolicy(policyName)
   139  		require.True(t, ok, "Should have found policy %s", policyName)
   140  
   141  		absName := absPrefix + policyName
   142  		_, ok = m.GetPolicy(absName)
   143  		require.True(t, ok, "Should have found absolute policy %s", absName)
   144  	}
   145  
   146  	for policyName := range config.Groups["nest1"].Policies {
   147  		_, ok := n1.GetPolicy(policyName)
   148  		require.True(t, ok, "Should have found policy %s", policyName)
   149  
   150  		relPathFromBase := "nest1" + PathSeparator + policyName
   151  		_, ok = m.GetPolicy(relPathFromBase)
   152  		require.True(t, ok, "Should have found policy %s", policyName)
   153  
   154  		for i, abs := range []Manager{n1, m} {
   155  			absName := absPrefix + relPathFromBase
   156  			_, ok = abs.GetPolicy(absName)
   157  			require.True(t, ok, "Should have found absolutely policy for manager %d", i)
   158  		}
   159  	}
   160  
   161  	for policyName := range config.Groups["nest1"].Groups["nest2a"].Policies {
   162  		_, ok := n2a.GetPolicy(policyName)
   163  		require.True(t, ok, "Should have found policy %s", policyName)
   164  
   165  		relPathFromN1 := "nest2a" + PathSeparator + policyName
   166  		_, ok = n1.GetPolicy(relPathFromN1)
   167  		require.True(t, ok, "Should have found policy %s", policyName)
   168  
   169  		relPathFromBase := "nest1" + PathSeparator + relPathFromN1
   170  		_, ok = m.GetPolicy(relPathFromBase)
   171  		require.True(t, ok, "Should have found policy %s", policyName)
   172  
   173  		for i, abs := range []Manager{n2a, n1, m} {
   174  			absName := absPrefix + relPathFromBase
   175  			_, ok = abs.GetPolicy(absName)
   176  			require.True(t, ok, "Should have found absolutely policy for manager %d", i)
   177  		}
   178  	}
   179  
   180  	for policyName := range config.Groups["nest1"].Groups["nest2b"].Policies {
   181  		_, ok := n2b.GetPolicy(policyName)
   182  		require.True(t, ok, "Should have found policy %s", policyName)
   183  
   184  		relPathFromN1 := "nest2b" + PathSeparator + policyName
   185  		_, ok = n1.GetPolicy(relPathFromN1)
   186  		require.True(t, ok, "Should have found policy %s", policyName)
   187  
   188  		relPathFromBase := "nest1" + PathSeparator + relPathFromN1
   189  		_, ok = m.GetPolicy(relPathFromBase)
   190  		require.True(t, ok, "Should have found policy %s", policyName)
   191  
   192  		for i, abs := range []Manager{n2b, n1, m} {
   193  			absName := absPrefix + relPathFromBase
   194  			_, ok = abs.GetPolicy(absName)
   195  			require.True(t, ok, "Should have found absolutely policy for manager %d", i)
   196  		}
   197  	}
   198  }
   199  
   200  func TestPrincipalUniqueSet(t *testing.T) {
   201  	var principalSet PrincipalSet
   202  	addPrincipal := func(i int) {
   203  		principalSet = append(principalSet, &msp.MSPPrincipal{
   204  			PrincipalClassification: msp.MSPPrincipal_Classification(i),
   205  			Principal:               []byte(fmt.Sprintf("%d", i)),
   206  		})
   207  	}
   208  
   209  	addPrincipal(1)
   210  	addPrincipal(2)
   211  	addPrincipal(2)
   212  	addPrincipal(3)
   213  	addPrincipal(3)
   214  	addPrincipal(3)
   215  
   216  	for principal, plurality := range principalSet.UniqueSet() {
   217  		require.Equal(t, int(principal.PrincipalClassification), plurality)
   218  		require.Equal(t, fmt.Sprintf("%d", plurality), string(principal.Principal))
   219  	}
   220  
   221  	v := reflect.Indirect(reflect.ValueOf(msp.MSPPrincipal{}))
   222  	// Ensure msp.MSPPrincipal has only 2 fields.
   223  	// This is essential for 'UniqueSet' to work properly
   224  	// XXX This is a rather brittle check and brittle way to fix the test
   225  	// There seems to be an assumption that the number of fields in the proto
   226  	// struct matches the number of fields in the proto message
   227  	require.Equal(t, 5, v.NumField())
   228  }
   229  
   230  func TestPrincipalSetContainingOnly(t *testing.T) {
   231  	var principalSets PrincipalSets
   232  	var principalSet PrincipalSet
   233  	for j := 0; j < 3; j++ {
   234  		for i := 0; i < 10; i++ {
   235  			principalSet = append(principalSet, &msp.MSPPrincipal{
   236  				PrincipalClassification: msp.MSPPrincipal_IDENTITY,
   237  				Principal:               []byte(fmt.Sprintf("%d", j*10+i)),
   238  			})
   239  		}
   240  		principalSets = append(principalSets, principalSet)
   241  		principalSet = nil
   242  	}
   243  
   244  	between20And30 := func(principal *msp.MSPPrincipal) bool {
   245  		n, _ := strconv.ParseInt(string(principal.Principal), 10, 32)
   246  		return n >= 20 && n <= 29
   247  	}
   248  
   249  	principalSets = principalSets.ContainingOnly(between20And30)
   250  
   251  	require.Len(t, principalSets, 1)
   252  	require.True(t, principalSets[0].ContainingOnly(between20And30))
   253  }
   254  
   255  func TestSignatureSetToValidIdentities(t *testing.T) {
   256  	sd := []*protoutil.SignedData{
   257  		{
   258  			Data:      []byte("data1"),
   259  			Identity:  []byte("identity1"),
   260  			Signature: []byte("signature1"),
   261  		},
   262  		{
   263  			Data:      []byte("data1"),
   264  			Identity:  []byte("identity1"),
   265  			Signature: []byte("signature1"),
   266  		},
   267  	}
   268  
   269  	fIDDs := &mocks.IdentityDeserializer{}
   270  	fID := &mocks.Identity{}
   271  	fID.VerifyReturns(nil)
   272  	fID.GetIdentifierReturns(&mspi.IdentityIdentifier{
   273  		Id:    "id",
   274  		Mspid: "mspid",
   275  	})
   276  	fIDDs.DeserializeIdentityReturns(fID, nil)
   277  
   278  	ids := SignatureSetToValidIdentities(sd, fIDDs)
   279  	require.Len(t, ids, 1)
   280  	require.NotNil(t, ids[0].GetIdentifier())
   281  	require.Equal(t, "id", ids[0].GetIdentifier().Id)
   282  	require.Equal(t, "mspid", ids[0].GetIdentifier().Mspid)
   283  	data, sig := fID.VerifyArgsForCall(0)
   284  	require.Equal(t, []byte("data1"), data)
   285  	require.Equal(t, []byte("signature1"), sig)
   286  	sidBytes := fIDDs.DeserializeIdentityArgsForCall(0)
   287  	require.Equal(t, []byte("identity1"), sidBytes)
   288  }
   289  
   290  func TestSignatureSetToValidIdentitiesDeserializeErr(t *testing.T) {
   291  	oldLogger := logger
   292  	l, recorder := floggingtest.NewTestLogger(t, floggingtest.AtLevel(zapcore.InfoLevel))
   293  	logger = l
   294  	defer func() { logger = oldLogger }()
   295  
   296  	fakeIdentityDeserializer := &mocks.IdentityDeserializer{}
   297  	fakeIdentityDeserializer.DeserializeIdentityReturns(nil, errors.New("mango"))
   298  
   299  	// generate actual x509 certificate
   300  	ca, err := tlsgen.NewCA()
   301  	require.NoError(t, err)
   302  	client1, err := ca.NewClientCertKeyPair()
   303  	require.NoError(t, err)
   304  	id := &msp.SerializedIdentity{
   305  		Mspid:   "MyMSP",
   306  		IdBytes: client1.Cert,
   307  	}
   308  	idBytes, err := proto.Marshal(id)
   309  	require.NoError(t, err)
   310  
   311  	tests := []struct {
   312  		spec                     string
   313  		signedData               []*protoutil.SignedData
   314  		expectedLogEntryContains []string
   315  	}{
   316  		{
   317  			spec: "deserialize identity error - identity is random bytes",
   318  			signedData: []*protoutil.SignedData{
   319  				{
   320  					Identity: []byte("identity1"),
   321  				},
   322  			},
   323  			expectedLogEntryContains: []string{"invalid identity", fmt.Sprintf("serialized-identity=%x", []byte("identity1")), "error=mango"},
   324  		},
   325  		{
   326  			spec: "deserialize identity error - actual certificate",
   327  			signedData: []*protoutil.SignedData{
   328  				{
   329  					Identity: idBytes,
   330  				},
   331  			},
   332  			expectedLogEntryContains: []string{"invalid identity", fmt.Sprintf("mspid=MyMSP subject=%s issuer=%s serialnumber=%d", client1.TLSCert.Subject, client1.TLSCert.Issuer, client1.TLSCert.SerialNumber), "error=mango"},
   333  		},
   334  	}
   335  
   336  	for _, tc := range tests {
   337  		t.Run(tc.spec, func(t *testing.T) {
   338  			ids := SignatureSetToValidIdentities(tc.signedData, fakeIdentityDeserializer)
   339  			require.Len(t, ids, 0)
   340  			assertLogContains(t, recorder, tc.expectedLogEntryContains...)
   341  		})
   342  	}
   343  }
   344  
   345  func TestSignatureSetToValidIdentitiesVerifyErr(t *testing.T) {
   346  	sd := []*protoutil.SignedData{
   347  		{
   348  			Data:      []byte("data1"),
   349  			Identity:  []byte("identity1"),
   350  			Signature: []byte("signature1"),
   351  		},
   352  	}
   353  
   354  	fIDDs := &mocks.IdentityDeserializer{}
   355  	fID := &mocks.Identity{}
   356  	fID.VerifyReturns(errors.New("bad signature"))
   357  	fID.GetIdentifierReturns(&mspi.IdentityIdentifier{
   358  		Id:    "id",
   359  		Mspid: "mspid",
   360  	})
   361  	fIDDs.DeserializeIdentityReturns(fID, nil)
   362  
   363  	ids := SignatureSetToValidIdentities(sd, fIDDs)
   364  	require.Len(t, ids, 0)
   365  	data, sig := fID.VerifyArgsForCall(0)
   366  	require.Equal(t, []byte("data1"), data)
   367  	require.Equal(t, []byte("signature1"), sig)
   368  	sidBytes := fIDDs.DeserializeIdentityArgsForCall(0)
   369  	require.Equal(t, []byte("identity1"), sidBytes)
   370  }
   371  
   372  func assertLogContains(t *testing.T, r *floggingtest.Recorder, ss ...string) {
   373  	defer r.Reset()
   374  	entries := r.Entries()
   375  	for _, entry := range entries {
   376  		fmt.Println(entry)
   377  	}
   378  	for _, s := range ss {
   379  		require.NotEmpty(t, r.EntriesContaining(s))
   380  	}
   381  }