github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/prolly/sort/external.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sort
    16  
    17  import (
    18  	"bufio"
    19  	"container/heap"
    20  	"context"
    21  	"encoding/binary"
    22  	"errors"
    23  	"fmt"
    24  	"io"
    25  	"os"
    26  	"sort"
    27  
    28  	"github.com/dolthub/dolt/go/store/util/tempfiles"
    29  	"github.com/dolthub/dolt/go/store/val"
    30  )
    31  
    32  // tupleSorter inputs a series of unsorted tuples and outputs a sorted list
    33  // of tuples. Batches of tuples sorted in memory are written to disk, and
    34  // then k-way merge sorted to produce a final sorted list. The |fileMax|
    35  // parameter limits the number of files spilled to disk at any given time.
    36  // The maximum memory used will be |fileMax| * |batchSize|.
    37  type tupleSorter struct {
    38  	keyCmp    func(val.Tuple, val.Tuple) bool
    39  	files     [][]keyIterable
    40  	inProg    *keyMem
    41  	fileMax   int
    42  	fileCnt   int
    43  	batchSize int
    44  	tmpProv   tempfiles.TempFileProvider
    45  }
    46  
    47  func NewTupleSorter(batchSize, fileMax int, keyCmp func(val.Tuple, val.Tuple) bool, tmpProv tempfiles.TempFileProvider) *tupleSorter {
    48  	if fileMax%2 == 1 {
    49  		// round down to even
    50  		// fileMax/2 will be compact parallelism
    51  		fileMax -= 1
    52  	}
    53  	ret := &tupleSorter{
    54  		fileMax:   fileMax,
    55  		batchSize: batchSize,
    56  		keyCmp:    keyCmp,
    57  		tmpProv:   tmpProv,
    58  	}
    59  	ret.inProg = newKeyMem(batchSize)
    60  	return ret
    61  }
    62  
    63  func (a *tupleSorter) Flush(ctx context.Context) (iter keyIterable, err error) {
    64  	// don't flush in-progress, just sort in memory
    65  	a.inProg.sort(a.keyCmp)
    66  
    67  	if len(a.files) == 0 {
    68  		// don't go to disk if we didn't reach a mem flush
    69  		return a.inProg, nil
    70  	}
    71  
    72  	var iterables []keyIterable
    73  	iterables = append(iterables, a.inProg)
    74  	for _, level := range a.files {
    75  		for _, file := range level {
    76  			iterables = append(iterables, file)
    77  		}
    78  	}
    79  
    80  	newF, err := a.newFile()
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  	allKeys := newKeyFile(newF, a.inProg.byteLim)
    85  	defer func() {
    86  		if err != nil {
    87  			allKeys.Close()
    88  		}
    89  	}()
    90  
    91  	m, err := newFileMerger(ctx, a.keyCmp, allKeys, iterables...)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	defer m.Close()
    96  	if err := m.run(ctx); err != nil {
    97  		return nil, err
    98  	}
    99  	return allKeys, nil
   100  }
   101  
   102  func (a *tupleSorter) Insert(ctx context.Context, k val.Tuple) (err error) {
   103  	if !a.inProg.insert(k) {
   104  		if err := a.flushMem(ctx); err != nil {
   105  			return err
   106  		}
   107  		a.inProg.insert(k)
   108  	}
   109  	return
   110  }
   111  func (a *tupleSorter) Close() {
   112  	for _, level := range a.files {
   113  		for _, f := range level {
   114  			f.Close()
   115  		}
   116  	}
   117  }
   118  
   119  func (a *tupleSorter) flushMem(ctx context.Context) error {
   120  	// flush and replace |inProg|
   121  	if a.inProg.Len() > 0 {
   122  		newF, err := a.newFile()
   123  		if err != nil {
   124  			return err
   125  		}
   126  		newFile, err := a.inProg.flush(newF, a.keyCmp)
   127  		if err != nil {
   128  			return err
   129  		}
   130  		a.inProg = newKeyMem(a.batchSize)
   131  		if len(a.files) == 0 {
   132  			a.files = append(a.files, []keyIterable{newFile})
   133  		} else {
   134  			a.files[0] = append(a.files[0], newFile)
   135  		}
   136  		a.fileCnt++
   137  	}
   138  	for level, ok := a.shouldCompact(); ok; level, ok = a.shouldCompact() {
   139  		if err := a.compact(ctx, level); err != nil {
   140  			return err
   141  		}
   142  	}
   143  	return nil
   144  }
   145  
   146  func (a *tupleSorter) newFile() (*os.File, error) {
   147  	f, err := a.tmpProv.NewFile("", "key_file_")
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	return f, nil
   152  }
   153  
   154  func (a *tupleSorter) shouldCompact() (int, bool) {
   155  	for i, level := range a.files {
   156  		if len(level) >= a.fileMax {
   157  			return i, true
   158  		}
   159  	}
   160  	return -1, false
   161  }
   162  
   163  // compact merges the first `a.fileMax` files in `a.files[level]` into a single sorted file which is added to `a.files[level+1]`
   164  func (a *tupleSorter) compact(ctx context.Context, level int) error {
   165  	newF, err := a.newFile()
   166  	if err != nil {
   167  		return err
   168  	}
   169  	outF := newKeyFile(newF, a.batchSize)
   170  	func() {
   171  		if err != nil {
   172  			outF.Close()
   173  		}
   174  	}()
   175  
   176  	fileLevel := a.files[level]
   177  	m, err := newFileMerger(ctx, a.keyCmp, outF, fileLevel[:a.fileMax]...)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	defer m.Close()
   182  	if err := m.run(ctx); err != nil {
   183  		return err
   184  	}
   185  
   186  	// zero out compacted level
   187  	// add to next level
   188  	a.files[level] = a.files[level][a.fileMax:]
   189  	if len(a.files) <= level+1 {
   190  		a.files = append(a.files, []keyIterable{outF})
   191  	} else {
   192  		a.files[level+1] = append(a.files[level+1], outF)
   193  	}
   194  
   195  	return nil
   196  }
   197  
   198  func newKeyMem(size int) *keyMem {
   199  	return &keyMem{byteLim: size}
   200  }
   201  
   202  type keyMem struct {
   203  	keys    []val.Tuple
   204  	byteCnt int
   205  	byteLim int
   206  }
   207  
   208  func (k *keyMem) insert(key val.Tuple) bool {
   209  	if len(key)+int(keyLenSz)+k.byteCnt > k.byteLim {
   210  		return false
   211  	}
   212  
   213  	k.keys = append(k.keys, key)
   214  	k.byteCnt += len(key) + int(keyLenSz)
   215  	return true
   216  }
   217  
   218  func (k *keyMem) flush(f *os.File, cmp func(val.Tuple, val.Tuple) bool) (*keyFile, error) {
   219  	k.sort(cmp)
   220  	kf := newKeyFile(f, k.byteLim)
   221  	for _, k := range k.keys {
   222  		if err := kf.append(k); err != nil {
   223  			return nil, err
   224  		}
   225  	}
   226  	if err := kf.buf.Flush(); err != nil {
   227  		return nil, err
   228  	}
   229  	return kf, nil
   230  }
   231  
   232  func (f *keyMem) IterAll(_ context.Context) (keyIter, error) {
   233  	return &keyMemIter{keys: f.keys}, nil
   234  }
   235  
   236  func (f *keyMem) Close() {
   237  	return
   238  }
   239  
   240  type keyMemIter struct {
   241  	keys []val.Tuple
   242  	i    int
   243  }
   244  
   245  func (i *keyMemIter) Next(ctx context.Context) (val.Tuple, error) {
   246  	if i.i >= len(i.keys) {
   247  		return nil, io.EOF
   248  	}
   249  	ret := i.keys[i.i]
   250  	i.i++
   251  	return ret, nil
   252  }
   253  
   254  func (i *keyMemIter) Close() {
   255  	return
   256  }
   257  
   258  func (k *keyMem) Len() int {
   259  	return len(k.keys)
   260  }
   261  
   262  // sort sorts the tuples in memory without flushing to disk
   263  func (k *keyMem) sort(cmp func(val.Tuple, val.Tuple) bool) {
   264  	sort.Slice(k.keys, func(i, j int) bool {
   265  		return cmp(k.keys[i], k.keys[j])
   266  	})
   267  }
   268  
   269  func newKeyFile(f *os.File, batchSize int) *keyFile {
   270  	return &keyFile{f: f, buf: bufio.NewWriterSize(f, batchSize), batchSize: batchSize}
   271  }
   272  
   273  type keyFile struct {
   274  	f         *os.File
   275  	buf       *bufio.Writer
   276  	batchSize int
   277  }
   278  
   279  func (f *keyFile) IterAll(ctx context.Context) (keyIter, error) {
   280  	if f.batchSize == 0 {
   281  		return nil, fmt.Errorf("invalid zero batch size")
   282  	}
   283  	if _, err := f.f.Seek(0, io.SeekStart); err != nil {
   284  		return nil, err
   285  	}
   286  	file := f.f
   287  	f.f = nil
   288  	return &keyFileReader{buf: bufio.NewReader(file), f: file}, nil
   289  }
   290  
   291  func (f *keyFile) Close() {
   292  	if f != nil && f.f != nil {
   293  		f.f.Close()
   294  		os.Remove(f.f.Name())
   295  	}
   296  	return
   297  }
   298  
   299  // append writes |keySize|key| to the intermediate file
   300  func (f *keyFile) append(k val.Tuple) error {
   301  	v := uint32(len(k))
   302  	var sizeBuf [4]byte
   303  	writeUint32(sizeBuf[:], v)
   304  
   305  	if _, err := f.buf.Write(sizeBuf[:]); err != nil {
   306  		return err
   307  	}
   308  	if _, err := f.buf.Write(k[:]); err != nil {
   309  		return err
   310  	}
   311  
   312  	return nil
   313  }
   314  
   315  type keyFileReader struct {
   316  	buf *bufio.Reader
   317  	f   *os.File
   318  }
   319  
   320  const (
   321  	keyLenSz = 4
   322  )
   323  
   324  func readUint32(buf []byte) uint32 {
   325  	return binary.BigEndian.Uint32(buf)
   326  }
   327  
   328  func writeUint32(buf []byte, u uint32) {
   329  	binary.BigEndian.PutUint32(buf, u)
   330  }
   331  
   332  func (r *keyFileReader) Next(ctx context.Context) (val.Tuple, error) {
   333  	var keySizeBuf [4]byte
   334  	if _, err := io.ReadFull(r.buf, keySizeBuf[:]); err != nil {
   335  		return nil, err
   336  	}
   337  
   338  	keySize := readUint32(keySizeBuf[:])
   339  	key := make([]byte, keySize)
   340  	if _, err := io.ReadFull(r.buf, key); err != nil {
   341  		return nil, err
   342  	}
   343  
   344  	return key, nil
   345  }
   346  
   347  func (r *keyFileReader) Close() {
   348  	if r != nil && r.f != nil {
   349  		r.f.Close()
   350  		os.Remove(r.f.Name())
   351  	}
   352  }
   353  
   354  type keyIterable interface {
   355  	IterAll(context.Context) (keyIter, error)
   356  	Close()
   357  }
   358  
   359  type keyIter interface {
   360  	Next(ctx context.Context) (val.Tuple, error)
   361  	Close()
   362  }
   363  
   364  // mergeFileReader is the heap object for a k-way merge.
   365  type mergeFileReader struct {
   366  	// iter abstracts file or in-memory sorted tuples
   367  	iter keyIter
   368  	// head is the next tuple in the sorted list
   369  	head val.Tuple
   370  }
   371  
   372  func (r *mergeFileReader) next(ctx context.Context) (bool, error) {
   373  	var err error
   374  	r.head, err = r.iter.Next(ctx)
   375  	if err != nil {
   376  		if errors.Is(err, io.EOF) {
   377  			return false, nil
   378  		}
   379  		return false, err
   380  	}
   381  	return true, nil
   382  }
   383  
   384  func newMergeFileReader(ctx context.Context, iter keyIter) (*mergeFileReader, error) {
   385  	root, err := iter.Next(ctx)
   386  	if err != nil {
   387  		return nil, err
   388  	}
   389  	return &mergeFileReader{iter: iter, head: root}, nil
   390  }
   391  
   392  type mergeQueue struct {
   393  	files  []*mergeFileReader
   394  	keyCmp func(val.Tuple, val.Tuple) bool
   395  }
   396  
   397  func (mq mergeQueue) Len() int { return len(mq.files) }
   398  
   399  func (mq mergeQueue) Less(i, j int) bool {
   400  	// We want Pop to give us the lowest, not highest, priority so we use less than here.
   401  	return mq.keyCmp(mq.files[i].head, mq.files[j].head)
   402  }
   403  
   404  func (mq mergeQueue) Swap(i, j int) {
   405  	mq.files[i], mq.files[j] = mq.files[j], mq.files[i]
   406  }
   407  
   408  func (mq *mergeQueue) Push(x any) {
   409  	item := x.(*mergeFileReader)
   410  	mq.files = append(mq.files, item)
   411  }
   412  
   413  func (mq *mergeQueue) Pop() any {
   414  	old := mq.files
   415  	n := len(old)
   416  	item := old[n-1]
   417  	old[n-1] = nil // avoid memory leak
   418  	mq.files = old[0 : n-1]
   419  	return item
   420  }
   421  
   422  type fileMerger struct {
   423  	mq  *mergeQueue
   424  	out *keyFile
   425  }
   426  
   427  func newFileMerger(ctx context.Context, keyCmp func(val.Tuple, val.Tuple) bool, target *keyFile, files ...keyIterable) (m *fileMerger, err error) {
   428  	var fileHeads []*mergeFileReader
   429  	defer func() {
   430  		if err != nil {
   431  			for _, fh := range fileHeads {
   432  				fh.iter.Close()
   433  			}
   434  		}
   435  	}()
   436  
   437  	for _, f := range files {
   438  		iter, err := f.IterAll(ctx)
   439  		if err != nil {
   440  			return nil, err
   441  		}
   442  		reader, err := newMergeFileReader(ctx, iter)
   443  		if err != nil {
   444  			iter.Close()
   445  			if !errors.Is(err, io.EOF) {
   446  				// empty file excluded from merge queue
   447  				return nil, err
   448  			}
   449  		} else {
   450  			fileHeads = append(fileHeads, reader)
   451  		}
   452  	}
   453  
   454  	mq := &mergeQueue{files: fileHeads, keyCmp: keyCmp}
   455  	heap.Init(mq)
   456  
   457  	return &fileMerger{
   458  		mq:  mq,
   459  		out: target,
   460  	}, nil
   461  }
   462  
   463  func (m *fileMerger) run(ctx context.Context) error {
   464  	for {
   465  		if m.mq.Len() == 0 {
   466  			return m.out.buf.Flush()
   467  		}
   468  		reader := heap.Pop(m.mq).(*mergeFileReader)
   469  		m.out.append(reader.head)
   470  		if ok, err := reader.next(ctx); ok {
   471  			heap.Push(m.mq, reader)
   472  		} else {
   473  			reader.iter.Close()
   474  			if err != nil {
   475  				return err
   476  			}
   477  		}
   478  	}
   479  }
   480  
   481  func (m *fileMerger) Close() {
   482  	if m != nil && m.mq != nil {
   483  		for _, f := range m.mq.files {
   484  			f.iter.Close()
   485  		}
   486  	}
   487  }