github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/sortio/reader.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package sortio 6 7 import ( 8 "container/heap" 9 "context" 10 "reflect" 11 12 "github.com/grailbio/bigslice/frame" 13 "github.com/grailbio/bigslice/internal/defaultsize" 14 "github.com/grailbio/bigslice/slicefunc" 15 "github.com/grailbio/bigslice/sliceio" 16 "github.com/grailbio/bigslice/slicetype" 17 "github.com/grailbio/bigslice/typecheck" 18 ) 19 20 var defaultChunksize = defaultsize.Chunk 21 22 type reader struct { 23 typ slicetype.Type 24 name string 25 combiner slicefunc.Func 26 readers []sliceio.Reader 27 err error 28 29 heap *FrameBufferHeap 30 frame frame.Frame 31 } 32 33 // Reduce returns a Reader that merges and reduces a set of 34 // sorted (and possibly combined) readers. Reduce panics if 35 // the provided type is not reducable. 36 func Reduce(typ slicetype.Type, name string, readers []sliceio.Reader, combiner slicefunc.Func) sliceio.Reader { 37 if typ.NumOut()-typ.Prefix() != 1 { 38 typecheck.Panicf(1, "cannot reduce type %s", slicetype.String(typ)) 39 } 40 return &reader{ 41 typ: typ, 42 name: name, 43 readers: readers, 44 combiner: combiner, 45 } 46 } 47 48 func (r *reader) Read(ctx context.Context, out frame.Frame) (int, error) { 49 if r.err != nil { 50 return 0, r.err 51 } 52 if r.heap == nil { 53 n := len(r.readers) * defaultChunksize 54 r.frame = frame.Make(r.typ, n, n) 55 r.heap = new(FrameBufferHeap) 56 r.heap.LessFunc = func(i, j int) bool { 57 return r.frame.Less(r.heap.Buffers[i].Pos(), r.heap.Buffers[j].Pos()) 58 } 59 r.heap.Buffers = make([]*FrameBuffer, 0, len(r.readers)) 60 for i := range r.readers { 61 off := i * defaultChunksize 62 buf := &FrameBuffer{ 63 Frame: r.frame.Slice(off, off+defaultChunksize), 64 Reader: r.readers[i], 65 Off: off, 66 } 67 switch err := buf.Fill(ctx); { 68 case err == sliceio.EOF: 69 // No data. Skip. 70 case err != nil: 71 r.err = err 72 return 0, r.err 73 default: 74 r.heap.Buffers = append(r.heap.Buffers, buf) 75 } 76 } 77 heap.Init(r.heap) 78 } 79 var ( 80 n int 81 max = out.Len() 82 ) 83 for n < max && len(r.heap.Buffers) > 0 { 84 // Gather all the buffers that have the same key. Each parent 85 // reader has at most one entry for a given key, since they have 86 // already been combined. 87 var combine []*FrameBuffer 88 for len(combine) == 0 || len(r.heap.Buffers) > 0 && 89 !r.frame.Less(combine[0].Pos(), r.heap.Buffers[0].Pos()) { 90 buf := heap.Pop(r.heap).(*FrameBuffer) 91 combine = append(combine, buf) 92 } 93 // TODO(marius): pass a vector of values to be combined, if it is supported 94 // by the combiner. 95 vcol := out.NumOut() - 1 96 var combined reflect.Value 97 for i, buf := range combine { 98 val := buf.Frame.Index(vcol, buf.Index) 99 if i == 0 { 100 combined = val 101 } else { 102 combined = r.combiner.Call(ctx, []reflect.Value{combined, val})[0] 103 } 104 } 105 // Emit the output before overwriting the frame. Copy key columns 106 // first, and then set the combined value. 107 frame.Copy(out.Slice(n, n+1), combine[0].Frame.Slice(combine[0].Index, combine[0].Index+1)) 108 out.Index(vcol, n).Set(combined) 109 110 for _, buf := range combine { 111 buf.Index++ 112 if buf.Index == buf.Len { 113 if err := buf.Fill(ctx); err != nil && err != sliceio.EOF { 114 r.err = err 115 return n, err 116 } else if err == nil { 117 heap.Push(r.heap, buf) 118 } // Otherwise it's EOF and we drop it from the heap. 119 } else { 120 heap.Push(r.heap, buf) 121 } 122 } 123 n++ 124 } 125 var err error 126 if len(r.heap.Buffers) == 0 { 127 err = sliceio.EOF 128 } 129 return n, err 130 }