github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/shardingsource/sharding_datasource_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package shardingsource
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"fmt"
    21  	"testing"
    22  
    23  	"github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves/roundrobin"
    24  
    25  	"github.com/ecodeclub/eorm/internal/datasource/masterslave"
    26  	"github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves"
    27  
    28  	"github.com/ecodeclub/eorm/internal/errs"
    29  
    30  	"github.com/DATA-DOG/go-sqlmock"
    31  	"github.com/ecodeclub/eorm/internal/datasource"
    32  	"github.com/ecodeclub/eorm/internal/datasource/cluster"
    33  	"github.com/ecodeclub/eorm/internal/sharding"
    34  	_ "github.com/mattn/go-sqlite3"
    35  	"github.com/stretchr/testify/assert"
    36  	"github.com/stretchr/testify/require"
    37  	"github.com/stretchr/testify/suite"
    38  )
    39  
    40  func ExampleShardingDataSource_Close() {
    41  	db, _ := sql.Open("sqlite3", "file:test.db?cache=shared&mode=memory")
    42  	cl := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{
    43  		"db0": masterslave.NewMasterSlavesDB(db),
    44  	})
    45  	ds := NewShardingDataSource(map[string]datasource.DataSource{
    46  		"source0": cl,
    47  	})
    48  	err := ds.Close()
    49  	if err == nil {
    50  		fmt.Println("close")
    51  	}
    52  
    53  	// Output:
    54  	// close
    55  }
    56  
    57  type ShardingDataSourceSuite struct {
    58  	suite.Suite
    59  	datasource.DataSource
    60  	mockMaster1DB *sql.DB
    61  	mockMaster    sqlmock.Sqlmock
    62  
    63  	mockSlave1DB *sql.DB
    64  	mockSlave1   sqlmock.Sqlmock
    65  
    66  	mockSlave2DB *sql.DB
    67  	mockSlave2   sqlmock.Sqlmock
    68  
    69  	mockSlave3DB *sql.DB
    70  	mockSlave3   sqlmock.Sqlmock
    71  
    72  	mockMaster2DB *sql.DB
    73  	mockMaster2   sqlmock.Sqlmock
    74  
    75  	mockSlave4DB *sql.DB
    76  	mockSlave4   sqlmock.Sqlmock
    77  
    78  	mockSlave5DB *sql.DB
    79  	mockSlave5   sqlmock.Sqlmock
    80  
    81  	mockSlave6DB *sql.DB
    82  	mockSlave6   sqlmock.Sqlmock
    83  }
    84  
    85  func (c *ShardingDataSourceSuite) SetupTest() {
    86  	t := c.T()
    87  	c.initMock(t)
    88  }
    89  
    90  func (c *ShardingDataSourceSuite) TearDownTest() {
    91  	_ = c.mockMaster1DB.Close()
    92  	_ = c.mockSlave1DB.Close()
    93  	_ = c.mockSlave2DB.Close()
    94  	_ = c.mockSlave3DB.Close()
    95  
    96  	_ = c.mockMaster2DB.Close()
    97  	_ = c.mockSlave4DB.Close()
    98  	_ = c.mockSlave5DB.Close()
    99  	_ = c.mockSlave6DB.Close()
   100  }
   101  
   102  func (c *ShardingDataSourceSuite) initMock(t *testing.T) {
   103  	var err error
   104  	c.mockMaster1DB, c.mockMaster, err = sqlmock.New()
   105  	if err != nil {
   106  		t.Fatal(err)
   107  	}
   108  	c.mockSlave1DB, c.mockSlave1, err = sqlmock.New()
   109  	if err != nil {
   110  		t.Fatal(err)
   111  	}
   112  	c.mockSlave2DB, c.mockSlave2, err = sqlmock.New()
   113  	if err != nil {
   114  		t.Fatal(err)
   115  	}
   116  	c.mockSlave3DB, c.mockSlave3, err = sqlmock.New()
   117  	if err != nil {
   118  		t.Fatal(err)
   119  	}
   120  
   121  	c.mockMaster2DB, c.mockMaster2, err = sqlmock.New()
   122  	if err != nil {
   123  		t.Fatal(err)
   124  	}
   125  	c.mockSlave4DB, c.mockSlave4, err = sqlmock.New()
   126  	if err != nil {
   127  		t.Fatal(err)
   128  	}
   129  	c.mockSlave5DB, c.mockSlave5, err = sqlmock.New()
   130  	if err != nil {
   131  		t.Fatal(err)
   132  	}
   133  	c.mockSlave6DB, c.mockSlave6, err = sqlmock.New()
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  
   138  	db1 := masterslave.NewMasterSlavesDB(c.mockMaster1DB, masterslave.MasterSlavesWithSlaves(
   139  		c.newSlaves(c.mockSlave1DB, c.mockSlave2DB, c.mockSlave3DB)))
   140  
   141  	db2 := masterslave.NewMasterSlavesDB(c.mockMaster2DB, masterslave.MasterSlavesWithSlaves(
   142  		c.newSlaves(c.mockSlave4DB, c.mockSlave5DB, c.mockSlave6DB)))
   143  
   144  	clusterDB1 := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{"db_0": db1})
   145  	clusterDB2 := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{"db_0": db2})
   146  
   147  	c.DataSource = NewShardingDataSource(map[string]datasource.DataSource{
   148  		"0.db.cluster.company.com:3306": clusterDB1,
   149  		"1.db.cluster.company.com:3306": clusterDB2,
   150  	})
   151  
   152  }
   153  
   154  func (c *ShardingDataSourceSuite) TestClusterDbQuery() {
   155  	// 通过select不同的数据表示访问不同的db
   156  	c.mockMaster.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster0 master"))
   157  	c.mockSlave1.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster0 slave1_1"))
   158  	c.mockSlave2.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster0 slave1_2"))
   159  	c.mockSlave3.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster0 slave1_3"))
   160  
   161  	c.mockMaster2.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster1 master"))
   162  	c.mockSlave4.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster1 slave1_1"))
   163  	c.mockSlave5.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster1 slave1_2"))
   164  	c.mockSlave6.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster1 slave1_3"))
   165  
   166  	testCasesQuery := []struct {
   167  		name     string
   168  		reqCnt   int
   169  		ctx      context.Context
   170  		query    sharding.Query
   171  		wantResp []string
   172  		wantErr  error
   173  	}{
   174  		{
   175  			name:   "not found target DataSource",
   176  			ctx:    context.Background(),
   177  			reqCnt: 1,
   178  			query: sharding.Query{
   179  				SQL:        "SELECT `first_name` FROM `test_model`",
   180  				DB:         "db_0",
   181  				Datasource: "2.db.cluster.company.com:3306",
   182  			},
   183  			wantErr: errs.NewErrNotFoundTargetDataSource("2.db.cluster.company.com:3306"),
   184  		},
   185  		{
   186  			name:   "cluster0 select default use slave",
   187  			ctx:    context.Background(),
   188  			reqCnt: 3,
   189  			query: sharding.Query{
   190  				SQL:        "SELECT `first_name` FROM `test_model`",
   191  				DB:         "db_0",
   192  				Datasource: "0.db.cluster.company.com:3306",
   193  			},
   194  			wantResp: []string{"cluster0 slave1_1", "cluster0 slave1_2", "cluster0 slave1_3"},
   195  		},
   196  		{
   197  			name:   "cluster1 select default use slave",
   198  			ctx:    context.Background(),
   199  			reqCnt: 3,
   200  			query: sharding.Query{
   201  				SQL:        "SELECT `first_name` FROM `test_model`",
   202  				DB:         "db_0",
   203  				Datasource: "1.db.cluster.company.com:3306",
   204  			},
   205  			wantResp: []string{"cluster1 slave1_1", "cluster1 slave1_2", "cluster1 slave1_3"},
   206  		},
   207  		{
   208  			name:   "cluster0 use master",
   209  			reqCnt: 1,
   210  			ctx:    masterslave.UseMaster(context.Background()),
   211  			query: sharding.Query{
   212  				SQL:        "SELECT `first_name` FROM `test_model`",
   213  				DB:         "db_0",
   214  				Datasource: "0.db.cluster.company.com:3306",
   215  			},
   216  			wantResp: []string{"cluster0 master"},
   217  		},
   218  		{
   219  			name:   "cluster1 use master",
   220  			reqCnt: 1,
   221  			ctx:    masterslave.UseMaster(context.Background()),
   222  			query: sharding.Query{
   223  				SQL:        "SELECT `first_name` FROM `test_model`",
   224  				DB:         "db_0",
   225  				Datasource: "1.db.cluster.company.com:3306",
   226  			},
   227  			wantResp: []string{"cluster1 master"},
   228  		},
   229  	}
   230  
   231  	for _, tc := range testCasesQuery {
   232  		c.T().Run(tc.name, func(t *testing.T) {
   233  			var resp []string
   234  			for i := 1; i <= tc.reqCnt; i++ {
   235  				rows, queryErr := c.DataSource.Query(tc.ctx, tc.query)
   236  				assert.Equal(t, queryErr, tc.wantErr)
   237  				if queryErr != nil {
   238  					return
   239  				}
   240  				assert.NotNil(t, rows)
   241  				ok := rows.Next()
   242  				assert.True(t, ok)
   243  
   244  				val := new(string)
   245  				err := rows.Scan(val)
   246  				assert.Nil(t, err)
   247  				if err != nil {
   248  					return
   249  				}
   250  				assert.NotNil(t, val)
   251  
   252  				resp = append(resp, *val)
   253  			}
   254  			assert.ElementsMatch(t, tc.wantResp, resp)
   255  		})
   256  	}
   257  }
   258  
   259  func (c *ShardingDataSourceSuite) TestClusterDbExec() {
   260  	// 使用 sql.Result.LastInsertId 表示请求的是 master或者slave
   261  	c.mockMaster.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(1, 1))
   262  	c.mockMaster2.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(2, 1))
   263  
   264  	testCasesExec := []struct {
   265  		name              string
   266  		reqCnt            int
   267  		ctx               context.Context
   268  		slaves            slaves.Slaves
   269  		query             sharding.Query
   270  		wantRowsAffected  []int64
   271  		wantLastInsertIds []int64
   272  		wantErr           error
   273  	}{
   274  		{
   275  			name:   "not found target DataSource",
   276  			ctx:    context.Background(),
   277  			reqCnt: 1,
   278  			query: sharding.Query{
   279  				SQL:        "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   280  				DB:         "db_0",
   281  				Datasource: "2.db.cluster.company.com:3306",
   282  			},
   283  			wantErr: errs.NewErrNotFoundTargetDataSource("2.db.cluster.company.com:3306"),
   284  		},
   285  		{
   286  			name:   "cluster0 exec",
   287  			reqCnt: 1,
   288  			ctx:    masterslave.UseMaster(context.Background()),
   289  			query: sharding.Query{
   290  				SQL:        "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   291  				DB:         "db_0",
   292  				Datasource: "0.db.cluster.company.com:3306",
   293  			},
   294  			wantRowsAffected:  []int64{1},
   295  			wantLastInsertIds: []int64{1},
   296  		},
   297  		{
   298  			name:   "cluster1 exec",
   299  			reqCnt: 1,
   300  			ctx:    masterslave.UseMaster(context.Background()),
   301  			query: sharding.Query{
   302  				SQL:        "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   303  				DB:         "db_0",
   304  				Datasource: "1.db.cluster.company.com:3306",
   305  			},
   306  			wantRowsAffected:  []int64{1},
   307  			wantLastInsertIds: []int64{2},
   308  		},
   309  	}
   310  
   311  	for _, tc := range testCasesExec {
   312  		c.T().Run(tc.name, func(t *testing.T) {
   313  			var resAffectID []int64
   314  			var resLastID []int64
   315  			for i := 1; i <= tc.reqCnt; i++ {
   316  				res, execErr := c.DataSource.Exec(tc.ctx, tc.query)
   317  				assert.Equal(t, execErr, tc.wantErr)
   318  				if execErr != nil {
   319  					return
   320  				}
   321  				afID, er := res.RowsAffected()
   322  				if er != nil {
   323  					continue
   324  				}
   325  				lastID, er := res.LastInsertId()
   326  				if er != nil {
   327  					continue
   328  				}
   329  				resAffectID = append(resAffectID, afID)
   330  				resLastID = append(resLastID, lastID)
   331  			}
   332  			assert.ElementsMatch(t, tc.wantRowsAffected, resAffectID)
   333  			assert.ElementsMatch(t, tc.wantLastInsertIds, resLastID)
   334  		})
   335  	}
   336  }
   337  
   338  func (c *ShardingDataSourceSuite) newSlaves(dbs ...*sql.DB) slaves.Slaves {
   339  	res, err := roundrobin.NewSlaves(dbs...)
   340  	require.NoError(c.T(), err)
   341  	return res
   342  }
   343  
   344  func TestShardingDataSourceSuite(t *testing.T) {
   345  	suite.Run(t, &ShardingDataSourceSuite{})
   346  }