github.com/ava-labs/avalanchego@v1.11.11/vms/platformvm/warp/validator.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package warp 5 6 import ( 7 "bytes" 8 "context" 9 "errors" 10 "fmt" 11 12 "golang.org/x/exp/maps" 13 14 "github.com/ava-labs/avalanchego/ids" 15 "github.com/ava-labs/avalanchego/snow/validators" 16 "github.com/ava-labs/avalanchego/utils" 17 "github.com/ava-labs/avalanchego/utils/crypto/bls" 18 "github.com/ava-labs/avalanchego/utils/math" 19 "github.com/ava-labs/avalanchego/utils/set" 20 ) 21 22 var ( 23 _ utils.Sortable[*Validator] = (*Validator)(nil) 24 25 ErrUnknownValidator = errors.New("unknown validator") 26 ErrWeightOverflow = errors.New("weight overflowed") 27 ) 28 29 // ValidatorState defines the functions that must be implemented to get 30 // the canonical validator set for warp message validation. 31 type ValidatorState interface { 32 GetValidatorSet(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) 33 } 34 35 type Validator struct { 36 PublicKey *bls.PublicKey 37 PublicKeyBytes []byte 38 Weight uint64 39 NodeIDs []ids.NodeID 40 } 41 42 func (v *Validator) Compare(o *Validator) int { 43 return bytes.Compare(v.PublicKeyBytes, o.PublicKeyBytes) 44 } 45 46 // GetCanonicalValidatorSet returns the validator set of [subnetID] at 47 // [pChcainHeight] in a canonical ordering. Also returns the total weight on 48 // [subnetID]. 49 func GetCanonicalValidatorSet( 50 ctx context.Context, 51 pChainState ValidatorState, 52 pChainHeight uint64, 53 subnetID ids.ID, 54 ) ([]*Validator, uint64, error) { 55 // Get the validator set at the given height. 56 vdrSet, err := pChainState.GetValidatorSet(ctx, pChainHeight, subnetID) 57 if err != nil { 58 return nil, 0, err 59 } 60 61 // Convert the validator set into the canonical ordering. 62 return FlattenValidatorSet(vdrSet) 63 } 64 65 // FlattenValidatorSet converts the provided [vdrSet] into a canonical ordering. 66 // Also returns the total weight of the validator set. 67 func FlattenValidatorSet(vdrSet map[ids.NodeID]*validators.GetValidatorOutput) ([]*Validator, uint64, error) { 68 var ( 69 vdrs = make(map[string]*Validator, len(vdrSet)) 70 totalWeight uint64 71 err error 72 ) 73 for _, vdr := range vdrSet { 74 totalWeight, err = math.Add(totalWeight, vdr.Weight) 75 if err != nil { 76 return nil, 0, fmt.Errorf("%w: %w", ErrWeightOverflow, err) 77 } 78 79 if vdr.PublicKey == nil { 80 continue 81 } 82 83 pkBytes := bls.PublicKeyToUncompressedBytes(vdr.PublicKey) 84 uniqueVdr, ok := vdrs[string(pkBytes)] 85 if !ok { 86 uniqueVdr = &Validator{ 87 PublicKey: vdr.PublicKey, 88 PublicKeyBytes: pkBytes, 89 } 90 vdrs[string(pkBytes)] = uniqueVdr 91 } 92 93 uniqueVdr.Weight += vdr.Weight // Impossible to overflow here 94 uniqueVdr.NodeIDs = append(uniqueVdr.NodeIDs, vdr.NodeID) 95 } 96 97 // Sort validators by public key 98 vdrList := maps.Values(vdrs) 99 utils.Sort(vdrList) 100 return vdrList, totalWeight, nil 101 } 102 103 // FilterValidators returns the validators in [vdrs] whose bit is set to 1 in 104 // [indices]. 105 // 106 // Returns an error if [indices] references an unknown validator. 107 func FilterValidators( 108 indices set.Bits, 109 vdrs []*Validator, 110 ) ([]*Validator, error) { 111 // Verify that all alleged signers exist 112 if indices.BitLen() > len(vdrs) { 113 return nil, fmt.Errorf( 114 "%w: NumIndices (%d) >= NumFilteredValidators (%d)", 115 ErrUnknownValidator, 116 indices.BitLen()-1, // -1 to convert from length to index 117 len(vdrs), 118 ) 119 } 120 121 filteredVdrs := make([]*Validator, 0, len(vdrs)) 122 for i, vdr := range vdrs { 123 if !indices.Contains(i) { 124 continue 125 } 126 127 filteredVdrs = append(filteredVdrs, vdr) 128 } 129 return filteredVdrs, nil 130 } 131 132 // SumWeight returns the total weight of the provided validators. 133 func SumWeight(vdrs []*Validator) (uint64, error) { 134 var ( 135 weight uint64 136 err error 137 ) 138 for _, vdr := range vdrs { 139 weight, err = math.Add(weight, vdr.Weight) 140 if err != nil { 141 return 0, fmt.Errorf("%w: %w", ErrWeightOverflow, err) 142 } 143 } 144 return weight, nil 145 } 146 147 // AggregatePublicKeys returns the public key of the provided validators. 148 // 149 // Invariant: All of the public keys in [vdrs] are valid. 150 func AggregatePublicKeys(vdrs []*Validator) (*bls.PublicKey, error) { 151 pks := make([]*bls.PublicKey, len(vdrs)) 152 for i, vdr := range vdrs { 153 pks[i] = vdr.PublicKey 154 } 155 return bls.AggregatePublicKeys(pks) 156 }