vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/semi_join.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"vitess.io/vitess/go/sqltypes"
    25  	querypb "vitess.io/vitess/go/vt/proto/query"
    26  )
    27  
    28  var _ Primitive = (*SemiJoin)(nil)
    29  
    30  // SemiJoin specifies the parameters for a SemiJoin primitive.
    31  type SemiJoin struct {
    32  	// Left and Right are the LHS and RHS primitives
    33  	// of the SemiJoin. They can be any primitive.
    34  	Left, Right Primitive `json:",omitempty"`
    35  
    36  	// Cols defines which columns from the left
    37  	// results should be used to build the
    38  	// return result. For results coming from the
    39  	// left query, the index values go as -1, -2, etc.
    40  	// If Cols is {-1, -2}, it means that
    41  	// the returned result will be {Left0, Left1}.
    42  	Cols []int `json:",omitempty"`
    43  
    44  	// Vars defines the list of SemiJoinVars that need to
    45  	// be built from the LHS result before invoking
    46  	// the RHS subqquery.
    47  	Vars map[string]int `json:",omitempty"`
    48  }
    49  
    50  // TryExecute performs a non-streaming exec.
    51  func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    52  	joinVars := make(map[string]*querypb.BindVariable)
    53  	lresult, err := vcursor.ExecutePrimitive(ctx, jn.Left, bindVars, wantfields)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	result := &sqltypes.Result{Fields: projectFields(lresult.Fields, jn.Cols)}
    58  	for _, lrow := range lresult.Rows {
    59  		for k, col := range jn.Vars {
    60  			joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
    61  		}
    62  		rresult, err := vcursor.ExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false)
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  		if len(rresult.Rows) > 0 {
    67  			result.Rows = append(result.Rows, projectRows(lrow, jn.Cols))
    68  		}
    69  	}
    70  	return result, nil
    71  }
    72  
    73  // TryStreamExecute performs a streaming exec.
    74  func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
    75  	joinVars := make(map[string]*querypb.BindVariable)
    76  	err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
    77  		result := &sqltypes.Result{Fields: projectFields(lresult.Fields, jn.Cols)}
    78  		for _, lrow := range lresult.Rows {
    79  			for k, col := range jn.Vars {
    80  				joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
    81  			}
    82  			rowAdded := false
    83  			err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error {
    84  				if len(rresult.Rows) > 0 && !rowAdded {
    85  					result.Rows = append(result.Rows, projectRows(lrow, jn.Cols))
    86  					rowAdded = true
    87  				}
    88  				return nil
    89  			})
    90  			if err != nil {
    91  				return err
    92  			}
    93  		}
    94  		return callback(result)
    95  	})
    96  	return err
    97  }
    98  
    99  // GetFields fetches the field info.
   100  func (jn *SemiJoin) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   101  	return jn.Left.GetFields(ctx, vcursor, bindVars)
   102  }
   103  
   104  // Inputs returns the input primitives for this SemiJoin
   105  func (jn *SemiJoin) Inputs() []Primitive {
   106  	return []Primitive{jn.Left, jn.Right}
   107  }
   108  
   109  // RouteType returns a description of the query routing type used by the primitive
   110  func (jn *SemiJoin) RouteType() string {
   111  	return "SemiJoin"
   112  }
   113  
   114  // GetKeyspaceName specifies the Keyspace that this primitive routes to.
   115  func (jn *SemiJoin) GetKeyspaceName() string {
   116  	if jn.Left.GetKeyspaceName() == jn.Right.GetKeyspaceName() {
   117  		return jn.Left.GetKeyspaceName()
   118  	}
   119  	return jn.Left.GetKeyspaceName() + "_" + jn.Right.GetKeyspaceName()
   120  }
   121  
   122  // GetTableName specifies the table that this primitive routes to.
   123  func (jn *SemiJoin) GetTableName() string {
   124  	return jn.Left.GetTableName() + "_" + jn.Right.GetTableName()
   125  }
   126  
   127  // NeedsTransaction implements the Primitive interface
   128  func (jn *SemiJoin) NeedsTransaction() bool {
   129  	return jn.Right.NeedsTransaction() || jn.Left.NeedsTransaction()
   130  }
   131  
   132  func (jn *SemiJoin) description() PrimitiveDescription {
   133  	other := map[string]any{
   134  		"TableName":        jn.GetTableName(),
   135  		"ProjectedIndexes": strings.Trim(strings.Join(strings.Fields(fmt.Sprint(jn.Cols)), ","), "[]"),
   136  	}
   137  	if len(jn.Vars) > 0 {
   138  		other["JoinVars"] = orderedStringIntMap(jn.Vars)
   139  	}
   140  	return PrimitiveDescription{
   141  		OperatorType: "SemiJoin",
   142  		Other:        other,
   143  	}
   144  }
   145  
   146  func projectFields(lfields []*querypb.Field, cols []int) []*querypb.Field {
   147  	if lfields == nil {
   148  		return nil
   149  	}
   150  	fields := make([]*querypb.Field, len(cols))
   151  	for i, index := range cols {
   152  		fields[i] = lfields[-index-1]
   153  	}
   154  	return fields
   155  }
   156  
   157  func projectRows(lrow []sqltypes.Value, cols []int) []sqltypes.Value {
   158  	row := make([]sqltypes.Value, len(cols))
   159  	for i, index := range cols {
   160  		if index < 0 {
   161  			row[i] = lrow[-index-1]
   162  		}
   163  	}
   164  	return row
   165  }