github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/reshuffle.go (about)

     1  // Copyright 2019 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package bigslice
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"reflect"
    11  
    12  	"github.com/grailbio/bigslice/frame"
    13  	"github.com/grailbio/bigslice/slicefunc"
    14  	"github.com/grailbio/bigslice/sliceio"
    15  	"github.com/grailbio/bigslice/slicetype"
    16  	"github.com/grailbio/bigslice/typecheck"
    17  )
    18  
    19  var (
    20  	typeOfInt    = reflect.TypeOf(int(0))
    21  	sliceTypeInt = slicetype.New(typeOfInt)
    22  )
    23  
    24  type reshuffleSlice struct {
    25  	name        Name
    26  	partitioner Partitioner
    27  	Slice
    28  }
    29  
    30  // Reshuffle returns a slice that shuffles rows by prefix so that
    31  // all rows with equal prefix values end up in the same shard.
    32  // Rows are not sorted within a shard.
    33  //
    34  // The output slice has the same type as the input.
    35  //
    36  // TODO: Add ReshuffleSort, which also sorts keys within each shard.
    37  func Reshuffle(slice Slice) Slice {
    38  	if err := canMakeCombiningFrame(slice); err != nil {
    39  		typecheck.Panic(1, err.Error())
    40  	}
    41  	return &reshuffleSlice{MakeName("reshuffle"), nil, slice}
    42  }
    43  
    44  // Repartition (re-)partitions the slice according to the provided function
    45  // fn, which is invoked for each record in the slice to assign that record's
    46  // shard. The function is supplied with the number of shards to partition
    47  // over as well as the column values; the assigned shard is returned.
    48  //
    49  // Schematically:
    50  //
    51  //	Repartition(Slice<t1, t2, ..., tn> func(nshard int, v1 t1, ..., vn tn) int)  Slice<t1, t2, ..., tn>
    52  func Repartition(slice Slice, partition interface{}) Slice {
    53  	var (
    54  		expectArg = slicetype.Append(sliceTypeInt, slice)
    55  		expectRet = sliceTypeInt
    56  	)
    57  	fn, ok := slicefunc.Of(partition)
    58  	if !ok {
    59  		typecheck.Panicf(1, "repartition: not a function: %T", partition)
    60  	}
    61  	if !typecheck.Equal(fn.In, expectArg) || !typecheck.Equal(fn.Out, expectRet) {
    62  		typecheck.Panicf(1, "repartition: expected %s, got %T", slicetype.Signature(expectArg, expectRet), partition)
    63  	}
    64  	part := func(ctx context.Context, frame frame.Frame, nshard int, shards []int) {
    65  		args := make([]reflect.Value, slice.NumOut()+1)
    66  		args[0] = reflect.ValueOf(nshard)
    67  		for i := range shards {
    68  			for j := 0; j < slice.NumOut(); j++ {
    69  				args[j+1] = frame.Index(j, i)
    70  			}
    71  			result := fn.Call(ctx, args)
    72  			shards[i] = int(result[0].Int())
    73  		}
    74  	}
    75  	return &reshuffleSlice{MakeName("repartition"), part, slice}
    76  }
    77  
    78  func (r *reshuffleSlice) Name() Name             { return r.name }
    79  func (*reshuffleSlice) NumDep() int              { return 1 }
    80  func (r *reshuffleSlice) Dep(i int) Dep          { return Dep{r.Slice, true, r.partitioner, false} }
    81  func (*reshuffleSlice) Combiner() slicefunc.Func { return slicefunc.Nil }
    82  
    83  func (r *reshuffleSlice) Reader(shard int, deps []sliceio.Reader) sliceio.Reader {
    84  	if len(deps) != 1 {
    85  		panic(fmt.Errorf("expected one dep, got %d", len(deps)))
    86  	}
    87  	return deps[0]
    88  }