vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/vindex_func.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"fmt"
    23  
    24  	"vitess.io/vitess/go/vt/vterrors"
    25  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    26  
    27  	"vitess.io/vitess/go/sqltypes"
    28  	"vitess.io/vitess/go/vt/key"
    29  	"vitess.io/vitess/go/vt/vtgate/vindexes"
    30  
    31  	querypb "vitess.io/vitess/go/vt/proto/query"
    32  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    33  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    34  )
    35  
    36  var _ Primitive = (*VindexFunc)(nil)
    37  
    38  // VindexFunc is a primitive that performs vindex functions.
    39  type VindexFunc struct {
    40  	Opcode VindexOpcode
    41  	// Fields is the field info for the result.
    42  	Fields []*querypb.Field
    43  	// Cols contains source column numbers: 0 for id, 1 for keyspace_id.
    44  	Cols []int
    45  	// TODO(sougou): add support for MultiColumn.
    46  	Vindex vindexes.SingleColumn
    47  	Value  evalengine.Expr
    48  
    49  	// VindexFunc does not take inputs
    50  	noInputs
    51  
    52  	// VindexFunc does not need to work inside a tx
    53  	noTxNeeded
    54  }
    55  
    56  // VindexOpcode is the opcode for a VindexFunc.
    57  type VindexOpcode int
    58  
    59  // These are opcode values for VindexFunc.
    60  const (
    61  	VindexNone = VindexOpcode(iota)
    62  	VindexMap
    63  	NumVindexCodes
    64  )
    65  
    66  var vindexOpcodeName = map[VindexOpcode]string{
    67  	VindexMap: "VindexMap",
    68  }
    69  
    70  // MarshalJSON serializes the VindexOpcode into a JSON representation.
    71  // It's used for testing and diagnostics.
    72  func (code VindexOpcode) MarshalJSON() ([]byte, error) {
    73  	return json.Marshal(vindexOpcodeName[code])
    74  }
    75  
    76  // RouteType returns a description of the query routing type used by the primitive
    77  func (vf *VindexFunc) RouteType() string {
    78  	return vindexOpcodeName[vf.Opcode]
    79  }
    80  
    81  // GetKeyspaceName specifies the Keyspace that this primitive routes to.
    82  func (vf *VindexFunc) GetKeyspaceName() string {
    83  	return ""
    84  }
    85  
    86  // GetTableName specifies the table that this primitive routes to.
    87  func (vf *VindexFunc) GetTableName() string {
    88  	return ""
    89  }
    90  
    91  // TryExecute performs a non-streaming exec.
    92  func (vf *VindexFunc) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    93  	return vf.mapVindex(ctx, vcursor, bindVars)
    94  }
    95  
    96  // TryStreamExecute performs a streaming exec.
    97  func (vf *VindexFunc) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
    98  	r, err := vf.mapVindex(ctx, vcursor, bindVars)
    99  	if err != nil {
   100  		return err
   101  	}
   102  	if err := callback(r.Metadata()); err != nil {
   103  		return err
   104  	}
   105  	return callback(&sqltypes.Result{Rows: r.Rows})
   106  }
   107  
   108  // GetFields fetches the field info.
   109  func (vf *VindexFunc) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   110  	return &sqltypes.Result{Fields: vf.Fields}, nil
   111  }
   112  
   113  func (vf *VindexFunc) mapVindex(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   114  	env := evalengine.EnvWithBindVars(bindVars, vcursor.ConnCollation())
   115  	k, err := env.Evaluate(vf.Value)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	var values []sqltypes.Value
   120  	if k.Value().Type() == querypb.Type_TUPLE {
   121  		values = k.TupleValues()
   122  	} else {
   123  		values = append(values, k.Value())
   124  	}
   125  	result := &sqltypes.Result{
   126  		Fields: vf.Fields,
   127  	}
   128  	destinations, err := vf.Vindex.Map(ctx, vcursor, values)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	if len(destinations) != len(values) {
   133  		// should never happen
   134  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Vindex.Map() length mismatch: input values count is %d, output destinations count is %d",
   135  			len(values), len(destinations))
   136  	}
   137  	for i, value := range values {
   138  		vkey, err := evalengine.Cast(value, sqltypes.VarBinary)
   139  		if err != nil {
   140  			return nil, err
   141  		}
   142  		switch d := destinations[i].(type) {
   143  		case key.DestinationKeyRange:
   144  			if d.KeyRange != nil {
   145  				row, err := vf.buildRow(vkey, nil, d.KeyRange)
   146  				if err != nil {
   147  					return result, err
   148  				}
   149  				result.Rows = append(result.Rows, row)
   150  			}
   151  		case key.DestinationKeyspaceID:
   152  			if len(d) > 0 {
   153  				if vcursor != nil {
   154  					resolvedShards, _, err := vcursor.ResolveDestinations(ctx, vcursor.GetKeyspace(), nil, []key.Destination{d})
   155  					if err != nil {
   156  						return nil, err
   157  					}
   158  					if len(resolvedShards) > 0 {
   159  						kr, err := key.ParseShardingSpec(resolvedShards[0].Target.Shard)
   160  						if err != nil {
   161  							return nil, err
   162  						}
   163  						row, err := vf.buildRow(vkey, d, kr[0])
   164  						if err != nil {
   165  							return result, err
   166  						}
   167  						result.Rows = append(result.Rows, row)
   168  						break
   169  					}
   170  				}
   171  
   172  				row, err := vf.buildRow(vkey, d, nil)
   173  				if err != nil {
   174  					return result, err
   175  				}
   176  				result.Rows = append(result.Rows, row)
   177  			}
   178  		case key.DestinationKeyspaceIDs:
   179  			for _, ksid := range d {
   180  				row, err := vf.buildRow(vkey, ksid, nil)
   181  				if err != nil {
   182  					return result, err
   183  				}
   184  				result.Rows = append(result.Rows, row)
   185  			}
   186  		case key.DestinationNone:
   187  			// Nothing to do.
   188  		default:
   189  			return result, vterrors.NewErrorf(vtrpcpb.Code_INTERNAL, vterrors.WrongTypeForVar, "unexpected destination type: %T", d)
   190  		}
   191  	}
   192  	return result, nil
   193  }
   194  
   195  func (vf *VindexFunc) buildRow(id sqltypes.Value, ksid []byte, kr *topodatapb.KeyRange) ([]sqltypes.Value, error) {
   196  	row := make([]sqltypes.Value, 0, len(vf.Fields))
   197  	for _, col := range vf.Cols {
   198  		switch col {
   199  		case 0:
   200  			row = append(row, id)
   201  		case 1:
   202  			if ksid != nil {
   203  				row = append(row, sqltypes.MakeTrusted(sqltypes.VarBinary, ksid))
   204  			} else {
   205  				row = append(row, sqltypes.NULL)
   206  			}
   207  		case 2:
   208  			if kr != nil {
   209  				row = append(row, sqltypes.MakeTrusted(sqltypes.VarBinary, kr.Start))
   210  			} else {
   211  				row = append(row, sqltypes.NULL)
   212  			}
   213  		case 3:
   214  			if kr != nil {
   215  				row = append(row, sqltypes.MakeTrusted(sqltypes.VarBinary, kr.End))
   216  			} else {
   217  				row = append(row, sqltypes.NULL)
   218  			}
   219  		case 4:
   220  			if ksid != nil {
   221  				row = append(row, sqltypes.NewVarBinary(fmt.Sprintf("%x", ksid)))
   222  			} else {
   223  				row = append(row, sqltypes.NULL)
   224  			}
   225  		case 5:
   226  			if ksid != nil {
   227  				row = append(row, sqltypes.NewVarBinary(key.KeyRangeString(kr)))
   228  			} else {
   229  				row = append(row, sqltypes.NULL)
   230  			}
   231  		default:
   232  			return row, vterrors.NewErrorf(vtrpcpb.Code_OUT_OF_RANGE, vterrors.BadFieldError, "column %v out of range", col)
   233  		}
   234  	}
   235  	return row, nil
   236  }
   237  
   238  func (vf *VindexFunc) description() PrimitiveDescription {
   239  	fields := map[string]string{}
   240  	for _, field := range vf.Fields {
   241  		fields[field.Name] = field.Type.String()
   242  	}
   243  
   244  	other := map[string]any{
   245  		"Fields":  fields,
   246  		"Columns": vf.Cols,
   247  		"Value":   evalengine.FormatExpr(vf.Value),
   248  	}
   249  	if vf.Vindex != nil {
   250  		other["Vindex"] = vf.Vindex.String()
   251  	}
   252  
   253  	return PrimitiveDescription{
   254  		OperatorType: "VindexFunc",
   255  		Variant:      vindexOpcodeName[vf.Opcode],
   256  		Other:        other,
   257  	}
   258  }