vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/memory_sort.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"container/heap"
    21  	"context"
    22  	"fmt"
    23  	"math"
    24  	"reflect"
    25  	"sort"
    26  	"strings"
    27  
    28  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    29  
    30  	"vitess.io/vitess/go/sqltypes"
    31  	querypb "vitess.io/vitess/go/vt/proto/query"
    32  )
    33  
    34  var _ Primitive = (*MemorySort)(nil)
    35  
    36  // MemorySort is a primitive that performs in-memory sorting.
    37  type MemorySort struct {
    38  	UpperLimit evalengine.Expr
    39  	OrderBy    []OrderByParams
    40  	Input      Primitive
    41  
    42  	// TruncateColumnCount specifies the number of columns to return
    43  	// in the final result. Rest of the columns are truncated
    44  	// from the result received. If 0, no truncation happens.
    45  	TruncateColumnCount int `json:",omitempty"`
    46  }
    47  
    48  // RouteType returns a description of the query routing type used by the primitive.
    49  func (ms *MemorySort) RouteType() string {
    50  	return ms.Input.RouteType()
    51  }
    52  
    53  // GetKeyspaceName specifies the Keyspace that this primitive routes to.
    54  func (ms *MemorySort) GetKeyspaceName() string {
    55  	return ms.Input.GetKeyspaceName()
    56  }
    57  
    58  // GetTableName specifies the table that this primitive routes to.
    59  func (ms *MemorySort) GetTableName() string {
    60  	return ms.Input.GetTableName()
    61  }
    62  
    63  // SetTruncateColumnCount sets the truncate column count.
    64  func (ms *MemorySort) SetTruncateColumnCount(count int) {
    65  	ms.TruncateColumnCount = count
    66  }
    67  
    68  // TryExecute satisfies the Primitive interface.
    69  func (ms *MemorySort) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    70  	count, err := ms.fetchCount(vcursor, bindVars)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	result, err := vcursor.ExecutePrimitive(ctx, ms.Input, bindVars, wantfields)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	sh := &sortHeap{
    80  		rows:      result.Rows,
    81  		comparers: extractSlices(ms.OrderBy),
    82  	}
    83  	sort.Sort(sh)
    84  	if sh.err != nil {
    85  		return nil, sh.err
    86  	}
    87  	result.Rows = sh.rows
    88  	if len(result.Rows) > count {
    89  		result.Rows = result.Rows[:count]
    90  	}
    91  	return result.Truncate(ms.TruncateColumnCount), nil
    92  }
    93  
    94  // TryStreamExecute satisfies the Primitive interface.
    95  func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
    96  	count, err := ms.fetchCount(vcursor, bindVars)
    97  	if err != nil {
    98  		return err
    99  	}
   100  
   101  	cb := func(qr *sqltypes.Result) error {
   102  		return callback(qr.Truncate(ms.TruncateColumnCount))
   103  	}
   104  
   105  	// You have to reverse the ordering because the highest values
   106  	// must be dropped once the upper limit is reached.
   107  	sh := &sortHeap{
   108  		comparers: extractSlices(ms.OrderBy),
   109  		reverse:   true,
   110  	}
   111  	err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
   112  		if len(qr.Fields) != 0 {
   113  			if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil {
   114  				return err
   115  			}
   116  		}
   117  		for _, row := range qr.Rows {
   118  			heap.Push(sh, row)
   119  			// Remove the highest element from the heap if the size is more than the count
   120  			// This optimization means that the maximum size of the heap is going to be (count + 1)
   121  			for len(sh.rows) > count {
   122  				_ = heap.Pop(sh)
   123  			}
   124  		}
   125  		if vcursor.ExceedsMaxMemoryRows(len(sh.rows)) {
   126  			return fmt.Errorf("in-memory row count exceeded allowed limit of %d", vcursor.MaxMemoryRows())
   127  		}
   128  		return nil
   129  	})
   130  	if err != nil {
   131  		return err
   132  	}
   133  	if sh.err != nil {
   134  		return sh.err
   135  	}
   136  	// Set ordering to normal for the final ordering.
   137  	sh.reverse = false
   138  	sort.Sort(sh)
   139  	if sh.err != nil {
   140  		// Unreachable.
   141  		return sh.err
   142  	}
   143  	return cb(&sqltypes.Result{Rows: sh.rows})
   144  }
   145  
   146  // GetFields satisfies the Primitive interface.
   147  func (ms *MemorySort) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   148  	return ms.Input.GetFields(ctx, vcursor, bindVars)
   149  }
   150  
   151  // Inputs returns the input to memory sort
   152  func (ms *MemorySort) Inputs() []Primitive {
   153  	return []Primitive{ms.Input}
   154  }
   155  
   156  // NeedsTransaction implements the Primitive interface
   157  func (ms *MemorySort) NeedsTransaction() bool {
   158  	return ms.Input.NeedsTransaction()
   159  }
   160  
   161  func (ms *MemorySort) fetchCount(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (int, error) {
   162  	if ms.UpperLimit == nil {
   163  		return math.MaxInt64, nil
   164  	}
   165  	env := evalengine.EnvWithBindVars(bindVars, vcursor.ConnCollation())
   166  	resolved, err := env.Evaluate(ms.UpperLimit)
   167  	if err != nil {
   168  		return 0, err
   169  	}
   170  	num, err := resolved.Value().ToUint64()
   171  	if err != nil {
   172  		return 0, err
   173  	}
   174  	count := int(num)
   175  	if count < 0 {
   176  		return 0, fmt.Errorf("requested limit is out of range: %v", num)
   177  	}
   178  	return count, nil
   179  }
   180  
   181  func (ms *MemorySort) description() PrimitiveDescription {
   182  	orderByIndexes := GenericJoin(ms.OrderBy, orderByParamsToString)
   183  	other := map[string]any{"OrderBy": orderByIndexes}
   184  	if ms.TruncateColumnCount > 0 {
   185  		other["ResultColumns"] = ms.TruncateColumnCount
   186  	}
   187  	return PrimitiveDescription{
   188  		OperatorType: "Sort",
   189  		Variant:      "Memory",
   190  		Other:        other,
   191  	}
   192  }
   193  
   194  func orderByParamsToString(i any) string {
   195  	return i.(OrderByParams).String()
   196  }
   197  
   198  // GenericJoin will iterate over arrays, slices or maps, and executes the f function to get a
   199  // string representation of each element, and then uses strings.Join() join all the strings into a single one
   200  func GenericJoin(input any, f func(any) string) string {
   201  	sl := reflect.ValueOf(input)
   202  	var keys []string
   203  	switch sl.Kind() {
   204  	case reflect.Slice:
   205  		for i := 0; i < sl.Len(); i++ {
   206  			keys = append(keys, f(sl.Index(i).Interface()))
   207  		}
   208  	case reflect.Map:
   209  		for _, k := range sl.MapKeys() {
   210  			keys = append(keys, f(k.Interface()))
   211  		}
   212  	default:
   213  		panic("GenericJoin doesn't know how to deal with " + sl.Kind().String())
   214  	}
   215  	return strings.Join(keys, ", ")
   216  }
   217  
   218  // sortHeap is sorted based on the orderBy params.
   219  // Implementation is similar to scatterHeap
   220  type sortHeap struct {
   221  	rows      [][]sqltypes.Value
   222  	comparers []*comparer
   223  	reverse   bool
   224  	err       error
   225  }
   226  
   227  // Len satisfies sort.Interface and heap.Interface.
   228  func (sh *sortHeap) Len() int {
   229  	return len(sh.rows)
   230  }
   231  
   232  // Less satisfies sort.Interface and heap.Interface.
   233  func (sh *sortHeap) Less(i, j int) bool {
   234  	for _, c := range sh.comparers {
   235  		if sh.err != nil {
   236  			return true
   237  		}
   238  		cmp, err := c.compare(sh.rows[i], sh.rows[j])
   239  		if err != nil {
   240  			sh.err = err
   241  			return true
   242  		}
   243  		if cmp == 0 {
   244  			continue
   245  		}
   246  		if sh.reverse {
   247  			cmp = -cmp
   248  		}
   249  		return cmp < 0
   250  	}
   251  	return true
   252  }
   253  
   254  // Swap satisfies sort.Interface and heap.Interface.
   255  func (sh *sortHeap) Swap(i, j int) {
   256  	sh.rows[i], sh.rows[j] = sh.rows[j], sh.rows[i]
   257  }
   258  
   259  // Push satisfies heap.Interface.
   260  func (sh *sortHeap) Push(x any) {
   261  	sh.rows = append(sh.rows, x.([]sqltypes.Value))
   262  }
   263  
   264  // Pop satisfies heap.Interface.
   265  func (sh *sortHeap) Pop() any {
   266  	n := len(sh.rows)
   267  	x := sh.rows[n-1]
   268  	sh.rows = sh.rows[:n-1]
   269  	return x
   270  }