github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/masterslave/master_slave_db_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package masterslave
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"fmt"
    22  	"testing"
    23  
    24  	"github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves/roundrobin"
    25  
    26  	"github.com/ecodeclub/eorm/internal/datasource"
    27  	"github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves"
    28  
    29  	"github.com/stretchr/testify/require"
    30  
    31  	_ "github.com/mattn/go-sqlite3"
    32  	"github.com/stretchr/testify/suite"
    33  
    34  	"github.com/DATA-DOG/go-sqlmock"
    35  	"github.com/stretchr/testify/assert"
    36  )
    37  
    38  func ExampleMasterSlavesDB_Close() {
    39  	masterDB, _ := sql.Open("sqlite3", "file:test.db?cache=shared&mode=memory")
    40  	slaveDB1, _ := sql.Open("sqlite3", "file:test.db?cache=shared&mode=memory")
    41  	slaveDB2, _ := sql.Open("sqlite3", "file:test.db?cache=shared&mode=memory")
    42  	sl, _ := roundrobin.NewSlaves(slaveDB1, slaveDB2)
    43  	ms := NewMasterSlavesDB(masterDB, MasterSlavesWithSlaves(sl))
    44  	err := ms.Close()
    45  	if err == nil {
    46  		fmt.Println("close")
    47  	}
    48  
    49  	// Output:
    50  	// close
    51  }
    52  
    53  func TestMasterSlavesDB_BeginTx(t *testing.T) {
    54  	mockDB, mock, err := sqlmock.New()
    55  	if err != nil {
    56  		t.Fatal(err)
    57  	}
    58  	defer func() { _ = mockDB.Close() }()
    59  
    60  	db := NewMasterSlavesDB(mockDB)
    61  
    62  	// Begin 失败
    63  	mock.ExpectBegin().WillReturnError(errors.New("begin failed"))
    64  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
    65  	assert.Equal(t, errors.New("begin failed"), err)
    66  	assert.Nil(t, tx)
    67  
    68  	mock.ExpectBegin()
    69  	tx, err = db.BeginTx(context.Background(), &sql.TxOptions{})
    70  	assert.Nil(t, err)
    71  	assert.NotNil(t, tx)
    72  }
    73  
    74  func ExampleMasterSlavesDB_BeginTx() {
    75  	sqlite3db, _ := sql.Open("sqlite3", "file:test.db?cache=shared&mode=memory")
    76  	db := NewMasterSlavesDB(sqlite3db)
    77  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
    78  	if err == nil {
    79  		fmt.Println("Begin")
    80  	}
    81  	err = tx.Commit()
    82  	if err == nil {
    83  		fmt.Println("Commit")
    84  	}
    85  	// Output:
    86  	// Begin
    87  	// Commit
    88  }
    89  
    90  type MasterSlaveSuite struct {
    91  	suite.Suite
    92  	mockMasterDB *sql.DB
    93  	mockMaster   sqlmock.Sqlmock
    94  	mockSlave1DB *sql.DB
    95  	mockSlave1   sqlmock.Sqlmock
    96  	mockSlave2DB *sql.DB
    97  	mockSlave2   sqlmock.Sqlmock
    98  	mockSlave3DB *sql.DB
    99  	mockSlave3   sqlmock.Sqlmock
   100  }
   101  
   102  func (ms *MasterSlaveSuite) SetupTest() {
   103  	t := ms.T()
   104  	ms.initMock(t)
   105  }
   106  
   107  func (ms *MasterSlaveSuite) TearDownTest() {
   108  	_ = ms.mockMasterDB.Close()
   109  	_ = ms.mockSlave1DB.Close()
   110  	_ = ms.mockSlave2DB.Close()
   111  	_ = ms.mockSlave3DB.Close()
   112  }
   113  
   114  func (ms *MasterSlaveSuite) initMock(t *testing.T) {
   115  	var err error
   116  	ms.mockMasterDB, ms.mockMaster, err = sqlmock.New()
   117  	if err != nil {
   118  		t.Fatal(err)
   119  	}
   120  	ms.mockSlave1DB, ms.mockSlave1, err = sqlmock.New()
   121  	if err != nil {
   122  		t.Fatal(err)
   123  	}
   124  	ms.mockSlave2DB, ms.mockSlave2, err = sqlmock.New()
   125  	if err != nil {
   126  		t.Fatal(err)
   127  	}
   128  	ms.mockSlave3DB, ms.mockSlave3, err = sqlmock.New()
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  }
   133  
   134  func (ms *MasterSlaveSuite) TestMasterSlaveDbQuery() {
   135  	// 通过select不同的数据表示访问不同的db
   136  	ms.mockMaster.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("master"))
   137  	ms.mockSlave1.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("slave1_1"))
   138  	ms.mockSlave2.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("slave1_2"))
   139  	ms.mockSlave3.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("slave1_3"))
   140  
   141  	testCasesQuery := []struct {
   142  		name     string
   143  		ctx      context.Context
   144  		query    datasource.Query
   145  		reqCnt   int
   146  		slaves   slaves.Slaves
   147  		wantResp []string
   148  		wantErr  error
   149  	}{
   150  		{
   151  			name:   "select default use slave",
   152  			ctx:    context.Background(),
   153  			reqCnt: 3,
   154  			query: datasource.Query{
   155  				SQL: "SELECT `first_name` FROM `test_model`",
   156  			},
   157  			slaves:   ms.newSlaves(ms.mockSlave1DB, ms.mockSlave2DB, ms.mockSlave3DB),
   158  			wantResp: []string{"slave1_1", "slave1_2", "slave1_3"},
   159  		},
   160  		{
   161  			name:   "use master",
   162  			reqCnt: 1,
   163  			ctx:    UseMaster(context.Background()),
   164  			query: datasource.Query{
   165  				SQL: "SELECT `first_name` FROM `test_model`",
   166  			},
   167  			slaves:   ms.newSlaves(ms.mockSlave1DB, ms.mockSlave2DB, ms.mockSlave3DB),
   168  			wantResp: []string{"master"},
   169  		},
   170  	}
   171  
   172  	for _, tc := range testCasesQuery {
   173  		ms.T().Run(tc.name, func(t *testing.T) {
   174  			db := NewMasterSlavesDB(ms.mockMasterDB, MasterSlavesWithSlaves(tc.slaves))
   175  			//  TODO
   176  			//db, ok := source.(*masterSlavesDB)
   177  			//assert.True(t, ok)
   178  			var resp []string
   179  			for i := 1; i <= tc.reqCnt; i++ {
   180  				rows, queryErr := db.Query(tc.ctx, tc.query)
   181  				assert.Equal(t, queryErr, tc.wantErr)
   182  				if queryErr != nil {
   183  					return
   184  				}
   185  				assert.NotNil(t, rows)
   186  				ok := rows.Next()
   187  				assert.True(t, ok)
   188  
   189  				val := new(string)
   190  				err := rows.Scan(val)
   191  				assert.Nil(t, err)
   192  				if err != nil {
   193  					return
   194  				}
   195  				assert.NotNil(t, val)
   196  				resp = append(resp, *val)
   197  			}
   198  			assert.ElementsMatch(t, tc.wantResp, resp)
   199  		})
   200  	}
   201  }
   202  
   203  func (ms *MasterSlaveSuite) TestMasterSlaveDbExec() {
   204  	// 使用 sql.Result.LastInsertId 表示请求的是 master或者slave
   205  	ms.mockSlave1.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(2, 1))
   206  	ms.mockSlave2.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(3, 1))
   207  	ms.mockSlave3.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(4, 1))
   208  
   209  	testCasesExec := []struct {
   210  		name              string
   211  		ctx               context.Context
   212  		query             datasource.Query
   213  		reqCnt            int
   214  		slaves            slaves.Slaves
   215  		wantRowsAffected  []int64
   216  		wantLastInsertIds []int64
   217  		wantErr           error
   218  	}{
   219  		{
   220  			name:   "null slave",
   221  			ctx:    context.Background(),
   222  			reqCnt: 1,
   223  			query: datasource.Query{
   224  				SQL:  "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(?,?,?,?)",
   225  				Args: []any{1, 2, 3, 4},
   226  			},
   227  			wantRowsAffected:  []int64{1}, // 切片元素表示的是 lastInsertID, 这里表示请求 master db 1 次
   228  			wantLastInsertIds: []int64{1},
   229  		},
   230  		{
   231  			name:   "3 salves",
   232  			ctx:    context.Background(),
   233  			reqCnt: 3,
   234  			query: datasource.Query{
   235  				SQL:  "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(?,?,?,?)",
   236  				Args: []any{1, 2, 3, 4},
   237  			},
   238  			slaves:            ms.newSlaves(ms.mockSlave1DB, ms.mockSlave2DB, ms.mockSlave3DB),
   239  			wantRowsAffected:  []int64{1, 1, 1}, // 切片元素表示的是 lastInsertID, 这里表示请求 master db 3 次
   240  			wantLastInsertIds: []int64{1, 1, 1},
   241  		},
   242  		{
   243  			name:   "use master with 3 slaves",
   244  			reqCnt: 1,
   245  			ctx:    UseMaster(context.Background()),
   246  			query: datasource.Query{
   247  				SQL:  "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(?,?,?,?)",
   248  				Args: []any{1, 2, 3, 4},
   249  			},
   250  			slaves:            ms.newSlaves(ms.mockSlave1DB, ms.mockSlave2DB, ms.mockSlave3DB),
   251  			wantRowsAffected:  []int64{1}, // 切片元素表示的是 lastInsertID, 这里表示请求 master db 1 次
   252  			wantLastInsertIds: []int64{1},
   253  		},
   254  	}
   255  
   256  	for _, tc := range testCasesExec {
   257  		ms.T().Run(tc.name, func(t *testing.T) {
   258  			db := NewMasterSlavesDB(ms.mockMasterDB, MasterSlavesWithSlaves(tc.slaves))
   259  			//  TODO
   260  			//db, ok := source.(*masterSlavesDB)
   261  			//assert.True(t, ok)
   262  			var resAffectID []int64
   263  			var resLastID []int64
   264  			for i := 1; i <= tc.reqCnt; i++ {
   265  				ms.mockMaster.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(1, 1))
   266  				res, err := db.Exec(tc.ctx, tc.query)
   267  				assert.Nil(t, err)
   268  				afID, er := res.RowsAffected()
   269  				if er != nil {
   270  					continue
   271  				}
   272  				lastID, er := res.LastInsertId()
   273  				if er != nil {
   274  					continue
   275  				}
   276  				resAffectID = append(resAffectID, afID)
   277  				resLastID = append(resLastID, lastID)
   278  			}
   279  			assert.ElementsMatch(t, tc.wantRowsAffected, resAffectID)
   280  			assert.ElementsMatch(t, tc.wantLastInsertIds, resLastID)
   281  
   282  		})
   283  	}
   284  }
   285  
   286  func (ms *MasterSlaveSuite) newSlaves(dbs ...*sql.DB) slaves.Slaves {
   287  	res, err := roundrobin.NewSlaves(dbs...)
   288  	require.NoError(ms.T(), err)
   289  	return res
   290  }
   291  
   292  func TestMasterSlave(t *testing.T) {
   293  	suite.Run(t, &MasterSlaveSuite{})
   294  }