vitess.io/vitess@v0.16.2/go/vt/binlog/binlogplayer/mock_dbclient.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package binlogplayer
    18  
    19  import (
    20  	"regexp"
    21  	"strings"
    22  	"testing"
    23  	"time"
    24  
    25  	"vitess.io/vitess/go/sqltypes"
    26  )
    27  
    28  const mockClientUNameFiltered = "Filtered"
    29  const mockClientUNameDba = "Dba"
    30  
    31  // MockDBClient mocks a DBClient.
    32  // It must be configured to expect requests in a specific order.
    33  type MockDBClient struct {
    34  	t             *testing.T
    35  	UName         string
    36  	expect        []*mockExpect
    37  	currentResult int
    38  	done          chan struct{}
    39  	invariants    map[string]*sqltypes.Result
    40  }
    41  
    42  type mockExpect struct {
    43  	query  string
    44  	re     *regexp.Regexp
    45  	result *sqltypes.Result
    46  	err    error
    47  }
    48  
    49  // NewMockDBClient returns a new DBClientMock with the default "Filtered" UName.
    50  func NewMockDBClient(t *testing.T) *MockDBClient {
    51  	return &MockDBClient{
    52  		t:     t,
    53  		UName: mockClientUNameFiltered,
    54  		done:  make(chan struct{}),
    55  		invariants: map[string]*sqltypes.Result{
    56  			"CREATE TABLE IF NOT EXISTS _vt.vreplication_log":           {},
    57  			"select id, type, state, message from _vt.vreplication_log": {},
    58  			"insert into _vt.vreplication_log":                          {},
    59  		},
    60  	}
    61  }
    62  
    63  // NewMockDbaClient returns a new DBClientMock with the default "Dba" UName.
    64  func NewMockDbaClient(t *testing.T) *MockDBClient {
    65  	return &MockDBClient{
    66  		t:     t,
    67  		UName: mockClientUNameDba,
    68  		done:  make(chan struct{}),
    69  	}
    70  }
    71  
    72  // ExpectRequest adds an expected result to the mock.
    73  // This function should not be called conncurrently with other commands.
    74  func (dc *MockDBClient) ExpectRequest(query string, result *sqltypes.Result, err error) {
    75  	select {
    76  	case <-dc.done:
    77  		dc.done = make(chan struct{})
    78  	default:
    79  	}
    80  	dc.expect = append(dc.expect, &mockExpect{
    81  		query:  query,
    82  		result: result,
    83  		err:    err,
    84  	})
    85  }
    86  
    87  // ExpectRequestRE adds an expected result to the mock.
    88  // queryRE is a regular expression.
    89  // This function should not be called conncurrently with other commands.
    90  func (dc *MockDBClient) ExpectRequestRE(queryRE string, result *sqltypes.Result, err error) {
    91  	select {
    92  	case <-dc.done:
    93  		dc.done = make(chan struct{})
    94  	default:
    95  	}
    96  	dc.expect = append(dc.expect, &mockExpect{
    97  		query:  queryRE,
    98  		re:     regexp.MustCompile(queryRE),
    99  		result: result,
   100  		err:    err,
   101  	})
   102  }
   103  
   104  // Wait waits for all expected requests to be executed.
   105  // dc.t.Fatalf is executed on 1 second timeout. Wait should
   106  // not be called concurrently with ExpectRequest.
   107  func (dc *MockDBClient) Wait() {
   108  	dc.t.Helper()
   109  	select {
   110  	case <-dc.done:
   111  		return
   112  	case <-time.After(5 * time.Second):
   113  		dc.t.Fatalf("timeout waiting for requests, want: %v", dc.expect[dc.currentResult].query)
   114  	}
   115  }
   116  
   117  // DBName is part of the DBClient interface
   118  func (dc *MockDBClient) DBName() string {
   119  	return "db"
   120  }
   121  
   122  // Connect is part of the DBClient interface
   123  func (dc *MockDBClient) Connect() error {
   124  	return nil
   125  }
   126  
   127  // Begin is part of the DBClient interface
   128  func (dc *MockDBClient) Begin() error {
   129  	_, err := dc.ExecuteFetch("begin", 1)
   130  	return err
   131  }
   132  
   133  // Commit is part of the DBClient interface
   134  func (dc *MockDBClient) Commit() error {
   135  	_, err := dc.ExecuteFetch("commit", 1)
   136  	return err
   137  }
   138  
   139  // Rollback is part of the DBClient interface
   140  func (dc *MockDBClient) Rollback() error {
   141  	_, err := dc.ExecuteFetch("rollback", 1)
   142  	return err
   143  }
   144  
   145  // Close is part of the DBClient interface
   146  func (dc *MockDBClient) Close() {
   147  }
   148  
   149  // ExecuteFetch is part of the DBClient interface
   150  func (dc *MockDBClient) ExecuteFetch(query string, maxrows int) (qr *sqltypes.Result, err error) {
   151  	dc.t.Helper()
   152  	dc.t.Logf("DBClient query: %v", query)
   153  
   154  	for q, result := range dc.invariants {
   155  		if strings.Contains(query, q) {
   156  			return result, nil
   157  		}
   158  	}
   159  
   160  	if dc.currentResult >= len(dc.expect) {
   161  		dc.t.Fatalf("DBClientMock: query: %s, no more requests are expected", query)
   162  	}
   163  	result := dc.expect[dc.currentResult]
   164  	if result.re == nil {
   165  		if query != result.query {
   166  			dc.t.Fatalf("DBClientMock: query: %s, want %s", query, result.query)
   167  		}
   168  	} else {
   169  		if !result.re.MatchString(query) {
   170  			dc.t.Fatalf("DBClientMock: query: %s, must match %s", query, result.query)
   171  		}
   172  	}
   173  	dc.currentResult++
   174  	if dc.currentResult >= len(dc.expect) {
   175  		close(dc.done)
   176  	}
   177  	return result.result, result.err
   178  }