vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/merge_sort.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"container/heap"
    21  	"context"
    22  	"io"
    23  
    24  	"vitess.io/vitess/go/mysql"
    25  
    26  	"vitess.io/vitess/go/sqltypes"
    27  
    28  	querypb "vitess.io/vitess/go/vt/proto/query"
    29  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    30  	"vitess.io/vitess/go/vt/vterrors"
    31  )
    32  
    33  // StreamExecutor is a subset of Primitive that MergeSort
    34  // requires its inputs to satisfy.
    35  type StreamExecutor interface {
    36  	StreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error
    37  }
    38  
    39  var _ Primitive = (*MergeSort)(nil)
    40  
    41  // MergeSort performs a merge-sort of rows returned by each Input. This should
    42  // only be used for StreamExecute. One row from each stream is added to the
    43  // merge-sorter heap. Every time a value is pulled out of the heap,
    44  // a new value is added to it from the stream that was the source of the value that
    45  // was pulled out. Since the input streams are sorted the same way that the heap is
    46  // sorted, this guarantees that the merged stream will also be sorted the same way.
    47  // MergeSort only supports the StreamExecute function of a Primitive. So, it cannot
    48  // be used like other Primitives in VTGate. However, it satisfies the Primitive API
    49  // so that vdiff can use it. In that situation, only StreamExecute is used.
    50  type MergeSort struct {
    51  	Primitives              []StreamExecutor
    52  	OrderBy                 []OrderByParams
    53  	ScatterErrorsAsWarnings bool
    54  	noInputs
    55  	noTxNeeded
    56  }
    57  
    58  // RouteType satisfies Primitive.
    59  func (ms *MergeSort) RouteType() string { return "MergeSort" }
    60  
    61  // GetKeyspaceName satisfies Primitive.
    62  func (ms *MergeSort) GetKeyspaceName() string { return "" }
    63  
    64  // GetTableName satisfies Primitive.
    65  func (ms *MergeSort) GetTableName() string { return "" }
    66  
    67  // TryExecute is not supported.
    68  func (ms *MergeSort) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    69  	return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] Execute is not reachable")
    70  }
    71  
    72  // GetFields is not supported.
    73  func (ms *MergeSort) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
    74  	return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] GetFields is not reachable")
    75  }
    76  
    77  // TryStreamExecute performs a streaming exec.
    78  func (ms *MergeSort) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
    79  	var cancel context.CancelFunc
    80  	ctx, cancel = context.WithCancel(ctx)
    81  	defer cancel()
    82  	gotFields := wantfields
    83  	handles := make([]*streamHandle, len(ms.Primitives))
    84  	for i, input := range ms.Primitives {
    85  		handles[i] = runOneStream(ctx, vcursor, input, bindVars, gotFields)
    86  		if !ms.ScatterErrorsAsWarnings {
    87  			// we only need the fields from the first input, unless we allow ScatterErrorsAsWarnings.
    88  			// in that case, we need to ask all the inputs for fields - we don't know which will return anything
    89  			gotFields = false
    90  		}
    91  	}
    92  
    93  	if wantfields {
    94  		err := ms.getStreamingFields(handles, callback)
    95  		if err != nil {
    96  			return err
    97  		}
    98  	}
    99  
   100  	comparers := extractSlices(ms.OrderBy)
   101  	sh := &scatterHeap{
   102  		rows:      make([]streamRow, 0, len(handles)),
   103  		comparers: comparers,
   104  	}
   105  
   106  	var errs []error
   107  	// Prime the heap. One element must be pulled from
   108  	// each stream.
   109  	for i, handle := range handles {
   110  		select {
   111  		case row, ok := <-handle.row:
   112  			if !ok {
   113  				if handle.err != nil {
   114  					if ms.ScatterErrorsAsWarnings {
   115  						errs = append(errs, handle.err)
   116  						break
   117  					}
   118  					return handle.err
   119  				}
   120  				// It's possible that a stream returns no rows.
   121  				// If so, don't add anything to the heap.
   122  				continue
   123  			}
   124  			sh.rows = append(sh.rows, streamRow{row: row, id: i})
   125  		case <-ctx.Done():
   126  			return ctx.Err()
   127  		}
   128  	}
   129  	heap.Init(sh)
   130  	if sh.err != nil {
   131  		return sh.err
   132  	}
   133  
   134  	// Iterate one row at a time:
   135  	// Pop a row from the heap and send it out.
   136  	// Then pull the next row from the stream the popped
   137  	// row came from and push it into the heap.
   138  	for len(sh.rows) != 0 {
   139  		sr := heap.Pop(sh).(streamRow)
   140  		if sh.err != nil {
   141  			// Unreachable: This should never fail.
   142  			return sh.err
   143  		}
   144  		if err := callback(&sqltypes.Result{Rows: [][]sqltypes.Value{sr.row}}); err != nil {
   145  			return err
   146  		}
   147  
   148  		select {
   149  		case row, ok := <-handles[sr.id].row:
   150  			if !ok {
   151  				if handles[sr.id].err != nil {
   152  					return handles[sr.id].err
   153  				}
   154  				continue
   155  			}
   156  			sr.row = row
   157  			heap.Push(sh, sr)
   158  			if sh.err != nil {
   159  				return sh.err
   160  			}
   161  		case <-ctx.Done():
   162  			return ctx.Err()
   163  		}
   164  	}
   165  
   166  	err := vterrors.Aggregate(errs)
   167  	if err != nil && ms.ScatterErrorsAsWarnings && len(errs) < len(handles) {
   168  		// we got errors, but not all shards failed, so we can hide the error and just warn instead
   169  		partialSuccessScatterQueries.Add(1)
   170  		sErr := mysql.NewSQLErrorFromError(err).(*mysql.SQLError)
   171  		vcursor.Session().RecordWarning(&querypb.QueryWarning{Code: uint32(sErr.Num), Message: err.Error()})
   172  		return nil
   173  	}
   174  	return err
   175  }
   176  
   177  func (ms *MergeSort) getStreamingFields(handles []*streamHandle, callback func(*sqltypes.Result) error) error {
   178  	var fields []*querypb.Field
   179  
   180  	if ms.ScatterErrorsAsWarnings {
   181  		for _, handle := range handles {
   182  			// Fetch field info from just one stream.
   183  			fields = <-handle.fields
   184  			// If fields is nil, it means there was an error.
   185  			if fields != nil {
   186  				break
   187  			}
   188  		}
   189  	} else {
   190  		// Fetch field info from just one stream.
   191  		fields = <-handles[0].fields
   192  	}
   193  	if fields == nil {
   194  		// something went wrong. need to figure out where the error can be
   195  		if !ms.ScatterErrorsAsWarnings {
   196  			return handles[0].err
   197  		}
   198  
   199  		var errs []error
   200  		for _, handle := range handles {
   201  			errs = append(errs, handle.err)
   202  		}
   203  		return vterrors.Aggregate(errs)
   204  	}
   205  
   206  	if err := callback(&sqltypes.Result{Fields: fields}); err != nil {
   207  		return err
   208  	}
   209  	return nil
   210  }
   211  
   212  func (ms *MergeSort) description() PrimitiveDescription {
   213  	other := map[string]any{
   214  		"OrderBy": ms.OrderBy,
   215  	}
   216  	return PrimitiveDescription{
   217  		OperatorType: "Sort",
   218  		Variant:      "Merge",
   219  		Other:        other,
   220  	}
   221  }
   222  
   223  // streamHandle is the rendez-vous point between each stream and the merge-sorter.
   224  // The fields channel is used by the stream to transmit the field info, which
   225  // is the first packet. Following this, the stream sends each row to the row
   226  // channel. At the end of the stream, fields and row are closed. If there
   227  // was an error, err is set before the channels are closed. The MergeSort
   228  // routine that pulls the rows out of each streamHandle can abort the stream
   229  // by calling canceling the context.
   230  type streamHandle struct {
   231  	fields chan []*querypb.Field
   232  	row    chan []sqltypes.Value
   233  	err    error
   234  }
   235  
   236  // runOnestream starts a streaming query on one shard, and returns a streamHandle for it.
   237  func runOneStream(ctx context.Context, vcursor VCursor, input StreamExecutor, bindVars map[string]*querypb.BindVariable, wantfields bool) *streamHandle {
   238  	handle := &streamHandle{
   239  		fields: make(chan []*querypb.Field, 1),
   240  		row:    make(chan []sqltypes.Value, 10),
   241  	}
   242  
   243  	go func() {
   244  		defer close(handle.fields)
   245  		defer close(handle.row)
   246  
   247  		handle.err = input.StreamExecute(ctx, vcursor, bindVars, wantfields, func(qr *sqltypes.Result) error {
   248  			if len(qr.Fields) != 0 {
   249  				select {
   250  				case handle.fields <- qr.Fields:
   251  				case <-ctx.Done():
   252  					return io.EOF
   253  				}
   254  			}
   255  
   256  			for _, row := range qr.Rows {
   257  				select {
   258  				case handle.row <- row:
   259  				case <-ctx.Done():
   260  					return io.EOF
   261  				}
   262  			}
   263  			return nil
   264  		})
   265  	}()
   266  
   267  	return handle
   268  }
   269  
   270  // A streamRow represents a row identified by the stream
   271  // it came from. It is used as an element in scatterHeap.
   272  type streamRow struct {
   273  	row []sqltypes.Value
   274  	id  int
   275  }
   276  
   277  // scatterHeap is the heap that is used for merge-sorting.
   278  // You can push streamRow elements into it. Popping an
   279  // element will return the one with the lowest value
   280  // as defined by the orderBy criteria. If a comparison
   281  // yielded an error, err is set. This must be checked
   282  // after every heap operation.
   283  type scatterHeap struct {
   284  	rows      []streamRow
   285  	err       error
   286  	comparers []*comparer
   287  }
   288  
   289  // Len satisfies sort.Interface and heap.Interface.
   290  func (sh *scatterHeap) Len() int {
   291  	return len(sh.rows)
   292  }
   293  
   294  // Less satisfies sort.Interface and heap.Interface.
   295  func (sh *scatterHeap) Less(i, j int) bool {
   296  	for _, c := range sh.comparers {
   297  		if sh.err != nil {
   298  			return true
   299  		}
   300  		// First try to compare the columns that we want to order
   301  		cmp, err := c.compare(sh.rows[i].row, sh.rows[j].row)
   302  		if err != nil {
   303  			sh.err = err
   304  			return true
   305  		}
   306  		if cmp == 0 {
   307  			continue
   308  		}
   309  		return cmp < 0
   310  	}
   311  	return true
   312  }
   313  
   314  // Swap satisfies sort.Interface and heap.Interface.
   315  func (sh *scatterHeap) Swap(i, j int) {
   316  	sh.rows[i], sh.rows[j] = sh.rows[j], sh.rows[i]
   317  }
   318  
   319  // Push satisfies heap.Interface.
   320  func (sh *scatterHeap) Push(x any) {
   321  	sh.rows = append(sh.rows, x.(streamRow))
   322  }
   323  
   324  // Pop satisfies heap.Interface.
   325  func (sh *scatterHeap) Pop() any {
   326  	n := len(sh.rows)
   327  	x := sh.rows[n-1]
   328  	sh.rows = sh.rows[:n-1]
   329  	return x
   330  }