github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/sortio/sort.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package sortio provides facilities for sorting slice outputs
     6  // and merging and reducing sorted record streams.
     7  package sortio
     8  
     9  import (
    10  	"container/heap"
    11  	"context"
    12  	"math"
    13  	"sort"
    14  
    15  	"github.com/grailbio/base/log"
    16  	"github.com/grailbio/bigslice/frame"
    17  	"github.com/grailbio/bigslice/internal/defaultsize"
    18  	"github.com/grailbio/bigslice/sliceio"
    19  	"github.com/grailbio/bigslice/slicetype"
    20  )
    21  
    22  var numCanaryRows = &defaultsize.SortCanary
    23  
    24  // SortReader sorts a Reader by its prefix columns. SortReader may
    25  // spill to disk, in which case it targets spill file sizes of
    26  // spillTarget (in bytes). Because the encoded size of objects is not
    27  // known in advance, sortReader uses a "canary" batch size of ~16k
    28  // rows in order to estimate the size of future reads. The estimate
    29  // is revisited on every subsequent fill and adjusted if it is
    30  // violated by more than 5%.
    31  func SortReader(ctx context.Context, spillTarget int, typ slicetype.Type, r sliceio.Reader) (sliceio.Reader, error) {
    32  	spill, err := sliceio.NewSpiller("sorter")
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  	defer func() {
    37  		if cleanupErr := spill.Cleanup(); cleanupErr != nil {
    38  			// Consider temporary file cleanup to be best-effort.
    39  			log.Debug.Printf("%s: failed to clean up temporary files: %v",
    40  				spill, cleanupErr)
    41  		}
    42  	}()
    43  	f := frame.Make(typ, *numCanaryRows, *numCanaryRows)
    44  	for {
    45  		var n int
    46  		n, err = sliceio.ReadFull(ctx, r, f)
    47  		if err != nil && err != sliceio.EOF {
    48  			return nil, err
    49  		}
    50  		eof := err == sliceio.EOF
    51  		g := f.Slice(0, n)
    52  		sort.Sort(g)
    53  		var size int
    54  		size, err = spill.Spill(g)
    55  		if err != nil {
    56  			return nil, err
    57  		}
    58  		if eof {
    59  			break
    60  		}
    61  		bytesPerRow := size / n
    62  		targetRows := spillTarget / bytesPerRow
    63  		if targetRows < sliceio.SpillBatchSize {
    64  			targetRows = sliceio.SpillBatchSize
    65  		}
    66  		// If we're within 5%, that's ok.
    67  		if math.Abs(float64(f.Len()-targetRows)/float64(targetRows)) > 0.05 {
    68  			f = f.Ensure(targetRows)
    69  		}
    70  	}
    71  	readers, err := spill.ClosingReaders()
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	return NewMergeReader(ctx, typ, readers)
    76  }
    77  
    78  // A FrameBuffer is a buffered frame. The frame is filled from
    79  // a reader, and maintains a current index and length.
    80  type FrameBuffer struct {
    81  	// Frame is the buffer into which new data are read. The buffer is
    82  	// always allocated externally and must be nonempty.
    83  	frame.Frame
    84  	// Reader is the slice reader from which the buffer is filled.
    85  	sliceio.Reader
    86  	// Index, Len is current index and length of the frame.
    87  	Index, Len int
    88  
    89  	// Off is the global offset of this framebuffer. It is added to
    90  	// the index to compute the buffer's current position.
    91  	Off int
    92  }
    93  
    94  // Pos returns this frame buffer's current position.
    95  func (f FrameBuffer) Pos() int {
    96  	return f.Off + f.Index
    97  }
    98  
    99  // Fill (re-) fills the FrameBuffer when it's empty. An error
   100  // is returned if the underlying reader returns an error.
   101  // EOF is returned if no more data are available.
   102  func (f *FrameBuffer) Fill(ctx context.Context) error {
   103  	if f.Index != f.Len {
   104  		panic("FrameBuffer.Fill: fill on nonempty buffer")
   105  	}
   106  	var err error
   107  	f.Len, err = f.Reader.Read(ctx, f.Frame)
   108  	if err != nil && err != sliceio.EOF {
   109  		return err
   110  	}
   111  	if err == sliceio.EOF && f.Len > 0 {
   112  		err = nil
   113  	}
   114  	f.Index = 0
   115  	if f.Len == 0 && err == nil {
   116  		err = sliceio.EOF
   117  	}
   118  	return err
   119  }
   120  
   121  // FrameBufferHeap implements a heap of FrameBuffers,
   122  // ordered by the provided sorter.
   123  type FrameBufferHeap struct {
   124  	Buffers []*FrameBuffer
   125  	// Less compares the current index of buffers i and j.
   126  	LessFunc func(i, j int) bool
   127  }
   128  
   129  func (f *FrameBufferHeap) Len() int { return len(f.Buffers) }
   130  func (f *FrameBufferHeap) Less(i, j int) bool {
   131  	return f.LessFunc(i, j)
   132  }
   133  func (f *FrameBufferHeap) Swap(i, j int) {
   134  	f.Buffers[i], f.Buffers[j] = f.Buffers[j], f.Buffers[i]
   135  }
   136  
   137  // Push pushes a FrameBuffer onto the heap.
   138  func (f *FrameBufferHeap) Push(x interface{}) {
   139  	buf := x.(*FrameBuffer)
   140  	f.Buffers = append(f.Buffers, buf)
   141  }
   142  
   143  // Pop removes the FrameBuffer with the smallest priority
   144  // from the heap.
   145  func (f *FrameBufferHeap) Pop() interface{} {
   146  	n := len(f.Buffers)
   147  	elem := f.Buffers[n-1]
   148  	f.Buffers = f.Buffers[:n-1]
   149  	return elem
   150  }
   151  
   152  // MergeReader merges multiple (sorted) readers into a
   153  // single sorted reader.
   154  type mergeReader struct {
   155  	err  error
   156  	heap *FrameBufferHeap
   157  }
   158  
   159  // NewMergeReader returns a new Reader that is sorted by its prefix columns. The
   160  // readers to be merged must already be sorted.
   161  func NewMergeReader(ctx context.Context, typ slicetype.Type, readers []sliceio.Reader) (sliceio.Reader, error) {
   162  	h := new(FrameBufferHeap)
   163  	h.Buffers = make([]*FrameBuffer, 0, len(readers))
   164  	n := len(readers) * sliceio.SpillBatchSize
   165  	f := frame.Make(typ, n, n)
   166  	h.LessFunc = func(i, j int) bool {
   167  		return f.Less(h.Buffers[i].Pos(), h.Buffers[j].Pos())
   168  	}
   169  	for i := range readers {
   170  		off := i * sliceio.SpillBatchSize
   171  		fr := &FrameBuffer{
   172  			Reader: readers[i],
   173  			Frame:  f.Slice(off, off+sliceio.SpillBatchSize),
   174  			Off:    off,
   175  		}
   176  		switch err := fr.Fill(ctx); {
   177  		case err == sliceio.EOF:
   178  			// No data. Skip.
   179  		case err != nil:
   180  			return nil, err
   181  		default:
   182  			h.Buffers = append(h.Buffers, fr)
   183  		}
   184  	}
   185  	heap.Init(h)
   186  	return &mergeReader{heap: h}, nil
   187  }
   188  
   189  // Read implements Reader.
   190  func (m *mergeReader) Read(ctx context.Context, out frame.Frame) (int, error) {
   191  	if m.err != nil {
   192  		return 0, m.err
   193  	}
   194  	var (
   195  		n   int
   196  		max = out.Len()
   197  	)
   198  	for n < max && len(m.heap.Buffers) > 0 {
   199  		idx := m.heap.Buffers[0].Index
   200  		frame.Copy(out.Slice(n, n+1), m.heap.Buffers[0].Slice(idx, idx+1))
   201  		n++
   202  		m.heap.Buffers[0].Index++
   203  		if m.heap.Buffers[0].Index == m.heap.Buffers[0].Len {
   204  			if err := m.heap.Buffers[0].Fill(ctx); err != nil && err != sliceio.EOF {
   205  				// TODO(jcharumilind): Close other buffer readers that have not
   206  				// yet returned an error; otherwise they leak.
   207  				m.err = err
   208  				return 0, err
   209  			} else if err == sliceio.EOF {
   210  				heap.Remove(m.heap, 0)
   211  			} else {
   212  				heap.Fix(m.heap, 0)
   213  			}
   214  		} else {
   215  			heap.Fix(m.heap, 0)
   216  		}
   217  	}
   218  	if n == 0 {
   219  		m.err = sliceio.EOF
   220  	}
   221  	return n, m.err
   222  }