vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/memory_sort.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "container/heap" 21 "context" 22 "fmt" 23 "math" 24 "reflect" 25 "sort" 26 "strings" 27 28 "vitess.io/vitess/go/vt/vtgate/evalengine" 29 30 "vitess.io/vitess/go/sqltypes" 31 querypb "vitess.io/vitess/go/vt/proto/query" 32 ) 33 34 var _ Primitive = (*MemorySort)(nil) 35 36 // MemorySort is a primitive that performs in-memory sorting. 37 type MemorySort struct { 38 UpperLimit evalengine.Expr 39 OrderBy []OrderByParams 40 Input Primitive 41 42 // TruncateColumnCount specifies the number of columns to return 43 // in the final result. Rest of the columns are truncated 44 // from the result received. If 0, no truncation happens. 45 TruncateColumnCount int `json:",omitempty"` 46 } 47 48 // RouteType returns a description of the query routing type used by the primitive. 49 func (ms *MemorySort) RouteType() string { 50 return ms.Input.RouteType() 51 } 52 53 // GetKeyspaceName specifies the Keyspace that this primitive routes to. 54 func (ms *MemorySort) GetKeyspaceName() string { 55 return ms.Input.GetKeyspaceName() 56 } 57 58 // GetTableName specifies the table that this primitive routes to. 59 func (ms *MemorySort) GetTableName() string { 60 return ms.Input.GetTableName() 61 } 62 63 // SetTruncateColumnCount sets the truncate column count. 64 func (ms *MemorySort) SetTruncateColumnCount(count int) { 65 ms.TruncateColumnCount = count 66 } 67 68 // TryExecute satisfies the Primitive interface. 69 func (ms *MemorySort) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 70 count, err := ms.fetchCount(vcursor, bindVars) 71 if err != nil { 72 return nil, err 73 } 74 75 result, err := vcursor.ExecutePrimitive(ctx, ms.Input, bindVars, wantfields) 76 if err != nil { 77 return nil, err 78 } 79 sh := &sortHeap{ 80 rows: result.Rows, 81 comparers: extractSlices(ms.OrderBy), 82 } 83 sort.Sort(sh) 84 if sh.err != nil { 85 return nil, sh.err 86 } 87 result.Rows = sh.rows 88 if len(result.Rows) > count { 89 result.Rows = result.Rows[:count] 90 } 91 return result.Truncate(ms.TruncateColumnCount), nil 92 } 93 94 // TryStreamExecute satisfies the Primitive interface. 95 func (ms *MemorySort) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 96 count, err := ms.fetchCount(vcursor, bindVars) 97 if err != nil { 98 return err 99 } 100 101 cb := func(qr *sqltypes.Result) error { 102 return callback(qr.Truncate(ms.TruncateColumnCount)) 103 } 104 105 // You have to reverse the ordering because the highest values 106 // must be dropped once the upper limit is reached. 107 sh := &sortHeap{ 108 comparers: extractSlices(ms.OrderBy), 109 reverse: true, 110 } 111 err = vcursor.StreamExecutePrimitive(ctx, ms.Input, bindVars, wantfields, func(qr *sqltypes.Result) error { 112 if len(qr.Fields) != 0 { 113 if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil { 114 return err 115 } 116 } 117 for _, row := range qr.Rows { 118 heap.Push(sh, row) 119 // Remove the highest element from the heap if the size is more than the count 120 // This optimization means that the maximum size of the heap is going to be (count + 1) 121 for len(sh.rows) > count { 122 _ = heap.Pop(sh) 123 } 124 } 125 if vcursor.ExceedsMaxMemoryRows(len(sh.rows)) { 126 return fmt.Errorf("in-memory row count exceeded allowed limit of %d", vcursor.MaxMemoryRows()) 127 } 128 return nil 129 }) 130 if err != nil { 131 return err 132 } 133 if sh.err != nil { 134 return sh.err 135 } 136 // Set ordering to normal for the final ordering. 137 sh.reverse = false 138 sort.Sort(sh) 139 if sh.err != nil { 140 // Unreachable. 141 return sh.err 142 } 143 return cb(&sqltypes.Result{Rows: sh.rows}) 144 } 145 146 // GetFields satisfies the Primitive interface. 147 func (ms *MemorySort) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 148 return ms.Input.GetFields(ctx, vcursor, bindVars) 149 } 150 151 // Inputs returns the input to memory sort 152 func (ms *MemorySort) Inputs() []Primitive { 153 return []Primitive{ms.Input} 154 } 155 156 // NeedsTransaction implements the Primitive interface 157 func (ms *MemorySort) NeedsTransaction() bool { 158 return ms.Input.NeedsTransaction() 159 } 160 161 func (ms *MemorySort) fetchCount(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (int, error) { 162 if ms.UpperLimit == nil { 163 return math.MaxInt64, nil 164 } 165 env := evalengine.EnvWithBindVars(bindVars, vcursor.ConnCollation()) 166 resolved, err := env.Evaluate(ms.UpperLimit) 167 if err != nil { 168 return 0, err 169 } 170 num, err := resolved.Value().ToUint64() 171 if err != nil { 172 return 0, err 173 } 174 count := int(num) 175 if count < 0 { 176 return 0, fmt.Errorf("requested limit is out of range: %v", num) 177 } 178 return count, nil 179 } 180 181 func (ms *MemorySort) description() PrimitiveDescription { 182 orderByIndexes := GenericJoin(ms.OrderBy, orderByParamsToString) 183 other := map[string]any{"OrderBy": orderByIndexes} 184 if ms.TruncateColumnCount > 0 { 185 other["ResultColumns"] = ms.TruncateColumnCount 186 } 187 return PrimitiveDescription{ 188 OperatorType: "Sort", 189 Variant: "Memory", 190 Other: other, 191 } 192 } 193 194 func orderByParamsToString(i any) string { 195 return i.(OrderByParams).String() 196 } 197 198 // GenericJoin will iterate over arrays, slices or maps, and executes the f function to get a 199 // string representation of each element, and then uses strings.Join() join all the strings into a single one 200 func GenericJoin(input any, f func(any) string) string { 201 sl := reflect.ValueOf(input) 202 var keys []string 203 switch sl.Kind() { 204 case reflect.Slice: 205 for i := 0; i < sl.Len(); i++ { 206 keys = append(keys, f(sl.Index(i).Interface())) 207 } 208 case reflect.Map: 209 for _, k := range sl.MapKeys() { 210 keys = append(keys, f(k.Interface())) 211 } 212 default: 213 panic("GenericJoin doesn't know how to deal with " + sl.Kind().String()) 214 } 215 return strings.Join(keys, ", ") 216 } 217 218 // sortHeap is sorted based on the orderBy params. 219 // Implementation is similar to scatterHeap 220 type sortHeap struct { 221 rows [][]sqltypes.Value 222 comparers []*comparer 223 reverse bool 224 err error 225 } 226 227 // Len satisfies sort.Interface and heap.Interface. 228 func (sh *sortHeap) Len() int { 229 return len(sh.rows) 230 } 231 232 // Less satisfies sort.Interface and heap.Interface. 233 func (sh *sortHeap) Less(i, j int) bool { 234 for _, c := range sh.comparers { 235 if sh.err != nil { 236 return true 237 } 238 cmp, err := c.compare(sh.rows[i], sh.rows[j]) 239 if err != nil { 240 sh.err = err 241 return true 242 } 243 if cmp == 0 { 244 continue 245 } 246 if sh.reverse { 247 cmp = -cmp 248 } 249 return cmp < 0 250 } 251 return true 252 } 253 254 // Swap satisfies sort.Interface and heap.Interface. 255 func (sh *sortHeap) Swap(i, j int) { 256 sh.rows[i], sh.rows[j] = sh.rows[j], sh.rows[i] 257 } 258 259 // Push satisfies heap.Interface. 260 func (sh *sortHeap) Push(x any) { 261 sh.rows = append(sh.rows, x.([]sqltypes.Value)) 262 } 263 264 // Pop satisfies heap.Interface. 265 func (sh *sortHeap) Pop() any { 266 n := len(sh.rows) 267 x := sh.rows[n-1] 268 sh.rows = sh.rows[:n-1] 269 return x 270 }