vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/semi_join.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "context" 21 "fmt" 22 "strings" 23 24 "vitess.io/vitess/go/sqltypes" 25 querypb "vitess.io/vitess/go/vt/proto/query" 26 ) 27 28 var _ Primitive = (*SemiJoin)(nil) 29 30 // SemiJoin specifies the parameters for a SemiJoin primitive. 31 type SemiJoin struct { 32 // Left and Right are the LHS and RHS primitives 33 // of the SemiJoin. They can be any primitive. 34 Left, Right Primitive `json:",omitempty"` 35 36 // Cols defines which columns from the left 37 // results should be used to build the 38 // return result. For results coming from the 39 // left query, the index values go as -1, -2, etc. 40 // If Cols is {-1, -2}, it means that 41 // the returned result will be {Left0, Left1}. 42 Cols []int `json:",omitempty"` 43 44 // Vars defines the list of SemiJoinVars that need to 45 // be built from the LHS result before invoking 46 // the RHS subqquery. 47 Vars map[string]int `json:",omitempty"` 48 } 49 50 // TryExecute performs a non-streaming exec. 51 func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 52 joinVars := make(map[string]*querypb.BindVariable) 53 lresult, err := vcursor.ExecutePrimitive(ctx, jn.Left, bindVars, wantfields) 54 if err != nil { 55 return nil, err 56 } 57 result := &sqltypes.Result{Fields: projectFields(lresult.Fields, jn.Cols)} 58 for _, lrow := range lresult.Rows { 59 for k, col := range jn.Vars { 60 joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) 61 } 62 rresult, err := vcursor.ExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false) 63 if err != nil { 64 return nil, err 65 } 66 if len(rresult.Rows) > 0 { 67 result.Rows = append(result.Rows, projectRows(lrow, jn.Cols)) 68 } 69 } 70 return result, nil 71 } 72 73 // TryStreamExecute performs a streaming exec. 74 func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 75 joinVars := make(map[string]*querypb.BindVariable) 76 err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error { 77 result := &sqltypes.Result{Fields: projectFields(lresult.Fields, jn.Cols)} 78 for _, lrow := range lresult.Rows { 79 for k, col := range jn.Vars { 80 joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) 81 } 82 rowAdded := false 83 err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error { 84 if len(rresult.Rows) > 0 && !rowAdded { 85 result.Rows = append(result.Rows, projectRows(lrow, jn.Cols)) 86 rowAdded = true 87 } 88 return nil 89 }) 90 if err != nil { 91 return err 92 } 93 } 94 return callback(result) 95 }) 96 return err 97 } 98 99 // GetFields fetches the field info. 100 func (jn *SemiJoin) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 101 return jn.Left.GetFields(ctx, vcursor, bindVars) 102 } 103 104 // Inputs returns the input primitives for this SemiJoin 105 func (jn *SemiJoin) Inputs() []Primitive { 106 return []Primitive{jn.Left, jn.Right} 107 } 108 109 // RouteType returns a description of the query routing type used by the primitive 110 func (jn *SemiJoin) RouteType() string { 111 return "SemiJoin" 112 } 113 114 // GetKeyspaceName specifies the Keyspace that this primitive routes to. 115 func (jn *SemiJoin) GetKeyspaceName() string { 116 if jn.Left.GetKeyspaceName() == jn.Right.GetKeyspaceName() { 117 return jn.Left.GetKeyspaceName() 118 } 119 return jn.Left.GetKeyspaceName() + "_" + jn.Right.GetKeyspaceName() 120 } 121 122 // GetTableName specifies the table that this primitive routes to. 123 func (jn *SemiJoin) GetTableName() string { 124 return jn.Left.GetTableName() + "_" + jn.Right.GetTableName() 125 } 126 127 // NeedsTransaction implements the Primitive interface 128 func (jn *SemiJoin) NeedsTransaction() bool { 129 return jn.Right.NeedsTransaction() || jn.Left.NeedsTransaction() 130 } 131 132 func (jn *SemiJoin) description() PrimitiveDescription { 133 other := map[string]any{ 134 "TableName": jn.GetTableName(), 135 "ProjectedIndexes": strings.Trim(strings.Join(strings.Fields(fmt.Sprint(jn.Cols)), ","), "[]"), 136 } 137 if len(jn.Vars) > 0 { 138 other["JoinVars"] = orderedStringIntMap(jn.Vars) 139 } 140 return PrimitiveDescription{ 141 OperatorType: "SemiJoin", 142 Other: other, 143 } 144 } 145 146 func projectFields(lfields []*querypb.Field, cols []int) []*querypb.Field { 147 if lfields == nil { 148 return nil 149 } 150 fields := make([]*querypb.Field, len(cols)) 151 for i, index := range cols { 152 fields[i] = lfields[-index-1] 153 } 154 return fields 155 } 156 157 func projectRows(lrow []sqltypes.Value, cols []int) []sqltypes.Value { 158 row := make([]sqltypes.Value, len(cols)) 159 for i, index := range cols { 160 if index < 0 { 161 row[i] = lrow[-index-1] 162 } 163 } 164 return row 165 }