go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/tq/internal/partition/partition.go (about)

     1  // Copyright 2020 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package partition encapsulates partitioning and querying large keyspace which
    16  // can't be expressed even as uint64.
    17  //
    18  // All to/from string functions use hex encoding.
    19  package partition
    20  
    21  import (
    22  	"container/list"
    23  	"fmt"
    24  	"math/big"
    25  	"sort"
    26  	"strings"
    27  
    28  	"go.chromium.org/luci/common/errors"
    29  )
    30  
    31  // Partition represents a range [Low..High).
    32  type Partition struct {
    33  	Low  big.Int // inclusive
    34  	High big.Int // exclusive. May be equal to max SHA2 hash value + 1.
    35  }
    36  
    37  // SortedPartitions are disjoint partitions sorted by ascending .Low field.
    38  type SortedPartitions []*Partition
    39  
    40  func FromInts(low, high int64) *Partition {
    41  	if low > high {
    42  		panic(errors.Reason("Partition %d..%d is invalid", low, high))
    43  	}
    44  	p := &Partition{}
    45  	p.Low.SetInt64(low)
    46  	p.High.SetInt64(high)
    47  	return p
    48  }
    49  
    50  func SpanInclusive(low, highInclusive string) (*Partition, error) {
    51  	p := &Partition{}
    52  	if err := setBigIntFromString(&p.Low, low); err != nil {
    53  		return nil, err
    54  	}
    55  	if err := setBigIntFromString(&p.High, highInclusive); err != nil {
    56  		return nil, err
    57  	}
    58  	p.High.Add(&p.High, bigInt1) // s.high++
    59  	if p.Low.Cmp(&p.High) > 0 {
    60  		return nil, errors.Reason("Partition %s is invalid", p.String()).Err()
    61  	}
    62  	return p, nil
    63  }
    64  
    65  func Universe(keySpaceBytes int) *Partition {
    66  	p := &Partition{}
    67  	p.High.SetBit(&p.High, keySpaceBytes*8, 1) // 2^(keySpaceBytes*8)
    68  	return p
    69  }
    70  
    71  func FromString(s string) (*Partition, error) {
    72  	i := strings.Index(s, "_")
    73  	if i <= 0 || i == len(s)-1 {
    74  		return nil, errors.Reason("partition %q has invalid format", s).Err()
    75  	}
    76  	p := &Partition{}
    77  	if err := setBigIntFromString(&p.Low, s[:i]); err != nil {
    78  		return nil, err
    79  	}
    80  	if err := setBigIntFromString(&p.High, s[i+1:]); err != nil {
    81  		return nil, err
    82  	}
    83  	if p.Low.Cmp(&p.High) > 0 {
    84  		return nil, errors.Reason("Partition %s is invalid", p.String()).Err()
    85  	}
    86  	return p, nil
    87  }
    88  
    89  func (p Partition) String() string {
    90  	return fmt.Sprintf("%s_%s", p.Low.Text(16 /*hex*/), p.High.Text(16 /*hex*/))
    91  }
    92  
    93  func (p Partition) MarshalJSON() ([]byte, error) {
    94  	return []byte(fmt.Sprintf(`"%s_%s"`, p.Low.Text(16 /*hex*/), p.High.Text(16 /*hex*/))), nil
    95  }
    96  
    97  func (p *Partition) UnmarshalJSON(bs []byte) error {
    98  	s := string(bs)
    99  	switch {
   100  	case s == `null`:
   101  		return nil
   102  	case len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"':
   103  		return errors.Reason("invalid JSON-serialized partition %q", s).Err()
   104  	default:
   105  		if tmp, err := FromString(s[1 : len(s)-1]); err != nil {
   106  			return err
   107  		} else {
   108  			*p = *tmp
   109  			return nil
   110  		}
   111  	}
   112  }
   113  
   114  func (p Partition) Copy() *Partition {
   115  	r := &Partition{}
   116  	r.Low.Set(&p.Low)
   117  	r.High.Set(&p.High)
   118  	return r
   119  }
   120  
   121  func (p Partition) QueryBounds(keySpaceBytes int) (low, high string) {
   122  	low = paddedHex(&p.Low, keySpaceBytes)
   123  	if !inKeySpace(&p.High, keySpaceBytes) {
   124  		// In practice, this should mean p.high == 2^(keySpaceBytes*8).
   125  		high = "g" // all hex strings are smaller than "g".
   126  	} else {
   127  		high = paddedHex(&p.High, keySpaceBytes)
   128  	}
   129  	return
   130  }
   131  
   132  func (p Partition) Split(shards int) SortedPartitions {
   133  	if shards <= 0 {
   134  		panic(">=1 shard required")
   135  	}
   136  	var increment, remainder, cur big.Int
   137  	increment.QuoRem(
   138  		cur.Sub(&p.High, &p.Low),
   139  		big.NewInt(int64(shards)),
   140  		&remainder)
   141  	if remainder.Cmp(bigInt0) > 0 {
   142  		increment.Add(&increment, bigInt1)
   143  	}
   144  
   145  	partitions := make([]*Partition, 0, shards)
   146  	cur.Set(&p.Low)
   147  	for cur.Cmp(&p.High) < 0 {
   148  		next := &Partition{}
   149  		next.Low.Set(&cur)
   150  		next.High.Add(&cur, &increment)
   151  		cur.Set(&next.High)
   152  		partitions = append(partitions, next)
   153  	}
   154  	// Due to int division to compute the increment, the last partition may
   155  	// overshoot, so ensure it ends exactly at the end of the original.
   156  	partitions[len(partitions)-1].High = p.High
   157  	return partitions
   158  }
   159  
   160  // EducatedSplitAfter splits partition after a given boundary assuming constant
   161  // density s.t. each shard has approximately targetItems.
   162  //
   163  // Caps the number of resulting partitions to at most maxShards.
   164  // panics if called on invalid data.
   165  func (p Partition) EducatedSplitAfter(exclusive string, beforeItems, targetItems, maxShards int) SortedPartitions {
   166  	remaining := Partition{}
   167  	if err := setBigIntFromString(&remaining.Low, exclusive); err != nil {
   168  		panic(err)
   169  	}
   170  	if p.Low.Cmp(&remaining.Low) > 0 { // low > remaining.Low
   171  		panic("must be within the partition")
   172  	}
   173  	if p.High.Cmp(&remaining.Low) <= 0 { // high <= remaining.Low
   174  		panic("must be within the partition")
   175  	}
   176  	remaining.Low.Add(&remaining.Low, bigInt1) // remaining.Low++
   177  	remaining.High.Set(&p.High)
   178  
   179  	// Compute expShards as
   180  	//
   181  	//     beforeItems / len(before) * len(remaining) / targetItems
   182  	//
   183  	// in a somewhat readable way as
   184  	//
   185  	//     (beforeItems * len(remaining)) / ( targetItems * len(before))
   186  	//
   187  	// NOTE: this can be optimized if needed to avoid excessive memory allocations
   188  	// in bit.Int at the cost of readability.
   189  	iBefore := big.NewInt(int64(beforeItems))
   190  	iTarget := big.NewInt(int64(targetItems))
   191  	var expShards, iRemainder big.Int
   192  	expShards.QuoRem(
   193  		(&big.Int{}).Mul(iBefore, distance(&remaining.Low, &remaining.High)),
   194  		(&big.Int{}).Mul(iTarget, distance(&p.Low, &remaining.Low)),
   195  		&iRemainder,
   196  	)
   197  	if iRemainder.Cmp(bigInt0) > 0 {
   198  		expShards.Add(&expShards, bigInt1)
   199  	}
   200  	shards := maxShards
   201  	if expShards.Cmp(big.NewInt(int64(maxShards))) < 0 {
   202  		shards = int(expShards.Int64())
   203  	}
   204  	return remaining.Split(shards)
   205  }
   206  
   207  // SortedPartitionsBuilder constructs a sequence of partitions by excluding
   208  // chunks from a starting partion.
   209  //
   210  // Not intended to scale to large number of exclusion operations.
   211  type SortedPartitionsBuilder struct {
   212  	// l holds partitions in sorted order, leading to O(len(l)) runtime of the
   213  	// Exclude().
   214  	//
   215  	// For max performance with >~20 exclusions, an interval tree should be used
   216  	// instead. Unfortunately, due to lack of generics in Go, most interval tree
   217  	// libraries expect float64 or int64 nounds, not big.Int.
   218  	l *list.List
   219  }
   220  
   221  func NewSortedPartitionsBuilder(p *Partition) SortedPartitionsBuilder {
   222  	b := SortedPartitionsBuilder{l: list.New()}
   223  	b.l.PushBack(p.Copy())
   224  	return b
   225  }
   226  
   227  func (b *SortedPartitionsBuilder) IsEmpty() bool {
   228  	return b.l.Len() == 0
   229  }
   230  
   231  func (b *SortedPartitionsBuilder) Result() SortedPartitions {
   232  	r := make([]*Partition, 0, b.l.Len())
   233  	for el := b.l.Front(); el != nil; el = el.Next() {
   234  		r = append(r, el.Value.(*Partition))
   235  	}
   236  	return r
   237  }
   238  
   239  func (b *SortedPartitionsBuilder) Exclude(exclude *Partition) {
   240  	for el := b.l.Front(); el != nil; {
   241  		avail := el.Value.(*Partition)
   242  		switch {
   243  		case exclude.Low.Cmp(&avail.High) >= 0:
   244  			// avail < exclude
   245  			el = el.Next()
   246  
   247  		case exclude.High.Cmp(&avail.Low) <= 0:
   248  			// exclude < avail
   249  			return
   250  
   251  		case exclude.Low.Cmp(&avail.Low) <= 0:
   252  			// front excluded
   253  			if exclude.High.Cmp(&avail.High) >= 0 {
   254  				// back also excluded
   255  				next := el.Next()
   256  				b.l.Remove(el)
   257  				el = next
   258  			} else {
   259  				// only back remains.
   260  				avail.Low.Set(&exclude.High)
   261  				return
   262  			}
   263  
   264  		case exclude.High.Cmp(&avail.High) >= 0:
   265  			// only front remains.
   266  			avail.High.Set(&exclude.Low)
   267  			el = el.Next()
   268  
   269  		default:
   270  			// middle is excluded.
   271  			second := &Partition{}
   272  			second.Low.Set(&exclude.High)
   273  			second.High.Set(&avail.High)
   274  			avail.High.Set(&exclude.Low)
   275  			b.l.InsertAfter(second, el)
   276  			return
   277  		}
   278  	}
   279  }
   280  
   281  // OnlyIn efficiently returns a subsequence of the `n` sorted by key objects
   282  // whose key belongs to one of the partitions.
   283  //
   284  // Calls use(i,j) for each objects[i:j] which belong to the range.
   285  func (ps SortedPartitions) OnlyIn(n int, key func(i int) string, use func(l, h int), keySpaceBytes int) {
   286  	k := 0
   287  	// Remaining slice is [k..n)
   288  	for len(ps) > 0 && k < n {
   289  		lowStr, highStr := ps[0].QueryBounds(keySpaceBytes)
   290  		fr := sort.Search(n-k, func(i int) bool { return key(k+i) >= lowStr })
   291  		if fr == n-k {
   292  			return
   293  		}
   294  		to := sort.Search(n-k-fr, func(i int) bool { return key(fr+k+i) >= highStr })
   295  		if to > 0 {
   296  			use(fr+k, k+fr+to)
   297  		}
   298  		// Can be optimized more by doing binary search over `ps` if fr == to == 0.
   299  		k = k + fr + to
   300  		ps = ps[1:]
   301  	}
   302  }
   303  
   304  // helpers
   305  
   306  var (
   307  	// these are effectively constants predefined to avoid needless memory allocations.
   308  
   309  	bigInt0 = big.NewInt(0)
   310  	bigInt1 = big.NewInt(1)
   311  )
   312  
   313  func distance(low, high *big.Int) *big.Int {
   314  	return (&big.Int{}).Sub(high, low)
   315  }
   316  
   317  func setBigIntFromString(b *big.Int, s string) error {
   318  	if _, ok := b.SetString(s, 16 /*hex*/); !ok {
   319  		return errors.Reason("invalid bigint hex %q", s).Err()
   320  	}
   321  	if b.Sign() == -1 {
   322  		return errors.Reason("negative value %q not allowed", s).Err()
   323  	}
   324  	return nil
   325  }
   326  
   327  func paddedHex(b *big.Int, keySpaceBytes int) string {
   328  	s := b.Text(16 /*hex*/)
   329  	return strings.Repeat("0", keySpaceBytes*2-len(s)) + s
   330  }
   331  
   332  // inKeySpace returns whether v does not exceed keyspace upper boundary.
   333  func inKeySpace(v *big.Int, keySpaceBytes int) bool {
   334  	return v.BitLen() <= keySpaceBytes*8
   335  }