github.com/prysmaticlabs/prysm@v1.4.4/shared/aggregation/attestations/attestations_test.go (about)

     1  package attestations
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"sort"
     7  	"testing"
     8  
     9  	"github.com/prysmaticlabs/go-bitfield"
    10  	ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1"
    11  	"github.com/prysmaticlabs/prysm/shared/aggregation"
    12  	aggtesting "github.com/prysmaticlabs/prysm/shared/aggregation/testing"
    13  	"github.com/prysmaticlabs/prysm/shared/bls"
    14  	"github.com/prysmaticlabs/prysm/shared/featureconfig"
    15  	"github.com/prysmaticlabs/prysm/shared/params"
    16  	"github.com/prysmaticlabs/prysm/shared/sszutil"
    17  	"github.com/prysmaticlabs/prysm/shared/testutil/assert"
    18  	"github.com/prysmaticlabs/prysm/shared/testutil/require"
    19  	"github.com/sirupsen/logrus"
    20  )
    21  
    22  func TestMain(m *testing.M) {
    23  	logrus.SetLevel(logrus.DebugLevel)
    24  	logrus.SetOutput(ioutil.Discard)
    25  	resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{
    26  		AttestationAggregationStrategy: string(OptMaxCoverAggregation),
    27  	})
    28  	defer resetCfg()
    29  	m.Run()
    30  }
    31  
    32  func TestAggregateAttestations_AggregatePair(t *testing.T) {
    33  	tests := []struct {
    34  		a1   *ethpb.Attestation
    35  		a2   *ethpb.Attestation
    36  		want *ethpb.Attestation
    37  	}{
    38  		{
    39  			a1:   &ethpb.Attestation{AggregationBits: []byte{}},
    40  			a2:   &ethpb.Attestation{AggregationBits: []byte{}},
    41  			want: &ethpb.Attestation{AggregationBits: []byte{}},
    42  		},
    43  		{
    44  			a1:   &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0x03}},
    45  			a2:   &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0x02}},
    46  			want: &ethpb.Attestation{AggregationBits: []byte{0x03}},
    47  		},
    48  		{
    49  			a1:   &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0x02}},
    50  			a2:   &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0x03}},
    51  			want: &ethpb.Attestation{AggregationBits: []byte{0x03}},
    52  		},
    53  	}
    54  	for _, tt := range tests {
    55  		got, err := AggregatePair(tt.a1, tt.a2)
    56  		require.NoError(t, err)
    57  		require.Equal(t, true, sszutil.DeepEqual(got, tt.want))
    58  	}
    59  }
    60  
    61  func TestAggregateAttestations_AggregatePair_OverlapFails(t *testing.T) {
    62  	tests := []struct {
    63  		a1 *ethpb.Attestation
    64  		a2 *ethpb.Attestation
    65  	}{
    66  		{
    67  			a1: &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0x1F}},
    68  			a2: &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0x11}},
    69  		},
    70  		{
    71  			a1: &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0xFF, 0x85}},
    72  			a2: &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0x13, 0x8F}},
    73  		},
    74  	}
    75  	for _, tt := range tests {
    76  		_, err := AggregatePair(tt.a1, tt.a2)
    77  		require.ErrorContains(t, aggregation.ErrBitsOverlap.Error(), err)
    78  	}
    79  }
    80  
    81  func TestAggregateAttestations_AggregatePair_DiffLengthFails(t *testing.T) {
    82  	tests := []struct {
    83  		a1 *ethpb.Attestation
    84  		a2 *ethpb.Attestation
    85  	}{
    86  		{
    87  			a1: &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0x0F}},
    88  			a2: &ethpb.Attestation{AggregationBits: bitfield.Bitlist{0x11}},
    89  		},
    90  	}
    91  	for _, tt := range tests {
    92  		_, err := AggregatePair(tt.a1, tt.a2)
    93  		require.ErrorContains(t, bitfield.ErrBitlistDifferentLength.Error(), err)
    94  	}
    95  }
    96  
    97  func TestAggregateAttestations_Aggregate(t *testing.T) {
    98  	// Each test defines the aggregation bitfield inputs and the wanted output result.
    99  	bitlistLen := params.BeaconConfig().MaxValidatorsPerCommittee
   100  	tests := []struct {
   101  		name   string
   102  		inputs []bitfield.Bitlist
   103  		want   []bitfield.Bitlist
   104  		err    error
   105  	}{
   106  		{
   107  			name:   "empty list",
   108  			inputs: []bitfield.Bitlist{},
   109  			want:   []bitfield.Bitlist{},
   110  		},
   111  		{
   112  			name: "single attestation",
   113  			inputs: []bitfield.Bitlist{
   114  				{0b00000010, 0b1},
   115  			},
   116  			want: []bitfield.Bitlist{
   117  				{0b00000010, 0b1},
   118  			},
   119  		},
   120  		{
   121  			name: "two attestations with no overlap",
   122  			inputs: []bitfield.Bitlist{
   123  				{0b00000001, 0b1},
   124  				{0b00000010, 0b1},
   125  			},
   126  			want: []bitfield.Bitlist{
   127  				{0b00000011, 0b1},
   128  			},
   129  		},
   130  		{
   131  			name:   "256 attestations with single bit set",
   132  			inputs: aggtesting.BitlistsWithSingleBitSet(256, bitlistLen),
   133  			want: []bitfield.Bitlist{
   134  				aggtesting.BitlistWithAllBitsSet(256),
   135  			},
   136  		},
   137  		{
   138  			name:   "1024 attestations with single bit set",
   139  			inputs: aggtesting.BitlistsWithSingleBitSet(1024, bitlistLen),
   140  			want: []bitfield.Bitlist{
   141  				aggtesting.BitlistWithAllBitsSet(1024),
   142  			},
   143  		},
   144  		{
   145  			name: "two attestations with overlap",
   146  			inputs: []bitfield.Bitlist{
   147  				{0b00000101, 0b1},
   148  				{0b00000110, 0b1},
   149  			},
   150  			want: []bitfield.Bitlist{
   151  				{0b00000101, 0b1},
   152  				{0b00000110, 0b1},
   153  			},
   154  		},
   155  		{
   156  			name: "some attestations overlap",
   157  			inputs: []bitfield.Bitlist{
   158  				{0b00001001, 0b1},
   159  				{0b00010110, 0b1},
   160  				{0b00001010, 0b1},
   161  				{0b00110001, 0b1},
   162  			},
   163  			want: []bitfield.Bitlist{
   164  				{0b00111011, 0b1},
   165  				{0b00011111, 0b1},
   166  			},
   167  		},
   168  		{
   169  			name: "some attestations produce duplicates which are removed",
   170  			inputs: []bitfield.Bitlist{
   171  				{0b00000101, 0b1},
   172  				{0b00000110, 0b1},
   173  				{0b00001010, 0b1},
   174  				{0b00001001, 0b1},
   175  			},
   176  			want: []bitfield.Bitlist{
   177  				{0b00001111, 0b1}, // both 0&1 and 2&3 produce this bitlist
   178  			},
   179  		},
   180  		{
   181  			name: "two attestations where one is fully contained within the other",
   182  			inputs: []bitfield.Bitlist{
   183  				{0b00000001, 0b1},
   184  				{0b00000011, 0b1},
   185  			},
   186  			want: []bitfield.Bitlist{
   187  				{0b00000011, 0b1},
   188  			},
   189  		},
   190  		{
   191  			name: "two attestations where one is fully contained within the other reversed",
   192  			inputs: []bitfield.Bitlist{
   193  				{0b00000011, 0b1},
   194  				{0b00000001, 0b1},
   195  			},
   196  			want: []bitfield.Bitlist{
   197  				{0b00000011, 0b1},
   198  			},
   199  		},
   200  		{
   201  			name: "attestations with different bitlist lengths",
   202  			inputs: []bitfield.Bitlist{
   203  				{0b00000011, 0b10},
   204  				{0b00000111, 0b100},
   205  				{0b00000100, 0b1},
   206  			},
   207  			want: []bitfield.Bitlist{
   208  				{0b00000011, 0b10},
   209  				{0b00000111, 0b100},
   210  				{0b00000100, 0b1},
   211  			},
   212  			err: bitfield.ErrBitlistDifferentLength,
   213  		},
   214  	}
   215  
   216  	for _, tt := range tests {
   217  		runner := func() {
   218  			got, err := Aggregate(aggtesting.MakeAttestationsFromBitlists(tt.inputs))
   219  			if tt.err != nil {
   220  				require.ErrorContains(t, tt.err.Error(), err)
   221  				return
   222  			}
   223  			require.NoError(t, err)
   224  			sort.Slice(got, func(i, j int) bool {
   225  				return got[i].AggregationBits.Bytes()[0] < got[j].AggregationBits.Bytes()[0]
   226  			})
   227  			sort.Slice(tt.want, func(i, j int) bool {
   228  				return tt.want[i].Bytes()[0] < tt.want[j].Bytes()[0]
   229  			})
   230  			assert.Equal(t, len(tt.want), len(got))
   231  			for i, w := range tt.want {
   232  				assert.DeepEqual(t, w.Bytes(), got[i].AggregationBits.Bytes())
   233  			}
   234  		}
   235  		t.Run(fmt.Sprintf("%s/%s", tt.name, NaiveAggregation), func(t *testing.T) {
   236  			resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{
   237  				AttestationAggregationStrategy: string(NaiveAggregation),
   238  			})
   239  			defer resetCfg()
   240  			runner()
   241  		})
   242  		t.Run(fmt.Sprintf("%s/%s", tt.name, MaxCoverAggregation), func(t *testing.T) {
   243  			resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{
   244  				AttestationAggregationStrategy: string(MaxCoverAggregation),
   245  			})
   246  			defer resetCfg()
   247  			runner()
   248  		})
   249  		t.Run(fmt.Sprintf("%s/%s", tt.name, OptMaxCoverAggregation), func(t *testing.T) {
   250  			resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{
   251  				AttestationAggregationStrategy: string(OptMaxCoverAggregation),
   252  			})
   253  			defer resetCfg()
   254  			runner()
   255  		})
   256  	}
   257  
   258  	t.Run("invalid strategy", func(t *testing.T) {
   259  		resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{
   260  			AttestationAggregationStrategy: "foobar",
   261  		})
   262  		defer resetCfg()
   263  		_, err := Aggregate(aggtesting.MakeAttestationsFromBitlists([]bitfield.Bitlist{}))
   264  		assert.ErrorContains(t, "\"foobar\": invalid aggregation strategy", err)
   265  	})
   266  
   267  	t.Run("broken attestation bitset", func(t *testing.T) {
   268  		wantErr := "bitlist cannot be nil or empty: invalid max_cover problem"
   269  		t.Run(string(MaxCoverAggregation), func(t *testing.T) {
   270  			resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{
   271  				AttestationAggregationStrategy: string(MaxCoverAggregation),
   272  			})
   273  			defer resetCfg()
   274  			_, err := Aggregate(aggtesting.MakeAttestationsFromBitlists([]bitfield.Bitlist{
   275  				{0b00000011, 0b0},
   276  				{0b00000111, 0b100},
   277  				{0b00000100, 0b1},
   278  			}))
   279  			assert.ErrorContains(t, wantErr, err)
   280  		})
   281  		t.Run(string(OptMaxCoverAggregation), func(t *testing.T) {
   282  			resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{
   283  				AttestationAggregationStrategy: string(OptMaxCoverAggregation),
   284  			})
   285  			defer resetCfg()
   286  			_, err := Aggregate(aggtesting.MakeAttestationsFromBitlists([]bitfield.Bitlist{
   287  				{0b00000011, 0b0},
   288  				{0b00000111, 0b100},
   289  				{0b00000100, 0b1},
   290  			}))
   291  			assert.ErrorContains(t, wantErr, err)
   292  		})
   293  	})
   294  
   295  	t.Run("candidate swapping when aggregating", func(t *testing.T) {
   296  		// The first item cannot be aggregated, and should be pushed down the list,
   297  		// by two swaps with aggregated items (aggregation is done in-place, so the very same
   298  		// underlying array is used for storing both aggregated and non-aggregated items).
   299  		resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{
   300  			AttestationAggregationStrategy: string(OptMaxCoverAggregation),
   301  		})
   302  		defer resetCfg()
   303  		got, err := Aggregate(aggtesting.MakeAttestationsFromBitlists([]bitfield.Bitlist{
   304  			{0b10000000, 0b1},
   305  			{0b11000101, 0b1},
   306  			{0b00011000, 0b1},
   307  			{0b01010100, 0b1},
   308  			{0b10001010, 0b1},
   309  		}))
   310  		want := []bitfield.Bitlist{
   311  			{0b11011101, 0b1},
   312  			{0b11011110, 0b1},
   313  			{0b10000000, 0b1},
   314  		}
   315  		assert.NoError(t, err)
   316  		assert.Equal(t, len(want), len(got))
   317  		for i, w := range want {
   318  			assert.DeepEqual(t, w.Bytes(), got[i].AggregationBits.Bytes())
   319  		}
   320  	})
   321  }
   322  
   323  func TestAggregateAttestations_PerformanceComparison(t *testing.T) {
   324  	// Tests below are examples of cases where max-cover's greedy approach outperforms the original
   325  	// naive aggregation (which is very much dependent on order in which items are fed into it).
   326  	tests := []struct {
   327  		name     string
   328  		bitsList [][]byte
   329  	}{
   330  		{
   331  			name: "test1",
   332  			bitsList: [][]byte{
   333  				{0b00000100, 0b1},
   334  				{0b00000010, 0b1},
   335  				{0b00000001, 0b1},
   336  				{0b00011001, 0b1},
   337  			},
   338  		},
   339  		{
   340  			name: "test2",
   341  			bitsList: [][]byte{
   342  				{0b10010001, 0b1},
   343  				{0b00100000, 0b1},
   344  				{0b01101110, 0b1},
   345  			},
   346  		},
   347  		{
   348  			name: "test3",
   349  			bitsList: [][]byte{
   350  				{0b00100000, 0b00000011, 0b1},
   351  				{0b00011100, 0b11000000, 0b1},
   352  				{0b11111100, 0b00000000, 0b1},
   353  				{0b00000011, 0b10000000, 0b1},
   354  				{0b11100011, 0b00000000, 0b1},
   355  			},
   356  		},
   357  	}
   358  
   359  	scoreAtts := func(atts []*ethpb.Attestation) uint64 {
   360  		score := uint64(0)
   361  		sort.Slice(atts, func(i, j int) bool {
   362  			return atts[i].AggregationBits.Count() > atts[j].AggregationBits.Count()
   363  		})
   364  		// Score the best aggregate.
   365  		if len(atts) > 0 {
   366  			score = atts[0].AggregationBits.Count()
   367  		}
   368  		return score
   369  	}
   370  
   371  	generateAtts := func(bitsList [][]byte) []*ethpb.Attestation {
   372  		sign := bls.NewAggregateSignature().Marshal()
   373  		atts := make([]*ethpb.Attestation, 0)
   374  		for _, b := range bitsList {
   375  			atts = append(atts, &ethpb.Attestation{
   376  				AggregationBits: b,
   377  				Signature:       sign,
   378  			})
   379  		}
   380  		return atts
   381  	}
   382  
   383  	for _, tt := range tests {
   384  		t.Run(tt.name, func(t *testing.T) {
   385  			atts, err := NaiveAttestationAggregation(generateAtts(tt.bitsList))
   386  			require.NoError(t, err)
   387  			score1 := scoreAtts(atts)
   388  
   389  			atts, err = MaxCoverAttestationAggregation(generateAtts(tt.bitsList))
   390  			require.NoError(t, err)
   391  			score2 := scoreAtts(atts)
   392  
   393  			t.Logf("native = %d, max-cover: %d\n", score1, score2)
   394  			assert.Equal(t, true, score1 <= score2,
   395  				"max-cover failed to produce higher score (naive: %d, max-cover: %d)", score1, score2)
   396  		})
   397  	}
   398  }