git.frostfs.info/TrueCloudLab/frostfs-sdk-go@v0.0.0-20241022124111-5361f0ecebd3/netmap/selector.go (about)

     1  package netmap
     2  
     3  import (
     4  	"cmp"
     5  	"fmt"
     6  	"slices"
     7  
     8  	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/netmap"
     9  	"git.frostfs.info/TrueCloudLab/hrw"
    10  )
    11  
    12  // processSelectors processes selectors and returns error is any of them is invalid.
    13  func (c *context) processSelectors(p PlacementPolicy) error {
    14  	for i := range p.selectors {
    15  		fName := p.selectors[i].GetFilter()
    16  		if fName != mainFilterName {
    17  			_, ok := c.processedFilters[p.selectors[i].GetFilter()]
    18  			if !ok {
    19  				return fmt.Errorf("%w: SELECT FROM '%s'", errFilterNotFound, fName)
    20  			}
    21  		}
    22  
    23  		sName := p.selectors[i].GetName()
    24  
    25  		c.processedSelectors[sName] = &p.selectors[i]
    26  
    27  		result, err := c.getSelection(p.selectors[i])
    28  		if err != nil {
    29  			return err
    30  		}
    31  
    32  		c.selections[sName] = result
    33  	}
    34  
    35  	return nil
    36  }
    37  
    38  // calcNodesCount returns number of buckets and minimum number of nodes in every bucket
    39  // for the given selector.
    40  func calcNodesCount(s netmap.Selector) (int, int) {
    41  	switch s.GetClause() {
    42  	case netmap.Same:
    43  		return 1, int(s.GetCount())
    44  	default:
    45  		return int(s.GetCount()), 1
    46  	}
    47  }
    48  
    49  // calcBucketWeight computes weight for a node bucket.
    50  func calcBucketWeight(ns nodes, a aggregator, wf weightFunc) float64 {
    51  	for i := range ns {
    52  		a.Add(wf(ns[i]))
    53  	}
    54  
    55  	return a.Compute()
    56  }
    57  
    58  // getSelection returns nodes grouped by s.attribute.
    59  // Last argument specifies if more buckets can be used to fulfill CBF.
    60  func (c *context) getSelection(s netmap.Selector) ([]nodes, error) {
    61  	bucketCount, nodesInBucket := calcNodesCount(s)
    62  	buckets := c.getSelectionBase(s)
    63  
    64  	if c.strict && len(buckets) < bucketCount {
    65  		return nil, fmt.Errorf("%w: '%s'", errNotEnoughNodes, s.GetName())
    66  	}
    67  
    68  	// We need deterministic output in case there is no pivot.
    69  	// If pivot is set, buckets are sorted by HRW.
    70  	// However, because initial order influences HRW order for buckets with equal weights,
    71  	// we also need to have deterministic input to HRW sorting routine.
    72  	if len(c.hrwSeed) == 0 {
    73  		if s.GetAttribute() == "" {
    74  			slices.SortFunc(buckets, func(b1, b2 nodeAttrPair) int {
    75  				return cmp.Compare(b1.nodes[0].Hash(), b2.nodes[0].Hash())
    76  			})
    77  		} else {
    78  			slices.SortFunc(buckets, func(b1, b2 nodeAttrPair) int {
    79  				return cmp.Compare(b1.attr, b2.attr)
    80  			})
    81  		}
    82  	}
    83  
    84  	maxNodesInBucket := nodesInBucket * int(c.cbf)
    85  	res := make([]nodes, 0, len(buckets))
    86  	fallback := make([]nodes, 0, len(buckets))
    87  
    88  	for i := range buckets {
    89  		ns := buckets[i].nodes
    90  		if len(ns) >= maxNodesInBucket {
    91  			res = append(res, ns[:maxNodesInBucket])
    92  		} else if len(ns) >= nodesInBucket {
    93  			fallback = append(fallback, ns)
    94  		}
    95  	}
    96  
    97  	if len(res) < bucketCount {
    98  		// Fallback to using minimum allowed backup factor (1).
    99  		res = append(res, fallback...)
   100  		if c.strict && len(res) < bucketCount {
   101  			return nil, fmt.Errorf("%w: '%s'", errNotEnoughNodes, s.GetName())
   102  		}
   103  	}
   104  
   105  	if len(c.hrwSeed) != 0 {
   106  		weights := make([]float64, len(res))
   107  		a := new(meanIQRAgg)
   108  		for i := range res {
   109  			a.clear()
   110  			weights[i] = calcBucketWeight(res[i], a, c.weightFunc)
   111  		}
   112  
   113  		hrw.SortHasherSliceByWeightValue(res, weights, c.hrwSeedHash)
   114  	}
   115  
   116  	if len(res) < bucketCount {
   117  		if len(res) == 0 {
   118  			return nil, errNotEnoughNodes
   119  		}
   120  		bucketCount = len(res)
   121  	}
   122  
   123  	if s.GetAttribute() == "" {
   124  		res, fallback = res[:bucketCount], res[bucketCount:]
   125  		for i := range fallback {
   126  			index := i % bucketCount
   127  			if len(res[index]) >= maxNodesInBucket {
   128  				break
   129  			}
   130  			res[index] = append(res[index], fallback[i]...)
   131  		}
   132  	}
   133  
   134  	return res[:bucketCount], nil
   135  }
   136  
   137  type nodeAttrPair struct {
   138  	attr  string
   139  	nodes nodes
   140  }
   141  
   142  // getSelectionBase returns nodes grouped by selector attribute.
   143  // It it guaranteed that each pair will contain at least one node.
   144  func (c *context) getSelectionBase(s netmap.Selector) []nodeAttrPair {
   145  	fName := s.GetFilter()
   146  	f := c.processedFilters[fName]
   147  	isMain := fName == mainFilterName
   148  	result := []nodeAttrPair{}
   149  	nodeMap := map[string][]NodeInfo{}
   150  	attr := s.GetAttribute()
   151  
   152  	for i := range c.netMap.nodes {
   153  		if c.usedNodes[c.netMap.nodes[i].hash] {
   154  			continue
   155  		}
   156  		if isMain || c.match(f, c.netMap.nodes[i]) {
   157  			if attr == "" {
   158  				// Default attribute is transparent identifier which is different for every node.
   159  				result = append(result, nodeAttrPair{attr: "", nodes: nodes{c.netMap.nodes[i]}})
   160  			} else {
   161  				v := c.netMap.nodes[i].Attribute(attr)
   162  				nodeMap[v] = append(nodeMap[v], c.netMap.nodes[i])
   163  			}
   164  		}
   165  	}
   166  
   167  	if attr != "" {
   168  		for k, ns := range nodeMap {
   169  			result = append(result, nodeAttrPair{attr: k, nodes: ns})
   170  		}
   171  	}
   172  
   173  	if len(c.hrwSeed) != 0 {
   174  		var ws []float64
   175  		for i := range result {
   176  			ws = result[i].nodes.appendWeightsTo(c.weightFunc, ws[:0])
   177  			hrw.SortHasherSliceByWeightValue(result[i].nodes, ws, c.hrwSeedHash)
   178  		}
   179  	}
   180  
   181  	return result
   182  }