github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/cluster/cluster_db_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cluster
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"fmt"
    21  	"testing"
    22  
    23  	"github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves/roundrobin"
    24  
    25  	"github.com/ecodeclub/eorm/internal/datasource/masterslave"
    26  	slaves2 "github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves"
    27  
    28  	"github.com/ecodeclub/eorm/internal/errs"
    29  	_ "github.com/mattn/go-sqlite3"
    30  
    31  	"github.com/DATA-DOG/go-sqlmock"
    32  	"github.com/ecodeclub/eorm/internal/datasource"
    33  	"github.com/ecodeclub/eorm/internal/sharding"
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  	"github.com/stretchr/testify/suite"
    37  )
    38  
    39  func Example_clusterDB_Close() {
    40  	db, _ := sql.Open("sqlite3", "file:test.db?cache=shared&mode=memory")
    41  	cl := NewClusterDB(map[string]*masterslave.MasterSlavesDB{
    42  		"db0": masterslave.NewMasterSlavesDB(db),
    43  	})
    44  	err := cl.Close()
    45  	if err == nil {
    46  		fmt.Println("close")
    47  	}
    48  
    49  	// Output:
    50  	// close
    51  }
    52  
    53  type ClusterSuite struct {
    54  	suite.Suite
    55  	datasource.DataSource
    56  	mockMasterDB *sql.DB
    57  	mockMaster   sqlmock.Sqlmock
    58  
    59  	mockSlave1DB *sql.DB
    60  	mockSlave1   sqlmock.Sqlmock
    61  
    62  	mockSlave2DB *sql.DB
    63  	mockSlave2   sqlmock.Sqlmock
    64  
    65  	mockSlave3DB *sql.DB
    66  	mockSlave3   sqlmock.Sqlmock
    67  }
    68  
    69  func (c *ClusterSuite) SetupTest() {
    70  	t := c.T()
    71  	c.initMock(t)
    72  }
    73  
    74  func (c *ClusterSuite) TearDownTest() {
    75  	_ = c.mockMasterDB.Close()
    76  	_ = c.mockSlave1DB.Close()
    77  	_ = c.mockSlave2DB.Close()
    78  	_ = c.mockSlave3DB.Close()
    79  }
    80  
    81  func (c *ClusterSuite) initMock(t *testing.T) {
    82  	var err error
    83  	c.mockMasterDB, c.mockMaster, err = sqlmock.New()
    84  	if err != nil {
    85  		t.Fatal(err)
    86  	}
    87  	c.mockSlave1DB, c.mockSlave1, err = sqlmock.New()
    88  	if err != nil {
    89  		t.Fatal(err)
    90  	}
    91  	c.mockSlave2DB, c.mockSlave2, err = sqlmock.New()
    92  	if err != nil {
    93  		t.Fatal(err)
    94  	}
    95  	c.mockSlave3DB, c.mockSlave3, err = sqlmock.New()
    96  	if err != nil {
    97  		t.Fatal(err)
    98  	}
    99  }
   100  
   101  func (c *ClusterSuite) TestClusterDbQuery() {
   102  	// 通过select不同的数据表示访问不同的db
   103  	c.mockMaster.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("master"))
   104  	c.mockSlave1.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("slave1_1"))
   105  	c.mockSlave2.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("slave1_2"))
   106  	c.mockSlave3.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("slave1_3"))
   107  
   108  	db := masterslave.NewMasterSlavesDB(c.mockMasterDB, masterslave.MasterSlavesWithSlaves(
   109  		c.newSlaves(c.mockSlave1DB, c.mockSlave2DB, c.mockSlave3DB)))
   110  	testCasesQuery := []struct {
   111  		name     string
   112  		reqCnt   int
   113  		ctx      context.Context
   114  		query    sharding.Query
   115  		ms       map[string]*masterslave.MasterSlavesDB
   116  		wantResp []string
   117  		wantErr  error
   118  	}{
   119  		{
   120  			name:   "query not found target db",
   121  			ctx:    context.Background(),
   122  			reqCnt: 3,
   123  			query: sharding.Query{
   124  				SQL: "SELECT `first_name` FROM `test_model`",
   125  				DB:  "order_db_1",
   126  			},
   127  			ms: func() map[string]*masterslave.MasterSlavesDB {
   128  				masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_0": db}
   129  				return masterSlaves
   130  			}(),
   131  			wantErr: errs.NewErrNotFoundTargetDB("order_db_1"),
   132  		},
   133  		{
   134  			name:   "select default use slave",
   135  			ctx:    context.Background(),
   136  			reqCnt: 3,
   137  			query: sharding.Query{
   138  				SQL: "SELECT `first_name` FROM `test_model`",
   139  				DB:  "order_db_0",
   140  			},
   141  			ms: func() map[string]*masterslave.MasterSlavesDB {
   142  				masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_0": db}
   143  				return masterSlaves
   144  			}(),
   145  			wantResp: []string{"slave1_1", "slave1_2", "slave1_3"},
   146  		},
   147  		{
   148  			name:   "use master",
   149  			reqCnt: 1,
   150  			ctx:    masterslave.UseMaster(context.Background()),
   151  			query: sharding.Query{
   152  				SQL: "SELECT `first_name` FROM `test_model`",
   153  				DB:  "order_db_1",
   154  			},
   155  			ms: func() map[string]*masterslave.MasterSlavesDB {
   156  				masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_1": db}
   157  				return masterSlaves
   158  			}(),
   159  			wantResp: []string{"master"},
   160  		},
   161  	}
   162  
   163  	for _, tc := range testCasesQuery {
   164  		c.T().Run(tc.name, func(t *testing.T) {
   165  			clusDB := NewClusterDB(tc.ms)
   166  			var resp []string
   167  			for i := 1; i <= tc.reqCnt; i++ {
   168  				rows, queryErr := clusDB.Query(tc.ctx, tc.query)
   169  				assert.Equal(t, queryErr, tc.wantErr)
   170  				if queryErr != nil {
   171  					return
   172  				}
   173  				assert.NotNil(t, rows)
   174  				ok := rows.Next()
   175  				assert.True(t, ok)
   176  
   177  				val := new(string)
   178  				err := rows.Scan(val)
   179  				assert.Nil(t, err)
   180  				if err != nil {
   181  					return
   182  				}
   183  				assert.NotNil(t, val)
   184  
   185  				resp = append(resp, *val)
   186  			}
   187  			assert.ElementsMatch(t, tc.wantResp, resp)
   188  		})
   189  	}
   190  }
   191  
   192  func (c *ClusterSuite) TestClusterDbExec() {
   193  	// 使用 sql.Result.LastInsertId 表示请求的是 master或者slave
   194  	c.mockSlave1.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(2, 1))
   195  	c.mockSlave2.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(3, 1))
   196  	c.mockSlave3.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(4, 1))
   197  
   198  	testCasesExec := []struct {
   199  		name     string
   200  		reqCnt   int
   201  		ctx      context.Context
   202  		slaves   slaves2.Slaves
   203  		query    sharding.Query
   204  		ms       map[string]*masterslave.MasterSlavesDB
   205  		wantResp []int64
   206  		wantErr  error
   207  	}{
   208  		{
   209  			name:   "exec not found target db",
   210  			ctx:    context.Background(),
   211  			reqCnt: 1,
   212  			query: sharding.Query{
   213  				SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   214  				DB:  "order_db_1",
   215  			},
   216  			ms: func() map[string]*masterslave.MasterSlavesDB {
   217  				db := masterslave.NewMasterSlavesDB(c.mockMasterDB,
   218  					masterslave.MasterSlavesWithSlaves(c.newSlaves(nil)))
   219  				masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_0": db}
   220  				return masterSlaves
   221  			}(),
   222  			wantErr: errs.NewErrNotFoundTargetDB("order_db_1"),
   223  		},
   224  		{
   225  			name:   "null slave",
   226  			ctx:    context.Background(),
   227  			reqCnt: 1,
   228  			query: sharding.Query{
   229  				SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   230  				DB:  "order_db_0",
   231  			},
   232  			ms: func() map[string]*masterslave.MasterSlavesDB {
   233  				db := masterslave.NewMasterSlavesDB(c.mockMasterDB,
   234  					masterslave.MasterSlavesWithSlaves(c.newSlaves(nil)))
   235  				masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_0": db}
   236  				return masterSlaves
   237  			}(),
   238  			wantResp: []int64{1}, // 切片元素表示的是 lastInsertID, 这里表示请求 master db 1 次
   239  		},
   240  		{
   241  			name:   "3 salves",
   242  			ctx:    context.Background(),
   243  			reqCnt: 3,
   244  			query: sharding.Query{
   245  				SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   246  				DB:  "order_db_1",
   247  			},
   248  			ms: func() map[string]*masterslave.MasterSlavesDB {
   249  				db := masterslave.NewMasterSlavesDB(c.mockMasterDB, masterslave.MasterSlavesWithSlaves(
   250  					c.newSlaves(c.mockSlave1DB, c.mockSlave2DB, c.mockSlave3DB)))
   251  				masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_1": db}
   252  				return masterSlaves
   253  			}(),
   254  			wantResp: []int64{1, 1, 1}, // 切片元素表示的是 lastInsertID, 这里表示请求 master db 3 次
   255  		},
   256  		{
   257  			name:   "use master with 3 slaves",
   258  			reqCnt: 1,
   259  			ctx:    masterslave.UseMaster(context.Background()),
   260  			query: sharding.Query{
   261  				SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   262  				DB:  "order_db_2",
   263  			},
   264  			ms: func() map[string]*masterslave.MasterSlavesDB {
   265  				db := masterslave.NewMasterSlavesDB(c.mockMasterDB, masterslave.MasterSlavesWithSlaves(
   266  					c.newSlaves(c.mockSlave1DB, c.mockSlave2DB, c.mockSlave3DB)))
   267  				masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_2": db}
   268  				return masterSlaves
   269  			}(),
   270  			wantResp: []int64{1}, // 切片元素表示的是 lastInsertID, 这里表示请求 master db 1 次
   271  		},
   272  	}
   273  
   274  	for _, tc := range testCasesExec {
   275  		c.T().Run(tc.name, func(t *testing.T) {
   276  			db := NewClusterDB(tc.ms)
   277  			var resAffectID []int64
   278  			for i := 1; i <= tc.reqCnt; i++ {
   279  				c.mockMaster.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(1, 1))
   280  				res, execErr := db.Exec(tc.ctx, tc.query)
   281  				assert.Equal(t, execErr, tc.wantErr)
   282  				if execErr != nil {
   283  					return
   284  				}
   285  
   286  				afID, er := res.LastInsertId()
   287  				if er != nil {
   288  					continue
   289  				}
   290  				resAffectID = append(resAffectID, afID)
   291  			}
   292  			assert.ElementsMatch(t, tc.wantResp, resAffectID)
   293  		})
   294  	}
   295  }
   296  
   297  func (c *ClusterSuite) newSlaves(dbs ...*sql.DB) slaves2.Slaves {
   298  	res, err := roundrobin.NewSlaves(dbs...)
   299  	require.NoError(c.T(), err)
   300  	return res
   301  }
   302  
   303  func TestClusterDB(t *testing.T) {
   304  	suite.Run(t, &ClusterSuite{})
   305  }