vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/fake_primitive_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package engine 18 19 import ( 20 "context" 21 "fmt" 22 "reflect" 23 "strings" 24 "testing" 25 26 "vitess.io/vitess/go/sqltypes" 27 28 querypb "vitess.io/vitess/go/vt/proto/query" 29 ) 30 31 // fakePrimitive fakes a primitive. For every call, it sends the 32 // next result from the results. If the next result is nil, it 33 // returns sendErr. For streaming calls, it sends the field info 34 // first and two rows at a time till all rows are sent. 35 type fakePrimitive struct { 36 results []*sqltypes.Result 37 curResult int 38 // sendErr is sent at the end of the stream if it's set. 39 sendErr error 40 41 log []string 42 43 allResultsInOneCall bool 44 } 45 46 func (f *fakePrimitive) Inputs() []Primitive { 47 return []Primitive{} 48 } 49 50 var _ Primitive = (*fakePrimitive)(nil) 51 52 func (f *fakePrimitive) rewind() { 53 f.curResult = 0 54 f.log = nil 55 } 56 57 func (f *fakePrimitive) RouteType() string { 58 return "Fake" 59 } 60 61 func (f *fakePrimitive) GetKeyspaceName() string { 62 return "fakeKs" 63 } 64 65 func (f *fakePrimitive) GetTableName() string { 66 return "fakeTable" 67 } 68 69 func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 70 f.log = append(f.log, fmt.Sprintf("Execute %v %v", printBindVars(bindVars), wantfields)) 71 if f.results == nil { 72 return nil, f.sendErr 73 } 74 75 r := f.results[f.curResult] 76 f.curResult++ 77 if r == nil { 78 return nil, f.sendErr 79 } 80 return r, nil 81 } 82 83 func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { 84 f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) 85 if f.results == nil { 86 return f.sendErr 87 } 88 89 readMoreResults := true 90 for readMoreResults && f.curResult < len(f.results) { 91 readMoreResults = f.allResultsInOneCall 92 r := f.results[f.curResult] 93 f.curResult++ 94 if r == nil { 95 return f.sendErr 96 } 97 if wantfields { 98 if err := callback(&sqltypes.Result{Fields: r.Fields}); err != nil { 99 return err 100 } 101 } 102 result := &sqltypes.Result{} 103 for i := 0; i < len(r.Rows); i++ { 104 result.Rows = append(result.Rows, r.Rows[i]) 105 // Send only two rows at a time. 106 if i%2 == 1 { 107 if err := callback(result); err != nil { 108 return err 109 } 110 result = &sqltypes.Result{} 111 } 112 } 113 if len(result.Rows) != 0 { 114 if err := callback(result); err != nil { 115 return err 116 } 117 } 118 } 119 120 return nil 121 } 122 func (f *fakePrimitive) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { 123 f.log = append(f.log, fmt.Sprintf("GetFields %v", printBindVars(bindVars))) 124 return f.TryExecute(ctx, vcursor, bindVars, true /* wantfields */) 125 } 126 127 func (f *fakePrimitive) ExpectLog(t *testing.T, want []string) { 128 t.Helper() 129 if !reflect.DeepEqual(f.log, want) { 130 t.Errorf("vc.log got:\n%v\nwant:\n%v", strings.Join(f.log, "\n"), strings.Join(want, "\n")) 131 } 132 } 133 134 func (f *fakePrimitive) NeedsTransaction() bool { 135 return false 136 } 137 138 func wrapStreamExecute(prim Primitive, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { 139 var result *sqltypes.Result 140 err := prim.TryStreamExecute(context.Background(), vcursor, bindVars, wantfields, func(r *sqltypes.Result) error { 141 if result == nil { 142 result = r 143 } else { 144 result.Rows = append(result.Rows, r.Rows...) 145 } 146 return nil 147 }) 148 return result, err 149 } 150 151 func (f *fakePrimitive) description() PrimitiveDescription { 152 return PrimitiveDescription{OperatorType: "fake"} 153 }