vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/rename_fields.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "context" 21 22 "vitess.io/vitess/go/sqltypes" 23 querypb "vitess.io/vitess/go/vt/proto/query" 24 "vitess.io/vitess/go/vt/vterrors" 25 ) 26 27 var _ Primitive = (*RenameFields)(nil) 28 29 // RenameFields is a primitive that renames the fields 30 type RenameFields struct { 31 Cols []string 32 Indices []int 33 Input Primitive 34 noTxNeeded 35 } 36 37 // NewRenameField creates a new rename field 38 func NewRenameField(cols []string, indices []int, input Primitive) (*RenameFields, error) { 39 if len(cols) != len(indices) { 40 return nil, vterrors.VT13001("number of columns does not match number of indices in RenameField primitive") 41 } 42 return &RenameFields{ 43 Cols: cols, 44 Indices: indices, 45 Input: input, 46 }, nil 47 } 48 49 // RouteType implements the primitive interface 50 func (r *RenameFields) RouteType() string { 51 return r.Input.RouteType() 52 } 53 54 // GetKeyspaceName implements the primitive interface 55 func (r *RenameFields) GetKeyspaceName() string { 56 return r.Input.GetKeyspaceName() 57 } 58 59 // GetTableName implements the primitive interface 60 func (r *RenameFields) GetTableName() string { 61 return r.Input.GetTableName() 62 } 63 64 // TryExecute implements the Primitive interface 65 func (r *RenameFields) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 66 qr, err := vcursor.ExecutePrimitive(ctx, r.Input, bindVars, wantfields) 67 if err != nil { 68 return nil, err 69 } 70 if wantfields { 71 r.renameFields(qr) 72 } 73 return qr, nil 74 } 75 76 func (r *RenameFields) renameFields(qr *sqltypes.Result) { 77 for ind, index := range r.Indices { 78 if index >= len(qr.Fields) { 79 continue 80 } 81 colName := r.Cols[ind] 82 qr.Fields[index].Name = colName 83 } 84 } 85 86 // TryStreamExecute implements the Primitive interface 87 func (r *RenameFields) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 88 if wantfields { 89 innerCallback := callback 90 callback = func(result *sqltypes.Result) error { 91 // Only the first callback will contain the fields. 92 // This check is to avoid going over the RenameFields indices when no fields are present in the result set. 93 if len(result.Fields) != 0 { 94 r.renameFields(result) 95 } 96 return innerCallback(result) 97 } 98 } 99 return vcursor.StreamExecutePrimitive(ctx, r.Input, bindVars, wantfields, callback) 100 } 101 102 // GetFields implements the primitive interface 103 func (r *RenameFields) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 104 qr, err := r.Input.GetFields(ctx, vcursor, bindVars) 105 if err != nil { 106 return nil, err 107 } 108 r.renameFields(qr) 109 return qr, nil 110 } 111 112 // Inputs implements the primitive interface 113 func (r *RenameFields) Inputs() []Primitive { 114 return []Primitive{r.Input} 115 } 116 117 // description implements the primitive interface 118 func (r *RenameFields) description() PrimitiveDescription { 119 return PrimitiveDescription{ 120 OperatorType: "RenameFields", 121 Other: map[string]any{ 122 "Indices": r.Indices, 123 "Columns": r.Cols, 124 }, 125 } 126 }