github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/sortio/reader.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package sortio
     6  
     7  import (
     8  	"container/heap"
     9  	"context"
    10  	"reflect"
    11  
    12  	"github.com/grailbio/bigslice/frame"
    13  	"github.com/grailbio/bigslice/internal/defaultsize"
    14  	"github.com/grailbio/bigslice/slicefunc"
    15  	"github.com/grailbio/bigslice/sliceio"
    16  	"github.com/grailbio/bigslice/slicetype"
    17  	"github.com/grailbio/bigslice/typecheck"
    18  )
    19  
    20  var defaultChunksize = defaultsize.Chunk
    21  
    22  type reader struct {
    23  	typ      slicetype.Type
    24  	name     string
    25  	combiner slicefunc.Func
    26  	readers  []sliceio.Reader
    27  	err      error
    28  
    29  	heap  *FrameBufferHeap
    30  	frame frame.Frame
    31  }
    32  
    33  // Reduce returns a Reader that merges and reduces a set of
    34  // sorted (and possibly combined) readers. Reduce panics if
    35  // the provided type is not reducable.
    36  func Reduce(typ slicetype.Type, name string, readers []sliceio.Reader, combiner slicefunc.Func) sliceio.Reader {
    37  	if typ.NumOut()-typ.Prefix() != 1 {
    38  		typecheck.Panicf(1, "cannot reduce type %s", slicetype.String(typ))
    39  	}
    40  	return &reader{
    41  		typ:      typ,
    42  		name:     name,
    43  		readers:  readers,
    44  		combiner: combiner,
    45  	}
    46  }
    47  
    48  func (r *reader) Read(ctx context.Context, out frame.Frame) (int, error) {
    49  	if r.err != nil {
    50  		return 0, r.err
    51  	}
    52  	if r.heap == nil {
    53  		n := len(r.readers) * defaultChunksize
    54  		r.frame = frame.Make(r.typ, n, n)
    55  		r.heap = new(FrameBufferHeap)
    56  		r.heap.LessFunc = func(i, j int) bool {
    57  			return r.frame.Less(r.heap.Buffers[i].Pos(), r.heap.Buffers[j].Pos())
    58  		}
    59  		r.heap.Buffers = make([]*FrameBuffer, 0, len(r.readers))
    60  		for i := range r.readers {
    61  			off := i * defaultChunksize
    62  			buf := &FrameBuffer{
    63  				Frame:  r.frame.Slice(off, off+defaultChunksize),
    64  				Reader: r.readers[i],
    65  				Off:    off,
    66  			}
    67  			switch err := buf.Fill(ctx); {
    68  			case err == sliceio.EOF:
    69  				// No data. Skip.
    70  			case err != nil:
    71  				r.err = err
    72  				return 0, r.err
    73  			default:
    74  				r.heap.Buffers = append(r.heap.Buffers, buf)
    75  			}
    76  		}
    77  		heap.Init(r.heap)
    78  	}
    79  	var (
    80  		n   int
    81  		max = out.Len()
    82  	)
    83  	for n < max && len(r.heap.Buffers) > 0 {
    84  		// Gather all the buffers that have the same key. Each parent
    85  		// reader has at most one entry for a given key, since they have
    86  		// already been combined.
    87  		var combine []*FrameBuffer
    88  		for len(combine) == 0 || len(r.heap.Buffers) > 0 &&
    89  			!r.frame.Less(combine[0].Pos(), r.heap.Buffers[0].Pos()) {
    90  			buf := heap.Pop(r.heap).(*FrameBuffer)
    91  			combine = append(combine, buf)
    92  		}
    93  		// TODO(marius): pass a vector of values to be combined, if it is supported
    94  		// by the combiner.
    95  		vcol := out.NumOut() - 1
    96  		var combined reflect.Value
    97  		for i, buf := range combine {
    98  			val := buf.Frame.Index(vcol, buf.Index)
    99  			if i == 0 {
   100  				combined = val
   101  			} else {
   102  				combined = r.combiner.Call(ctx, []reflect.Value{combined, val})[0]
   103  			}
   104  		}
   105  		// Emit the output before overwriting the frame. Copy key columns
   106  		// first, and then set the combined value.
   107  		frame.Copy(out.Slice(n, n+1), combine[0].Frame.Slice(combine[0].Index, combine[0].Index+1))
   108  		out.Index(vcol, n).Set(combined)
   109  
   110  		for _, buf := range combine {
   111  			buf.Index++
   112  			if buf.Index == buf.Len {
   113  				if err := buf.Fill(ctx); err != nil && err != sliceio.EOF {
   114  					r.err = err
   115  					return n, err
   116  				} else if err == nil {
   117  					heap.Push(r.heap, buf)
   118  				} // Otherwise it's EOF and we drop it from the heap.
   119  			} else {
   120  				heap.Push(r.heap, buf)
   121  			}
   122  		}
   123  		n++
   124  	}
   125  	var err error
   126  	if len(r.heap.Buffers) == 0 {
   127  		err = sliceio.EOF
   128  	}
   129  	return n, err
   130  }