github.com/evdatsion/aphelion-dpos-bft@v0.32.1/types/part_set_test.go (about)

     1  package types
     2  
     3  import (
     4  	"io/ioutil"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  	"github.com/stretchr/testify/require"
     9  
    10  	cmn "github.com/evdatsion/aphelion-dpos-bft/libs/common"
    11  )
    12  
    13  const (
    14  	testPartSize = 65536 // 64KB ...  4096 // 4KB
    15  )
    16  
    17  func TestBasicPartSet(t *testing.T) {
    18  	// Construct random data of size partSize * 100
    19  	data := cmn.RandBytes(testPartSize * 100)
    20  	partSet := NewPartSetFromData(data, testPartSize)
    21  
    22  	assert.NotEmpty(t, partSet.Hash())
    23  	assert.Equal(t, 100, partSet.Total())
    24  	assert.Equal(t, 100, partSet.BitArray().Size())
    25  	assert.True(t, partSet.HashesTo(partSet.Hash()))
    26  	assert.True(t, partSet.IsComplete())
    27  	assert.Equal(t, 100, partSet.Count())
    28  
    29  	// Test adding parts to a new partSet.
    30  	partSet2 := NewPartSetFromHeader(partSet.Header())
    31  
    32  	assert.True(t, partSet2.HasHeader(partSet.Header()))
    33  	for i := 0; i < partSet.Total(); i++ {
    34  		part := partSet.GetPart(i)
    35  		//t.Logf("\n%v", part)
    36  		added, err := partSet2.AddPart(part)
    37  		if !added || err != nil {
    38  			t.Errorf("Failed to add part %v, error: %v", i, err)
    39  		}
    40  	}
    41  	// adding part with invalid index
    42  	added, err := partSet2.AddPart(&Part{Index: 10000})
    43  	assert.False(t, added)
    44  	assert.Error(t, err)
    45  	// adding existing part
    46  	added, err = partSet2.AddPart(partSet2.GetPart(0))
    47  	assert.False(t, added)
    48  	assert.Nil(t, err)
    49  
    50  	assert.Equal(t, partSet.Hash(), partSet2.Hash())
    51  	assert.Equal(t, 100, partSet2.Total())
    52  	assert.True(t, partSet2.IsComplete())
    53  
    54  	// Reconstruct data, assert that they are equal.
    55  	data2Reader := partSet2.GetReader()
    56  	data2, err := ioutil.ReadAll(data2Reader)
    57  	require.NoError(t, err)
    58  
    59  	assert.Equal(t, data, data2)
    60  }
    61  
    62  func TestWrongProof(t *testing.T) {
    63  	// Construct random data of size partSize * 100
    64  	data := cmn.RandBytes(testPartSize * 100)
    65  	partSet := NewPartSetFromData(data, testPartSize)
    66  
    67  	// Test adding a part with wrong data.
    68  	partSet2 := NewPartSetFromHeader(partSet.Header())
    69  
    70  	// Test adding a part with wrong trail.
    71  	part := partSet.GetPart(0)
    72  	part.Proof.Aunts[0][0] += byte(0x01)
    73  	added, err := partSet2.AddPart(part)
    74  	if added || err == nil {
    75  		t.Errorf("Expected to fail adding a part with bad trail.")
    76  	}
    77  
    78  	// Test adding a part with wrong bytes.
    79  	part = partSet.GetPart(1)
    80  	part.Bytes[0] += byte(0x01)
    81  	added, err = partSet2.AddPart(part)
    82  	if added || err == nil {
    83  		t.Errorf("Expected to fail adding a part with bad bytes.")
    84  	}
    85  }
    86  
    87  func TestPartSetHeaderValidateBasic(t *testing.T) {
    88  	testCases := []struct {
    89  		testName              string
    90  		malleatePartSetHeader func(*PartSetHeader)
    91  		expectErr             bool
    92  	}{
    93  		{"Good PartSet", func(psHeader *PartSetHeader) {}, false},
    94  		{"Negative Total", func(psHeader *PartSetHeader) { psHeader.Total = -2 }, true},
    95  		{"Invalid Hash", func(psHeader *PartSetHeader) { psHeader.Hash = make([]byte, 1) }, true},
    96  	}
    97  	for _, tc := range testCases {
    98  		t.Run(tc.testName, func(t *testing.T) {
    99  			data := cmn.RandBytes(testPartSize * 100)
   100  			ps := NewPartSetFromData(data, testPartSize)
   101  			psHeader := ps.Header()
   102  			tc.malleatePartSetHeader(&psHeader)
   103  			assert.Equal(t, tc.expectErr, psHeader.ValidateBasic() != nil, "Validate Basic had an unexpected result")
   104  		})
   105  	}
   106  }
   107  
   108  func TestPartValidateBasic(t *testing.T) {
   109  	testCases := []struct {
   110  		testName     string
   111  		malleatePart func(*Part)
   112  		expectErr    bool
   113  	}{
   114  		{"Good Part", func(pt *Part) {}, false},
   115  		{"Negative index", func(pt *Part) { pt.Index = -1 }, true},
   116  		{"Too big part", func(pt *Part) { pt.Bytes = make([]byte, BlockPartSizeBytes+1) }, true},
   117  	}
   118  
   119  	for _, tc := range testCases {
   120  		t.Run(tc.testName, func(t *testing.T) {
   121  			data := cmn.RandBytes(testPartSize * 100)
   122  			ps := NewPartSetFromData(data, testPartSize)
   123  			part := ps.GetPart(0)
   124  			tc.malleatePart(part)
   125  			assert.Equal(t, tc.expectErr, part.ValidateBasic() != nil, "Validate Basic had an unexpected result")
   126  		})
   127  	}
   128  }