go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/tq/internal/partition/partition.go (about) 1 // Copyright 2020 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package partition encapsulates partitioning and querying large keyspace which 16 // can't be expressed even as uint64. 17 // 18 // All to/from string functions use hex encoding. 19 package partition 20 21 import ( 22 "container/list" 23 "fmt" 24 "math/big" 25 "sort" 26 "strings" 27 28 "go.chromium.org/luci/common/errors" 29 ) 30 31 // Partition represents a range [Low..High). 32 type Partition struct { 33 Low big.Int // inclusive 34 High big.Int // exclusive. May be equal to max SHA2 hash value + 1. 35 } 36 37 // SortedPartitions are disjoint partitions sorted by ascending .Low field. 38 type SortedPartitions []*Partition 39 40 func FromInts(low, high int64) *Partition { 41 if low > high { 42 panic(errors.Reason("Partition %d..%d is invalid", low, high)) 43 } 44 p := &Partition{} 45 p.Low.SetInt64(low) 46 p.High.SetInt64(high) 47 return p 48 } 49 50 func SpanInclusive(low, highInclusive string) (*Partition, error) { 51 p := &Partition{} 52 if err := setBigIntFromString(&p.Low, low); err != nil { 53 return nil, err 54 } 55 if err := setBigIntFromString(&p.High, highInclusive); err != nil { 56 return nil, err 57 } 58 p.High.Add(&p.High, bigInt1) // s.high++ 59 if p.Low.Cmp(&p.High) > 0 { 60 return nil, errors.Reason("Partition %s is invalid", p.String()).Err() 61 } 62 return p, nil 63 } 64 65 func Universe(keySpaceBytes int) *Partition { 66 p := &Partition{} 67 p.High.SetBit(&p.High, keySpaceBytes*8, 1) // 2^(keySpaceBytes*8) 68 return p 69 } 70 71 func FromString(s string) (*Partition, error) { 72 i := strings.Index(s, "_") 73 if i <= 0 || i == len(s)-1 { 74 return nil, errors.Reason("partition %q has invalid format", s).Err() 75 } 76 p := &Partition{} 77 if err := setBigIntFromString(&p.Low, s[:i]); err != nil { 78 return nil, err 79 } 80 if err := setBigIntFromString(&p.High, s[i+1:]); err != nil { 81 return nil, err 82 } 83 if p.Low.Cmp(&p.High) > 0 { 84 return nil, errors.Reason("Partition %s is invalid", p.String()).Err() 85 } 86 return p, nil 87 } 88 89 func (p Partition) String() string { 90 return fmt.Sprintf("%s_%s", p.Low.Text(16 /*hex*/), p.High.Text(16 /*hex*/)) 91 } 92 93 func (p Partition) MarshalJSON() ([]byte, error) { 94 return []byte(fmt.Sprintf(`"%s_%s"`, p.Low.Text(16 /*hex*/), p.High.Text(16 /*hex*/))), nil 95 } 96 97 func (p *Partition) UnmarshalJSON(bs []byte) error { 98 s := string(bs) 99 switch { 100 case s == `null`: 101 return nil 102 case len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"': 103 return errors.Reason("invalid JSON-serialized partition %q", s).Err() 104 default: 105 if tmp, err := FromString(s[1 : len(s)-1]); err != nil { 106 return err 107 } else { 108 *p = *tmp 109 return nil 110 } 111 } 112 } 113 114 func (p Partition) Copy() *Partition { 115 r := &Partition{} 116 r.Low.Set(&p.Low) 117 r.High.Set(&p.High) 118 return r 119 } 120 121 func (p Partition) QueryBounds(keySpaceBytes int) (low, high string) { 122 low = paddedHex(&p.Low, keySpaceBytes) 123 if !inKeySpace(&p.High, keySpaceBytes) { 124 // In practice, this should mean p.high == 2^(keySpaceBytes*8). 125 high = "g" // all hex strings are smaller than "g". 126 } else { 127 high = paddedHex(&p.High, keySpaceBytes) 128 } 129 return 130 } 131 132 func (p Partition) Split(shards int) SortedPartitions { 133 if shards <= 0 { 134 panic(">=1 shard required") 135 } 136 var increment, remainder, cur big.Int 137 increment.QuoRem( 138 cur.Sub(&p.High, &p.Low), 139 big.NewInt(int64(shards)), 140 &remainder) 141 if remainder.Cmp(bigInt0) > 0 { 142 increment.Add(&increment, bigInt1) 143 } 144 145 partitions := make([]*Partition, 0, shards) 146 cur.Set(&p.Low) 147 for cur.Cmp(&p.High) < 0 { 148 next := &Partition{} 149 next.Low.Set(&cur) 150 next.High.Add(&cur, &increment) 151 cur.Set(&next.High) 152 partitions = append(partitions, next) 153 } 154 // Due to int division to compute the increment, the last partition may 155 // overshoot, so ensure it ends exactly at the end of the original. 156 partitions[len(partitions)-1].High = p.High 157 return partitions 158 } 159 160 // EducatedSplitAfter splits partition after a given boundary assuming constant 161 // density s.t. each shard has approximately targetItems. 162 // 163 // Caps the number of resulting partitions to at most maxShards. 164 // panics if called on invalid data. 165 func (p Partition) EducatedSplitAfter(exclusive string, beforeItems, targetItems, maxShards int) SortedPartitions { 166 remaining := Partition{} 167 if err := setBigIntFromString(&remaining.Low, exclusive); err != nil { 168 panic(err) 169 } 170 if p.Low.Cmp(&remaining.Low) > 0 { // low > remaining.Low 171 panic("must be within the partition") 172 } 173 if p.High.Cmp(&remaining.Low) <= 0 { // high <= remaining.Low 174 panic("must be within the partition") 175 } 176 remaining.Low.Add(&remaining.Low, bigInt1) // remaining.Low++ 177 remaining.High.Set(&p.High) 178 179 // Compute expShards as 180 // 181 // beforeItems / len(before) * len(remaining) / targetItems 182 // 183 // in a somewhat readable way as 184 // 185 // (beforeItems * len(remaining)) / ( targetItems * len(before)) 186 // 187 // NOTE: this can be optimized if needed to avoid excessive memory allocations 188 // in bit.Int at the cost of readability. 189 iBefore := big.NewInt(int64(beforeItems)) 190 iTarget := big.NewInt(int64(targetItems)) 191 var expShards, iRemainder big.Int 192 expShards.QuoRem( 193 (&big.Int{}).Mul(iBefore, distance(&remaining.Low, &remaining.High)), 194 (&big.Int{}).Mul(iTarget, distance(&p.Low, &remaining.Low)), 195 &iRemainder, 196 ) 197 if iRemainder.Cmp(bigInt0) > 0 { 198 expShards.Add(&expShards, bigInt1) 199 } 200 shards := maxShards 201 if expShards.Cmp(big.NewInt(int64(maxShards))) < 0 { 202 shards = int(expShards.Int64()) 203 } 204 return remaining.Split(shards) 205 } 206 207 // SortedPartitionsBuilder constructs a sequence of partitions by excluding 208 // chunks from a starting partion. 209 // 210 // Not intended to scale to large number of exclusion operations. 211 type SortedPartitionsBuilder struct { 212 // l holds partitions in sorted order, leading to O(len(l)) runtime of the 213 // Exclude(). 214 // 215 // For max performance with >~20 exclusions, an interval tree should be used 216 // instead. Unfortunately, due to lack of generics in Go, most interval tree 217 // libraries expect float64 or int64 nounds, not big.Int. 218 l *list.List 219 } 220 221 func NewSortedPartitionsBuilder(p *Partition) SortedPartitionsBuilder { 222 b := SortedPartitionsBuilder{l: list.New()} 223 b.l.PushBack(p.Copy()) 224 return b 225 } 226 227 func (b *SortedPartitionsBuilder) IsEmpty() bool { 228 return b.l.Len() == 0 229 } 230 231 func (b *SortedPartitionsBuilder) Result() SortedPartitions { 232 r := make([]*Partition, 0, b.l.Len()) 233 for el := b.l.Front(); el != nil; el = el.Next() { 234 r = append(r, el.Value.(*Partition)) 235 } 236 return r 237 } 238 239 func (b *SortedPartitionsBuilder) Exclude(exclude *Partition) { 240 for el := b.l.Front(); el != nil; { 241 avail := el.Value.(*Partition) 242 switch { 243 case exclude.Low.Cmp(&avail.High) >= 0: 244 // avail < exclude 245 el = el.Next() 246 247 case exclude.High.Cmp(&avail.Low) <= 0: 248 // exclude < avail 249 return 250 251 case exclude.Low.Cmp(&avail.Low) <= 0: 252 // front excluded 253 if exclude.High.Cmp(&avail.High) >= 0 { 254 // back also excluded 255 next := el.Next() 256 b.l.Remove(el) 257 el = next 258 } else { 259 // only back remains. 260 avail.Low.Set(&exclude.High) 261 return 262 } 263 264 case exclude.High.Cmp(&avail.High) >= 0: 265 // only front remains. 266 avail.High.Set(&exclude.Low) 267 el = el.Next() 268 269 default: 270 // middle is excluded. 271 second := &Partition{} 272 second.Low.Set(&exclude.High) 273 second.High.Set(&avail.High) 274 avail.High.Set(&exclude.Low) 275 b.l.InsertAfter(second, el) 276 return 277 } 278 } 279 } 280 281 // OnlyIn efficiently returns a subsequence of the `n` sorted by key objects 282 // whose key belongs to one of the partitions. 283 // 284 // Calls use(i,j) for each objects[i:j] which belong to the range. 285 func (ps SortedPartitions) OnlyIn(n int, key func(i int) string, use func(l, h int), keySpaceBytes int) { 286 k := 0 287 // Remaining slice is [k..n) 288 for len(ps) > 0 && k < n { 289 lowStr, highStr := ps[0].QueryBounds(keySpaceBytes) 290 fr := sort.Search(n-k, func(i int) bool { return key(k+i) >= lowStr }) 291 if fr == n-k { 292 return 293 } 294 to := sort.Search(n-k-fr, func(i int) bool { return key(fr+k+i) >= highStr }) 295 if to > 0 { 296 use(fr+k, k+fr+to) 297 } 298 // Can be optimized more by doing binary search over `ps` if fr == to == 0. 299 k = k + fr + to 300 ps = ps[1:] 301 } 302 } 303 304 // helpers 305 306 var ( 307 // these are effectively constants predefined to avoid needless memory allocations. 308 309 bigInt0 = big.NewInt(0) 310 bigInt1 = big.NewInt(1) 311 ) 312 313 func distance(low, high *big.Int) *big.Int { 314 return (&big.Int{}).Sub(high, low) 315 } 316 317 func setBigIntFromString(b *big.Int, s string) error { 318 if _, ok := b.SetString(s, 16 /*hex*/); !ok { 319 return errors.Reason("invalid bigint hex %q", s).Err() 320 } 321 if b.Sign() == -1 { 322 return errors.Reason("negative value %q not allowed", s).Err() 323 } 324 return nil 325 } 326 327 func paddedHex(b *big.Int, keySpaceBytes int) string { 328 s := b.Text(16 /*hex*/) 329 return strings.Repeat("0", keySpaceBytes*2-len(s)) + s 330 } 331 332 // inKeySpace returns whether v does not exceed keyspace upper boundary. 333 func inKeySpace(v *big.Int, keySpaceBytes int) bool { 334 return v.BitLen() <= keySpaceBytes*8 335 }