github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/reshuffle_test.go (about)

     1  // Copyright 2019 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package bigslice_test
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"math/rand"
    11  	"reflect"
    12  	"sort"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/grailbio/bigslice"
    17  	"github.com/grailbio/bigslice/exec"
    18  	"github.com/grailbio/bigslice/frame"
    19  	"github.com/grailbio/bigslice/sliceio"
    20  	"github.com/grailbio/bigslice/slicetest"
    21  )
    22  
    23  func reshuffleTest(t *testing.T, transform func(bigslice.Slice) bigslice.Slice) {
    24  	t.Helper()
    25  	const N = 100
    26  	col1 := make([]lengthHashKey, N)
    27  	col2 := make([]int, N)
    28  	want := map[lengthHashKey][]int{} // map of val1 -> []val2
    29  	rnd := rand.New(rand.NewSource(0))
    30  	for i := range col1 {
    31  		s := strings.Repeat("a", rnd.Intn(N/10)) // We want many strings of each length.
    32  		col1[i] = lengthHashKey(s)
    33  		col2[i] = rnd.Intn(100)
    34  		want[col1[i]] = append(want[col1[i]], col2[i])
    35  	}
    36  	for _, vals := range want {
    37  		sort.Ints(vals) // for equality test later
    38  	}
    39  	for m := 1; m < 10; m++ {
    40  		t.Run(fmt.Sprint(m), func(t *testing.T) {
    41  			slice := bigslice.Const(m, append([]lengthHashKey{}, col1...), append([]int{}, col2...))
    42  			slice = transform(slice)
    43  
    44  			// map of col1 length -> set of shards that had keys of that length.
    45  			lengthShards := map[int]map[int]struct{}{}
    46  			got := map[lengthHashKey][]int{} // map of val1 -> []val2
    47  			var (
    48  				val1 lengthHashKey
    49  				val2 int
    50  			)
    51  			slice = bigslice.Scan(slice, func(shard int, scanner *sliceio.Scanner) error {
    52  				for scanner.Scan(context.Background(), &val1, &val2) {
    53  					if _, ok := lengthShards[len(val1)]; !ok {
    54  						lengthShards[len(val1)] = map[int]struct{}{}
    55  					}
    56  					lengthShards[len(val1)][shard] = struct{}{}
    57  					got[val1] = append(got[val1], val2)
    58  				}
    59  				return scanner.Err()
    60  			})
    61  
    62  			sess := exec.Start(exec.Local)
    63  			defer sess.Shutdown()
    64  			_, err := sess.Run(context.Background(), bigslice.Func(func() bigslice.Slice { return slice }))
    65  			if err != nil {
    66  				t.Fatalf("run error %v", err)
    67  			}
    68  
    69  			for length, shards := range lengthShards {
    70  				if len(shards) != 1 {
    71  					t.Errorf("found keys of length %d in multiple shards: %v", length, shards)
    72  				}
    73  			}
    74  
    75  			for _, vals := range got {
    76  				sort.Ints(vals)
    77  			}
    78  			if !reflect.DeepEqual(got, want) {
    79  				t.Errorf("got %v, want %v", got, want)
    80  			}
    81  		})
    82  	}
    83  }
    84  
    85  type lengthHashKey string
    86  
    87  func init() {
    88  	frame.RegisterOps(func(slice []lengthHashKey) frame.Ops {
    89  		return frame.Ops{
    90  			Less:         func(i, j int) bool { return slice[i] < slice[j] },
    91  			HashWithSeed: func(i int, _ uint32) uint32 { return uint32(len(slice[i])) },
    92  		}
    93  	})
    94  }
    95  
    96  func TestReshuffle(t *testing.T) {
    97  	reshuffleTest(t, bigslice.Reshuffle)
    98  }
    99  
   100  func TestRepartition(t *testing.T) {
   101  	reshuffleTest(t, func(slice bigslice.Slice) bigslice.Slice {
   102  		return bigslice.Repartition(slice, func(nshard int, key lengthHashKey, value int) int {
   103  			return len(key) % nshard
   104  		})
   105  	})
   106  }
   107  
   108  func TestRepartitionType(t *testing.T) {
   109  	slice := bigslice.Const(1, []int{}, []string{})
   110  	expectTypeError(t, "repartition: expected func(int, int, string) int, got func() int", func() {
   111  		bigslice.Repartition(slice, func() int { return 0 })
   112  	})
   113  	expectTypeError(t, "repartition: expected func(int, int, string) int, got func(int, int, string)", func() {
   114  		bigslice.Repartition(slice, func(_ int, _ int, _ string) {})
   115  	})
   116  }
   117  
   118  func ExampleRepartition() {
   119  	// Count rows per shard before and after using Repartition to get ideal
   120  	// partitioning by taking advantage of the knowledge that our keys are
   121  	// sequential integers.
   122  
   123  	// countRowsPerShard is a utility that counts the number of rows per shard
   124  	// and stores it in rowsPerShard.
   125  	var rowsPerShard []int
   126  	countRowsPerShard := func(numShards int, slice bigslice.Slice) bigslice.Slice {
   127  		rowsPerShard = make([]int, numShards)
   128  		return bigslice.WriterFunc(slice,
   129  			func(shard int, _ struct{}, _ error, xs []int) error {
   130  				rowsPerShard[shard] += len(xs)
   131  				return nil
   132  			},
   133  		)
   134  	}
   135  
   136  	const numShards = 2
   137  	slice := bigslice.Const(numShards, []int{1, 2, 3, 4, 5, 6})
   138  
   139  	slice0 := countRowsPerShard(numShards, slice)
   140  	fmt.Println("# default partitioning")
   141  	fmt.Println("## slice contents")
   142  	slicetest.Print(slice0)
   143  	fmt.Println("## row count per shard")
   144  	for shard, count := range rowsPerShard {
   145  		fmt.Printf("shard:%d count:%d\n", shard, count)
   146  	}
   147  
   148  	slice1 := bigslice.Repartition(slice, func(nshard, x int) int {
   149  		// Put everything in partition 0 for illustration.
   150  		return 0
   151  	})
   152  	slice1 = countRowsPerShard(numShards, slice1)
   153  	fmt.Println("# repartitioned")
   154  	// Note that the slice contents are unchanged.
   155  	fmt.Println("## slice contents")
   156  	slicetest.Print(slice1)
   157  	// Note that the partitioning has changed.
   158  	fmt.Println("## row count per shard")
   159  	for shard, count := range rowsPerShard {
   160  		fmt.Printf("shard:%d count:%d\n", shard, count)
   161  	}
   162  	// Output:
   163  	// # default partitioning
   164  	// ## slice contents
   165  	// 1
   166  	// 2
   167  	// 3
   168  	// 4
   169  	// 5
   170  	// 6
   171  	// ## row count per shard
   172  	// shard:0 count:3
   173  	// shard:1 count:3
   174  	// # repartitioned
   175  	// ## slice contents
   176  	// 1
   177  	// 2
   178  	// 3
   179  	// 4
   180  	// 5
   181  	// 6
   182  	// ## row count per shard
   183  	// shard:0 count:6
   184  	// shard:1 count:0
   185  }
   186  
   187  func ExampleReshard() {
   188  	// Count rows per shard before and after using Reshard to change the number
   189  	// of shards from 2 to 4.
   190  
   191  	// countRowsPerShard is a utility that counts the number of rows per shard
   192  	// and stores it in rowsPerShard.
   193  	var rowsPerShard []int
   194  	countRowsPerShard := func(numShards int, slice bigslice.Slice) bigslice.Slice {
   195  		rowsPerShard = make([]int, numShards)
   196  		return bigslice.WriterFunc(slice,
   197  			func(shard int, _ struct{}, _ error, xs []int) error {
   198  				rowsPerShard[shard] += len(xs)
   199  				return nil
   200  			},
   201  		)
   202  	}
   203  
   204  	const beforeNumShards = 2
   205  	slice := bigslice.Const(beforeNumShards, []int{1, 2, 3, 4, 5, 6})
   206  
   207  	before := countRowsPerShard(beforeNumShards, slice)
   208  	fmt.Println("# before")
   209  	fmt.Println("## slice contents")
   210  	slicetest.Print(before)
   211  	fmt.Println("## row count per shard")
   212  	for shard, count := range rowsPerShard {
   213  		fmt.Printf("shard:%d count:%d\n", shard, count)
   214  	}
   215  
   216  	// Reshard to 4 shards.
   217  	const afterNumShards = 4
   218  	after := bigslice.Reshard(slice, afterNumShards)
   219  	after = countRowsPerShard(afterNumShards, after)
   220  	fmt.Println("# after")
   221  	fmt.Println("## slice contents")
   222  	slicetest.Print(after)
   223  	fmt.Println("## row count per shard")
   224  	for shard, count := range rowsPerShard {
   225  		fmt.Printf("shard:%d count:%d\n", shard, count)
   226  	}
   227  	// Output:
   228  	// # before
   229  	// ## slice contents
   230  	// 1
   231  	// 2
   232  	// 3
   233  	// 4
   234  	// 5
   235  	// 6
   236  	// ## row count per shard
   237  	// shard:0 count:3
   238  	// shard:1 count:3
   239  	// # after
   240  	// ## slice contents
   241  	// 1
   242  	// 2
   243  	// 3
   244  	// 4
   245  	// 5
   246  	// 6
   247  	// ## row count per shard
   248  	// shard:0 count:2
   249  	// shard:1 count:1
   250  	// shard:2 count:1
   251  	// shard:3 count:2
   252  }
   253  
   254  func ExampleReshuffle() {
   255  	// Count rows per shard before and after a Reshuffle, showing same-keyed
   256  	// rows all go to the same shard.
   257  
   258  	// countRowsPerShard is a utility that counts the number of rows per shard
   259  	// and stores it in rowsPerShard.
   260  	var rowsPerShard []int
   261  	countRowsPerShard := func(numShards int, slice bigslice.Slice) bigslice.Slice {
   262  		rowsPerShard = make([]int, numShards)
   263  		return bigslice.WriterFunc(slice,
   264  			func(shard int, _ struct{}, _ error, xs []int) error {
   265  				rowsPerShard[shard] += len(xs)
   266  				return nil
   267  			},
   268  		)
   269  	}
   270  
   271  	const numShards = 2
   272  	slice := bigslice.Const(numShards, []int{1, 2, 3, 4, 5, 6})
   273  	slice = bigslice.Map(slice, func(_ int) int { return 0 })
   274  
   275  	before := countRowsPerShard(numShards, slice)
   276  	fmt.Println("# before")
   277  	fmt.Println("## slice contents")
   278  	slicetest.Print(before)
   279  	fmt.Println("## row count per shard")
   280  	for shard, count := range rowsPerShard {
   281  		fmt.Printf("shard:%d count:%d\n", shard, count)
   282  	}
   283  
   284  	after := bigslice.Reshuffle(slice)
   285  	after = countRowsPerShard(numShards, after)
   286  	fmt.Println("# after")
   287  	// We set all our keys to 0. After reshuffling, all rows will be in the same
   288  	// shard.
   289  	fmt.Println("## slice contents")
   290  	slicetest.Print(after)
   291  	fmt.Println("## row count per shard")
   292  	for shard, count := range rowsPerShard {
   293  		fmt.Printf("shard:%d count:%d\n", shard, count)
   294  	}
   295  	// Output:
   296  	// # before
   297  	// ## slice contents
   298  	// 0
   299  	// 0
   300  	// 0
   301  	// 0
   302  	// 0
   303  	// 0
   304  	// ## row count per shard
   305  	// shard:0 count:3
   306  	// shard:1 count:3
   307  	// # after
   308  	// ## slice contents
   309  	// 0
   310  	// 0
   311  	// 0
   312  	// 0
   313  	// 0
   314  	// 0
   315  	// ## row count per shard
   316  	// shard:0 count:6
   317  	// shard:1 count:0
   318  }