vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/pullout_subquery.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "context" 21 "fmt" 22 23 "vitess.io/vitess/go/sqltypes" 24 "vitess.io/vitess/go/vt/vterrors" 25 26 querypb "vitess.io/vitess/go/vt/proto/query" 27 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 28 ) 29 30 var _ Primitive = (*PulloutSubquery)(nil) 31 32 // PulloutSubquery executes a "pulled out" subquery and stores 33 // the results in a bind variable. 34 type PulloutSubquery struct { 35 Opcode PulloutOpcode 36 37 // SubqueryResult and HasValues are used to send in the bindvar used in the query to the underlying primitive 38 SubqueryResult string 39 HasValues string 40 41 Subquery Primitive 42 Underlying Primitive 43 } 44 45 // Inputs returns the input primitives for this join 46 func (ps *PulloutSubquery) Inputs() []Primitive { 47 return []Primitive{ps.Subquery, ps.Underlying} 48 } 49 50 // RouteType returns a description of the query routing type used by the primitive 51 func (ps *PulloutSubquery) RouteType() string { 52 return ps.Opcode.String() 53 } 54 55 // GetKeyspaceName specifies the Keyspace that this primitive routes to. 56 func (ps *PulloutSubquery) GetKeyspaceName() string { 57 return ps.Underlying.GetKeyspaceName() 58 } 59 60 // GetTableName specifies the table that this primitive routes to. 61 func (ps *PulloutSubquery) GetTableName() string { 62 return ps.Underlying.GetTableName() 63 } 64 65 // TryExecute satisfies the Primitive interface. 66 func (ps *PulloutSubquery) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 67 combinedVars, err := ps.execSubquery(ctx, vcursor, bindVars) 68 if err != nil { 69 return nil, err 70 } 71 return vcursor.ExecutePrimitive(ctx, ps.Underlying, combinedVars, wantfields) 72 } 73 74 // TryStreamExecute performs a streaming exec. 75 func (ps *PulloutSubquery) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 76 combinedVars, err := ps.execSubquery(ctx, vcursor, bindVars) 77 if err != nil { 78 return err 79 } 80 return vcursor.StreamExecutePrimitive(ctx, ps.Underlying, combinedVars, wantfields, callback) 81 } 82 83 // GetFields fetches the field info. 84 func (ps *PulloutSubquery) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 85 combinedVars := make(map[string]*querypb.BindVariable, len(bindVars)+1) 86 for k, v := range bindVars { 87 combinedVars[k] = v 88 } 89 switch ps.Opcode { 90 case PulloutValue: 91 combinedVars[ps.SubqueryResult] = sqltypes.NullBindVariable 92 case PulloutIn, PulloutNotIn: 93 combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(0) 94 combinedVars[ps.SubqueryResult] = &querypb.BindVariable{ 95 Type: querypb.Type_TUPLE, 96 Values: []*querypb.Value{sqltypes.ValueToProto(sqltypes.NewInt64(0))}, 97 } 98 case PulloutExists: 99 combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(0) 100 } 101 return ps.Underlying.GetFields(ctx, vcursor, combinedVars) 102 } 103 104 // NeedsTransaction implements the Primitive interface 105 func (ps *PulloutSubquery) NeedsTransaction() bool { 106 return ps.Subquery.NeedsTransaction() || ps.Underlying.NeedsTransaction() 107 } 108 109 var ( 110 errSqRow = vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "subquery returned more than one row") 111 errSqColumn = vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "subquery returned more than one column") 112 ) 113 114 func (ps *PulloutSubquery) execSubquery(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (map[string]*querypb.BindVariable, error) { 115 subqueryBindVars := make(map[string]*querypb.BindVariable, len(bindVars)) 116 for k, v := range bindVars { 117 subqueryBindVars[k] = v 118 } 119 result, err := vcursor.ExecutePrimitive(ctx, ps.Subquery, subqueryBindVars, false) 120 if err != nil { 121 return nil, err 122 } 123 combinedVars := make(map[string]*querypb.BindVariable, len(bindVars)+1) 124 for k, v := range bindVars { 125 combinedVars[k] = v 126 } 127 switch ps.Opcode { 128 case PulloutValue: 129 switch len(result.Rows) { 130 case 0: 131 combinedVars[ps.SubqueryResult] = sqltypes.NullBindVariable 132 case 1: 133 if len(result.Rows[0]) != 1 { 134 return nil, errSqColumn 135 } 136 combinedVars[ps.SubqueryResult] = sqltypes.ValueBindVariable(result.Rows[0][0]) 137 default: 138 return nil, errSqRow 139 } 140 case PulloutIn, PulloutNotIn: 141 switch len(result.Rows) { 142 case 0: 143 combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(0) 144 // Add a bogus value. It will not be checked. 145 combinedVars[ps.SubqueryResult] = &querypb.BindVariable{ 146 Type: querypb.Type_TUPLE, 147 Values: []*querypb.Value{sqltypes.ValueToProto(sqltypes.NewInt64(0))}, 148 } 149 default: 150 if len(result.Rows[0]) != 1 { 151 return nil, errSqColumn 152 } 153 combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(1) 154 values := &querypb.BindVariable{ 155 Type: querypb.Type_TUPLE, 156 Values: make([]*querypb.Value, len(result.Rows)), 157 } 158 for i, v := range result.Rows { 159 values.Values[i] = sqltypes.ValueToProto(v[0]) 160 } 161 combinedVars[ps.SubqueryResult] = values 162 } 163 case PulloutExists: 164 switch len(result.Rows) { 165 case 0: 166 combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(0) 167 default: 168 combinedVars[ps.HasValues] = sqltypes.Int64BindVariable(1) 169 } 170 } 171 return combinedVars, nil 172 } 173 174 func (ps *PulloutSubquery) description() PrimitiveDescription { 175 other := map[string]any{} 176 var pulloutVars []string 177 if ps.HasValues != "" { 178 pulloutVars = append(pulloutVars, ps.HasValues) 179 } 180 if ps.SubqueryResult != "" { 181 pulloutVars = append(pulloutVars, ps.SubqueryResult) 182 } 183 if len(pulloutVars) > 0 { 184 other["PulloutVars"] = pulloutVars 185 } 186 return PrimitiveDescription{ 187 OperatorType: "Subquery", 188 Variant: ps.Opcode.String(), 189 Other: other, 190 } 191 } 192 193 // PulloutOpcode is a number representing the opcode 194 // for the PulloutSubquery primitive. 195 type PulloutOpcode int 196 197 // This is the list of PulloutOpcode values. 198 const ( 199 PulloutValue = PulloutOpcode(iota) 200 PulloutIn 201 PulloutNotIn 202 PulloutExists 203 ) 204 205 var pulloutName = map[PulloutOpcode]string{ 206 PulloutValue: "PulloutValue", 207 PulloutIn: "PulloutIn", 208 PulloutNotIn: "PulloutNotIn", 209 PulloutExists: "PulloutExists", 210 } 211 212 func (code PulloutOpcode) String() string { 213 return pulloutName[code] 214 } 215 216 // MarshalJSON serializes the PulloutOpcode as a JSON string. 217 // It's used for testing and diagnostics. 218 func (code PulloutOpcode) MarshalJSON() ([]byte, error) { 219 return ([]byte)(fmt.Sprintf("\"%s\"", code.String())), nil 220 }