vitess.io/vitess@v0.16.2/go/vt/wrangler/fake_dbclient_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package wrangler
    18  
    19  import (
    20  	"fmt"
    21  	"regexp"
    22  	"strings"
    23  	"sync"
    24  	"testing"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  
    28  	"vitess.io/vitess/go/vt/log"
    29  
    30  	"vitess.io/vitess/go/sqltypes"
    31  )
    32  
    33  func verifyQueries(t *testing.T, dcs []*fakeDBClient) {
    34  	t.Helper()
    35  	for _, dc := range dcs {
    36  		dc.verifyQueries(t)
    37  	}
    38  }
    39  
    40  type dbResults struct {
    41  	index   int
    42  	results []*dbResult
    43  	err     error
    44  }
    45  
    46  type dbResult struct {
    47  	result *sqltypes.Result
    48  	err    error
    49  }
    50  
    51  func (dbrs *dbResults) next(query string) (*sqltypes.Result, error) {
    52  	if dbrs.exhausted() {
    53  		log.Infof(fmt.Sprintf("Unexpected query >%s<", query))
    54  		return nil, fmt.Errorf("code executed this query, but the test did not expect it: %s", query)
    55  	}
    56  	i := dbrs.index
    57  	dbrs.index++
    58  	return dbrs.results[i].result, dbrs.results[i].err
    59  }
    60  
    61  func (dbrs *dbResults) exhausted() bool {
    62  	return dbrs.index == len(dbrs.results)
    63  }
    64  
    65  // fakeDBClient fakes a binlog_player.DBClient.
    66  type fakeDBClient struct {
    67  	mu         sync.Mutex
    68  	name       string
    69  	queries    map[string]*dbResults
    70  	queriesRE  map[string]*dbResults
    71  	invariants map[string]*sqltypes.Result
    72  }
    73  
    74  // NewfakeDBClient returns a new DBClientMock.
    75  func newFakeDBClient(name string) *fakeDBClient {
    76  	return &fakeDBClient{
    77  		name:      name,
    78  		queries:   make(map[string]*dbResults),
    79  		queriesRE: make(map[string]*dbResults),
    80  		invariants: map[string]*sqltypes.Result{
    81  			"use _vt": {},
    82  			"select * from _vt.vreplication where db_name='db'":         {},
    83  			"select id, type, state, message from _vt.vreplication_log": {},
    84  			"insert into _vt.vreplication_log":                          {},
    85  			"SELECT db_name FROM _vt.vreplication LIMIT 0":              {},
    86  		},
    87  	}
    88  }
    89  
    90  func (dc *fakeDBClient) addQuery(query string, result *sqltypes.Result, err error) {
    91  	dc.mu.Lock()
    92  	defer dc.mu.Unlock()
    93  	if testMode == "debug" {
    94  		log.Infof("%s::addQuery %s\n\n", dc.id(), query)
    95  	}
    96  	dbr := &dbResult{result: result, err: err}
    97  	if dbrs, ok := dc.queries[query]; ok {
    98  		dbrs.results = append(dbrs.results, dbr)
    99  		return
   100  	}
   101  	dc.queries[query] = &dbResults{results: []*dbResult{dbr}, err: err}
   102  }
   103  
   104  func (dc *fakeDBClient) addQueryRE(query string, result *sqltypes.Result, err error) {
   105  	dc.mu.Lock()
   106  	defer dc.mu.Unlock()
   107  	if testMode == "debug" {
   108  		log.Infof("%s::addQueryRE %s\n\n", dc.id(), query)
   109  	}
   110  	dbr := &dbResult{result: result, err: err}
   111  	if dbrs, ok := dc.queriesRE[query]; ok {
   112  		dbrs.results = append(dbrs.results, dbr)
   113  		return
   114  	}
   115  	dc.queriesRE[query] = &dbResults{results: []*dbResult{dbr}, err: err}
   116  }
   117  
   118  func (dc *fakeDBClient) getInvariant(query string) *sqltypes.Result {
   119  	dc.mu.Lock()
   120  	defer dc.mu.Unlock()
   121  	return dc.invariants[query]
   122  }
   123  
   124  // note: addInvariant will replace a previous result for a query with the provided one: this is used in the tests
   125  func (dc *fakeDBClient) addInvariant(query string, result *sqltypes.Result) {
   126  	dc.mu.Lock()
   127  	defer dc.mu.Unlock()
   128  	if testMode == "debug" {
   129  		log.Infof("%s::addInvariant %s\n\n", dc.id(), query)
   130  	}
   131  	dc.invariants[query] = result
   132  }
   133  
   134  // DBName is part of the DBClient interface
   135  func (dc *fakeDBClient) DBName() string {
   136  	return "db"
   137  }
   138  
   139  // Connect is part of the DBClient interface
   140  func (dc *fakeDBClient) Connect() error {
   141  	return nil
   142  }
   143  
   144  // Begin is part of the DBClient interface
   145  func (dc *fakeDBClient) Begin() error {
   146  	return nil
   147  }
   148  
   149  // Commit is part of the DBClient interface
   150  func (dc *fakeDBClient) Commit() error {
   151  	return nil
   152  }
   153  
   154  // Rollback is part of the DBClient interface
   155  func (dc *fakeDBClient) Rollback() error {
   156  	return nil
   157  }
   158  
   159  // Close is part of the DBClient interface
   160  func (dc *fakeDBClient) Close() {
   161  }
   162  
   163  func (dc *fakeDBClient) id() string {
   164  	return fmt.Sprintf("FakeDBClient(%s)", dc.name)
   165  }
   166  
   167  // ExecuteFetch is part of the DBClient interface
   168  func (dc *fakeDBClient) ExecuteFetch(query string, maxrows int) (*sqltypes.Result, error) {
   169  	dc.mu.Lock()
   170  	defer dc.mu.Unlock()
   171  	qr, err := dc.executeFetch(query, maxrows)
   172  	if testMode == "debug" {
   173  		log.Infof("%s::ExecuteFetch for >>>%s<<< returns >>>%v<<< error >>>%+v<<< ", dc.id(), query, qr, err)
   174  	}
   175  	return qr, err
   176  }
   177  
   178  // ExecuteFetch is part of the DBClient interface
   179  func (dc *fakeDBClient) executeFetch(query string, maxrows int) (*sqltypes.Result, error) {
   180  	if dbrs := dc.queries[query]; dbrs != nil {
   181  		return dbrs.next(query)
   182  	}
   183  	for re, dbrs := range dc.queriesRE {
   184  		if regexp.MustCompile(re).MatchString(query) {
   185  			return dbrs.next(query)
   186  		}
   187  	}
   188  	if result := dc.invariants[query]; result != nil {
   189  		return result, nil
   190  	}
   191  	for q, result := range dc.invariants { //supports allowing just a prefix of an expected query
   192  		if strings.Contains(query, q) {
   193  			return result, nil
   194  		}
   195  	}
   196  
   197  	log.Infof("Missing query: >>>>>>>>>>>>>>>>>>%s<<<<<<<<<<<<<<<", query)
   198  	return nil, fmt.Errorf("unexpected query: %s", query)
   199  }
   200  
   201  func (dc *fakeDBClient) verifyQueries(t *testing.T) {
   202  	dc.mu.Lock()
   203  	defer dc.mu.Unlock()
   204  	t.Helper()
   205  	for query, dbrs := range dc.queries {
   206  		if !dbrs.exhausted() {
   207  			assert.FailNowf(t, "expected query did not get executed during the test", query)
   208  		}
   209  	}
   210  	for query, dbrs := range dc.queriesRE {
   211  		if !dbrs.exhausted() {
   212  			assert.FailNowf(t, "expected regex query did not get executed during the test", query)
   213  		}
   214  	}
   215  }