github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/cogroup.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package bigslice
     6  
     7  import (
     8  	"container/heap"
     9  	"context"
    10  	"reflect"
    11  
    12  	"github.com/grailbio/bigslice/frame"
    13  	"github.com/grailbio/bigslice/slicefunc"
    14  	"github.com/grailbio/bigslice/sliceio"
    15  	"github.com/grailbio/bigslice/slicetype"
    16  	"github.com/grailbio/bigslice/sortio"
    17  	"github.com/grailbio/bigslice/typecheck"
    18  )
    19  
    20  type cogroupSlice struct {
    21  	name     Name
    22  	slices   []Slice
    23  	out      []reflect.Type
    24  	prefix   int
    25  	numShard int
    26  }
    27  
    28  // Cogroup returns a slice that, for each key in any slice, contains
    29  // the group of values for that key, in each slice. Schematically:
    30  //
    31  //	Cogroup(Slice<tk1, ..., tkp, t11, ..., t1n>, Slice<tk1, ..., tkp, t21, ..., t2n>, ..., Slice<tk1, ..., tkp, tm1, ..., tmn>)
    32  //		Slice<tk1, ..., tkp, []t11, ..., []t1n, []t21, ..., []tmn>
    33  //
    34  // It thus implements a form of generalized JOIN and GROUP.
    35  //
    36  // Cogroup uses the prefix columns of each slice as its key; keys must be
    37  // partitionable.
    38  //
    39  // TODO(marius): don't require spilling to disk when the input data
    40  // set is small enough.
    41  //
    42  // TODO(marius): consider providing a version that returns scanners
    43  // in the returned slice, so that we can stream through. This would
    44  // require some changes downstream, however, so that buffering and
    45  // encoding functionality also know how to read scanner values.
    46  func Cogroup(slices ...Slice) Slice {
    47  	if len(slices) == 0 {
    48  		typecheck.Panic(1, "cogroup: expected at least one slice")
    49  	}
    50  	var keyTypes []reflect.Type
    51  	for i, slice := range slices {
    52  		if slice.NumOut() == 0 {
    53  			typecheck.Panicf(1, "cogroup: slice %d has no columns", i)
    54  		}
    55  		if i == 0 {
    56  			keyTypes = make([]reflect.Type, slice.Prefix())
    57  			for j := range keyTypes {
    58  				keyTypes[j] = slice.Out(j)
    59  			}
    60  		} else {
    61  			if got, want := slice.Prefix(), len(keyTypes); got != want {
    62  				typecheck.Panicf(1, "cogroup: prefix mismatch: expected %d but got %d", want, got)
    63  			}
    64  			for j := range keyTypes {
    65  				if got, want := slice.Out(j), keyTypes[j]; got != want {
    66  					typecheck.Panicf(1, "cogroup: key column type mismatch: expected %s but got %s", want, got)
    67  				}
    68  			}
    69  		}
    70  	}
    71  	for i := range keyTypes {
    72  		if !frame.CanHash(keyTypes[i]) {
    73  			typecheck.Panicf(1, "cogroup: key column(%d) type %s cannot be hashed", i, keyTypes[i])
    74  		}
    75  		if !frame.CanCompare(keyTypes[i]) {
    76  			typecheck.Panicf(1, "cogroup: key column(%d) type %s cannot be sorted", i, keyTypes[i])
    77  		}
    78  	}
    79  	out := keyTypes
    80  	for _, slice := range slices {
    81  		for i := len(keyTypes); i < slice.NumOut(); i++ {
    82  			out = append(out, reflect.SliceOf(slice.Out(i)))
    83  		}
    84  	}
    85  
    86  	// Pick the max of the number of parent shards, so that the input
    87  	// will be partitioned as widely as the user desires.
    88  	var numShard int
    89  	for _, slice := range slices {
    90  		if slice.NumShard() > numShard {
    91  			numShard = slice.NumShard()
    92  		}
    93  	}
    94  
    95  	return &cogroupSlice{
    96  		name:     MakeName("cogroup"),
    97  		numShard: numShard,
    98  		slices:   slices,
    99  		out:      out,
   100  		prefix:   len(keyTypes),
   101  	}
   102  }
   103  
   104  func (c *cogroupSlice) Name() Name             { return c.name }
   105  func (c *cogroupSlice) NumShard() int          { return c.numShard }
   106  func (c *cogroupSlice) ShardType() ShardType   { return HashShard }
   107  func (c *cogroupSlice) NumOut() int            { return len(c.out) }
   108  func (c *cogroupSlice) Out(i int) reflect.Type { return c.out[i] }
   109  func (c *cogroupSlice) Prefix() int            { return c.prefix }
   110  func (c *cogroupSlice) NumDep() int            { return len(c.slices) }
   111  func (c *cogroupSlice) Dep(i int) Dep          { return Dep{c.slices[i], true, nil, false} }
   112  func (*cogroupSlice) Combiner() slicefunc.Func { return slicefunc.Nil }
   113  
   114  type cogroupReader struct {
   115  	err error
   116  	op  *cogroupSlice
   117  
   118  	readers []sliceio.Reader
   119  
   120  	heap *sortio.FrameBufferHeap
   121  }
   122  
   123  func (c *cogroupReader) Read(ctx context.Context, out frame.Frame) (int, error) {
   124  	const (
   125  		bufferSize = 128
   126  		spillSize  = 1 << 25
   127  	)
   128  	if c.err != nil {
   129  		return 0, c.err
   130  	}
   131  	if c.heap == nil {
   132  		c.heap = new(sortio.FrameBufferHeap)
   133  		c.heap.Buffers = make([]*sortio.FrameBuffer, 0, len(c.readers))
   134  		// Maintain a compare buffer that's used to compare values across
   135  		// the heterogeneously typed buffers.
   136  		// TODO(marius): the extra copy and indirection here is unnecessary.
   137  		lessBuf := frame.Make(slicetype.New(c.op.out[:c.op.prefix]...), 2, 2).Prefixed(c.op.prefix)
   138  		c.heap.LessFunc = func(i, j int) bool {
   139  			ib, jb := c.heap.Buffers[i], c.heap.Buffers[j]
   140  			for i := 0; i < c.op.prefix; i++ {
   141  				lessBuf.Index(i, 0).Set(ib.Frame.Index(i, ib.Index))
   142  				lessBuf.Index(i, 1).Set(jb.Frame.Index(i, jb.Index))
   143  			}
   144  			return lessBuf.Less(0, 1)
   145  		}
   146  
   147  		// Sort each partition one-by-one. Since tasks are scheduled
   148  		// to map onto a single CPU, we attain parallelism through sharding
   149  		// at a higher level.
   150  		for i := range c.readers {
   151  			// Do the actual sort. Aim for ~30 MB spill files.
   152  			// TODO(marius): make spill sizes configurable, or dependent
   153  			// on the environment: for example, we could pass down a memory
   154  			// allotment to each task from the scheduler.
   155  			var sorted sliceio.Reader
   156  			sorted, c.err = sortio.SortReader(ctx, spillSize, c.op.Dep(i), c.readers[i])
   157  			if c.err != nil {
   158  				// TODO(marius): in case this fails, we may leave open file
   159  				// descriptors. We should make sure we close readers that
   160  				// implement Discard.
   161  				return 0, c.err
   162  			}
   163  			buf := &sortio.FrameBuffer{
   164  				Frame:  frame.Make(c.op.Dep(i), bufferSize, bufferSize),
   165  				Reader: sorted,
   166  				Off:    i * bufferSize,
   167  			}
   168  			switch err := buf.Fill(ctx); {
   169  			case err == sliceio.EOF:
   170  				// No data. Skip.
   171  			case err != nil:
   172  				c.err = err
   173  				return 0, err
   174  			default:
   175  				c.heap.Buffers = append(c.heap.Buffers, buf)
   176  			}
   177  		}
   178  	}
   179  	heap.Init(c.heap)
   180  
   181  	// Now that we're sorted, perform a merge from each dependency.
   182  	var (
   183  		n       int
   184  		max     = out.Len()
   185  		lessBuf = frame.Make(slicetype.New(c.op.out[:c.op.prefix]...), 2, 2).Prefixed(c.op.prefix)
   186  	)
   187  	if max == 0 {
   188  		panic("bigslice.Cogroup: max == 0")
   189  	}
   190  	// BUG: this is gnarly
   191  	for n < max && len(c.heap.Buffers) > 0 {
   192  		// First, gather all the records that have the same key.
   193  		row := make([]frame.Frame, len(c.readers))
   194  		var (
   195  			key  = make([]reflect.Value, c.op.prefix)
   196  			last = -1
   197  		)
   198  		// TODO(marius): the extra copy and indirection here is unnecessary.
   199  		less := func() bool {
   200  			buf := c.heap.Buffers[0]
   201  			for i := 0; i < c.op.prefix; i++ {
   202  				lessBuf.Index(i, 0).Set(row[last].Index(i, 0))
   203  				lessBuf.Index(i, 1).Set(buf.Frame.Index(i, buf.Index))
   204  			}
   205  			return lessBuf.Less(0, 1)
   206  		}
   207  
   208  		for last < 0 || len(c.heap.Buffers) > 0 && !less() {
   209  			// first key: need to pick the smallest one
   210  			buf := c.heap.Buffers[0]
   211  			idx := buf.Off / bufferSize
   212  			row[idx] = frame.AppendFrame(row[idx], buf.Slice(buf.Index, buf.Index+1))
   213  			buf.Index++
   214  			if last < 0 {
   215  				for i := 0; i < c.op.prefix; i++ {
   216  					key[i] = row[idx].Index(i, 0)
   217  				}
   218  			}
   219  			last = idx
   220  			if buf.Index == buf.Len {
   221  				if err := buf.Fill(ctx); err != nil && err != sliceio.EOF {
   222  					c.err = err
   223  					return n, err
   224  				} else if err == sliceio.EOF {
   225  					heap.Remove(c.heap, 0)
   226  				} else {
   227  					heap.Fix(c.heap, 0)
   228  				}
   229  			} else {
   230  				heap.Fix(c.heap, 0)
   231  			}
   232  		}
   233  
   234  		// Now that we've gathered all the row values for a given key,
   235  		// push them into our output.
   236  		var j int
   237  		for i := range key {
   238  			out.Index(j, n).Set(key[i])
   239  			j++
   240  		}
   241  		// Note that here we are assuming that the key column is always first;
   242  		// elsewhere we don't really make this assumption, even though it is
   243  		// enforced when constructing a cogroup.
   244  		for i := range row {
   245  			typ := c.op.Dep(i)
   246  			if row[i].Len() == 0 {
   247  				for k := len(key); k < typ.NumOut(); k++ {
   248  					out.Index(j, n).Set(reflect.Zero(c.op.out[j]))
   249  					j++
   250  				}
   251  			} else {
   252  				for k := len(key); k < typ.NumOut(); k++ {
   253  					// TODO(marius): precompute type checks here.
   254  					out.Index(j, n).Set(row[i].Value(k))
   255  					j++
   256  				}
   257  			}
   258  		}
   259  		n++
   260  	}
   261  	if n == 0 {
   262  		c.err = sliceio.EOF
   263  	}
   264  	return n, c.err
   265  }
   266  
   267  func (c *cogroupSlice) Reader(shard int, deps []sliceio.Reader) sliceio.Reader {
   268  	return &cogroupReader{
   269  		op:      c,
   270  		readers: deps,
   271  	}
   272  }