github.com/prysmaticlabs/prysm@v1.4.4/shared/aggregation/sync_contribution/naive_test.go (about)

     1  package sync_contribution
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  	"testing"
     7  
     8  	"github.com/prysmaticlabs/go-bitfield"
     9  	prysmv2 "github.com/prysmaticlabs/prysm/proto/prysm/v2"
    10  	"github.com/prysmaticlabs/prysm/shared/aggregation"
    11  	aggtesting "github.com/prysmaticlabs/prysm/shared/aggregation/testing"
    12  	"github.com/prysmaticlabs/prysm/shared/bls"
    13  	"github.com/prysmaticlabs/prysm/shared/featureconfig"
    14  	"github.com/prysmaticlabs/prysm/shared/testutil/assert"
    15  	"github.com/prysmaticlabs/prysm/shared/testutil/require"
    16  )
    17  
    18  func TestAggregateAttestations_aggregate(t *testing.T) {
    19  	tests := []struct {
    20  		a1   *prysmv2.SyncCommitteeContribution
    21  		a2   *prysmv2.SyncCommitteeContribution
    22  		want *prysmv2.SyncCommitteeContribution
    23  	}{
    24  		{
    25  			a1:   &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x02}, Signature: bls.NewAggregateSignature().Marshal()},
    26  			a2:   &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x01}, Signature: bls.NewAggregateSignature().Marshal()},
    27  			want: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x03}},
    28  		},
    29  		{
    30  			a1:   &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x01}, Signature: bls.NewAggregateSignature().Marshal()},
    31  			a2:   &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x02}, Signature: bls.NewAggregateSignature().Marshal()},
    32  			want: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x03}},
    33  		},
    34  	}
    35  	for _, tt := range tests {
    36  		got, err := aggregate(tt.a1, tt.a2)
    37  		require.NoError(t, err)
    38  		require.DeepSSZEqual(t, tt.want.AggregationBits, got.AggregationBits)
    39  	}
    40  }
    41  
    42  func TestAggregateAttestations_aggregate_OverlapFails(t *testing.T) {
    43  	tests := []struct {
    44  		a1 *prysmv2.SyncCommitteeContribution
    45  		a2 *prysmv2.SyncCommitteeContribution
    46  	}{
    47  		{
    48  			a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x1F}},
    49  			a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x11}},
    50  		},
    51  		{
    52  			a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0xFF, 0x85}},
    53  			a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x13, 0x8F}},
    54  		},
    55  	}
    56  	for _, tt := range tests {
    57  		_, err := aggregate(tt.a1, tt.a2)
    58  		require.ErrorContains(t, aggregation.ErrBitsOverlap.Error(), err)
    59  	}
    60  }
    61  
    62  func TestAggregateAttestations_Aggregate(t *testing.T) {
    63  	tests := []struct {
    64  		name   string
    65  		inputs []bitfield.Bitvector128
    66  		want   []bitfield.Bitvector128
    67  	}{
    68  		{
    69  			name:   "empty list",
    70  			inputs: []bitfield.Bitvector128{},
    71  			want:   []bitfield.Bitvector128{},
    72  		},
    73  		{
    74  			name: "single attestation",
    75  			inputs: []bitfield.Bitvector128{
    76  				{0b00000010},
    77  			},
    78  			want: []bitfield.Bitvector128{
    79  				{0b00000010},
    80  			},
    81  		},
    82  		{
    83  			name: "two attestations with no overlap",
    84  			inputs: []bitfield.Bitvector128{
    85  				{0b00000001},
    86  				{0b00000010},
    87  			},
    88  			want: []bitfield.Bitvector128{
    89  				{0b00000011},
    90  			},
    91  		},
    92  		{
    93  			name: "two attestations with overlap",
    94  			inputs: []bitfield.Bitvector128{
    95  				{0b00000101},
    96  				{0b00000110},
    97  			},
    98  			want: []bitfield.Bitvector128{
    99  				{0b00000101},
   100  				{0b00000110},
   101  			},
   102  		},
   103  		{
   104  			name: "some attestations overlap",
   105  			inputs: []bitfield.Bitvector128{
   106  				{0b00001001},
   107  				{0b00010110},
   108  				{0b00001010},
   109  				{0b00110001},
   110  			},
   111  			want: []bitfield.Bitvector128{
   112  				{0b00111011},
   113  				{0b00011111},
   114  			},
   115  		},
   116  		{
   117  			name: "some attestations produce duplicates which are removed",
   118  			inputs: []bitfield.Bitvector128{
   119  				{0b00000101},
   120  				{0b00000110},
   121  				{0b00001010},
   122  				{0b00001001},
   123  			},
   124  			want: []bitfield.Bitvector128{
   125  				{0b00001111}, // both 0&1 and 2&3 produce this bitlist
   126  			},
   127  		},
   128  		{
   129  			name: "two attestations where one is fully contained within the other",
   130  			inputs: []bitfield.Bitvector128{
   131  				{0b00000001},
   132  				{0b00000011},
   133  			},
   134  			want: []bitfield.Bitvector128{
   135  				{0b00000011},
   136  			},
   137  		},
   138  		{
   139  			name: "two attestations where one is fully contained within the other reversed",
   140  			inputs: []bitfield.Bitvector128{
   141  				{0b00000011},
   142  				{0b00000001},
   143  			},
   144  			want: []bitfield.Bitvector128{
   145  				{0b00000011},
   146  			},
   147  		},
   148  	}
   149  
   150  	for _, tt := range tests {
   151  		runner := func() {
   152  			got, err := Aggregate(aggtesting.MakeSyncContributionsFromBitVector(tt.inputs))
   153  			require.NoError(t, err)
   154  			sort.Slice(got, func(i, j int) bool {
   155  				return got[i].AggregationBits.Bytes()[0] < got[j].AggregationBits.Bytes()[0]
   156  			})
   157  			sort.Slice(tt.want, func(i, j int) bool {
   158  				return tt.want[i].Bytes()[0] < tt.want[j].Bytes()[0]
   159  			})
   160  			assert.Equal(t, len(tt.want), len(got))
   161  			for i, w := range tt.want {
   162  				assert.DeepEqual(t, w.Bytes(), got[i].AggregationBits.Bytes())
   163  			}
   164  		}
   165  		t.Run(fmt.Sprintf("%s/%s", tt.name, NaiveAggregation), func(t *testing.T) {
   166  			resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{
   167  				AttestationAggregationStrategy: string(NaiveAggregation),
   168  			})
   169  			defer resetCfg()
   170  			runner()
   171  		})
   172  	}
   173  }