github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/common/policies/policy_test.go (about) 1 /* 2 Copyright hechain. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package policies 8 9 import ( 10 "fmt" 11 "reflect" 12 "strconv" 13 "testing" 14 15 "github.com/golang/protobuf/proto" 16 "github.com/hechain20/hechain/common/crypto/tlsgen" 17 "github.com/hechain20/hechain/common/flogging/floggingtest" 18 "github.com/hechain20/hechain/common/policies/mocks" 19 mspi "github.com/hechain20/hechain/msp" 20 "github.com/hechain20/hechain/protoutil" 21 cb "github.com/hyperledger/fabric-protos-go/common" 22 "github.com/hyperledger/fabric-protos-go/msp" 23 "github.com/pkg/errors" 24 "github.com/stretchr/testify/require" 25 "go.uber.org/zap/zapcore" 26 ) 27 28 //go:generate counterfeiter -o mocks/identity_deserializer.go --fake-name IdentityDeserializer . identityDeserializer 29 type identityDeserializer interface { 30 mspi.IdentityDeserializer 31 } 32 33 //go:generate counterfeiter -o mocks/identity.go --fake-name Identity . identity 34 type identity interface { 35 mspi.Identity 36 } 37 38 type mockProvider struct{} 39 40 func (mpp mockProvider) NewPolicy(data []byte) (Policy, proto.Message, error) { 41 return nil, nil, nil 42 } 43 44 const mockType = int32(0) 45 46 func defaultProviders() map[int32]Provider { 47 providers := make(map[int32]Provider) 48 providers[mockType] = &mockProvider{} 49 return providers 50 } 51 52 func TestUnnestedManager(t *testing.T) { 53 config := &cb.ConfigGroup{ 54 Policies: map[string]*cb.ConfigPolicy{ 55 "1": {Policy: &cb.Policy{Type: mockType}}, 56 "2": {Policy: &cb.Policy{Type: mockType}}, 57 "3": {Policy: &cb.Policy{Type: mockType}}, 58 }, 59 } 60 61 m, err := NewManagerImpl("test", defaultProviders(), config) 62 require.NoError(t, err) 63 require.NotNil(t, m) 64 65 _, ok := m.Manager([]string{"subGroup"}) 66 require.False(t, ok, "Should not have found a subgroup manager") 67 68 r, ok := m.Manager([]string{}) 69 require.True(t, ok, "Should have found the root manager") 70 require.Equal(t, m, r) 71 72 require.Len(t, m.Policies, len(config.Policies)) 73 74 for policyName := range config.Policies { 75 _, ok := m.GetPolicy(policyName) 76 require.True(t, ok, "Should have found policy %s", policyName) 77 } 78 } 79 80 func TestNestedManager(t *testing.T) { 81 config := &cb.ConfigGroup{ 82 Policies: map[string]*cb.ConfigPolicy{ 83 "n0a": {Policy: &cb.Policy{Type: mockType}}, 84 "n0b": {Policy: &cb.Policy{Type: mockType}}, 85 "n0c": {Policy: &cb.Policy{Type: mockType}}, 86 }, 87 Groups: map[string]*cb.ConfigGroup{ 88 "nest1": { 89 Policies: map[string]*cb.ConfigPolicy{ 90 "n1a": {Policy: &cb.Policy{Type: mockType}}, 91 "n1b": {Policy: &cb.Policy{Type: mockType}}, 92 "n1c": {Policy: &cb.Policy{Type: mockType}}, 93 }, 94 Groups: map[string]*cb.ConfigGroup{ 95 "nest2a": { 96 Policies: map[string]*cb.ConfigPolicy{ 97 "n2a_1": {Policy: &cb.Policy{Type: mockType}}, 98 "n2a_2": {Policy: &cb.Policy{Type: mockType}}, 99 "n2a_3": {Policy: &cb.Policy{Type: mockType}}, 100 }, 101 }, 102 "nest2b": { 103 Policies: map[string]*cb.ConfigPolicy{ 104 "n2b_1": {Policy: &cb.Policy{Type: mockType}}, 105 "n2b_2": {Policy: &cb.Policy{Type: mockType}}, 106 "n2b_3": {Policy: &cb.Policy{Type: mockType}}, 107 }, 108 }, 109 }, 110 }, 111 }, 112 } 113 114 m, err := NewManagerImpl("nest0", defaultProviders(), config) 115 require.NoError(t, err) 116 require.NotNil(t, m) 117 118 r, ok := m.Manager([]string{}) 119 require.True(t, ok, "Should have found the root manager") 120 require.Equal(t, m, r) 121 122 n1, ok := m.Manager([]string{"nest1"}) 123 require.True(t, ok) 124 n2a, ok := m.Manager([]string{"nest1", "nest2a"}) 125 require.True(t, ok) 126 n2b, ok := m.Manager([]string{"nest1", "nest2b"}) 127 require.True(t, ok) 128 129 n2as, ok := n1.Manager([]string{"nest2a"}) 130 require.True(t, ok) 131 require.Equal(t, n2a, n2as) 132 n2bs, ok := n1.Manager([]string{"nest2b"}) 133 require.True(t, ok) 134 require.Equal(t, n2b, n2bs) 135 136 absPrefix := PathSeparator + "nest0" + PathSeparator 137 for policyName := range config.Policies { 138 _, ok := m.GetPolicy(policyName) 139 require.True(t, ok, "Should have found policy %s", policyName) 140 141 absName := absPrefix + policyName 142 _, ok = m.GetPolicy(absName) 143 require.True(t, ok, "Should have found absolute policy %s", absName) 144 } 145 146 for policyName := range config.Groups["nest1"].Policies { 147 _, ok := n1.GetPolicy(policyName) 148 require.True(t, ok, "Should have found policy %s", policyName) 149 150 relPathFromBase := "nest1" + PathSeparator + policyName 151 _, ok = m.GetPolicy(relPathFromBase) 152 require.True(t, ok, "Should have found policy %s", policyName) 153 154 for i, abs := range []Manager{n1, m} { 155 absName := absPrefix + relPathFromBase 156 _, ok = abs.GetPolicy(absName) 157 require.True(t, ok, "Should have found absolutely policy for manager %d", i) 158 } 159 } 160 161 for policyName := range config.Groups["nest1"].Groups["nest2a"].Policies { 162 _, ok := n2a.GetPolicy(policyName) 163 require.True(t, ok, "Should have found policy %s", policyName) 164 165 relPathFromN1 := "nest2a" + PathSeparator + policyName 166 _, ok = n1.GetPolicy(relPathFromN1) 167 require.True(t, ok, "Should have found policy %s", policyName) 168 169 relPathFromBase := "nest1" + PathSeparator + relPathFromN1 170 _, ok = m.GetPolicy(relPathFromBase) 171 require.True(t, ok, "Should have found policy %s", policyName) 172 173 for i, abs := range []Manager{n2a, n1, m} { 174 absName := absPrefix + relPathFromBase 175 _, ok = abs.GetPolicy(absName) 176 require.True(t, ok, "Should have found absolutely policy for manager %d", i) 177 } 178 } 179 180 for policyName := range config.Groups["nest1"].Groups["nest2b"].Policies { 181 _, ok := n2b.GetPolicy(policyName) 182 require.True(t, ok, "Should have found policy %s", policyName) 183 184 relPathFromN1 := "nest2b" + PathSeparator + policyName 185 _, ok = n1.GetPolicy(relPathFromN1) 186 require.True(t, ok, "Should have found policy %s", policyName) 187 188 relPathFromBase := "nest1" + PathSeparator + relPathFromN1 189 _, ok = m.GetPolicy(relPathFromBase) 190 require.True(t, ok, "Should have found policy %s", policyName) 191 192 for i, abs := range []Manager{n2b, n1, m} { 193 absName := absPrefix + relPathFromBase 194 _, ok = abs.GetPolicy(absName) 195 require.True(t, ok, "Should have found absolutely policy for manager %d", i) 196 } 197 } 198 } 199 200 func TestPrincipalUniqueSet(t *testing.T) { 201 var principalSet PrincipalSet 202 addPrincipal := func(i int) { 203 principalSet = append(principalSet, &msp.MSPPrincipal{ 204 PrincipalClassification: msp.MSPPrincipal_Classification(i), 205 Principal: []byte(fmt.Sprintf("%d", i)), 206 }) 207 } 208 209 addPrincipal(1) 210 addPrincipal(2) 211 addPrincipal(2) 212 addPrincipal(3) 213 addPrincipal(3) 214 addPrincipal(3) 215 216 for principal, plurality := range principalSet.UniqueSet() { 217 require.Equal(t, int(principal.PrincipalClassification), plurality) 218 require.Equal(t, fmt.Sprintf("%d", plurality), string(principal.Principal)) 219 } 220 221 v := reflect.Indirect(reflect.ValueOf(msp.MSPPrincipal{})) 222 // Ensure msp.MSPPrincipal has only 2 fields. 223 // This is essential for 'UniqueSet' to work properly 224 // XXX This is a rather brittle check and brittle way to fix the test 225 // There seems to be an assumption that the number of fields in the proto 226 // struct matches the number of fields in the proto message 227 require.Equal(t, 5, v.NumField()) 228 } 229 230 func TestPrincipalSetContainingOnly(t *testing.T) { 231 var principalSets PrincipalSets 232 var principalSet PrincipalSet 233 for j := 0; j < 3; j++ { 234 for i := 0; i < 10; i++ { 235 principalSet = append(principalSet, &msp.MSPPrincipal{ 236 PrincipalClassification: msp.MSPPrincipal_IDENTITY, 237 Principal: []byte(fmt.Sprintf("%d", j*10+i)), 238 }) 239 } 240 principalSets = append(principalSets, principalSet) 241 principalSet = nil 242 } 243 244 between20And30 := func(principal *msp.MSPPrincipal) bool { 245 n, _ := strconv.ParseInt(string(principal.Principal), 10, 32) 246 return n >= 20 && n <= 29 247 } 248 249 principalSets = principalSets.ContainingOnly(between20And30) 250 251 require.Len(t, principalSets, 1) 252 require.True(t, principalSets[0].ContainingOnly(between20And30)) 253 } 254 255 func TestSignatureSetToValidIdentities(t *testing.T) { 256 sd := []*protoutil.SignedData{ 257 { 258 Data: []byte("data1"), 259 Identity: []byte("identity1"), 260 Signature: []byte("signature1"), 261 }, 262 { 263 Data: []byte("data1"), 264 Identity: []byte("identity1"), 265 Signature: []byte("signature1"), 266 }, 267 } 268 269 fIDDs := &mocks.IdentityDeserializer{} 270 fID := &mocks.Identity{} 271 fID.VerifyReturns(nil) 272 fID.GetIdentifierReturns(&mspi.IdentityIdentifier{ 273 Id: "id", 274 Mspid: "mspid", 275 }) 276 fIDDs.DeserializeIdentityReturns(fID, nil) 277 278 ids := SignatureSetToValidIdentities(sd, fIDDs) 279 require.Len(t, ids, 1) 280 require.NotNil(t, ids[0].GetIdentifier()) 281 require.Equal(t, "id", ids[0].GetIdentifier().Id) 282 require.Equal(t, "mspid", ids[0].GetIdentifier().Mspid) 283 data, sig := fID.VerifyArgsForCall(0) 284 require.Equal(t, []byte("data1"), data) 285 require.Equal(t, []byte("signature1"), sig) 286 sidBytes := fIDDs.DeserializeIdentityArgsForCall(0) 287 require.Equal(t, []byte("identity1"), sidBytes) 288 } 289 290 func TestSignatureSetToValidIdentitiesDeserializeErr(t *testing.T) { 291 oldLogger := logger 292 l, recorder := floggingtest.NewTestLogger(t, floggingtest.AtLevel(zapcore.InfoLevel)) 293 logger = l 294 defer func() { logger = oldLogger }() 295 296 fakeIdentityDeserializer := &mocks.IdentityDeserializer{} 297 fakeIdentityDeserializer.DeserializeIdentityReturns(nil, errors.New("mango")) 298 299 // generate actual x509 certificate 300 ca, err := tlsgen.NewCA() 301 require.NoError(t, err) 302 client1, err := ca.NewClientCertKeyPair() 303 require.NoError(t, err) 304 id := &msp.SerializedIdentity{ 305 Mspid: "MyMSP", 306 IdBytes: client1.Cert, 307 } 308 idBytes, err := proto.Marshal(id) 309 require.NoError(t, err) 310 311 tests := []struct { 312 spec string 313 signedData []*protoutil.SignedData 314 expectedLogEntryContains []string 315 }{ 316 { 317 spec: "deserialize identity error - identity is random bytes", 318 signedData: []*protoutil.SignedData{ 319 { 320 Identity: []byte("identity1"), 321 }, 322 }, 323 expectedLogEntryContains: []string{"invalid identity", fmt.Sprintf("serialized-identity=%x", []byte("identity1")), "error=mango"}, 324 }, 325 { 326 spec: "deserialize identity error - actual certificate", 327 signedData: []*protoutil.SignedData{ 328 { 329 Identity: idBytes, 330 }, 331 }, 332 expectedLogEntryContains: []string{"invalid identity", fmt.Sprintf("mspid=MyMSP subject=%s issuer=%s serialnumber=%d", client1.TLSCert.Subject, client1.TLSCert.Issuer, client1.TLSCert.SerialNumber), "error=mango"}, 333 }, 334 } 335 336 for _, tc := range tests { 337 t.Run(tc.spec, func(t *testing.T) { 338 ids := SignatureSetToValidIdentities(tc.signedData, fakeIdentityDeserializer) 339 require.Len(t, ids, 0) 340 assertLogContains(t, recorder, tc.expectedLogEntryContains...) 341 }) 342 } 343 } 344 345 func TestSignatureSetToValidIdentitiesVerifyErr(t *testing.T) { 346 sd := []*protoutil.SignedData{ 347 { 348 Data: []byte("data1"), 349 Identity: []byte("identity1"), 350 Signature: []byte("signature1"), 351 }, 352 } 353 354 fIDDs := &mocks.IdentityDeserializer{} 355 fID := &mocks.Identity{} 356 fID.VerifyReturns(errors.New("bad signature")) 357 fID.GetIdentifierReturns(&mspi.IdentityIdentifier{ 358 Id: "id", 359 Mspid: "mspid", 360 }) 361 fIDDs.DeserializeIdentityReturns(fID, nil) 362 363 ids := SignatureSetToValidIdentities(sd, fIDDs) 364 require.Len(t, ids, 0) 365 data, sig := fID.VerifyArgsForCall(0) 366 require.Equal(t, []byte("data1"), data) 367 require.Equal(t, []byte("signature1"), sig) 368 sidBytes := fIDDs.DeserializeIdentityArgsForCall(0) 369 require.Equal(t, []byte("identity1"), sidBytes) 370 } 371 372 func assertLogContains(t *testing.T, r *floggingtest.Recorder, ss ...string) { 373 defer r.Reset() 374 entries := r.Entries() 375 for _, entry := range entries { 376 fmt.Println(entry) 377 } 378 for _, s := range ss { 379 require.NotEmpty(t, r.EntriesContaining(s)) 380 } 381 }