vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/merge_sort.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "container/heap" 21 "context" 22 "io" 23 24 "vitess.io/vitess/go/mysql" 25 26 "vitess.io/vitess/go/sqltypes" 27 28 querypb "vitess.io/vitess/go/vt/proto/query" 29 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 30 "vitess.io/vitess/go/vt/vterrors" 31 ) 32 33 // StreamExecutor is a subset of Primitive that MergeSort 34 // requires its inputs to satisfy. 35 type StreamExecutor interface { 36 StreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error 37 } 38 39 var _ Primitive = (*MergeSort)(nil) 40 41 // MergeSort performs a merge-sort of rows returned by each Input. This should 42 // only be used for StreamExecute. One row from each stream is added to the 43 // merge-sorter heap. Every time a value is pulled out of the heap, 44 // a new value is added to it from the stream that was the source of the value that 45 // was pulled out. Since the input streams are sorted the same way that the heap is 46 // sorted, this guarantees that the merged stream will also be sorted the same way. 47 // MergeSort only supports the StreamExecute function of a Primitive. So, it cannot 48 // be used like other Primitives in VTGate. However, it satisfies the Primitive API 49 // so that vdiff can use it. In that situation, only StreamExecute is used. 50 type MergeSort struct { 51 Primitives []StreamExecutor 52 OrderBy []OrderByParams 53 ScatterErrorsAsWarnings bool 54 noInputs 55 noTxNeeded 56 } 57 58 // RouteType satisfies Primitive. 59 func (ms *MergeSort) RouteType() string { return "MergeSort" } 60 61 // GetKeyspaceName satisfies Primitive. 62 func (ms *MergeSort) GetKeyspaceName() string { return "" } 63 64 // GetTableName satisfies Primitive. 65 func (ms *MergeSort) GetTableName() string { return "" } 66 67 // TryExecute is not supported. 68 func (ms *MergeSort) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 69 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] Execute is not reachable") 70 } 71 72 // GetFields is not supported. 73 func (ms *MergeSort) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 74 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] GetFields is not reachable") 75 } 76 77 // TryStreamExecute performs a streaming exec. 78 func (ms *MergeSort) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 79 var cancel context.CancelFunc 80 ctx, cancel = context.WithCancel(ctx) 81 defer cancel() 82 gotFields := wantfields 83 handles := make([]*streamHandle, len(ms.Primitives)) 84 for i, input := range ms.Primitives { 85 handles[i] = runOneStream(ctx, vcursor, input, bindVars, gotFields) 86 if !ms.ScatterErrorsAsWarnings { 87 // we only need the fields from the first input, unless we allow ScatterErrorsAsWarnings. 88 // in that case, we need to ask all the inputs for fields - we don't know which will return anything 89 gotFields = false 90 } 91 } 92 93 if wantfields { 94 err := ms.getStreamingFields(handles, callback) 95 if err != nil { 96 return err 97 } 98 } 99 100 comparers := extractSlices(ms.OrderBy) 101 sh := &scatterHeap{ 102 rows: make([]streamRow, 0, len(handles)), 103 comparers: comparers, 104 } 105 106 var errs []error 107 // Prime the heap. One element must be pulled from 108 // each stream. 109 for i, handle := range handles { 110 select { 111 case row, ok := <-handle.row: 112 if !ok { 113 if handle.err != nil { 114 if ms.ScatterErrorsAsWarnings { 115 errs = append(errs, handle.err) 116 break 117 } 118 return handle.err 119 } 120 // It's possible that a stream returns no rows. 121 // If so, don't add anything to the heap. 122 continue 123 } 124 sh.rows = append(sh.rows, streamRow{row: row, id: i}) 125 case <-ctx.Done(): 126 return ctx.Err() 127 } 128 } 129 heap.Init(sh) 130 if sh.err != nil { 131 return sh.err 132 } 133 134 // Iterate one row at a time: 135 // Pop a row from the heap and send it out. 136 // Then pull the next row from the stream the popped 137 // row came from and push it into the heap. 138 for len(sh.rows) != 0 { 139 sr := heap.Pop(sh).(streamRow) 140 if sh.err != nil { 141 // Unreachable: This should never fail. 142 return sh.err 143 } 144 if err := callback(&sqltypes.Result{Rows: [][]sqltypes.Value{sr.row}}); err != nil { 145 return err 146 } 147 148 select { 149 case row, ok := <-handles[sr.id].row: 150 if !ok { 151 if handles[sr.id].err != nil { 152 return handles[sr.id].err 153 } 154 continue 155 } 156 sr.row = row 157 heap.Push(sh, sr) 158 if sh.err != nil { 159 return sh.err 160 } 161 case <-ctx.Done(): 162 return ctx.Err() 163 } 164 } 165 166 err := vterrors.Aggregate(errs) 167 if err != nil && ms.ScatterErrorsAsWarnings && len(errs) < len(handles) { 168 // we got errors, but not all shards failed, so we can hide the error and just warn instead 169 partialSuccessScatterQueries.Add(1) 170 sErr := mysql.NewSQLErrorFromError(err).(*mysql.SQLError) 171 vcursor.Session().RecordWarning(&querypb.QueryWarning{Code: uint32(sErr.Num), Message: err.Error()}) 172 return nil 173 } 174 return err 175 } 176 177 func (ms *MergeSort) getStreamingFields(handles []*streamHandle, callback func(*sqltypes.Result) error) error { 178 var fields []*querypb.Field 179 180 if ms.ScatterErrorsAsWarnings { 181 for _, handle := range handles { 182 // Fetch field info from just one stream. 183 fields = <-handle.fields 184 // If fields is nil, it means there was an error. 185 if fields != nil { 186 break 187 } 188 } 189 } else { 190 // Fetch field info from just one stream. 191 fields = <-handles[0].fields 192 } 193 if fields == nil { 194 // something went wrong. need to figure out where the error can be 195 if !ms.ScatterErrorsAsWarnings { 196 return handles[0].err 197 } 198 199 var errs []error 200 for _, handle := range handles { 201 errs = append(errs, handle.err) 202 } 203 return vterrors.Aggregate(errs) 204 } 205 206 if err := callback(&sqltypes.Result{Fields: fields}); err != nil { 207 return err 208 } 209 return nil 210 } 211 212 func (ms *MergeSort) description() PrimitiveDescription { 213 other := map[string]any{ 214 "OrderBy": ms.OrderBy, 215 } 216 return PrimitiveDescription{ 217 OperatorType: "Sort", 218 Variant: "Merge", 219 Other: other, 220 } 221 } 222 223 // streamHandle is the rendez-vous point between each stream and the merge-sorter. 224 // The fields channel is used by the stream to transmit the field info, which 225 // is the first packet. Following this, the stream sends each row to the row 226 // channel. At the end of the stream, fields and row are closed. If there 227 // was an error, err is set before the channels are closed. The MergeSort 228 // routine that pulls the rows out of each streamHandle can abort the stream 229 // by calling canceling the context. 230 type streamHandle struct { 231 fields chan []*querypb.Field 232 row chan []sqltypes.Value 233 err error 234 } 235 236 // runOnestream starts a streaming query on one shard, and returns a streamHandle for it. 237 func runOneStream(ctx context.Context, vcursor VCursor, input StreamExecutor, bindVars map[string]*querypb.BindVariable, wantfields bool) *streamHandle { 238 handle := &streamHandle{ 239 fields: make(chan []*querypb.Field, 1), 240 row: make(chan []sqltypes.Value, 10), 241 } 242 243 go func() { 244 defer close(handle.fields) 245 defer close(handle.row) 246 247 handle.err = input.StreamExecute(ctx, vcursor, bindVars, wantfields, func(qr *sqltypes.Result) error { 248 if len(qr.Fields) != 0 { 249 select { 250 case handle.fields <- qr.Fields: 251 case <-ctx.Done(): 252 return io.EOF 253 } 254 } 255 256 for _, row := range qr.Rows { 257 select { 258 case handle.row <- row: 259 case <-ctx.Done(): 260 return io.EOF 261 } 262 } 263 return nil 264 }) 265 }() 266 267 return handle 268 } 269 270 // A streamRow represents a row identified by the stream 271 // it came from. It is used as an element in scatterHeap. 272 type streamRow struct { 273 row []sqltypes.Value 274 id int 275 } 276 277 // scatterHeap is the heap that is used for merge-sorting. 278 // You can push streamRow elements into it. Popping an 279 // element will return the one with the lowest value 280 // as defined by the orderBy criteria. If a comparison 281 // yielded an error, err is set. This must be checked 282 // after every heap operation. 283 type scatterHeap struct { 284 rows []streamRow 285 err error 286 comparers []*comparer 287 } 288 289 // Len satisfies sort.Interface and heap.Interface. 290 func (sh *scatterHeap) Len() int { 291 return len(sh.rows) 292 } 293 294 // Less satisfies sort.Interface and heap.Interface. 295 func (sh *scatterHeap) Less(i, j int) bool { 296 for _, c := range sh.comparers { 297 if sh.err != nil { 298 return true 299 } 300 // First try to compare the columns that we want to order 301 cmp, err := c.compare(sh.rows[i].row, sh.rows[j].row) 302 if err != nil { 303 sh.err = err 304 return true 305 } 306 if cmp == 0 { 307 continue 308 } 309 return cmp < 0 310 } 311 return true 312 } 313 314 // Swap satisfies sort.Interface and heap.Interface. 315 func (sh *scatterHeap) Swap(i, j int) { 316 sh.rows[i], sh.rows[j] = sh.rows[j], sh.rows[i] 317 } 318 319 // Push satisfies heap.Interface. 320 func (sh *scatterHeap) Push(x any) { 321 sh.rows = append(sh.rows, x.(streamRow)) 322 } 323 324 // Pop satisfies heap.Interface. 325 func (sh *scatterHeap) Pop() any { 326 n := len(sh.rows) 327 x := sh.rows[n-1] 328 sh.rows = sh.rows[:n-1] 329 return x 330 }