vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/hash_join.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"vitess.io/vitess/go/mysql/collations"
    25  	"vitess.io/vitess/go/sqltypes"
    26  	querypb "vitess.io/vitess/go/vt/proto/query"
    27  	"vitess.io/vitess/go/vt/sqlparser"
    28  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    29  )
    30  
    31  var _ Primitive = (*HashJoin)(nil)
    32  
    33  // HashJoin specifies the parameters for a join primitive
    34  // Hash joins work by fetch all the input from the LHS, and building a hash map, known as the probe table, for this input.
    35  // The key to the map is the hashcode of the value for column that we are joining by.
    36  // Then the RHS is fetched, and we can check if the rows from the RHS matches any from the LHS.
    37  // When they match by hash code, we double-check that we are not working with a false positive by comparing the values.
    38  type HashJoin struct {
    39  	Opcode JoinOpcode
    40  
    41  	// Left and Right are the LHS and RHS primitives
    42  	// of the Join. They can be any primitive.
    43  	Left, Right Primitive `json:",omitempty"`
    44  
    45  	// Cols defines which columns from the left
    46  	// or right results should be used to build the
    47  	// return result. For results coming from the
    48  	// left query, the index values go as -1, -2, etc.
    49  	// For the right query, they're 1, 2, etc.
    50  	// If Cols is {-1, -2, 1, 2}, it means that
    51  	// the returned result will be {Left0, Left1, Right0, Right1}.
    52  	Cols []int `json:",omitempty"`
    53  
    54  	// The keys correspond to the column offset in the inputs where
    55  	// the join columns can be found
    56  	LHSKey, RHSKey int
    57  
    58  	// The join condition. Used for plan descriptions
    59  	ASTPred sqlparser.Expr
    60  
    61  	// collation and type are used to hash the incoming values correctly
    62  	Collation      collations.ID
    63  	ComparisonType querypb.Type
    64  }
    65  
    66  // TryExecute implements the Primitive interface
    67  func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    68  	lresult, err := vcursor.ExecutePrimitive(ctx, hj.Left, bindVars, wantfields)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	// build the probe table from the LHS result
    74  	probeTable, err := hj.buildProbeTable(lresult)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	rresult, err := vcursor.ExecutePrimitive(ctx, hj.Right, bindVars, wantfields)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	result := &sqltypes.Result{
    85  		Fields: joinFields(lresult.Fields, rresult.Fields, hj.Cols),
    86  	}
    87  
    88  	for _, currentRHSRow := range rresult.Rows {
    89  		joinVal := currentRHSRow[hj.RHSKey]
    90  		if joinVal.IsNull() {
    91  			continue
    92  		}
    93  		hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType)
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  		lftRows := probeTable[hashcode]
    98  		for _, currentLHSRow := range lftRows {
    99  			lhsVal := currentLHSRow[hj.LHSKey]
   100  			// hash codes can give false positives, so we need to check with a real comparison as well
   101  			cmp, err := evalengine.NullsafeCompare(joinVal, lhsVal, collations.Unknown)
   102  			if err != nil {
   103  				return nil, err
   104  			}
   105  
   106  			if cmp == 0 {
   107  				// we have a match!
   108  				result.Rows = append(result.Rows, joinRows(currentLHSRow, currentRHSRow, hj.Cols))
   109  			}
   110  		}
   111  	}
   112  
   113  	return result, nil
   114  }
   115  
   116  func (hj *HashJoin) buildProbeTable(lresult *sqltypes.Result) (map[evalengine.HashCode][]sqltypes.Row, error) {
   117  	probeTable := map[evalengine.HashCode][]sqltypes.Row{}
   118  	for _, current := range lresult.Rows {
   119  		joinVal := current[hj.LHSKey]
   120  		if joinVal.IsNull() {
   121  			continue
   122  		}
   123  		hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType)
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  		probeTable[hashcode] = append(probeTable[hashcode], current)
   128  	}
   129  	return probeTable, nil
   130  }
   131  
   132  // TryStreamExecute implements the Primitive interface
   133  func (hj *HashJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
   134  	// build the probe table from the LHS result
   135  	probeTable := map[evalengine.HashCode][]sqltypes.Row{}
   136  	var lfields []*querypb.Field
   137  	err := vcursor.StreamExecutePrimitive(ctx, hj.Left, bindVars, wantfields, func(result *sqltypes.Result) error {
   138  		if len(lfields) == 0 && len(result.Fields) != 0 {
   139  			lfields = result.Fields
   140  		}
   141  		for _, current := range result.Rows {
   142  			joinVal := current[hj.LHSKey]
   143  			if joinVal.IsNull() {
   144  				continue
   145  			}
   146  			hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType)
   147  			if err != nil {
   148  				return err
   149  			}
   150  			probeTable[hashcode] = append(probeTable[hashcode], current)
   151  		}
   152  		return nil
   153  	})
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	return vcursor.StreamExecutePrimitive(ctx, hj.Right, bindVars, wantfields, func(result *sqltypes.Result) error {
   159  		// compare the results coming from the RHS with the probe-table
   160  		res := &sqltypes.Result{}
   161  		if len(result.Fields) != 0 {
   162  			res = &sqltypes.Result{
   163  				Fields: joinFields(lfields, result.Fields, hj.Cols),
   164  			}
   165  		}
   166  		for _, currentRHSRow := range result.Rows {
   167  			joinVal := currentRHSRow[hj.RHSKey]
   168  			if joinVal.IsNull() {
   169  				continue
   170  			}
   171  			hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType)
   172  			if err != nil {
   173  				return err
   174  			}
   175  			lftRows := probeTable[hashcode]
   176  			for _, currentLHSRow := range lftRows {
   177  				lhsVal := currentLHSRow[hj.LHSKey]
   178  				// hash codes can give false positives, so we need to check with a real comparison as well
   179  				cmp, err := evalengine.NullsafeCompare(joinVal, lhsVal, hj.Collation)
   180  				if err != nil {
   181  					return err
   182  				}
   183  
   184  				if cmp == 0 {
   185  					// we have a match!
   186  					res.Rows = append(res.Rows, joinRows(currentLHSRow, currentRHSRow, hj.Cols))
   187  				}
   188  			}
   189  		}
   190  		if len(res.Rows) != 0 || len(res.Fields) != 0 {
   191  			return callback(res)
   192  		}
   193  		return nil
   194  	})
   195  }
   196  
   197  // RouteType implements the Primitive interface
   198  func (hj *HashJoin) RouteType() string {
   199  	return "HashJoin"
   200  }
   201  
   202  // GetKeyspaceName implements the Primitive interface
   203  func (hj *HashJoin) GetKeyspaceName() string {
   204  	if hj.Left.GetKeyspaceName() == hj.Right.GetKeyspaceName() {
   205  		return hj.Left.GetKeyspaceName()
   206  	}
   207  	return hj.Left.GetKeyspaceName() + "_" + hj.Right.GetKeyspaceName()
   208  }
   209  
   210  // GetTableName implements the Primitive interface
   211  func (hj *HashJoin) GetTableName() string {
   212  	return hj.Left.GetTableName() + "_" + hj.Right.GetTableName()
   213  }
   214  
   215  // GetFields implements the Primitive interface
   216  func (hj *HashJoin) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   217  	joinVars := make(map[string]*querypb.BindVariable)
   218  	lresult, err := hj.Left.GetFields(ctx, vcursor, bindVars)
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  	result := &sqltypes.Result{}
   223  	rresult, err := hj.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars))
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	result.Fields = joinFields(lresult.Fields, rresult.Fields, hj.Cols)
   228  	return result, nil
   229  }
   230  
   231  // NeedsTransaction implements the Primitive interface
   232  func (hj *HashJoin) NeedsTransaction() bool {
   233  	return hj.Right.NeedsTransaction() || hj.Left.NeedsTransaction()
   234  }
   235  
   236  // Inputs implements the Primitive interface
   237  func (hj *HashJoin) Inputs() []Primitive {
   238  	return []Primitive{hj.Left, hj.Right}
   239  }
   240  
   241  // description implements the Primitive interface
   242  func (hj *HashJoin) description() PrimitiveDescription {
   243  	other := map[string]any{
   244  		"TableName":         hj.GetTableName(),
   245  		"JoinColumnIndexes": strings.Trim(strings.Join(strings.Fields(fmt.Sprint(hj.Cols)), ","), "[]"),
   246  		"Predicate":         sqlparser.String(hj.ASTPred),
   247  		"ComparisonType":    hj.ComparisonType.String(),
   248  	}
   249  	coll := collations.Local().LookupByID(hj.Collation)
   250  	if coll != nil {
   251  		other["Collation"] = coll.Name()
   252  	}
   253  	return PrimitiveDescription{
   254  		OperatorType: "Join",
   255  		Variant:      "Hash" + hj.Opcode.String(),
   256  		Other:        other,
   257  	}
   258  }