kythe.io@v0.0.68-0.20240422202219-7225dbc01741/kythe/go/util/disksort/disksort.go (about)

     1  /*
     2   * Copyright 2015 The Kythe Authors. All rights reserved.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *   http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  // Package disksort implements sorting algorithms for sets of data too large to
    18  // fit fully in-memory.  If the number of elements becomes to large, data are
    19  // paged onto the disk.
    20  package disksort // import "kythe.io/kythe/go/util/disksort"
    21  
    22  import (
    23  	"bufio"
    24  	"container/heap"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"os"
    30  	"path/filepath"
    31  	"strings"
    32  
    33  	"kythe.io/kythe/go/platform/delimited"
    34  	"kythe.io/kythe/go/util/log"
    35  	"kythe.io/kythe/go/util/sortutil"
    36  
    37  	"github.com/golang/snappy"
    38  )
    39  
    40  // Interface is the standard interface for disk sorting algorithms.  Each
    41  // element in the set of data to be sorted is added to the sorter with Add.
    42  // Once all elements are added, Read can then be called to retrieve each element
    43  // sequentially in sorted order.  Once Read is called, no other operations on
    44  // the sorter are allowed.
    45  type Interface interface {
    46  	// Add adds a new element to the set of data to be sorted.
    47  	Add(i any) error
    48  
    49  	// Iterator returns an Iterator to read each of the elements previously added
    50  	// to the set of data to be sorted.  Once Iterator is called, no more data may
    51  	// be added to the sorter.  This is a lower-level version of Read.  Iterator
    52  	// and Read may only be called once.
    53  	Iterator() (Iterator, error)
    54  
    55  	// Read calls f on each elements previously added the set of data to be
    56  	// sorted.  If f returns an error, it is returned immediately and f is no
    57  	// longer called.  Once Read is called, no more data may be added to the
    58  	// sorter.  Iterator and Read may only be called once.
    59  	Read(f func(any) error) error
    60  }
    61  
    62  // An Iterator reads each element, in order, in a sorted dataset.
    63  type Iterator interface {
    64  	// Next returns the next ordered element.  If none exist, an io.EOF error is
    65  	// returned.
    66  	Next() (any, error)
    67  
    68  	// Close releases all of the Iterator's used resources.  Each Iterator must be
    69  	// closed after the client's last call to Next or stray temporary files may be
    70  	// left on disk.
    71  	Close() error
    72  }
    73  
    74  // Marshaler is an interface to functions that can binary encode/decode
    75  // elements.
    76  type Marshaler interface {
    77  	// Marshal binary encodes the given element.
    78  	Marshal(any) ([]byte, error)
    79  
    80  	// Unmarshal decodes the given encoding of an element.
    81  	Unmarshal([]byte) (any, error)
    82  }
    83  
    84  type mergeSorter struct {
    85  	opts MergeOptions
    86  
    87  	buffer  []any
    88  	workDir string
    89  	shards  []string
    90  
    91  	bufferSize int
    92  
    93  	finalized bool
    94  }
    95  
    96  // DefaultMaxInMemory is the default number of elements to keep in-memory during
    97  // a merge sort.
    98  const DefaultMaxInMemory = 32000
    99  
   100  // DefaultMaxBytesInMemory is the default maximum total size of elements to keep
   101  // in-memory during a merge sort.
   102  const DefaultMaxBytesInMemory = 1024 * 1024 * 256
   103  
   104  // MergeOptions specifies how to sort elements.
   105  type MergeOptions struct {
   106  	// Name is optionally used as part of the path for temporary file shards.
   107  	Name string
   108  
   109  	// Lesser is the comparison function for sorting the given elements.
   110  	Lesser sortutil.Lesser
   111  	// Marshaler is used for encoding/decoding elements in temporary file shards.
   112  	Marshaler Marshaler
   113  
   114  	// WorkDir is the directory used for writing temporary file shards.  If empty,
   115  	// the default directory for temporary files is used.
   116  	WorkDir string
   117  
   118  	// MaxInMemory is the maximum number of elements to keep in-memory before
   119  	// paging them to a temporary file shard.  If non-positive, DefaultMaxInMemory
   120  	// is used.
   121  	MaxInMemory int
   122  
   123  	// MaxBytesInMemory is the maximum total size of elements to keep in-memory
   124  	// before paging them to a temporary file shard.  An element's size is
   125  	// determined by its `Size() int` method. If non-positive,
   126  	// DefaultMaxBytesInMemory is used.
   127  	MaxBytesInMemory int
   128  
   129  	// CompressShards determines whether the temporary file shards should be
   130  	// compressed.
   131  	CompressShards bool
   132  }
   133  
   134  type sizer interface{ Size() int }
   135  
   136  // NewMergeSorter returns a new disk sorter using a mergesort algorithm.
   137  func NewMergeSorter(opts MergeOptions) (Interface, error) {
   138  	if opts.Lesser == nil {
   139  		return nil, errors.New("missing Lesser")
   140  	} else if opts.Marshaler == nil {
   141  		return nil, errors.New("missing Marshaler")
   142  	}
   143  
   144  	name := strings.Replace(opts.Name, string(filepath.Separator), ".", -1)
   145  	if name == "" {
   146  		name = "external.merge.sort"
   147  	}
   148  	dir, err := ioutil.TempDir(opts.WorkDir, name)
   149  	if err != nil {
   150  		return nil, fmt.Errorf("error creating temporary work directory: %v", err)
   151  	}
   152  
   153  	if opts.MaxInMemory <= 0 {
   154  		opts.MaxInMemory = DefaultMaxInMemory
   155  	}
   156  	if opts.MaxBytesInMemory <= 0 {
   157  		opts.MaxBytesInMemory = DefaultMaxBytesInMemory
   158  	}
   159  
   160  	return &mergeSorter{
   161  		opts:    opts,
   162  		buffer:  make([]any, 0, opts.MaxInMemory),
   163  		workDir: dir,
   164  	}, nil
   165  }
   166  
   167  var (
   168  	// ErrAlreadyFinalized is returned from Interface#Add and Interface#Read when
   169  	// Interface#Read has already been called, freezing the sort's inputs/outputs.
   170  	ErrAlreadyFinalized = errors.New("sorter already finalized")
   171  )
   172  
   173  // Add implements part of the Interface interface.
   174  func (m *mergeSorter) Add(i any) error {
   175  	if m.finalized {
   176  		return ErrAlreadyFinalized
   177  	}
   178  
   179  	m.buffer = append(m.buffer, i)
   180  	if sizer, ok := i.(sizer); ok {
   181  		m.bufferSize += sizer.Size()
   182  	}
   183  
   184  	if len(m.buffer) >= m.opts.MaxInMemory || m.bufferSize >= m.opts.MaxBytesInMemory {
   185  		return m.dumpShard()
   186  	}
   187  	return nil
   188  }
   189  
   190  type mergeIterator struct {
   191  	buffer []any
   192  
   193  	merger    *sortutil.ByLesser
   194  	marshaler Marshaler
   195  	workDir   string
   196  }
   197  
   198  const ioBufferSize = 2 << 15
   199  
   200  // Iterator implements part of the Interface interface.
   201  func (m *mergeSorter) Iterator() (iter Iterator, err error) {
   202  	if m.finalized {
   203  		return nil, ErrAlreadyFinalized
   204  	}
   205  	m.finalized = true // signal that further operations should fail
   206  
   207  	it := &mergeIterator{workDir: m.workDir, marshaler: m.opts.Marshaler}
   208  
   209  	if len(m.shards) == 0 {
   210  		// Fast path for a single, in-memory shard
   211  		it.buffer, m.buffer = m.buffer, nil
   212  		sortutil.Sort(m.opts.Lesser, it.buffer)
   213  		return it, nil
   214  	}
   215  
   216  	// This is a heap storing the head of each shard.
   217  	merger := &sortutil.ByLesser{
   218  		Lesser: &mergeElementLesser{Lesser: m.opts.Lesser},
   219  	}
   220  	it.merger = merger
   221  
   222  	defer func() {
   223  		// Try to cleanup on errors
   224  		if err != nil {
   225  			if cErr := it.Close(); cErr != nil {
   226  				log.Warningf("error closing Iterator after error: %v", cErr)
   227  			}
   228  		}
   229  	}()
   230  
   231  	// Push all of the in-memory elements into the merger heap.
   232  	for _, el := range m.buffer {
   233  		heap.Push(merger, &mergeElement{el: el})
   234  	}
   235  	m.buffer = nil
   236  
   237  	// Initialize the merger heap by reading the first element of each shard.
   238  	for _, shard := range m.shards {
   239  		f, err := os.OpenFile(shard, os.O_RDONLY, shardFileMode)
   240  		if err != nil {
   241  			return nil, fmt.Errorf("error opening shard %q: %v", shard, err)
   242  		}
   243  
   244  		var r io.Reader
   245  		if m.opts.CompressShards {
   246  			r = snappy.NewReader(f)
   247  		} else {
   248  			r = bufio.NewReaderSize(f, ioBufferSize)
   249  		}
   250  
   251  		rd := delimited.NewReader(r)
   252  		first, err := rd.Next()
   253  		if err != nil {
   254  			f.Close()
   255  			return nil, fmt.Errorf("error reading beginning of shard %q: %v", shard, err)
   256  		}
   257  		el, err := m.opts.Marshaler.Unmarshal(first)
   258  		if err != nil {
   259  			f.Close()
   260  			return nil, fmt.Errorf("error unmarshaling beginning of shard %q: %v", shard, err)
   261  		}
   262  
   263  		heap.Push(merger, &mergeElement{el: el, rd: rd, f: f})
   264  	}
   265  
   266  	return it, nil
   267  }
   268  
   269  // Next implements part of the Iterator interface.
   270  func (i *mergeIterator) Next() (any, error) {
   271  	if i.merger == nil {
   272  		// Fast path for a single, in-memory shard
   273  		if len(i.buffer) == 0 {
   274  			return nil, io.EOF
   275  		}
   276  		val := i.buffer[0]
   277  		i.buffer = i.buffer[1:]
   278  		return val, nil
   279  	}
   280  
   281  	if i.merger.Len() == 0 {
   282  		return nil, io.EOF
   283  	}
   284  
   285  	// While the merger heap is non-empty:
   286  	//   x := peek the head of the heap
   287  	//   pass x.el to the user-specific function
   288  	//   read the next element in x.rd; fix the merger heap order
   289  	x := i.merger.Slice[0].(*mergeElement)
   290  	el := x.el
   291  
   292  	if x.rd == nil {
   293  		heap.Pop(i.merger)
   294  	} else {
   295  		// Read and parse the next value on the same shard
   296  		rec, err := x.rd.Next()
   297  		if err != nil {
   298  			_ = x.f.Close()           // ignore errors (file is only open for reading)
   299  			_ = os.Remove(x.f.Name()) // ignore errors (os.RemoveAll used in Close)
   300  			heap.Pop(i.merger)
   301  			if err != io.EOF {
   302  				return nil, fmt.Errorf("error reading shard: %v", err)
   303  			}
   304  		} else {
   305  			next, err := i.marshaler.Unmarshal(rec)
   306  			if err != nil {
   307  				return nil, fmt.Errorf("error unmarshaling element: %v", err)
   308  			}
   309  
   310  			// Reuse mergeElement, reorder it in the merger heap with the next value
   311  			x.el = next
   312  			heap.Fix(i.merger, 0)
   313  		}
   314  	}
   315  
   316  	return el, nil
   317  }
   318  
   319  // Close implements part of the Iterator interface.
   320  func (i *mergeIterator) Close() error {
   321  	i.buffer = nil
   322  	if i.merger != nil {
   323  		for _, x := range i.merger.Slice {
   324  			el := x.(*mergeElement)
   325  			if el.f != nil {
   326  				el.f.Close() // ignore errors (file is only open for reading)
   327  			}
   328  		}
   329  		i.merger = nil
   330  	}
   331  	if rmErr := os.RemoveAll(i.workDir); rmErr != nil {
   332  		return fmt.Errorf("error removing temporary directory %q: %v", i.workDir, rmErr)
   333  	}
   334  	return nil
   335  }
   336  
   337  // Read implements part of the Interface interface.
   338  func (m *mergeSorter) Read(f func(i any) error) (err error) {
   339  	it, err := m.Iterator()
   340  	if err != nil {
   341  		return err
   342  	}
   343  	defer func() {
   344  		if cErr := it.Close(); cErr != nil {
   345  			if err == nil {
   346  				err = cErr
   347  			} else {
   348  				log.Warningf("error closing Iterator: %v", cErr)
   349  			}
   350  		}
   351  	}()
   352  	for {
   353  		val, err := it.Next()
   354  		if err == io.EOF {
   355  			return nil
   356  		} else if err != nil {
   357  			return err
   358  		}
   359  		if err := f(val); err != nil {
   360  			return err
   361  		}
   362  	}
   363  }
   364  
   365  const shardFileMode = 0600 | os.ModeExclusive | os.ModeAppend | os.ModeTemporary | os.ModeSticky
   366  
   367  func (m *mergeSorter) dumpShard() (err error) {
   368  	defer func() {
   369  		m.buffer = make([]any, 0, m.opts.MaxInMemory)
   370  		m.bufferSize = 0
   371  	}()
   372  
   373  	// Create a new shard file
   374  	shardPath := filepath.Join(m.workDir, fmt.Sprintf("shard.%.6d", len(m.shards)))
   375  	file, err := os.OpenFile(shardPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, shardFileMode)
   376  	if err != nil {
   377  		return fmt.Errorf("error creating shard: %v", err)
   378  	}
   379  	defer func() {
   380  		replaceErrIfNil(&err, "error closing shard: %v", file.Close())
   381  	}()
   382  
   383  	// Buffer writing to the shard
   384  	var buf interface {
   385  		io.Writer
   386  		Flush() error
   387  	}
   388  	if m.opts.CompressShards {
   389  		buf = snappy.NewBufferedWriter(file)
   390  	} else {
   391  		buf = bufio.NewWriterSize(file, ioBufferSize)
   392  	}
   393  
   394  	defer func() {
   395  		replaceErrIfNil(&err, "error flushing shard: %v", buf.Flush())
   396  	}()
   397  
   398  	// Sort the in-memory buffer of elements
   399  	sortutil.Sort(m.opts.Lesser, m.buffer)
   400  
   401  	// Write each element of the in-memory to shard file, in sorted order
   402  	wr := delimited.NewWriter(buf)
   403  	for len(m.buffer) > 0 {
   404  		rec, err := m.opts.Marshaler.Marshal(m.buffer[0])
   405  		if err != nil {
   406  			return fmt.Errorf("marshaling error: %v", err)
   407  		}
   408  		if _, err := wr.WriteRecord(rec); err != nil {
   409  			return fmt.Errorf("writing error: %v", err)
   410  		}
   411  		m.buffer = m.buffer[1:]
   412  	}
   413  
   414  	m.shards = append(m.shards, shardPath)
   415  	return nil
   416  }
   417  
   418  func replaceErrIfNil(err *error, s string, newError error) {
   419  	if newError != nil && *err == nil {
   420  		*err = fmt.Errorf(s, newError)
   421  	}
   422  }
   423  
   424  type mergeElement struct {
   425  	el any
   426  	rd *delimited.Reader
   427  	f  *os.File
   428  }
   429  
   430  type mergeElementLesser struct{ sortutil.Lesser }
   431  
   432  // Less implements the sortutil.Lesser interface.
   433  func (m *mergeElementLesser) Less(a, b any) bool {
   434  	x, y := a.(*mergeElement), b.(*mergeElement)
   435  	return m.Lesser.Less(x.el, y.el)
   436  }