github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/reshuffle_test.go (about) 1 // Copyright 2019 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package bigslice_test 6 7 import ( 8 "context" 9 "fmt" 10 "math/rand" 11 "reflect" 12 "sort" 13 "strings" 14 "testing" 15 16 "github.com/grailbio/bigslice" 17 "github.com/grailbio/bigslice/exec" 18 "github.com/grailbio/bigslice/frame" 19 "github.com/grailbio/bigslice/sliceio" 20 "github.com/grailbio/bigslice/slicetest" 21 ) 22 23 func reshuffleTest(t *testing.T, transform func(bigslice.Slice) bigslice.Slice) { 24 t.Helper() 25 const N = 100 26 col1 := make([]lengthHashKey, N) 27 col2 := make([]int, N) 28 want := map[lengthHashKey][]int{} // map of val1 -> []val2 29 rnd := rand.New(rand.NewSource(0)) 30 for i := range col1 { 31 s := strings.Repeat("a", rnd.Intn(N/10)) // We want many strings of each length. 32 col1[i] = lengthHashKey(s) 33 col2[i] = rnd.Intn(100) 34 want[col1[i]] = append(want[col1[i]], col2[i]) 35 } 36 for _, vals := range want { 37 sort.Ints(vals) // for equality test later 38 } 39 for m := 1; m < 10; m++ { 40 t.Run(fmt.Sprint(m), func(t *testing.T) { 41 slice := bigslice.Const(m, append([]lengthHashKey{}, col1...), append([]int{}, col2...)) 42 slice = transform(slice) 43 44 // map of col1 length -> set of shards that had keys of that length. 45 lengthShards := map[int]map[int]struct{}{} 46 got := map[lengthHashKey][]int{} // map of val1 -> []val2 47 var ( 48 val1 lengthHashKey 49 val2 int 50 ) 51 slice = bigslice.Scan(slice, func(shard int, scanner *sliceio.Scanner) error { 52 for scanner.Scan(context.Background(), &val1, &val2) { 53 if _, ok := lengthShards[len(val1)]; !ok { 54 lengthShards[len(val1)] = map[int]struct{}{} 55 } 56 lengthShards[len(val1)][shard] = struct{}{} 57 got[val1] = append(got[val1], val2) 58 } 59 return scanner.Err() 60 }) 61 62 sess := exec.Start(exec.Local) 63 defer sess.Shutdown() 64 _, err := sess.Run(context.Background(), bigslice.Func(func() bigslice.Slice { return slice })) 65 if err != nil { 66 t.Fatalf("run error %v", err) 67 } 68 69 for length, shards := range lengthShards { 70 if len(shards) != 1 { 71 t.Errorf("found keys of length %d in multiple shards: %v", length, shards) 72 } 73 } 74 75 for _, vals := range got { 76 sort.Ints(vals) 77 } 78 if !reflect.DeepEqual(got, want) { 79 t.Errorf("got %v, want %v", got, want) 80 } 81 }) 82 } 83 } 84 85 type lengthHashKey string 86 87 func init() { 88 frame.RegisterOps(func(slice []lengthHashKey) frame.Ops { 89 return frame.Ops{ 90 Less: func(i, j int) bool { return slice[i] < slice[j] }, 91 HashWithSeed: func(i int, _ uint32) uint32 { return uint32(len(slice[i])) }, 92 } 93 }) 94 } 95 96 func TestReshuffle(t *testing.T) { 97 reshuffleTest(t, bigslice.Reshuffle) 98 } 99 100 func TestRepartition(t *testing.T) { 101 reshuffleTest(t, func(slice bigslice.Slice) bigslice.Slice { 102 return bigslice.Repartition(slice, func(nshard int, key lengthHashKey, value int) int { 103 return len(key) % nshard 104 }) 105 }) 106 } 107 108 func TestRepartitionType(t *testing.T) { 109 slice := bigslice.Const(1, []int{}, []string{}) 110 expectTypeError(t, "repartition: expected func(int, int, string) int, got func() int", func() { 111 bigslice.Repartition(slice, func() int { return 0 }) 112 }) 113 expectTypeError(t, "repartition: expected func(int, int, string) int, got func(int, int, string)", func() { 114 bigslice.Repartition(slice, func(_ int, _ int, _ string) {}) 115 }) 116 } 117 118 func ExampleRepartition() { 119 // Count rows per shard before and after using Repartition to get ideal 120 // partitioning by taking advantage of the knowledge that our keys are 121 // sequential integers. 122 123 // countRowsPerShard is a utility that counts the number of rows per shard 124 // and stores it in rowsPerShard. 125 var rowsPerShard []int 126 countRowsPerShard := func(numShards int, slice bigslice.Slice) bigslice.Slice { 127 rowsPerShard = make([]int, numShards) 128 return bigslice.WriterFunc(slice, 129 func(shard int, _ struct{}, _ error, xs []int) error { 130 rowsPerShard[shard] += len(xs) 131 return nil 132 }, 133 ) 134 } 135 136 const numShards = 2 137 slice := bigslice.Const(numShards, []int{1, 2, 3, 4, 5, 6}) 138 139 slice0 := countRowsPerShard(numShards, slice) 140 fmt.Println("# default partitioning") 141 fmt.Println("## slice contents") 142 slicetest.Print(slice0) 143 fmt.Println("## row count per shard") 144 for shard, count := range rowsPerShard { 145 fmt.Printf("shard:%d count:%d\n", shard, count) 146 } 147 148 slice1 := bigslice.Repartition(slice, func(nshard, x int) int { 149 // Put everything in partition 0 for illustration. 150 return 0 151 }) 152 slice1 = countRowsPerShard(numShards, slice1) 153 fmt.Println("# repartitioned") 154 // Note that the slice contents are unchanged. 155 fmt.Println("## slice contents") 156 slicetest.Print(slice1) 157 // Note that the partitioning has changed. 158 fmt.Println("## row count per shard") 159 for shard, count := range rowsPerShard { 160 fmt.Printf("shard:%d count:%d\n", shard, count) 161 } 162 // Output: 163 // # default partitioning 164 // ## slice contents 165 // 1 166 // 2 167 // 3 168 // 4 169 // 5 170 // 6 171 // ## row count per shard 172 // shard:0 count:3 173 // shard:1 count:3 174 // # repartitioned 175 // ## slice contents 176 // 1 177 // 2 178 // 3 179 // 4 180 // 5 181 // 6 182 // ## row count per shard 183 // shard:0 count:6 184 // shard:1 count:0 185 } 186 187 func ExampleReshard() { 188 // Count rows per shard before and after using Reshard to change the number 189 // of shards from 2 to 4. 190 191 // countRowsPerShard is a utility that counts the number of rows per shard 192 // and stores it in rowsPerShard. 193 var rowsPerShard []int 194 countRowsPerShard := func(numShards int, slice bigslice.Slice) bigslice.Slice { 195 rowsPerShard = make([]int, numShards) 196 return bigslice.WriterFunc(slice, 197 func(shard int, _ struct{}, _ error, xs []int) error { 198 rowsPerShard[shard] += len(xs) 199 return nil 200 }, 201 ) 202 } 203 204 const beforeNumShards = 2 205 slice := bigslice.Const(beforeNumShards, []int{1, 2, 3, 4, 5, 6}) 206 207 before := countRowsPerShard(beforeNumShards, slice) 208 fmt.Println("# before") 209 fmt.Println("## slice contents") 210 slicetest.Print(before) 211 fmt.Println("## row count per shard") 212 for shard, count := range rowsPerShard { 213 fmt.Printf("shard:%d count:%d\n", shard, count) 214 } 215 216 // Reshard to 4 shards. 217 const afterNumShards = 4 218 after := bigslice.Reshard(slice, afterNumShards) 219 after = countRowsPerShard(afterNumShards, after) 220 fmt.Println("# after") 221 fmt.Println("## slice contents") 222 slicetest.Print(after) 223 fmt.Println("## row count per shard") 224 for shard, count := range rowsPerShard { 225 fmt.Printf("shard:%d count:%d\n", shard, count) 226 } 227 // Output: 228 // # before 229 // ## slice contents 230 // 1 231 // 2 232 // 3 233 // 4 234 // 5 235 // 6 236 // ## row count per shard 237 // shard:0 count:3 238 // shard:1 count:3 239 // # after 240 // ## slice contents 241 // 1 242 // 2 243 // 3 244 // 4 245 // 5 246 // 6 247 // ## row count per shard 248 // shard:0 count:2 249 // shard:1 count:1 250 // shard:2 count:1 251 // shard:3 count:2 252 } 253 254 func ExampleReshuffle() { 255 // Count rows per shard before and after a Reshuffle, showing same-keyed 256 // rows all go to the same shard. 257 258 // countRowsPerShard is a utility that counts the number of rows per shard 259 // and stores it in rowsPerShard. 260 var rowsPerShard []int 261 countRowsPerShard := func(numShards int, slice bigslice.Slice) bigslice.Slice { 262 rowsPerShard = make([]int, numShards) 263 return bigslice.WriterFunc(slice, 264 func(shard int, _ struct{}, _ error, xs []int) error { 265 rowsPerShard[shard] += len(xs) 266 return nil 267 }, 268 ) 269 } 270 271 const numShards = 2 272 slice := bigslice.Const(numShards, []int{1, 2, 3, 4, 5, 6}) 273 slice = bigslice.Map(slice, func(_ int) int { return 0 }) 274 275 before := countRowsPerShard(numShards, slice) 276 fmt.Println("# before") 277 fmt.Println("## slice contents") 278 slicetest.Print(before) 279 fmt.Println("## row count per shard") 280 for shard, count := range rowsPerShard { 281 fmt.Printf("shard:%d count:%d\n", shard, count) 282 } 283 284 after := bigslice.Reshuffle(slice) 285 after = countRowsPerShard(numShards, after) 286 fmt.Println("# after") 287 // We set all our keys to 0. After reshuffling, all rows will be in the same 288 // shard. 289 fmt.Println("## slice contents") 290 slicetest.Print(after) 291 fmt.Println("## row count per shard") 292 for shard, count := range rowsPerShard { 293 fmt.Printf("shard:%d count:%d\n", shard, count) 294 } 295 // Output: 296 // # before 297 // ## slice contents 298 // 0 299 // 0 300 // 0 301 // 0 302 // 0 303 // 0 304 // ## row count per shard 305 // shard:0 count:3 306 // shard:1 count:3 307 // # after 308 // ## slice contents 309 // 0 310 // 0 311 // 0 312 // 0 313 // 0 314 // 0 315 // ## row count per shard 316 // shard:0 count:6 317 // shard:1 count:0 318 }