github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/shardingsource/sharding_datasource_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package shardingsource 16 17 import ( 18 "context" 19 "database/sql" 20 "fmt" 21 "testing" 22 23 "github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves/roundrobin" 24 25 "github.com/ecodeclub/eorm/internal/datasource/masterslave" 26 "github.com/ecodeclub/eorm/internal/datasource/masterslave/slaves" 27 28 "github.com/ecodeclub/eorm/internal/errs" 29 30 "github.com/DATA-DOG/go-sqlmock" 31 "github.com/ecodeclub/eorm/internal/datasource" 32 "github.com/ecodeclub/eorm/internal/datasource/cluster" 33 "github.com/ecodeclub/eorm/internal/sharding" 34 _ "github.com/mattn/go-sqlite3" 35 "github.com/stretchr/testify/assert" 36 "github.com/stretchr/testify/require" 37 "github.com/stretchr/testify/suite" 38 ) 39 40 func ExampleShardingDataSource_Close() { 41 db, _ := sql.Open("sqlite3", "file:test.db?cache=shared&mode=memory") 42 cl := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{ 43 "db0": masterslave.NewMasterSlavesDB(db), 44 }) 45 ds := NewShardingDataSource(map[string]datasource.DataSource{ 46 "source0": cl, 47 }) 48 err := ds.Close() 49 if err == nil { 50 fmt.Println("close") 51 } 52 53 // Output: 54 // close 55 } 56 57 type ShardingDataSourceSuite struct { 58 suite.Suite 59 datasource.DataSource 60 mockMaster1DB *sql.DB 61 mockMaster sqlmock.Sqlmock 62 63 mockSlave1DB *sql.DB 64 mockSlave1 sqlmock.Sqlmock 65 66 mockSlave2DB *sql.DB 67 mockSlave2 sqlmock.Sqlmock 68 69 mockSlave3DB *sql.DB 70 mockSlave3 sqlmock.Sqlmock 71 72 mockMaster2DB *sql.DB 73 mockMaster2 sqlmock.Sqlmock 74 75 mockSlave4DB *sql.DB 76 mockSlave4 sqlmock.Sqlmock 77 78 mockSlave5DB *sql.DB 79 mockSlave5 sqlmock.Sqlmock 80 81 mockSlave6DB *sql.DB 82 mockSlave6 sqlmock.Sqlmock 83 } 84 85 func (c *ShardingDataSourceSuite) SetupTest() { 86 t := c.T() 87 c.initMock(t) 88 } 89 90 func (c *ShardingDataSourceSuite) TearDownTest() { 91 _ = c.mockMaster1DB.Close() 92 _ = c.mockSlave1DB.Close() 93 _ = c.mockSlave2DB.Close() 94 _ = c.mockSlave3DB.Close() 95 96 _ = c.mockMaster2DB.Close() 97 _ = c.mockSlave4DB.Close() 98 _ = c.mockSlave5DB.Close() 99 _ = c.mockSlave6DB.Close() 100 } 101 102 func (c *ShardingDataSourceSuite) initMock(t *testing.T) { 103 var err error 104 c.mockMaster1DB, c.mockMaster, err = sqlmock.New() 105 if err != nil { 106 t.Fatal(err) 107 } 108 c.mockSlave1DB, c.mockSlave1, err = sqlmock.New() 109 if err != nil { 110 t.Fatal(err) 111 } 112 c.mockSlave2DB, c.mockSlave2, err = sqlmock.New() 113 if err != nil { 114 t.Fatal(err) 115 } 116 c.mockSlave3DB, c.mockSlave3, err = sqlmock.New() 117 if err != nil { 118 t.Fatal(err) 119 } 120 121 c.mockMaster2DB, c.mockMaster2, err = sqlmock.New() 122 if err != nil { 123 t.Fatal(err) 124 } 125 c.mockSlave4DB, c.mockSlave4, err = sqlmock.New() 126 if err != nil { 127 t.Fatal(err) 128 } 129 c.mockSlave5DB, c.mockSlave5, err = sqlmock.New() 130 if err != nil { 131 t.Fatal(err) 132 } 133 c.mockSlave6DB, c.mockSlave6, err = sqlmock.New() 134 if err != nil { 135 t.Fatal(err) 136 } 137 138 db1 := masterslave.NewMasterSlavesDB(c.mockMaster1DB, masterslave.MasterSlavesWithSlaves( 139 c.newSlaves(c.mockSlave1DB, c.mockSlave2DB, c.mockSlave3DB))) 140 141 db2 := masterslave.NewMasterSlavesDB(c.mockMaster2DB, masterslave.MasterSlavesWithSlaves( 142 c.newSlaves(c.mockSlave4DB, c.mockSlave5DB, c.mockSlave6DB))) 143 144 clusterDB1 := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{"db_0": db1}) 145 clusterDB2 := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{"db_0": db2}) 146 147 c.DataSource = NewShardingDataSource(map[string]datasource.DataSource{ 148 "0.db.cluster.company.com:3306": clusterDB1, 149 "1.db.cluster.company.com:3306": clusterDB2, 150 }) 151 152 } 153 154 func (c *ShardingDataSourceSuite) TestClusterDbQuery() { 155 // 通过select不同的数据表示访问不同的db 156 c.mockMaster.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster0 master")) 157 c.mockSlave1.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster0 slave1_1")) 158 c.mockSlave2.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster0 slave1_2")) 159 c.mockSlave3.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster0 slave1_3")) 160 161 c.mockMaster2.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster1 master")) 162 c.mockSlave4.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster1 slave1_1")) 163 c.mockSlave5.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster1 slave1_2")) 164 c.mockSlave6.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("cluster1 slave1_3")) 165 166 testCasesQuery := []struct { 167 name string 168 reqCnt int 169 ctx context.Context 170 query sharding.Query 171 wantResp []string 172 wantErr error 173 }{ 174 { 175 name: "not found target DataSource", 176 ctx: context.Background(), 177 reqCnt: 1, 178 query: sharding.Query{ 179 SQL: "SELECT `first_name` FROM `test_model`", 180 DB: "db_0", 181 Datasource: "2.db.cluster.company.com:3306", 182 }, 183 wantErr: errs.NewErrNotFoundTargetDataSource("2.db.cluster.company.com:3306"), 184 }, 185 { 186 name: "cluster0 select default use slave", 187 ctx: context.Background(), 188 reqCnt: 3, 189 query: sharding.Query{ 190 SQL: "SELECT `first_name` FROM `test_model`", 191 DB: "db_0", 192 Datasource: "0.db.cluster.company.com:3306", 193 }, 194 wantResp: []string{"cluster0 slave1_1", "cluster0 slave1_2", "cluster0 slave1_3"}, 195 }, 196 { 197 name: "cluster1 select default use slave", 198 ctx: context.Background(), 199 reqCnt: 3, 200 query: sharding.Query{ 201 SQL: "SELECT `first_name` FROM `test_model`", 202 DB: "db_0", 203 Datasource: "1.db.cluster.company.com:3306", 204 }, 205 wantResp: []string{"cluster1 slave1_1", "cluster1 slave1_2", "cluster1 slave1_3"}, 206 }, 207 { 208 name: "cluster0 use master", 209 reqCnt: 1, 210 ctx: masterslave.UseMaster(context.Background()), 211 query: sharding.Query{ 212 SQL: "SELECT `first_name` FROM `test_model`", 213 DB: "db_0", 214 Datasource: "0.db.cluster.company.com:3306", 215 }, 216 wantResp: []string{"cluster0 master"}, 217 }, 218 { 219 name: "cluster1 use master", 220 reqCnt: 1, 221 ctx: masterslave.UseMaster(context.Background()), 222 query: sharding.Query{ 223 SQL: "SELECT `first_name` FROM `test_model`", 224 DB: "db_0", 225 Datasource: "1.db.cluster.company.com:3306", 226 }, 227 wantResp: []string{"cluster1 master"}, 228 }, 229 } 230 231 for _, tc := range testCasesQuery { 232 c.T().Run(tc.name, func(t *testing.T) { 233 var resp []string 234 for i := 1; i <= tc.reqCnt; i++ { 235 rows, queryErr := c.DataSource.Query(tc.ctx, tc.query) 236 assert.Equal(t, queryErr, tc.wantErr) 237 if queryErr != nil { 238 return 239 } 240 assert.NotNil(t, rows) 241 ok := rows.Next() 242 assert.True(t, ok) 243 244 val := new(string) 245 err := rows.Scan(val) 246 assert.Nil(t, err) 247 if err != nil { 248 return 249 } 250 assert.NotNil(t, val) 251 252 resp = append(resp, *val) 253 } 254 assert.ElementsMatch(t, tc.wantResp, resp) 255 }) 256 } 257 } 258 259 func (c *ShardingDataSourceSuite) TestClusterDbExec() { 260 // 使用 sql.Result.LastInsertId 表示请求的是 master或者slave 261 c.mockMaster.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 262 c.mockMaster2.ExpectExec("^INSERT INTO (.+)").WillReturnResult(sqlmock.NewResult(2, 1)) 263 264 testCasesExec := []struct { 265 name string 266 reqCnt int 267 ctx context.Context 268 slaves slaves.Slaves 269 query sharding.Query 270 wantRowsAffected []int64 271 wantLastInsertIds []int64 272 wantErr error 273 }{ 274 { 275 name: "not found target DataSource", 276 ctx: context.Background(), 277 reqCnt: 1, 278 query: sharding.Query{ 279 SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)", 280 DB: "db_0", 281 Datasource: "2.db.cluster.company.com:3306", 282 }, 283 wantErr: errs.NewErrNotFoundTargetDataSource("2.db.cluster.company.com:3306"), 284 }, 285 { 286 name: "cluster0 exec", 287 reqCnt: 1, 288 ctx: masterslave.UseMaster(context.Background()), 289 query: sharding.Query{ 290 SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)", 291 DB: "db_0", 292 Datasource: "0.db.cluster.company.com:3306", 293 }, 294 wantRowsAffected: []int64{1}, 295 wantLastInsertIds: []int64{1}, 296 }, 297 { 298 name: "cluster1 exec", 299 reqCnt: 1, 300 ctx: masterslave.UseMaster(context.Background()), 301 query: sharding.Query{ 302 SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)", 303 DB: "db_0", 304 Datasource: "1.db.cluster.company.com:3306", 305 }, 306 wantRowsAffected: []int64{1}, 307 wantLastInsertIds: []int64{2}, 308 }, 309 } 310 311 for _, tc := range testCasesExec { 312 c.T().Run(tc.name, func(t *testing.T) { 313 var resAffectID []int64 314 var resLastID []int64 315 for i := 1; i <= tc.reqCnt; i++ { 316 res, execErr := c.DataSource.Exec(tc.ctx, tc.query) 317 assert.Equal(t, execErr, tc.wantErr) 318 if execErr != nil { 319 return 320 } 321 afID, er := res.RowsAffected() 322 if er != nil { 323 continue 324 } 325 lastID, er := res.LastInsertId() 326 if er != nil { 327 continue 328 } 329 resAffectID = append(resAffectID, afID) 330 resLastID = append(resLastID, lastID) 331 } 332 assert.ElementsMatch(t, tc.wantRowsAffected, resAffectID) 333 assert.ElementsMatch(t, tc.wantLastInsertIds, resLastID) 334 }) 335 } 336 } 337 338 func (c *ShardingDataSourceSuite) newSlaves(dbs ...*sql.DB) slaves.Slaves { 339 res, err := roundrobin.NewSlaves(dbs...) 340 require.NoError(c.T(), err) 341 return res 342 } 343 344 func TestShardingDataSourceSuite(t *testing.T) { 345 suite.Run(t, &ShardingDataSourceSuite{}) 346 }