vitess.io/vitess@v0.16.2/go/vt/wrangler/fake_dbclient_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package wrangler 18 19 import ( 20 "fmt" 21 "regexp" 22 "strings" 23 "sync" 24 "testing" 25 26 "github.com/stretchr/testify/assert" 27 28 "vitess.io/vitess/go/vt/log" 29 30 "vitess.io/vitess/go/sqltypes" 31 ) 32 33 func verifyQueries(t *testing.T, dcs []*fakeDBClient) { 34 t.Helper() 35 for _, dc := range dcs { 36 dc.verifyQueries(t) 37 } 38 } 39 40 type dbResults struct { 41 index int 42 results []*dbResult 43 err error 44 } 45 46 type dbResult struct { 47 result *sqltypes.Result 48 err error 49 } 50 51 func (dbrs *dbResults) next(query string) (*sqltypes.Result, error) { 52 if dbrs.exhausted() { 53 log.Infof(fmt.Sprintf("Unexpected query >%s<", query)) 54 return nil, fmt.Errorf("code executed this query, but the test did not expect it: %s", query) 55 } 56 i := dbrs.index 57 dbrs.index++ 58 return dbrs.results[i].result, dbrs.results[i].err 59 } 60 61 func (dbrs *dbResults) exhausted() bool { 62 return dbrs.index == len(dbrs.results) 63 } 64 65 // fakeDBClient fakes a binlog_player.DBClient. 66 type fakeDBClient struct { 67 mu sync.Mutex 68 name string 69 queries map[string]*dbResults 70 queriesRE map[string]*dbResults 71 invariants map[string]*sqltypes.Result 72 } 73 74 // NewfakeDBClient returns a new DBClientMock. 75 func newFakeDBClient(name string) *fakeDBClient { 76 return &fakeDBClient{ 77 name: name, 78 queries: make(map[string]*dbResults), 79 queriesRE: make(map[string]*dbResults), 80 invariants: map[string]*sqltypes.Result{ 81 "use _vt": {}, 82 "select * from _vt.vreplication where db_name='db'": {}, 83 "select id, type, state, message from _vt.vreplication_log": {}, 84 "insert into _vt.vreplication_log": {}, 85 "SELECT db_name FROM _vt.vreplication LIMIT 0": {}, 86 }, 87 } 88 } 89 90 func (dc *fakeDBClient) addQuery(query string, result *sqltypes.Result, err error) { 91 dc.mu.Lock() 92 defer dc.mu.Unlock() 93 if testMode == "debug" { 94 log.Infof("%s::addQuery %s\n\n", dc.id(), query) 95 } 96 dbr := &dbResult{result: result, err: err} 97 if dbrs, ok := dc.queries[query]; ok { 98 dbrs.results = append(dbrs.results, dbr) 99 return 100 } 101 dc.queries[query] = &dbResults{results: []*dbResult{dbr}, err: err} 102 } 103 104 func (dc *fakeDBClient) addQueryRE(query string, result *sqltypes.Result, err error) { 105 dc.mu.Lock() 106 defer dc.mu.Unlock() 107 if testMode == "debug" { 108 log.Infof("%s::addQueryRE %s\n\n", dc.id(), query) 109 } 110 dbr := &dbResult{result: result, err: err} 111 if dbrs, ok := dc.queriesRE[query]; ok { 112 dbrs.results = append(dbrs.results, dbr) 113 return 114 } 115 dc.queriesRE[query] = &dbResults{results: []*dbResult{dbr}, err: err} 116 } 117 118 func (dc *fakeDBClient) getInvariant(query string) *sqltypes.Result { 119 dc.mu.Lock() 120 defer dc.mu.Unlock() 121 return dc.invariants[query] 122 } 123 124 // note: addInvariant will replace a previous result for a query with the provided one: this is used in the tests 125 func (dc *fakeDBClient) addInvariant(query string, result *sqltypes.Result) { 126 dc.mu.Lock() 127 defer dc.mu.Unlock() 128 if testMode == "debug" { 129 log.Infof("%s::addInvariant %s\n\n", dc.id(), query) 130 } 131 dc.invariants[query] = result 132 } 133 134 // DBName is part of the DBClient interface 135 func (dc *fakeDBClient) DBName() string { 136 return "db" 137 } 138 139 // Connect is part of the DBClient interface 140 func (dc *fakeDBClient) Connect() error { 141 return nil 142 } 143 144 // Begin is part of the DBClient interface 145 func (dc *fakeDBClient) Begin() error { 146 return nil 147 } 148 149 // Commit is part of the DBClient interface 150 func (dc *fakeDBClient) Commit() error { 151 return nil 152 } 153 154 // Rollback is part of the DBClient interface 155 func (dc *fakeDBClient) Rollback() error { 156 return nil 157 } 158 159 // Close is part of the DBClient interface 160 func (dc *fakeDBClient) Close() { 161 } 162 163 func (dc *fakeDBClient) id() string { 164 return fmt.Sprintf("FakeDBClient(%s)", dc.name) 165 } 166 167 // ExecuteFetch is part of the DBClient interface 168 func (dc *fakeDBClient) ExecuteFetch(query string, maxrows int) (*sqltypes.Result, error) { 169 dc.mu.Lock() 170 defer dc.mu.Unlock() 171 qr, err := dc.executeFetch(query, maxrows) 172 if testMode == "debug" { 173 log.Infof("%s::ExecuteFetch for >>>%s<<< returns >>>%v<<< error >>>%+v<<< ", dc.id(), query, qr, err) 174 } 175 return qr, err 176 } 177 178 // ExecuteFetch is part of the DBClient interface 179 func (dc *fakeDBClient) executeFetch(query string, maxrows int) (*sqltypes.Result, error) { 180 if dbrs := dc.queries[query]; dbrs != nil { 181 return dbrs.next(query) 182 } 183 for re, dbrs := range dc.queriesRE { 184 if regexp.MustCompile(re).MatchString(query) { 185 return dbrs.next(query) 186 } 187 } 188 if result := dc.invariants[query]; result != nil { 189 return result, nil 190 } 191 for q, result := range dc.invariants { //supports allowing just a prefix of an expected query 192 if strings.Contains(query, q) { 193 return result, nil 194 } 195 } 196 197 log.Infof("Missing query: >>>>>>>>>>>>>>>>>>%s<<<<<<<<<<<<<<<", query) 198 return nil, fmt.Errorf("unexpected query: %s", query) 199 } 200 201 func (dc *fakeDBClient) verifyQueries(t *testing.T) { 202 dc.mu.Lock() 203 defer dc.mu.Unlock() 204 t.Helper() 205 for query, dbrs := range dc.queries { 206 if !dbrs.exhausted() { 207 assert.FailNowf(t, "expected query did not get executed during the test", query) 208 } 209 } 210 for query, dbrs := range dc.queriesRE { 211 if !dbrs.exhausted() { 212 assert.FailNowf(t, "expected regex query did not get executed during the test", query) 213 } 214 } 215 }