github.com/prysmaticlabs/prysm@v1.4.4/shared/aggregation/attestations/attestations.go (about)

     1  package attestations
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1"
     6  	"github.com/prysmaticlabs/prysm/shared/aggregation"
     7  	"github.com/prysmaticlabs/prysm/shared/bls"
     8  	"github.com/prysmaticlabs/prysm/shared/copyutil"
     9  	"github.com/prysmaticlabs/prysm/shared/featureconfig"
    10  	"github.com/sirupsen/logrus"
    11  )
    12  
    13  const (
    14  	// NaiveAggregation is an aggregation strategy without any optimizations.
    15  	NaiveAggregation AttestationAggregationStrategy = "naive"
    16  
    17  	// MaxCoverAggregation is a strategy based on Maximum Coverage greedy algorithm.
    18  	MaxCoverAggregation AttestationAggregationStrategy = "max_cover"
    19  
    20  	// OptMaxCoverAggregation is a strategy based on Maximum Coverage greedy algorithm.
    21  	// This new variant is optimized and relies on Bitlist64 (once fully tested, `max_cover`
    22  	// strategy will be replaced with this one).
    23  	OptMaxCoverAggregation AttestationAggregationStrategy = "opt_max_cover"
    24  )
    25  
    26  // AttestationAggregationStrategy defines attestation aggregation strategy.
    27  type AttestationAggregationStrategy string
    28  
    29  // attList represents list of attestations, defined for easier en masse operations (filtering, sorting).
    30  type attList []*ethpb.Attestation
    31  
    32  // BLS aggregate signature aliases for testing / benchmark substitution. These methods are
    33  // significantly more expensive than the inner logic of AggregateAttestations so they must be
    34  // substituted for benchmarks which analyze AggregateAttestations.
    35  var aggregateSignatures = bls.AggregateSignatures
    36  var signatureFromBytes = bls.SignatureFromBytes
    37  
    38  var _ = logrus.WithField("prefix", "aggregation.attestations")
    39  
    40  // ErrInvalidAttestationCount is returned when insufficient number
    41  // of attestations is provided for aggregation.
    42  var ErrInvalidAttestationCount = errors.New("invalid number of attestations")
    43  
    44  // Aggregate aggregates attestations. The minimal number of attestations is returned.
    45  // Aggregation occurs in-place i.e. contents of input array will be modified. Should you need to
    46  // preserve input attestations, clone them before aggregating:
    47  //
    48  //   clonedAtts := make([]*ethpb.Attestation, len(atts))
    49  //   for i, a := range atts {
    50  //       clonedAtts[i] = stateTrie.CopyAttestation(a)
    51  //   }
    52  //   aggregatedAtts, err := attaggregation.Aggregate(clonedAtts)
    53  func Aggregate(atts []*ethpb.Attestation) ([]*ethpb.Attestation, error) {
    54  	strategy := AttestationAggregationStrategy(featureconfig.Get().AttestationAggregationStrategy)
    55  	switch strategy {
    56  	case "", NaiveAggregation:
    57  		return NaiveAttestationAggregation(atts)
    58  	case MaxCoverAggregation:
    59  		return MaxCoverAttestationAggregation(atts)
    60  	case OptMaxCoverAggregation:
    61  		return optMaxCoverAttestationAggregation(atts)
    62  	default:
    63  		return nil, errors.Wrapf(aggregation.ErrInvalidStrategy, "%q", strategy)
    64  	}
    65  }
    66  
    67  // AggregatePair aggregates pair of attestations a1 and a2 together.
    68  func AggregatePair(a1, a2 *ethpb.Attestation) (*ethpb.Attestation, error) {
    69  	o, err := a1.AggregationBits.Overlaps(a2.AggregationBits)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	if o {
    74  		return nil, aggregation.ErrBitsOverlap
    75  	}
    76  
    77  	baseAtt := copyutil.CopyAttestation(a1)
    78  	newAtt := copyutil.CopyAttestation(a2)
    79  	if newAtt.AggregationBits.Count() > baseAtt.AggregationBits.Count() {
    80  		baseAtt, newAtt = newAtt, baseAtt
    81  	}
    82  
    83  	c, err := baseAtt.AggregationBits.Contains(newAtt.AggregationBits)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	if c {
    88  		return baseAtt, nil
    89  	}
    90  
    91  	newBits, err := baseAtt.AggregationBits.Or(newAtt.AggregationBits)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	newSig, err := signatureFromBytes(newAtt.Signature)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	baseSig, err := signatureFromBytes(baseAtt.Signature)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	aggregatedSig := aggregateSignatures([]bls.Signature{baseSig, newSig})
   105  	baseAtt.Signature = aggregatedSig.Marshal()
   106  	baseAtt.AggregationBits = newBits
   107  
   108  	return baseAtt, nil
   109  }