vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/pullout_subquery.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  
    23  	"vitess.io/vitess/go/sqltypes"
    24  	"vitess.io/vitess/go/vt/vterrors"
    25  
    26  	querypb "vitess.io/vitess/go/vt/proto/query"
    27  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    28  )
    29  
    30  var _ Primitive = (*PulloutSubquery)(nil)
    31  
    32  // PulloutSubquery executes a "pulled out" subquery and stores
    33  // the results in a bind variable.
    34  type PulloutSubquery struct {
    35  	Opcode PulloutOpcode
    36  
    37  	// SubqueryResult and HasValues are used to send in the bindvar used in the query to the underlying primitive
    38  	SubqueryResult string
    39  	HasValues      string
    40  
    41  	Subquery   Primitive
    42  	Underlying Primitive
    43  }
    44  
    45  // Inputs returns the input primitives for this join
    46  func (ps *PulloutSubquery) Inputs() []Primitive {
    47  	return []Primitive{ps.Subquery, ps.Underlying}
    48  }
    49  
    50  // RouteType returns a description of the query routing type used by the primitive
    51  func (ps *PulloutSubquery) RouteType() string {
    52  	return ps.Opcode.String()
    53  }
    54  
    55  // GetKeyspaceName specifies the Keyspace that this primitive routes to.
    56  func (ps *PulloutSubquery) GetKeyspaceName() string {
    57  	return ps.Underlying.GetKeyspaceName()
    58  }
    59  
    60  // GetTableName specifies the table that this primitive routes to.
    61  func (ps *PulloutSubquery) GetTableName() string {
    62  	return ps.Underlying.GetTableName()
    63  }
    64  
    65  // TryExecute satisfies the Primitive interface.
    66  func (ps *PulloutSubquery) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    67  	combinedVars, err := ps.execSubquery(ctx, vcursor, bindVars)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  	return vcursor.ExecutePrimitive(ctx, ps.Underlying, combinedVars, wantfields)
    72  }
    73  
    74  // TryStreamExecute performs a streaming exec.
    75  func (ps *PulloutSubquery) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
    76  	combinedVars, err := ps.execSubquery(ctx, vcursor, bindVars)
    77  	if err != nil {
    78  		return err
    79  	}
    80  	return vcursor.StreamExecutePrimitive(ctx, ps.Underlying, combinedVars, wantfields, callback)
    81  }
    82  
    83  // GetFields fetches the field info.
    84  func (ps *PulloutSubquery) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
    85  	combinedVars := make(map[string]*querypb.BindVariable, len(bindVars)+1)
    86  	for k, v := range bindVars {
    87  		combinedVars[k] = v
    88  	}
    89  	switch ps.Opcode {
    90  	case PulloutValue:
    91  		combinedVars[ps.SubqueryResult] = sqltypes.NullBindVariable
    92  	case PulloutIn, PulloutNotIn:
    93  		combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(0)
    94  		combinedVars[ps.SubqueryResult] = &querypb.BindVariable{
    95  			Type:   querypb.Type_TUPLE,
    96  			Values: []*querypb.Value{sqltypes.ValueToProto(sqltypes.NewInt64(0))},
    97  		}
    98  	case PulloutExists:
    99  		combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(0)
   100  	}
   101  	return ps.Underlying.GetFields(ctx, vcursor, combinedVars)
   102  }
   103  
   104  // NeedsTransaction implements the Primitive interface
   105  func (ps *PulloutSubquery) NeedsTransaction() bool {
   106  	return ps.Subquery.NeedsTransaction() || ps.Underlying.NeedsTransaction()
   107  }
   108  
   109  var (
   110  	errSqRow    = vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "subquery returned more than one row")
   111  	errSqColumn = vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "subquery returned more than one column")
   112  )
   113  
   114  func (ps *PulloutSubquery) execSubquery(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (map[string]*querypb.BindVariable, error) {
   115  	subqueryBindVars := make(map[string]*querypb.BindVariable, len(bindVars))
   116  	for k, v := range bindVars {
   117  		subqueryBindVars[k] = v
   118  	}
   119  	result, err := vcursor.ExecutePrimitive(ctx, ps.Subquery, subqueryBindVars, false)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	combinedVars := make(map[string]*querypb.BindVariable, len(bindVars)+1)
   124  	for k, v := range bindVars {
   125  		combinedVars[k] = v
   126  	}
   127  	switch ps.Opcode {
   128  	case PulloutValue:
   129  		switch len(result.Rows) {
   130  		case 0:
   131  			combinedVars[ps.SubqueryResult] = sqltypes.NullBindVariable
   132  		case 1:
   133  			if len(result.Rows[0]) != 1 {
   134  				return nil, errSqColumn
   135  			}
   136  			combinedVars[ps.SubqueryResult] = sqltypes.ValueBindVariable(result.Rows[0][0])
   137  		default:
   138  			return nil, errSqRow
   139  		}
   140  	case PulloutIn, PulloutNotIn:
   141  		switch len(result.Rows) {
   142  		case 0:
   143  			combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(0)
   144  			// Add a bogus value. It will not be checked.
   145  			combinedVars[ps.SubqueryResult] = &querypb.BindVariable{
   146  				Type:   querypb.Type_TUPLE,
   147  				Values: []*querypb.Value{sqltypes.ValueToProto(sqltypes.NewInt64(0))},
   148  			}
   149  		default:
   150  			if len(result.Rows[0]) != 1 {
   151  				return nil, errSqColumn
   152  			}
   153  			combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(1)
   154  			values := &querypb.BindVariable{
   155  				Type:   querypb.Type_TUPLE,
   156  				Values: make([]*querypb.Value, len(result.Rows)),
   157  			}
   158  			for i, v := range result.Rows {
   159  				values.Values[i] = sqltypes.ValueToProto(v[0])
   160  			}
   161  			combinedVars[ps.SubqueryResult] = values
   162  		}
   163  	case PulloutExists:
   164  		switch len(result.Rows) {
   165  		case 0:
   166  			combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(0)
   167  		default:
   168  			combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(1)
   169  		}
   170  	}
   171  	return combinedVars, nil
   172  }
   173  
   174  func (ps *PulloutSubquery) description() PrimitiveDescription {
   175  	other := map[string]any{}
   176  	var pulloutVars []string
   177  	if ps.HasValues != "" {
   178  		pulloutVars = append(pulloutVars, ps.HasValues)
   179  	}
   180  	if ps.SubqueryResult != "" {
   181  		pulloutVars = append(pulloutVars, ps.SubqueryResult)
   182  	}
   183  	if len(pulloutVars) > 0 {
   184  		other["PulloutVars"] = pulloutVars
   185  	}
   186  	return PrimitiveDescription{
   187  		OperatorType: "Subquery",
   188  		Variant:      ps.Opcode.String(),
   189  		Other:        other,
   190  	}
   191  }
   192  
   193  // PulloutOpcode is a number representing the opcode
   194  // for the PulloutSubquery primitive.
   195  type PulloutOpcode int
   196  
   197  // This is the list of PulloutOpcode values.
   198  const (
   199  	PulloutValue = PulloutOpcode(iota)
   200  	PulloutIn
   201  	PulloutNotIn
   202  	PulloutExists
   203  )
   204  
   205  var pulloutName = map[PulloutOpcode]string{
   206  	PulloutValue:  "PulloutValue",
   207  	PulloutIn:     "PulloutIn",
   208  	PulloutNotIn:  "PulloutNotIn",
   209  	PulloutExists: "PulloutExists",
   210  }
   211  
   212  func (code PulloutOpcode) String() string {
   213  	return pulloutName[code]
   214  }
   215  
   216  // MarshalJSON serializes the PulloutOpcode as a JSON string.
   217  // It's used for testing and diagnostics.
   218  func (code PulloutOpcode) MarshalJSON() ([]byte, error) {
   219  	return ([]byte)(fmt.Sprintf("\"%s\"", code.String())), nil
   220  }