vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/route.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"sort"
    23  	"strconv"
    24  	"strings"
    25  	"time"
    26  
    27  	"vitess.io/vitess/go/mysql/collations"
    28  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    29  
    30  	"vitess.io/vitess/go/mysql"
    31  	"vitess.io/vitess/go/sqltypes"
    32  	"vitess.io/vitess/go/stats"
    33  	"vitess.io/vitess/go/vt/key"
    34  	"vitess.io/vitess/go/vt/srvtopo"
    35  	"vitess.io/vitess/go/vt/vterrors"
    36  	"vitess.io/vitess/go/vt/vtgate/vindexes"
    37  
    38  	querypb "vitess.io/vitess/go/vt/proto/query"
    39  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    40  )
    41  
    42  var _ Primitive = (*Route)(nil)
    43  
    44  // Route represents the instructions to route a read query to
    45  // one or many vttablets.
    46  type Route struct {
    47  	// TargetTabletType specifies an explicit target destination tablet type
    48  	// this is only used in conjunction with TargetDestination
    49  	TargetTabletType topodatapb.TabletType
    50  
    51  	// Query specifies the query to be executed.
    52  	Query string
    53  
    54  	// TableName specifies the tables to send the query to.
    55  	TableName string
    56  
    57  	// FieldQuery specifies the query to be executed for a GetFieldInfo request.
    58  	FieldQuery string
    59  
    60  	// OrderBy specifies the key order for merge sorting. This will be
    61  	// set only for scatter queries that need the results to be
    62  	// merge-sorted.
    63  	OrderBy []OrderByParams
    64  
    65  	// TruncateColumnCount specifies the number of columns to return
    66  	// in the final result. Rest of the columns are truncated
    67  	// from the result received. If 0, no truncation happens.
    68  	TruncateColumnCount int
    69  
    70  	// QueryTimeout contains the optional timeout (in milliseconds) to apply to this query
    71  	QueryTimeout int
    72  
    73  	// ScatterErrorsAsWarnings is true if results should be returned even if some shards have an error
    74  	ScatterErrorsAsWarnings bool
    75  
    76  	// RoutingParameters parameters required for query routing.
    77  	*RoutingParameters
    78  
    79  	// NoRoutesSpecialHandling will make the route send a query to arbitrary shard if the routing logic can't find
    80  	// the correct shard. This is important for queries where no matches does not mean empty result - examples would be:
    81  	// select count(*) from tbl where lookupColumn = 'not there'
    82  	// select exists(<subq>)
    83  	NoRoutesSpecialHandling bool
    84  
    85  	// Route does not take inputs
    86  	noInputs
    87  
    88  	// Route does not need transaction handling
    89  	noTxNeeded
    90  }
    91  
    92  // NewSimpleRoute creates a Route with the bare minimum of parameters.
    93  func NewSimpleRoute(opcode Opcode, keyspace *vindexes.Keyspace) *Route {
    94  	return &Route{
    95  		RoutingParameters: &RoutingParameters{
    96  			Opcode:   opcode,
    97  			Keyspace: keyspace,
    98  		},
    99  	}
   100  }
   101  
   102  // NewRoute creates a Route.
   103  func NewRoute(opcode Opcode, keyspace *vindexes.Keyspace, query, fieldQuery string) *Route {
   104  	return &Route{
   105  		RoutingParameters: &RoutingParameters{
   106  			Opcode:   opcode,
   107  			Keyspace: keyspace,
   108  		},
   109  		Query:      query,
   110  		FieldQuery: fieldQuery,
   111  	}
   112  }
   113  
   114  // OrderByParams specifies the parameters for ordering.
   115  // This is used for merge-sorting scatter queries.
   116  type OrderByParams struct {
   117  	Col int
   118  	// WeightStringCol is the weight_string column that will be used for sorting.
   119  	// It is set to -1 if such a column is not added to the query
   120  	WeightStringCol   int
   121  	Desc              bool
   122  	StarColFixedIndex int
   123  	// v3 specific boolean. Used to also add weight strings originating from GroupBys to the Group by clause
   124  	FromGroupBy bool
   125  	// Collation ID for comparison using collation
   126  	CollationID collations.ID
   127  }
   128  
   129  // String returns a string. Used for plan descriptions
   130  func (obp OrderByParams) String() string {
   131  	val := strconv.Itoa(obp.Col)
   132  	if obp.StarColFixedIndex > obp.Col {
   133  		val = strconv.Itoa(obp.StarColFixedIndex)
   134  	}
   135  	if obp.WeightStringCol != -1 && obp.WeightStringCol != obp.Col {
   136  		val = fmt.Sprintf("(%s|%d)", val, obp.WeightStringCol)
   137  	}
   138  	if obp.Desc {
   139  		val += " DESC"
   140  	} else {
   141  		val += " ASC"
   142  	}
   143  	if obp.CollationID != collations.Unknown {
   144  		collation := collations.Local().LookupByID(obp.CollationID)
   145  		val += " COLLATE " + collation.Name()
   146  	}
   147  	return val
   148  }
   149  
   150  var (
   151  	partialSuccessScatterQueries = stats.NewCounter("PartialSuccessScatterQueries", "Count of partially successful scatter queries")
   152  )
   153  
   154  // RouteType returns a description of the query routing type used by the primitive
   155  func (route *Route) RouteType() string {
   156  	return route.Opcode.String()
   157  }
   158  
   159  // GetKeyspaceName specifies the Keyspace that this primitive routes to.
   160  func (route *Route) GetKeyspaceName() string {
   161  	return route.Keyspace.Name
   162  }
   163  
   164  // GetTableName specifies the table that this primitive routes to.
   165  func (route *Route) GetTableName() string {
   166  	return route.TableName
   167  }
   168  
   169  // SetTruncateColumnCount sets the truncate column count.
   170  func (route *Route) SetTruncateColumnCount(count int) {
   171  	route.TruncateColumnCount = count
   172  }
   173  
   174  // TryExecute performs a non-streaming exec.
   175  func (route *Route) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
   176  	ctx, cancelFunc := addQueryTimeout(ctx, vcursor, route.QueryTimeout)
   177  	defer cancelFunc()
   178  	qr, err := route.executeInternal(ctx, vcursor, bindVars, wantfields)
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	return qr.Truncate(route.TruncateColumnCount), nil
   183  }
   184  
   185  // addQueryTimeout adds a query timeout to the context it receives and returns the modified context along with the cancel function.
   186  func addQueryTimeout(ctx context.Context, vcursor VCursor, queryTimeout int) (context.Context, context.CancelFunc) {
   187  	timeout := vcursor.Session().GetQueryTimeout(queryTimeout)
   188  	if timeout != 0 {
   189  		return context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond)
   190  	}
   191  	return ctx, func() {}
   192  }
   193  
   194  type cxtKey int
   195  
   196  const (
   197  	IgnoreReserveTxn cxtKey = iota
   198  )
   199  
   200  func (route *Route) executeInternal(
   201  	ctx context.Context,
   202  	vcursor VCursor,
   203  	bindVars map[string]*querypb.BindVariable,
   204  	wantfields bool,
   205  ) (*sqltypes.Result, error) {
   206  	rss, bvs, err := route.findRoute(ctx, vcursor, bindVars)
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	return route.executeShards(ctx, vcursor, bindVars, wantfields, rss, bvs)
   212  }
   213  
   214  func (route *Route) executeShards(
   215  	ctx context.Context,
   216  	vcursor VCursor,
   217  	bindVars map[string]*querypb.BindVariable,
   218  	wantfields bool,
   219  	rss []*srvtopo.ResolvedShard,
   220  	bvs []map[string]*querypb.BindVariable,
   221  ) (*sqltypes.Result, error) {
   222  	// Select Next - sequence query does not need to be executed in a dedicated connection (reserved or transaction)
   223  	if route.Opcode == Next {
   224  		ctx = context.WithValue(ctx, IgnoreReserveTxn, true)
   225  	}
   226  
   227  	// No route.
   228  	if len(rss) == 0 {
   229  		if !route.NoRoutesSpecialHandling {
   230  			if wantfields {
   231  				return route.GetFields(ctx, vcursor, bindVars)
   232  			}
   233  			return &sqltypes.Result{}, nil
   234  		}
   235  		// Here we were earlier returning no rows back.
   236  		// But this was incorrect for queries like select count(*) from user where name='x'
   237  		// If the lookup_vindex for name, returns no shards, we still want a result from here
   238  		// with a single row with 0 as the output.
   239  		// However, at this level it is hard to distinguish between the cases that need a result
   240  		// and the ones that don't. So, we are sending the query to any shard! This is safe because
   241  		// the query contains a predicate that make it not match any rows on that shard. (If they did,
   242  		// we should have gotten that shard back already from findRoute)
   243  		var err error
   244  		rss, bvs, err = route.anyShard(ctx, vcursor, bindVars)
   245  		if err != nil {
   246  			return nil, err
   247  		}
   248  	}
   249  
   250  	queries := getQueries(route.Query, bvs)
   251  	result, errs := vcursor.ExecuteMultiShard(ctx, route, rss, queries, false /* rollbackOnError */, false /* canAutocommit */)
   252  
   253  	if errs != nil {
   254  		errs = filterOutNilErrors(errs)
   255  		if !route.ScatterErrorsAsWarnings || len(errs) == len(rss) {
   256  			return nil, vterrors.Aggregate(errs)
   257  		}
   258  
   259  		partialSuccessScatterQueries.Add(1)
   260  
   261  		for _, err := range errs {
   262  			serr := mysql.NewSQLErrorFromError(err).(*mysql.SQLError)
   263  			vcursor.Session().RecordWarning(&querypb.QueryWarning{Code: uint32(serr.Num), Message: err.Error()})
   264  		}
   265  	}
   266  
   267  	if len(route.OrderBy) == 0 {
   268  		return result, nil
   269  	}
   270  
   271  	return route.sort(result)
   272  }
   273  
   274  func filterOutNilErrors(errs []error) []error {
   275  	var errors []error
   276  	for _, err := range errs {
   277  		if err != nil {
   278  			errors = append(errors, err)
   279  		}
   280  	}
   281  	return errors
   282  }
   283  
   284  // TryStreamExecute performs a streaming exec.
   285  func (route *Route) TryStreamExecute(
   286  	ctx context.Context,
   287  	vcursor VCursor,
   288  	bindVars map[string]*querypb.BindVariable,
   289  	wantfields bool,
   290  	callback func(*sqltypes.Result) error,
   291  ) error {
   292  	if route.QueryTimeout != 0 {
   293  		var cancel context.CancelFunc
   294  		ctx, cancel = context.WithTimeout(ctx, time.Duration(route.QueryTimeout)*time.Millisecond)
   295  		defer cancel()
   296  	}
   297  	rss, bvs, err := route.findRoute(ctx, vcursor, bindVars)
   298  	if err != nil {
   299  		return err
   300  	}
   301  
   302  	return route.streamExecuteShards(ctx, vcursor, bindVars, wantfields, callback, rss, bvs)
   303  }
   304  
   305  func (route *Route) streamExecuteShards(
   306  	ctx context.Context,
   307  	vcursor VCursor,
   308  	bindVars map[string]*querypb.BindVariable,
   309  	wantfields bool,
   310  	callback func(*sqltypes.Result) error,
   311  	rss []*srvtopo.ResolvedShard,
   312  	bvs []map[string]*querypb.BindVariable,
   313  ) error {
   314  	// No route.
   315  	if len(rss) == 0 {
   316  		if !route.NoRoutesSpecialHandling {
   317  			if wantfields {
   318  				r, err := route.GetFields(ctx, vcursor, bindVars)
   319  				if err != nil {
   320  					return err
   321  				}
   322  				return callback(r)
   323  			}
   324  			return nil
   325  		}
   326  		// Here we were earlier returning no rows back.
   327  		// But this was incorrect for queries like select count(*) from user where name='x'
   328  		// If the lookup_vindex for name, returns no shards, we still want a result from here
   329  		// with a single row with 0 as the output.
   330  		// However, at this level it is hard to distinguish between the cases that need a result
   331  		// and the ones that don't. So, we are sending the query to any shard! This is safe because
   332  		// the query contains a predicate that make it not match any rows on that shard. (If they did,
   333  		// we should have gotten that shard back already from findRoute)
   334  		var err error
   335  		rss, bvs, err = route.anyShard(ctx, vcursor, bindVars)
   336  		if err != nil {
   337  			return err
   338  		}
   339  	}
   340  
   341  	if len(route.OrderBy) == 0 {
   342  		errs := vcursor.StreamExecuteMulti(ctx, route, route.Query, rss, bvs, false /* rollbackOnError */, false /* autocommit */, func(qr *sqltypes.Result) error {
   343  			return callback(qr.Truncate(route.TruncateColumnCount))
   344  		})
   345  		if len(errs) > 0 {
   346  			if !route.ScatterErrorsAsWarnings || len(errs) == len(rss) {
   347  				return vterrors.Aggregate(errs)
   348  			}
   349  			partialSuccessScatterQueries.Add(1)
   350  			for _, err := range errs {
   351  				sErr := mysql.NewSQLErrorFromError(err).(*mysql.SQLError)
   352  				vcursor.Session().RecordWarning(&querypb.QueryWarning{Code: uint32(sErr.Num), Message: err.Error()})
   353  			}
   354  		}
   355  		return nil
   356  	}
   357  
   358  	// There is an order by. We have to merge-sort.
   359  	return route.mergeSort(ctx, vcursor, bindVars, wantfields, callback, rss, bvs)
   360  }
   361  
   362  func (route *Route) mergeSort(
   363  	ctx context.Context,
   364  	vcursor VCursor,
   365  	bindVars map[string]*querypb.BindVariable,
   366  	wantfields bool,
   367  	callback func(*sqltypes.Result) error,
   368  	rss []*srvtopo.ResolvedShard,
   369  	bvs []map[string]*querypb.BindVariable,
   370  ) error {
   371  	prims := make([]StreamExecutor, 0, len(rss))
   372  	for i, rs := range rss {
   373  		prims = append(prims, &shardRoute{
   374  			query:     route.Query,
   375  			rs:        rs,
   376  			bv:        bvs[i],
   377  			primitive: route,
   378  		})
   379  	}
   380  	ms := MergeSort{
   381  		Primitives:              prims,
   382  		OrderBy:                 route.OrderBy,
   383  		ScatterErrorsAsWarnings: route.ScatterErrorsAsWarnings,
   384  	}
   385  	return vcursor.StreamExecutePrimitive(ctx, &ms, bindVars, wantfields, func(qr *sqltypes.Result) error {
   386  		return callback(qr.Truncate(route.TruncateColumnCount))
   387  	})
   388  }
   389  
   390  // GetFields fetches the field info.
   391  func (route *Route) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   392  	rss, _, err := vcursor.ResolveDestinations(ctx, route.Keyspace.Name, nil, []key.Destination{key.DestinationAnyShard{}})
   393  	if err != nil {
   394  		return nil, err
   395  	}
   396  	if len(rss) != 1 {
   397  		// This code is unreachable. It's just a sanity check.
   398  		return nil, fmt.Errorf("no shards for keyspace: %s", route.Keyspace.Name)
   399  	}
   400  	qr, err := execShard(ctx, route, vcursor, route.FieldQuery, bindVars, rss[0], false /* rollbackOnError */, false /* canAutocommit */)
   401  	if err != nil {
   402  		return nil, err
   403  	}
   404  	return qr.Truncate(route.TruncateColumnCount), nil
   405  }
   406  
   407  func (route *Route) sort(in *sqltypes.Result) (*sqltypes.Result, error) {
   408  	var err error
   409  	// Since Result is immutable, we make a copy.
   410  	// The copy can be shallow because we won't be changing
   411  	// the contents of any row.
   412  	out := in.ShallowCopy()
   413  
   414  	comparers := extractSlices(route.OrderBy)
   415  
   416  	sort.Slice(out.Rows, func(i, j int) bool {
   417  		var cmp int
   418  		if err != nil {
   419  			return true
   420  		}
   421  		// If there are any errors below, the function sets
   422  		// the external err and returns true. Once err is set,
   423  		// all subsequent calls return true. This will make
   424  		// Slice think that all elements are in the correct
   425  		// order and return more quickly.
   426  		for _, c := range comparers {
   427  			cmp, err = c.compare(out.Rows[i], out.Rows[j])
   428  			if err != nil {
   429  				return true
   430  			}
   431  			if cmp == 0 {
   432  				continue
   433  			}
   434  			return cmp < 0
   435  		}
   436  		return true
   437  	})
   438  
   439  	return out.Truncate(route.TruncateColumnCount), err
   440  }
   441  
   442  func (route *Route) description() PrimitiveDescription {
   443  	other := map[string]any{
   444  		"Query":      route.Query,
   445  		"Table":      route.GetTableName(),
   446  		"FieldQuery": route.FieldQuery,
   447  	}
   448  	if route.Vindex != nil {
   449  		other["Vindex"] = route.Vindex.String()
   450  	}
   451  	if route.Values != nil {
   452  		formattedValues := make([]string, 0, len(route.Values))
   453  		for _, value := range route.Values {
   454  			formattedValues = append(formattedValues, evalengine.FormatExpr(value))
   455  		}
   456  		other["Values"] = formattedValues
   457  	}
   458  	if len(route.SysTableTableSchema) != 0 {
   459  		sysTabSchema := "["
   460  		for idx, tableSchema := range route.SysTableTableSchema {
   461  			if idx != 0 {
   462  				sysTabSchema += ", "
   463  			}
   464  			sysTabSchema += evalengine.FormatExpr(tableSchema)
   465  		}
   466  		sysTabSchema += "]"
   467  		other["SysTableTableSchema"] = sysTabSchema
   468  	}
   469  	if len(route.SysTableTableName) != 0 {
   470  		var sysTableName []string
   471  		for k, v := range route.SysTableTableName {
   472  			sysTableName = append(sysTableName, k+":"+evalengine.FormatExpr(v))
   473  		}
   474  		sort.Strings(sysTableName)
   475  		other["SysTableTableName"] = "[" + strings.Join(sysTableName, ", ") + "]"
   476  	}
   477  	orderBy := GenericJoin(route.OrderBy, orderByToString)
   478  	if orderBy != "" {
   479  		other["OrderBy"] = orderBy
   480  	}
   481  	if route.TruncateColumnCount > 0 {
   482  		other["ResultColumns"] = route.TruncateColumnCount
   483  	}
   484  	if route.ScatterErrorsAsWarnings {
   485  		other["ScatterErrorsAsWarnings"] = true
   486  	}
   487  	if route.QueryTimeout > 0 {
   488  		other["QueryTimeout"] = route.QueryTimeout
   489  	}
   490  	return PrimitiveDescription{
   491  		OperatorType:      "Route",
   492  		Variant:           route.Opcode.String(),
   493  		Keyspace:          route.Keyspace,
   494  		TargetDestination: route.TargetDestination,
   495  		Other:             other,
   496  	}
   497  }
   498  
   499  func (route *Route) executeAfterLookup(
   500  	ctx context.Context,
   501  	vcursor VCursor,
   502  	bindVars map[string]*querypb.BindVariable,
   503  	wantfields bool,
   504  	ids []sqltypes.Value,
   505  	dest []key.Destination,
   506  ) (*sqltypes.Result, error) {
   507  	protoIds := make([]*querypb.Value, 0, len(ids))
   508  	for _, id := range ids {
   509  		protoIds = append(protoIds, sqltypes.ValueToProto(id))
   510  	}
   511  	rss, _, err := vcursor.ResolveDestinations(ctx, route.Keyspace.Name, protoIds, dest)
   512  	if err != nil {
   513  		return nil, err
   514  	}
   515  	bvs := make([]map[string]*querypb.BindVariable, len(rss))
   516  	for i := range bvs {
   517  		bvs[i] = bindVars
   518  	}
   519  	return route.executeShards(ctx, vcursor, bindVars, wantfields, rss, bvs)
   520  }
   521  
   522  func (route *Route) streamExecuteAfterLookup(
   523  	ctx context.Context,
   524  	vcursor VCursor,
   525  	bindVars map[string]*querypb.BindVariable,
   526  	wantfields bool,
   527  	callback func(*sqltypes.Result) error,
   528  	ids []sqltypes.Value,
   529  	dest []key.Destination,
   530  ) error {
   531  	protoIds := make([]*querypb.Value, 0, len(ids))
   532  	for _, id := range ids {
   533  		protoIds = append(protoIds, sqltypes.ValueToProto(id))
   534  	}
   535  	rss, _, err := vcursor.ResolveDestinations(ctx, route.Keyspace.Name, protoIds, dest)
   536  	if err != nil {
   537  		return err
   538  	}
   539  	bvs := make([]map[string]*querypb.BindVariable, len(rss))
   540  	for i := range bvs {
   541  		bvs[i] = bindVars
   542  	}
   543  	return route.streamExecuteShards(ctx, vcursor, bindVars, wantfields, callback, rss, bvs)
   544  }
   545  
   546  func execShard(
   547  	ctx context.Context,
   548  	primitive Primitive,
   549  	vcursor VCursor,
   550  	query string,
   551  	bindVars map[string]*querypb.BindVariable,
   552  	rs *srvtopo.ResolvedShard,
   553  	rollbackOnError, canAutocommit bool,
   554  ) (*sqltypes.Result, error) {
   555  	autocommit := canAutocommit && vcursor.AutocommitApproval()
   556  	result, errs := vcursor.ExecuteMultiShard(ctx, primitive, []*srvtopo.ResolvedShard{rs}, []*querypb.BoundQuery{
   557  		{
   558  			Sql:           query,
   559  			BindVariables: bindVars,
   560  		},
   561  	}, rollbackOnError, autocommit)
   562  	return result, vterrors.Aggregate(errs)
   563  }
   564  
   565  func getQueries(query string, bvs []map[string]*querypb.BindVariable) []*querypb.BoundQuery {
   566  	queries := make([]*querypb.BoundQuery, len(bvs))
   567  	for i, bv := range bvs {
   568  		queries[i] = &querypb.BoundQuery{
   569  			Sql:           query,
   570  			BindVariables: bv,
   571  		}
   572  	}
   573  	return queries
   574  }
   575  
   576  func orderByToString(in any) string {
   577  	return in.(OrderByParams).String()
   578  }