vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/join.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "context" 21 "fmt" 22 "strings" 23 24 "vitess.io/vitess/go/sqltypes" 25 querypb "vitess.io/vitess/go/vt/proto/query" 26 ) 27 28 var _ Primitive = (*Join)(nil) 29 30 // Join specifies the parameters for a join primitive. 31 type Join struct { 32 Opcode JoinOpcode 33 // Left and Right are the LHS and RHS primitives 34 // of the Join. They can be any primitive. 35 Left, Right Primitive `json:",omitempty"` 36 37 // Cols defines which columns from the left 38 // or right results should be used to build the 39 // return result. For results coming from the 40 // left query, the index values go as -1, -2, etc. 41 // For the right query, they're 1, 2, etc. 42 // If Cols is {-1, -2, 1, 2}, it means that 43 // the returned result will be {Left0, Left1, Right0, Right1}. 44 Cols []int `json:",omitempty"` 45 46 // Vars defines the list of joinVars that need to 47 // be built from the LHS result before invoking 48 // the RHS subqquery. 49 Vars map[string]int `json:",omitempty"` 50 } 51 52 // TryExecute performs a non-streaming exec. 53 func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 54 joinVars := make(map[string]*querypb.BindVariable) 55 lresult, err := vcursor.ExecutePrimitive(ctx, jn.Left, bindVars, wantfields) 56 if err != nil { 57 return nil, err 58 } 59 result := &sqltypes.Result{} 60 if len(lresult.Rows) == 0 && wantfields { 61 for k := range jn.Vars { 62 joinVars[k] = sqltypes.NullBindVariable 63 } 64 rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars)) 65 if err != nil { 66 return nil, err 67 } 68 result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols) 69 return result, nil 70 } 71 for _, lrow := range lresult.Rows { 72 for k, col := range jn.Vars { 73 joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) 74 } 75 rresult, err := vcursor.ExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), wantfields) 76 if err != nil { 77 return nil, err 78 } 79 if wantfields { 80 wantfields = false 81 result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols) 82 } 83 for _, rrow := range rresult.Rows { 84 result.Rows = append(result.Rows, joinRows(lrow, rrow, jn.Cols)) 85 } 86 if jn.Opcode == LeftJoin && len(rresult.Rows) == 0 { 87 result.Rows = append(result.Rows, joinRows(lrow, nil, jn.Cols)) 88 } 89 if vcursor.ExceedsMaxMemoryRows(len(result.Rows)) { 90 return nil, fmt.Errorf("in-memory row count exceeded allowed limit of %d", vcursor.MaxMemoryRows()) 91 } 92 } 93 return result, nil 94 } 95 96 // TryStreamExecute performs a streaming exec. 97 func (jn *Join) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 98 joinVars := make(map[string]*querypb.BindVariable) 99 err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error { 100 for _, lrow := range lresult.Rows { 101 for k, col := range jn.Vars { 102 joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) 103 } 104 rowSent := false 105 err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), wantfields, func(rresult *sqltypes.Result) error { 106 result := &sqltypes.Result{} 107 if wantfields { 108 // This code is currently unreachable because the first result 109 // will always be just the field info, which will cause the outer 110 // wantfields code path to be executed. But this may change in the future. 111 wantfields = false 112 result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols) 113 } 114 for _, rrow := range rresult.Rows { 115 result.Rows = append(result.Rows, joinRows(lrow, rrow, jn.Cols)) 116 } 117 if len(rresult.Rows) != 0 { 118 rowSent = true 119 } 120 return callback(result) 121 }) 122 if err != nil { 123 return err 124 } 125 if jn.Opcode == LeftJoin && !rowSent { 126 result := &sqltypes.Result{} 127 result.Rows = [][]sqltypes.Value{joinRows( 128 lrow, 129 nil, 130 jn.Cols, 131 )} 132 return callback(result) 133 } 134 } 135 if wantfields { 136 wantfields = false 137 for k := range jn.Vars { 138 joinVars[k] = sqltypes.NullBindVariable 139 } 140 result := &sqltypes.Result{} 141 rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars)) 142 if err != nil { 143 return err 144 } 145 result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols) 146 return callback(result) 147 } 148 return nil 149 }) 150 return err 151 } 152 153 // GetFields fetches the field info. 154 func (jn *Join) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 155 joinVars := make(map[string]*querypb.BindVariable) 156 lresult, err := jn.Left.GetFields(ctx, vcursor, bindVars) 157 if err != nil { 158 return nil, err 159 } 160 result := &sqltypes.Result{} 161 for k := range jn.Vars { 162 joinVars[k] = sqltypes.NullBindVariable 163 } 164 rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars)) 165 if err != nil { 166 return nil, err 167 } 168 result.Fields = joinFields(lresult.Fields, rresult.Fields, jn.Cols) 169 return result, nil 170 } 171 172 // Inputs returns the input primitives for this join 173 func (jn *Join) Inputs() []Primitive { 174 return []Primitive{jn.Left, jn.Right} 175 } 176 177 func joinFields(lfields, rfields []*querypb.Field, cols []int) []*querypb.Field { 178 fields := make([]*querypb.Field, len(cols)) 179 for i, index := range cols { 180 if index < 0 { 181 fields[i] = lfields[-index-1] 182 continue 183 } 184 fields[i] = rfields[index-1] 185 } 186 return fields 187 } 188 189 func joinRows(lrow, rrow []sqltypes.Value, cols []int) []sqltypes.Value { 190 row := make([]sqltypes.Value, len(cols)) 191 for i, index := range cols { 192 if index < 0 { 193 row[i] = lrow[-index-1] 194 continue 195 } 196 // rrow can be nil on left joins 197 if rrow != nil { 198 row[i] = rrow[index-1] 199 } 200 } 201 return row 202 } 203 204 // JoinOpcode is a number representing the opcode 205 // for the Join primitive. 206 type JoinOpcode int 207 208 // This is the list of JoinOpcode values. 209 const ( 210 InnerJoin = JoinOpcode(iota) 211 LeftJoin 212 ) 213 214 func (code JoinOpcode) String() string { 215 if code == InnerJoin { 216 return "Join" 217 } 218 return "LeftJoin" 219 } 220 221 // MarshalJSON serializes the JoinOpcode as a JSON string. 222 // It's used for testing and diagnostics. 223 func (code JoinOpcode) MarshalJSON() ([]byte, error) { 224 return ([]byte)(fmt.Sprintf("\"%s\"", code.String())), nil 225 } 226 227 // RouteType returns a description of the query routing type used by the primitive 228 func (jn *Join) RouteType() string { 229 return "Join" 230 } 231 232 // GetKeyspaceName specifies the Keyspace that this primitive routes to. 233 func (jn *Join) GetKeyspaceName() string { 234 if jn.Left.GetKeyspaceName() == jn.Right.GetKeyspaceName() { 235 return jn.Left.GetKeyspaceName() 236 } 237 return jn.Left.GetKeyspaceName() + "_" + jn.Right.GetKeyspaceName() 238 } 239 240 // GetTableName specifies the table that this primitive routes to. 241 func (jn *Join) GetTableName() string { 242 return jn.Left.GetTableName() + "_" + jn.Right.GetTableName() 243 } 244 245 // NeedsTransaction implements the Primitive interface 246 func (jn *Join) NeedsTransaction() bool { 247 return jn.Right.NeedsTransaction() || jn.Left.NeedsTransaction() 248 } 249 250 func combineVars(bv1, bv2 map[string]*querypb.BindVariable) map[string]*querypb.BindVariable { 251 out := make(map[string]*querypb.BindVariable) 252 for k, v := range bv1 { 253 out[k] = v 254 } 255 for k, v := range bv2 { 256 out[k] = v 257 } 258 return out 259 } 260 261 func (jn *Join) description() PrimitiveDescription { 262 other := map[string]any{ 263 "TableName": jn.GetTableName(), 264 "JoinColumnIndexes": jn.joinColsDescription(), 265 } 266 if len(jn.Vars) > 0 { 267 other["JoinVars"] = orderedStringIntMap(jn.Vars) 268 } 269 return PrimitiveDescription{ 270 OperatorType: "Join", 271 Variant: jn.Opcode.String(), 272 Other: other, 273 } 274 } 275 276 func (jn *Join) joinColsDescription() string { 277 var joinCols []string 278 for _, col := range jn.Cols { 279 if col < 0 { 280 joinCols = append(joinCols, fmt.Sprintf("L:%d", -col-1)) 281 } else { 282 joinCols = append(joinCols, fmt.Sprintf("R:%d", col-1)) 283 } 284 } 285 joinColTxt := strings.Join(joinCols, ",") 286 return joinColTxt 287 }