vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/concatenate.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"sync"
    22  
    23  	"vitess.io/vitess/go/sqltypes"
    24  	querypb "vitess.io/vitess/go/vt/proto/query"
    25  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    26  	"vitess.io/vitess/go/vt/vterrors"
    27  )
    28  
    29  // Concatenate Primitive is used to concatenate results from multiple sources.
    30  var _ Primitive = (*Concatenate)(nil)
    31  
    32  // Concatenate specified the parameter for concatenate primitive
    33  type Concatenate struct {
    34  	Sources []Primitive
    35  
    36  	// These column offsets do not need to be typed checked - they usually contain weight_string()
    37  	// columns that are not going to be returned to the user
    38  	NoNeedToTypeCheck map[int]any
    39  }
    40  
    41  // NewConcatenate creates a Concatenate primitive. The ignoreCols slice contains the offsets that
    42  // don't need to have the same type between sources -
    43  // weight_string() sometimes returns VARBINARY and sometimes VARCHAR
    44  func NewConcatenate(Sources []Primitive, ignoreCols []int) *Concatenate {
    45  	ignoreTypes := map[int]any{}
    46  	for _, i := range ignoreCols {
    47  		ignoreTypes[i] = nil
    48  	}
    49  	return &Concatenate{
    50  		Sources:           Sources,
    51  		NoNeedToTypeCheck: ignoreTypes,
    52  	}
    53  }
    54  
    55  // RouteType returns a description of the query routing type used by the primitive
    56  func (c *Concatenate) RouteType() string {
    57  	return "Concatenate"
    58  }
    59  
    60  // GetKeyspaceName specifies the Keyspace that this primitive routes to
    61  func (c *Concatenate) GetKeyspaceName() string {
    62  	res := c.Sources[0].GetKeyspaceName()
    63  	for i := 1; i < len(c.Sources); i++ {
    64  		res = formatTwoOptionsNicely(res, c.Sources[i].GetKeyspaceName())
    65  	}
    66  	return res
    67  }
    68  
    69  // GetTableName specifies the table that this primitive routes to.
    70  func (c *Concatenate) GetTableName() string {
    71  	res := c.Sources[0].GetTableName()
    72  	for i := 1; i < len(c.Sources); i++ {
    73  		res = formatTwoOptionsNicely(res, c.Sources[i].GetTableName())
    74  	}
    75  	return res
    76  }
    77  
    78  func formatTwoOptionsNicely(a, b string) string {
    79  	if a == b {
    80  		return a
    81  	}
    82  	return a + "_" + b
    83  }
    84  
    85  // ErrWrongNumberOfColumnsInSelect is an error
    86  var ErrWrongNumberOfColumnsInSelect = vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.WrongNumberOfColumnsInSelect, "The used SELECT statements have a different number of columns")
    87  
    88  // TryExecute performs a non-streaming exec.
    89  func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    90  	res, err := c.execSources(ctx, vcursor, bindVars, wantfields)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	fields, err := c.getFields(res)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	var rowsAffected uint64
   101  	var rows [][]sqltypes.Value
   102  
   103  	for _, r := range res {
   104  		rowsAffected += r.RowsAffected
   105  
   106  		if len(rows) > 0 &&
   107  			len(r.Rows) > 0 &&
   108  			len(rows[0]) != len(r.Rows[0]) {
   109  			return nil, ErrWrongNumberOfColumnsInSelect
   110  		}
   111  
   112  		rows = append(rows, r.Rows...)
   113  	}
   114  
   115  	return &sqltypes.Result{
   116  		Fields:       fields,
   117  		RowsAffected: rowsAffected,
   118  		Rows:         rows,
   119  	}, nil
   120  }
   121  
   122  func (c *Concatenate) getFields(res []*sqltypes.Result) ([]*querypb.Field, error) {
   123  	if len(res) == 0 {
   124  		return nil, nil
   125  	}
   126  
   127  	var fields []*querypb.Field
   128  	for _, r := range res {
   129  		if r.Fields == nil {
   130  			continue
   131  		}
   132  		if fields == nil {
   133  			fields = r.Fields
   134  			continue
   135  		}
   136  
   137  		err := c.compareFields(fields, r.Fields)
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  	}
   142  	return fields, nil
   143  }
   144  
   145  func (c *Concatenate) execSources(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) {
   146  	if vcursor.Session().InTransaction() {
   147  		// as we are in a transaction, we need to execute all queries inside a single transaction
   148  		// therefore it needs a sequential execution.
   149  		return c.sequentialExec(ctx, vcursor, bindVars, wantfields)
   150  	}
   151  	// not in transaction, so execute in parallel.
   152  	return c.parallelExec(ctx, vcursor, bindVars, wantfields)
   153  }
   154  
   155  func (c *Concatenate) parallelExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) {
   156  	results := make([]*sqltypes.Result, len(c.Sources))
   157  	var outerErr error
   158  
   159  	ctx, cancel := context.WithCancel(ctx)
   160  	defer cancel()
   161  
   162  	var wg sync.WaitGroup
   163  	for i, source := range c.Sources {
   164  		currIndex, currSource := i, source
   165  		vars := copyBindVars(bindVars)
   166  		wg.Add(1)
   167  		go func() {
   168  			defer wg.Done()
   169  			result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields)
   170  			if err != nil {
   171  				outerErr = err
   172  				cancel()
   173  			}
   174  			results[currIndex] = result
   175  		}()
   176  	}
   177  	wg.Wait()
   178  	return results, outerErr
   179  }
   180  
   181  func (c *Concatenate) sequentialExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) {
   182  	results := make([]*sqltypes.Result, len(c.Sources))
   183  	for i, source := range c.Sources {
   184  		currIndex, currSource := i, source
   185  		vars := copyBindVars(bindVars)
   186  		result, err := vcursor.ExecutePrimitive(ctx, currSource, vars, wantfields)
   187  		if err != nil {
   188  			return nil, err
   189  		}
   190  		results[currIndex] = result
   191  	}
   192  	return results, nil
   193  }
   194  
   195  // TryStreamExecute performs a streaming exec.
   196  func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
   197  	if vcursor.Session().InTransaction() {
   198  		// as we are in a transaction, we need to execute all queries inside a single transaction
   199  		// therefore it needs a sequential execution.
   200  		return c.sequentialStreamExec(ctx, vcursor, bindVars, wantfields, callback)
   201  	}
   202  	// not in transaction, so execute in parallel.
   203  	return c.parallelStreamExec(ctx, vcursor, bindVars, wantfields, callback)
   204  }
   205  
   206  func (c *Concatenate) parallelStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
   207  	var seenFields []*querypb.Field
   208  	var outerErr error
   209  
   210  	var fieldsSent bool
   211  	var cbMu, fieldsMu sync.Mutex
   212  	var wg, fieldSendWg sync.WaitGroup
   213  	fieldSendWg.Add(1)
   214  
   215  	for i, source := range c.Sources {
   216  		wg.Add(1)
   217  		currIndex, currSource := i, source
   218  
   219  		go func() {
   220  			defer wg.Done()
   221  			err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, wantfields, func(resultChunk *sqltypes.Result) error {
   222  				// if we have fields to compare, make sure all the fields are all the same
   223  				if currIndex == 0 {
   224  					fieldsMu.Lock()
   225  					if !fieldsSent {
   226  						defer fieldSendWg.Done()
   227  						defer fieldsMu.Unlock()
   228  						seenFields = resultChunk.Fields
   229  						fieldsSent = true
   230  						// No other call can happen before this call.
   231  						return callback(resultChunk)
   232  					}
   233  					fieldsMu.Unlock()
   234  				}
   235  				fieldSendWg.Wait()
   236  				if resultChunk.Fields != nil {
   237  					err := c.compareFields(seenFields, resultChunk.Fields)
   238  					if err != nil {
   239  						return err
   240  					}
   241  				}
   242  				// This to ensure only one send happens back to the client.
   243  				cbMu.Lock()
   244  				defer cbMu.Unlock()
   245  				select {
   246  				case <-ctx.Done():
   247  					return nil
   248  				default:
   249  					return callback(resultChunk)
   250  				}
   251  			})
   252  			// This is to ensure other streams complete if the first stream failed to unlock the wait.
   253  			if currIndex == 0 {
   254  				fieldsMu.Lock()
   255  				if !fieldsSent {
   256  					fieldsSent = true
   257  					fieldSendWg.Done()
   258  				}
   259  				fieldsMu.Unlock()
   260  			}
   261  			if err != nil {
   262  				outerErr = err
   263  				ctx.Done()
   264  			}
   265  		}()
   266  
   267  	}
   268  	wg.Wait()
   269  	return outerErr
   270  }
   271  
   272  func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
   273  	// all the below fields ensure that the fields are sent only once.
   274  	var seenFields []*querypb.Field
   275  	var fieldsMu sync.Mutex
   276  	var fieldsSent bool
   277  
   278  	for idx, source := range c.Sources {
   279  		err := vcursor.StreamExecutePrimitive(ctx, source, bindVars, wantfields, func(resultChunk *sqltypes.Result) error {
   280  			// if we have fields to compare, make sure all the fields are all the same
   281  			if idx == 0 {
   282  				fieldsMu.Lock()
   283  				defer fieldsMu.Unlock()
   284  				if !fieldsSent {
   285  					fieldsSent = true
   286  					seenFields = resultChunk.Fields
   287  					return callback(resultChunk)
   288  				}
   289  			}
   290  			if resultChunk.Fields != nil {
   291  				err := c.compareFields(seenFields, resultChunk.Fields)
   292  				if err != nil {
   293  					return err
   294  				}
   295  			}
   296  			// check if context has expired.
   297  			if ctx.Err() != nil {
   298  				return ctx.Err()
   299  			}
   300  			return callback(resultChunk)
   301  
   302  		})
   303  		if err != nil {
   304  			return err
   305  		}
   306  	}
   307  	return nil
   308  }
   309  
   310  // GetFields fetches the field info.
   311  func (c *Concatenate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   312  	// TODO: type coercions
   313  	res, err := c.Sources[0].GetFields(ctx, vcursor, bindVars)
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  
   318  	for i := 1; i < len(c.Sources); i++ {
   319  		result, err := c.Sources[i].GetFields(ctx, vcursor, bindVars)
   320  		if err != nil {
   321  			return nil, err
   322  		}
   323  		err = c.compareFields(res.Fields, result.Fields)
   324  		if err != nil {
   325  			return nil, err
   326  		}
   327  	}
   328  
   329  	return res, nil
   330  }
   331  
   332  // NeedsTransaction returns whether a transaction is needed for this primitive
   333  func (c *Concatenate) NeedsTransaction() bool {
   334  	for _, source := range c.Sources {
   335  		if source.NeedsTransaction() {
   336  			return true
   337  		}
   338  	}
   339  	return false
   340  }
   341  
   342  // Inputs returns the input primitives for this
   343  func (c *Concatenate) Inputs() []Primitive {
   344  	return c.Sources
   345  }
   346  
   347  func (c *Concatenate) description() PrimitiveDescription {
   348  	return PrimitiveDescription{OperatorType: c.RouteType()}
   349  }
   350  
   351  func (c *Concatenate) compareFields(fields1 []*querypb.Field, fields2 []*querypb.Field) error {
   352  	if len(fields1) != len(fields2) {
   353  		return ErrWrongNumberOfColumnsInSelect
   354  	}
   355  	for i, field1 := range fields1 {
   356  		if _, found := c.NoNeedToTypeCheck[i]; found {
   357  			continue
   358  		}
   359  		field2 := fields2[i]
   360  		if field1.Type != field2.Type {
   361  			return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "merging field of different types is not supported, name: (%v, %v) types: (%v, %v)", field1.Name, field2.Name, field1.Type, field2.Type)
   362  		}
   363  	}
   364  	return nil
   365  }