github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/reshuffle.go (about) 1 // Copyright 2019 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package bigslice 6 7 import ( 8 "context" 9 "fmt" 10 "reflect" 11 12 "github.com/grailbio/bigslice/frame" 13 "github.com/grailbio/bigslice/slicefunc" 14 "github.com/grailbio/bigslice/sliceio" 15 "github.com/grailbio/bigslice/slicetype" 16 "github.com/grailbio/bigslice/typecheck" 17 ) 18 19 var ( 20 typeOfInt = reflect.TypeOf(int(0)) 21 sliceTypeInt = slicetype.New(typeOfInt) 22 ) 23 24 type reshuffleSlice struct { 25 name Name 26 partitioner Partitioner 27 Slice 28 } 29 30 // Reshuffle returns a slice that shuffles rows by prefix so that 31 // all rows with equal prefix values end up in the same shard. 32 // Rows are not sorted within a shard. 33 // 34 // The output slice has the same type as the input. 35 // 36 // TODO: Add ReshuffleSort, which also sorts keys within each shard. 37 func Reshuffle(slice Slice) Slice { 38 if err := canMakeCombiningFrame(slice); err != nil { 39 typecheck.Panic(1, err.Error()) 40 } 41 return &reshuffleSlice{MakeName("reshuffle"), nil, slice} 42 } 43 44 // Repartition (re-)partitions the slice according to the provided function 45 // fn, which is invoked for each record in the slice to assign that record's 46 // shard. The function is supplied with the number of shards to partition 47 // over as well as the column values; the assigned shard is returned. 48 // 49 // Schematically: 50 // 51 // Repartition(Slice<t1, t2, ..., tn> func(nshard int, v1 t1, ..., vn tn) int) Slice<t1, t2, ..., tn> 52 func Repartition(slice Slice, partition interface{}) Slice { 53 var ( 54 expectArg = slicetype.Append(sliceTypeInt, slice) 55 expectRet = sliceTypeInt 56 ) 57 fn, ok := slicefunc.Of(partition) 58 if !ok { 59 typecheck.Panicf(1, "repartition: not a function: %T", partition) 60 } 61 if !typecheck.Equal(fn.In, expectArg) || !typecheck.Equal(fn.Out, expectRet) { 62 typecheck.Panicf(1, "repartition: expected %s, got %T", slicetype.Signature(expectArg, expectRet), partition) 63 } 64 part := func(ctx context.Context, frame frame.Frame, nshard int, shards []int) { 65 args := make([]reflect.Value, slice.NumOut()+1) 66 args[0] = reflect.ValueOf(nshard) 67 for i := range shards { 68 for j := 0; j < slice.NumOut(); j++ { 69 args[j+1] = frame.Index(j, i) 70 } 71 result := fn.Call(ctx, args) 72 shards[i] = int(result[0].Int()) 73 } 74 } 75 return &reshuffleSlice{MakeName("repartition"), part, slice} 76 } 77 78 func (r *reshuffleSlice) Name() Name { return r.name } 79 func (*reshuffleSlice) NumDep() int { return 1 } 80 func (r *reshuffleSlice) Dep(i int) Dep { return Dep{r.Slice, true, r.partitioner, false} } 81 func (*reshuffleSlice) Combiner() slicefunc.Func { return slicefunc.Nil } 82 83 func (r *reshuffleSlice) Reader(shard int, deps []sliceio.Reader) sliceio.Reader { 84 if len(deps) != 1 { 85 panic(fmt.Errorf("expected one dep, got %d", len(deps))) 86 } 87 return deps[0] 88 }