vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/concatenate.go (about) 1 /* 2 Copyright 2020 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "context" 21 "sync" 22 23 "vitess.io/vitess/go/sqltypes" 24 querypb "vitess.io/vitess/go/vt/proto/query" 25 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 26 "vitess.io/vitess/go/vt/vterrors" 27 ) 28 29 // Concatenate Primitive is used to concatenate results from multiple sources. 30 var _ Primitive = (*Concatenate)(nil) 31 32 // Concatenate specified the parameter for concatenate primitive 33 type Concatenate struct { 34 Sources []Primitive 35 36 // These column offsets do not need to be typed checked - they usually contain weight_string() 37 // columns that are not going to be returned to the user 38 NoNeedToTypeCheck map[int]any 39 } 40 41 // NewConcatenate creates a Concatenate primitive. The ignoreCols slice contains the offsets that 42 // don't need to have the same type between sources - 43 // weight_string() sometimes returns VARBINARY and sometimes VARCHAR 44 func NewConcatenate(Sources []Primitive, ignoreCols []int) *Concatenate { 45 ignoreTypes := map[int]any{} 46 for _, i := range ignoreCols { 47 ignoreTypes[i] = nil 48 } 49 return &Concatenate{ 50 Sources: Sources, 51 NoNeedToTypeCheck: ignoreTypes, 52 } 53 } 54 55 // RouteType returns a description of the query routing type used by the primitive 56 func (c *Concatenate) RouteType() string { 57 return "Concatenate" 58 } 59 60 // GetKeyspaceName specifies the Keyspace that this primitive routes to 61 func (c *Concatenate) GetKeyspaceName() string { 62 res := c.Sources[0].GetKeyspaceName() 63 for i := 1; i < len(c.Sources); i++ { 64 res = formatTwoOptionsNicely(res, c.Sources[i].GetKeyspaceName()) 65 } 66 return res 67 } 68 69 // GetTableName specifies the table that this primitive routes to. 70 func (c *Concatenate) GetTableName() string { 71 res := c.Sources[0].GetTableName() 72 for i := 1; i < len(c.Sources); i++ { 73 res = formatTwoOptionsNicely(res, c.Sources[i].GetTableName()) 74 } 75 return res 76 } 77 78 func formatTwoOptionsNicely(a, b string) string { 79 if a == b { 80 return a 81 } 82 return a + "_" + b 83 } 84 85 // ErrWrongNumberOfColumnsInSelect is an error 86 var ErrWrongNumberOfColumnsInSelect = vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.WrongNumberOfColumnsInSelect, "The used SELECT statements have a different number of columns") 87 88 // TryExecute performs a non-streaming exec. 89 func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 90 res, err := c.execSources(ctx, vcursor, bindVars, wantfields) 91 if err != nil { 92 return nil, err 93 } 94 95 fields, err := c.getFields(res) 96 if err != nil { 97 return nil, err 98 } 99 100 var rowsAffected uint64 101 var rows [][]sqltypes.Value 102 103 for _, r := range res { 104 rowsAffected += r.RowsAffected 105 106 if len(rows) > 0 && 107 len(r.Rows) > 0 && 108 len(rows[0]) != len(r.Rows[0]) { 109 return nil, ErrWrongNumberOfColumnsInSelect 110 } 111 112 rows = append(rows, r.Rows...) 113 } 114 115 return &sqltypes.Result{ 116 Fields: fields, 117 RowsAffected: rowsAffected, 118 Rows: rows, 119 }, nil 120 } 121 122 func (c *Concatenate) getFields(res []*sqltypes.Result) ([]*querypb.Field, error) { 123 if len(res) == 0 { 124 return nil, nil 125 } 126 127 var fields []*querypb.Field 128 for _, r := range res { 129 if r.Fields == nil { 130 continue 131 } 132 if fields == nil { 133 fields = r.Fields 134 continue 135 } 136 137 err := c.compareFields(fields, r.Fields) 138 if err != nil { 139 return nil, err 140 } 141 } 142 return fields, nil 143 } 144 145 func (c *Concatenate) execSources(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) { 146 if vcursor.Session().InTransaction() { 147 // as we are in a transaction, we need to execute all queries inside a single transaction 148 // therefore it needs a sequential execution. 149 return c.sequentialExec(ctx, vcursor, bindVars, wantfields) 150 } 151 // not in transaction, so execute in parallel. 152 return c.parallelExec(ctx, vcursor, bindVars, wantfields) 153 } 154 155 func (c *Concatenate) parallelExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) { 156 results := make([]*sqltypes.Result, len(c.Sources)) 157 var outerErr error 158 159 ctx, cancel := context.WithCancel(ctx) 160 defer cancel() 161 162 var wg sync.WaitGroup 163 for i, source := range c.Sources { 164 currIndex, currSource := i, source 165 vars := copyBindVars(bindVars) 166 wg.Add(1) 167 go func() { 168 defer wg.Done() 169 result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields) 170 if err != nil { 171 outerErr = err 172 cancel() 173 } 174 results[currIndex] = result 175 }() 176 } 177 wg.Wait() 178 return results, outerErr 179 } 180 181 func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) { 182 results := make([]*sqltypes.Result, len(c.Sources)) 183 for i, source := range c.Sources { 184 currIndex, currSource := i, source 185 vars := copyBindVars(bindVars) 186 result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields) 187 if err != nil { 188 return nil, err 189 } 190 results[currIndex] = result 191 } 192 return results, nil 193 } 194 195 // TryStreamExecute performs a streaming exec. 196 func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 197 if vcursor.Session().InTransaction() { 198 // as we are in a transaction, we need to execute all queries inside a single transaction 199 // therefore it needs a sequential execution. 200 return c.sequentialStreamExec(ctx, vcursor, bindVars, wantfields, callback) 201 } 202 // not in transaction, so execute in parallel. 203 return c.parallelStreamExec(ctx, vcursor, bindVars, wantfields, callback) 204 } 205 206 func (c *Concatenate) parallelStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 207 var seenFields []*querypb.Field 208 var outerErr error 209 210 var fieldsSent bool 211 var cbMu, fieldsMu sync.Mutex 212 var wg, fieldSendWg sync.WaitGroup 213 fieldSendWg.Add(1) 214 215 for i, source := range c.Sources { 216 wg.Add(1) 217 currIndex, currSource := i, source 218 219 go func() { 220 defer wg.Done() 221 err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { 222 // if we have fields to compare, make sure all the fields are all the same 223 if currIndex == 0 { 224 fieldsMu.Lock() 225 if !fieldsSent { 226 defer fieldSendWg.Done() 227 defer fieldsMu.Unlock() 228 seenFields = resultChunk.Fields 229 fieldsSent = true 230 // No other call can happen before this call. 231 return callback(resultChunk) 232 } 233 fieldsMu.Unlock() 234 } 235 fieldSendWg.Wait() 236 if resultChunk.Fields != nil { 237 err := c.compareFields(seenFields, resultChunk.Fields) 238 if err != nil { 239 return err 240 } 241 } 242 // This to ensure only one send happens back to the client. 243 cbMu.Lock() 244 defer cbMu.Unlock() 245 select { 246 case <-ctx.Done(): 247 return nil 248 default: 249 return callback(resultChunk) 250 } 251 }) 252 // This is to ensure other streams complete if the first stream failed to unlock the wait. 253 if currIndex == 0 { 254 fieldsMu.Lock() 255 if !fieldsSent { 256 fieldsSent = true 257 fieldSendWg.Done() 258 } 259 fieldsMu.Unlock() 260 } 261 if err != nil { 262 outerErr = err 263 ctx.Done() 264 } 265 }() 266 267 } 268 wg.Wait() 269 return outerErr 270 } 271 272 func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 273 // all the below fields ensure that the fields are sent only once. 274 var seenFields []*querypb.Field 275 var fieldsMu sync.Mutex 276 var fieldsSent bool 277 278 for idx, source := range c.Sources { 279 err := vcursor.StreamExecutePrimitive(ctx, source, bindVars, wantfields, func(resultChunk *sqltypes.Result) error { 280 // if we have fields to compare, make sure all the fields are all the same 281 if idx == 0 { 282 fieldsMu.Lock() 283 defer fieldsMu.Unlock() 284 if !fieldsSent { 285 fieldsSent = true 286 seenFields = resultChunk.Fields 287 return callback(resultChunk) 288 } 289 } 290 if resultChunk.Fields != nil { 291 err := c.compareFields(seenFields, resultChunk.Fields) 292 if err != nil { 293 return err 294 } 295 } 296 // check if context has expired. 297 if ctx.Err() != nil { 298 return ctx.Err() 299 } 300 return callback(resultChunk) 301 302 }) 303 if err != nil { 304 return err 305 } 306 } 307 return nil 308 } 309 310 // GetFields fetches the field info. 311 func (c *Concatenate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 312 // TODO: type coercions 313 res, err := c.Sources[0].GetFields(ctx, vcursor, bindVars) 314 if err != nil { 315 return nil, err 316 } 317 318 for i := 1; i < len(c.Sources); i++ { 319 result, err := c.Sources[i].GetFields(ctx, vcursor, bindVars) 320 if err != nil { 321 return nil, err 322 } 323 err = c.compareFields(res.Fields, result.Fields) 324 if err != nil { 325 return nil, err 326 } 327 } 328 329 return res, nil 330 } 331 332 // NeedsTransaction returns whether a transaction is needed for this primitive 333 func (c *Concatenate) NeedsTransaction() bool { 334 for _, source := range c.Sources { 335 if source.NeedsTransaction() { 336 return true 337 } 338 } 339 return false 340 } 341 342 // Inputs returns the input primitives for this 343 func (c *Concatenate) Inputs() []Primitive { 344 return c.Sources 345 } 346 347 func (c *Concatenate) description() PrimitiveDescription { 348 return PrimitiveDescription{OperatorType: c.RouteType()} 349 } 350 351 func (c *Concatenate) compareFields(fields1 []*querypb.Field, fields2 []*querypb.Field) error { 352 if len(fields1) != len(fields2) { 353 return ErrWrongNumberOfColumnsInSelect 354 } 355 for i, field1 := range fields1 { 356 if _, found := c.NoNeedToTypeCheck[i]; found { 357 continue 358 } 359 field2 := fields2[i] 360 if field1.Type != field2.Type { 361 return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "merging field of different types is not supported, name: (%v, %v) types: (%v, %v)", field1.Name, field2.Name, field1.Type, field2.Type) 362 } 363 } 364 return nil 365 }