vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/join.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"vitess.io/vitess/go/sqltypes"
    25  	querypb "vitess.io/vitess/go/vt/proto/query"
    26  )
    27  
    28  var _ Primitive = (*Join)(nil)
    29  
    30  // Join specifies the parameters for a join primitive.
    31  type Join struct {
    32  	Opcode JoinOpcode
    33  	// Left and Right are the LHS and RHS primitives
    34  	// of the Join. They can be any primitive.
    35  	Left, Right Primitive `json:",omitempty"`
    36  
    37  	// Cols defines which columns from the left
    38  	// or right results should be used to build the
    39  	// return result. For results coming from the
    40  	// left query, the index values go as -1, -2, etc.
    41  	// For the right query, they're 1, 2, etc.
    42  	// If Cols is {-1, -2, 1, 2}, it means that
    43  	// the returned result will be {Left0, Left1, Right0, Right1}.
    44  	Cols []int `json:",omitempty"`
    45  
    46  	// Vars defines the list of joinVars that need to
    47  	// be built from the LHS result before invoking
    48  	// the RHS subqquery.
    49  	Vars map[string]int `json:",omitempty"`
    50  }
    51  
    52  // TryExecute performs a non-streaming exec.
    53  func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    54  	joinVars := make(map[string]*querypb.BindVariable)
    55  	lresult, err := vcursor.ExecutePrimitive(ctx, jn.Left, bindVars, wantfields)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	result := &sqltypes.Result{}
    60  	if len(lresult.Rows) == 0 && wantfields {
    61  		for k := range jn.Vars {
    62  			joinVars[k] = sqltypes.NullBindVariable
    63  		}
    64  		rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars))
    65  		if err != nil {
    66  			return nil, err
    67  		}
    68  		result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols)
    69  		return result, nil
    70  	}
    71  	for _, lrow := range lresult.Rows {
    72  		for k, col := range jn.Vars {
    73  			joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
    74  		}
    75  		rresult, err := vcursor.ExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), wantfields)
    76  		if err != nil {
    77  			return nil, err
    78  		}
    79  		if wantfields {
    80  			wantfields = false
    81  			result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols)
    82  		}
    83  		for _, rrow := range rresult.Rows {
    84  			result.Rows = append(result.Rows, joinRows(lrow, rrow, jn.Cols))
    85  		}
    86  		if jn.Opcode == LeftJoin && len(rresult.Rows) == 0 {
    87  			result.Rows = append(result.Rows, joinRows(lrow, nil, jn.Cols))
    88  		}
    89  		if vcursor.ExceedsMaxMemoryRows(len(result.Rows)) {
    90  			return nil, fmt.Errorf("in-memory row count exceeded allowed limit of %d", vcursor.MaxMemoryRows())
    91  		}
    92  	}
    93  	return result, nil
    94  }
    95  
    96  // TryStreamExecute performs a streaming exec.
    97  func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
    98  	joinVars := make(map[string]*querypb.BindVariable)
    99  	err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
   100  		for _, lrow := range lresult.Rows {
   101  			for k, col := range jn.Vars {
   102  				joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
   103  			}
   104  			rowSent := false
   105  			err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), wantfields, func(rresult *sqltypes.Result) error {
   106  				result := &sqltypes.Result{}
   107  				if wantfields {
   108  					// This code is currently unreachable because the first result
   109  					// will always be just the field info, which will cause the outer
   110  					// wantfields code path to be executed. But this may change in the future.
   111  					wantfields = false
   112  					result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols)
   113  				}
   114  				for _, rrow := range rresult.Rows {
   115  					result.Rows = append(result.Rows, joinRows(lrow, rrow, jn.Cols))
   116  				}
   117  				if len(rresult.Rows) != 0 {
   118  					rowSent = true
   119  				}
   120  				return callback(result)
   121  			})
   122  			if err != nil {
   123  				return err
   124  			}
   125  			if jn.Opcode == LeftJoin && !rowSent {
   126  				result := &sqltypes.Result{}
   127  				result.Rows = [][]sqltypes.Value{joinRows(
   128  					lrow,
   129  					nil,
   130  					jn.Cols,
   131  				)}
   132  				return callback(result)
   133  			}
   134  		}
   135  		if wantfields {
   136  			wantfields = false
   137  			for k := range jn.Vars {
   138  				joinVars[k] = sqltypes.NullBindVariable
   139  			}
   140  			result := &sqltypes.Result{}
   141  			rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars))
   142  			if err != nil {
   143  				return err
   144  			}
   145  			result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols)
   146  			return callback(result)
   147  		}
   148  		return nil
   149  	})
   150  	return err
   151  }
   152  
   153  // GetFields fetches the field info.
   154  func (jn *Join) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   155  	joinVars := make(map[string]*querypb.BindVariable)
   156  	lresult, err := jn.Left.GetFields(ctx, vcursor, bindVars)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	result := &sqltypes.Result{}
   161  	for k := range jn.Vars {
   162  		joinVars[k] = sqltypes.NullBindVariable
   163  	}
   164  	rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars))
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols)
   169  	return result, nil
   170  }
   171  
   172  // Inputs returns the input primitives for this join
   173  func (jn *Join) Inputs() []Primitive {
   174  	return []Primitive{jn.Left, jn.Right}
   175  }
   176  
   177  func joinFields(lfields, rfields []*querypb.Field, cols []int) []*querypb.Field {
   178  	fields := make([]*querypb.Field, len(cols))
   179  	for i, index := range cols {
   180  		if index < 0 {
   181  			fields[i] = lfields[-index-1]
   182  			continue
   183  		}
   184  		fields[i] = rfields[index-1]
   185  	}
   186  	return fields
   187  }
   188  
   189  func joinRows(lrow, rrow []sqltypes.Value, cols []int) []sqltypes.Value {
   190  	row := make([]sqltypes.Value, len(cols))
   191  	for i, index := range cols {
   192  		if index < 0 {
   193  			row[i] = lrow[-index-1]
   194  			continue
   195  		}
   196  		// rrow can be nil on left joins
   197  		if rrow != nil {
   198  			row[i] = rrow[index-1]
   199  		}
   200  	}
   201  	return row
   202  }
   203  
   204  // JoinOpcode is a number representing the opcode
   205  // for the Join primitive.
   206  type JoinOpcode int
   207  
   208  // This is the list of JoinOpcode values.
   209  const (
   210  	InnerJoin = JoinOpcode(iota)
   211  	LeftJoin
   212  )
   213  
   214  func (code JoinOpcode) String() string {
   215  	if code == InnerJoin {
   216  		return "Join"
   217  	}
   218  	return "LeftJoin"
   219  }
   220  
   221  // MarshalJSON serializes the JoinOpcode as a JSON string.
   222  // It's used for testing and diagnostics.
   223  func (code JoinOpcode) MarshalJSON() ([]byte, error) {
   224  	return ([]byte)(fmt.Sprintf("\"%s\"", code.String())), nil
   225  }
   226  
   227  // RouteType returns a description of the query routing type used by the primitive
   228  func (jn *Join) RouteType() string {
   229  	return "Join"
   230  }
   231  
   232  // GetKeyspaceName specifies the Keyspace that this primitive routes to.
   233  func (jn *Join) GetKeyspaceName() string {
   234  	if jn.Left.GetKeyspaceName() == jn.Right.GetKeyspaceName() {
   235  		return jn.Left.GetKeyspaceName()
   236  	}
   237  	return jn.Left.GetKeyspaceName() + "_" + jn.Right.GetKeyspaceName()
   238  }
   239  
   240  // GetTableName specifies the table that this primitive routes to.
   241  func (jn *Join) GetTableName() string {
   242  	return jn.Left.GetTableName() + "_" + jn.Right.GetTableName()
   243  }
   244  
   245  // NeedsTransaction implements the Primitive interface
   246  func (jn *Join) NeedsTransaction() bool {
   247  	return jn.Right.NeedsTransaction() || jn.Left.NeedsTransaction()
   248  }
   249  
   250  func combineVars(bv1, bv2 map[string]*querypb.BindVariable) map[string]*querypb.BindVariable {
   251  	out := make(map[string]*querypb.BindVariable)
   252  	for k, v := range bv1 {
   253  		out[k] = v
   254  	}
   255  	for k, v := range bv2 {
   256  		out[k] = v
   257  	}
   258  	return out
   259  }
   260  
   261  func (jn *Join) description() PrimitiveDescription {
   262  	other := map[string]any{
   263  		"TableName":         jn.GetTableName(),
   264  		"JoinColumnIndexes": jn.joinColsDescription(),
   265  	}
   266  	if len(jn.Vars) > 0 {
   267  		other["JoinVars"] = orderedStringIntMap(jn.Vars)
   268  	}
   269  	return PrimitiveDescription{
   270  		OperatorType: "Join",
   271  		Variant:      jn.Opcode.String(),
   272  		Other:        other,
   273  	}
   274  }
   275  
   276  func (jn *Join) joinColsDescription() string {
   277  	var joinCols []string
   278  	for _, col := range jn.Cols {
   279  		if col < 0 {
   280  			joinCols = append(joinCols, fmt.Sprintf("L:%d", -col-1))
   281  		} else {
   282  			joinCols = append(joinCols, fmt.Sprintf("R:%d", col-1))
   283  		}
   284  	}
   285  	joinColTxt := strings.Join(joinCols, ",")
   286  	return joinColTxt
   287  }