github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/sharding_insert_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "fmt" 22 "regexp" 23 "testing" 24 25 "github.com/DATA-DOG/go-sqlmock" 26 "github.com/ecodeclub/eorm/internal/datasource" 27 "github.com/ecodeclub/eorm/internal/datasource/cluster" 28 "github.com/ecodeclub/eorm/internal/datasource/masterslave" 29 "github.com/ecodeclub/eorm/internal/datasource/shardingsource" 30 "github.com/ecodeclub/eorm/internal/errs" 31 "github.com/ecodeclub/eorm/internal/model" 32 "github.com/ecodeclub/eorm/internal/sharding" 33 "github.com/ecodeclub/eorm/internal/sharding/hash" 34 "github.com/stretchr/testify/assert" 35 "github.com/stretchr/testify/require" 36 "github.com/stretchr/testify/suite" 37 "go.uber.org/multierr" 38 ) 39 40 func newMockErr(dbName string) error { 41 return fmt.Errorf("mock error for %s", dbName) 42 } 43 44 type OrderInsert struct { 45 UserId int `eorm:"primary_key"` 46 OrderId int64 47 Content string 48 Account float64 49 } 50 51 func TestShardingInsert_Build(t *testing.T) { 52 r := model.NewMetaRegistry() 53 dbBase, tableBase, dsBase := 2, 3, 2 54 dbPattern, tablePattern, dsPattern := "order_db_%d", "order_tab_%d", "%d.db.cluster.company.com:3306" 55 _, err := r.Register(&OrderInsert{}, 56 model.WithTableShardingAlgorithm(&hash.Hash{ 57 ShardingKey: "UserId", 58 DBPattern: &hash.Pattern{Name: dbPattern, Base: dbBase}, 59 TablePattern: &hash.Pattern{Name: tablePattern, Base: tableBase}, 60 DsPattern: &hash.Pattern{Name: dsPattern, Base: dsBase}, 61 })) 62 require.NoError(t, err) 63 m := map[string]*masterslave.MasterSlavesDB{ 64 "order_db_0": MasterSlavesMemoryDB(), 65 "order_db_1": MasterSlavesMemoryDB(), 66 "order_db_2": MasterSlavesMemoryDB(), 67 } 68 clusterDB := cluster.NewClusterDB(m) 69 ds := map[string]datasource.DataSource{ 70 "0.db.cluster.company.com:3306": clusterDB, 71 "1.db.cluster.company.com:3306": clusterDB, 72 } 73 shardingDB, err := OpenDS("sqlite3", 74 shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) 75 require.NoError(t, err) 76 testCases := []struct { 77 name string 78 builder sharding.QueryBuilder 79 wantQs []sharding.Query 80 wantErr error 81 }{ 82 { 83 name: "插入一个元素", 84 builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{ 85 {UserId: 1, OrderId: 1, Content: "1", Account: 1.0}, 86 }), 87 wantQs: []sharding.Query{ 88 { 89 SQL: fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_1`", "`order_tab_1`"), 90 Args: []any{1, int64(1), "1", 1.0}, 91 DB: "order_db_1", 92 Datasource: "1.db.cluster.company.com:3306", 93 }, 94 }, 95 }, 96 { 97 name: "插入多个元素", 98 builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{ 99 {UserId: 1, OrderId: 1, Content: "1", Account: 1.0}, 100 {UserId: 2, OrderId: 2, Content: "2", Account: 2.0}, 101 {UserId: 3, OrderId: 3, Content: "3", Account: 3.0}, 102 {UserId: 4, OrderId: 4, Content: "4", Account: 4.0}, 103 }), 104 wantQs: []sharding.Query{ 105 { 106 SQL: fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_0`", "`order_tab_1`"), 107 Args: []any{4, int64(4), "4", 4.0}, 108 DB: "order_db_0", 109 Datasource: "0.db.cluster.company.com:3306", 110 }, 111 { 112 SQL: fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_0`", "`order_tab_2`"), 113 Args: []any{2, int64(2), "2", 2.0}, 114 DB: "order_db_0", 115 Datasource: "0.db.cluster.company.com:3306", 116 }, 117 { 118 SQL: fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_1`", "`order_tab_0`"), 119 Args: []any{3, int64(3), "3", 3.0}, 120 DB: "order_db_1", 121 Datasource: "1.db.cluster.company.com:3306", 122 }, 123 { 124 SQL: fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_1`", "`order_tab_1`"), 125 Args: []any{1, int64(1), "1", 1.0}, 126 DB: "order_db_1", 127 Datasource: "1.db.cluster.company.com:3306", 128 }, 129 }, 130 }, 131 { 132 name: "插入多个元素, 但是不同的元素会被分配到同一个库", 133 builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{ 134 {UserId: 1, OrderId: 1, Content: "1", Account: 1.0}, 135 {UserId: 7, OrderId: 7, Content: "7", Account: 7.0}, 136 }), 137 wantQs: []sharding.Query{ 138 { 139 SQL: fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?),(?,?,?,?);", "`order_db_1`", "`order_tab_1`"), 140 Args: []any{1, int64(1), "1", 1.0, 7, int64(7), "7", 7.0}, 141 DB: "order_db_1", 142 Datasource: "1.db.cluster.company.com:3306", 143 }, 144 }, 145 }, 146 { 147 name: "插入多个元素, 有不同的元素会被分配到同一个库表", 148 builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{ 149 {UserId: 1, OrderId: 1, Content: "1", Account: 1.0}, 150 {UserId: 7, OrderId: 7, Content: "7", Account: 7.0}, 151 {UserId: 2, OrderId: 2, Content: "2", Account: 2.0}, 152 {UserId: 8, OrderId: 8, Content: "8", Account: 8.0}, 153 {UserId: 3, OrderId: 3, Content: "3", Account: 3.0}, 154 }), 155 wantQs: []sharding.Query{ 156 157 { 158 SQL: fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?),(?,?,?,?);", "`order_db_0`", "`order_tab_2`"), 159 Args: []any{2, int64(2), "2", 2.0, 8, int64(8), "8", 8.0}, 160 DB: "order_db_0", 161 Datasource: "0.db.cluster.company.com:3306", 162 }, 163 { 164 SQL: fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_1`", "`order_tab_0`"), 165 Args: []any{3, int64(3), "3", 3.0}, 166 DB: "order_db_1", 167 Datasource: "1.db.cluster.company.com:3306", 168 }, 169 { 170 SQL: fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?),(?,?,?,?);", "`order_db_1`", "`order_tab_1`"), 171 Args: []any{1, int64(1), "1", 1.0, 7, int64(7), "7", 7.0}, 172 DB: "order_db_1", 173 Datasource: "1.db.cluster.company.com:3306", 174 }, 175 }, 176 }, 177 { 178 name: "插入时,插入的列没有包含分库分表的列", 179 builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{ 180 {OrderId: 1, Content: "1", Account: 1.0}, 181 }).Columns([]string{"OrderId", "Content", "Account"}), 182 wantErr: errs.ErrInsertShardingKeyNotFound, 183 }, 184 { 185 name: "插入时,忽略主键,但主键为shardingKey报错", 186 builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{ 187 {OrderId: 1, Content: "1", Account: 1.0}, 188 }).IgnorePK(), 189 wantErr: errs.ErrInsertShardingKeyNotFound, 190 }, 191 { 192 name: "values中没有元素报错", 193 builder: NewShardingInsert[OrderInsert](shardingDB), 194 wantErr: errors.New("插入0行"), 195 }, 196 } 197 for _, tc := range testCases { 198 t.Run(tc.name, func(t *testing.T) { 199 qs, err := tc.builder.Build(context.Background()) 200 require.Equal(t, tc.wantErr, err) 201 if err != nil { 202 return 203 } 204 assert.ElementsMatch(t, tc.wantQs, qs) 205 }) 206 } 207 } 208 209 type ShardingInsertSuite struct { 210 suite.Suite 211 mock01 sqlmock.Sqlmock 212 mockDB01 *sql.DB 213 mock02 sqlmock.Sqlmock 214 mockDB02 *sql.DB 215 } 216 217 func (s *ShardingInsertSuite) SetupSuite() { 218 t := s.T() 219 var err error 220 s.mockDB01, s.mock01, err = sqlmock.New() 221 if err != nil { 222 t.Fatal(err) 223 } 224 s.mockDB02, s.mock02, err = sqlmock.New() 225 if err != nil { 226 t.Fatal(err) 227 } 228 229 } 230 231 func (s *ShardingInsertSuite) TearDownTest() { 232 _ = s.mockDB01.Close() 233 _ = s.mockDB02.Close() 234 } 235 236 func (s *ShardingInsertSuite) TestShardingInsert_Exec() { 237 r := model.NewMetaRegistry() 238 dbBase, tableBase := 2, 3 239 dbPattern, tablePattern, dsPattern := "order_db_%d", "order_tab_%d", "0.db.cluster.company.com:3306" 240 _, err := r.Register(&OrderInsert{}, 241 model.WithTableShardingAlgorithm(&hash.Hash{ 242 ShardingKey: "UserId", 243 DBPattern: &hash.Pattern{Name: dbPattern, Base: dbBase}, 244 TablePattern: &hash.Pattern{Name: tablePattern, Base: tableBase}, 245 DsPattern: &hash.Pattern{Name: dsPattern, NotSharding: true}, 246 })) 247 require.NoError(s.T(), err) 248 249 m := map[string]*masterslave.MasterSlavesDB{ 250 "order_db_0": MasterSlavesMockDB(s.mockDB01), 251 "order_db_1": MasterSlavesMockDB(s.mockDB02), 252 } 253 clusterDB := cluster.NewClusterDB(m) 254 ds := map[string]datasource.DataSource{ 255 "0.db.cluster.company.com:3306": clusterDB, 256 } 257 shardingDB, err := OpenDS("mysql", 258 shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r)) 259 require.NoError(s.T(), err) 260 testcases := []struct { 261 name string 262 si *ShardingInserter[OrderInsert] 263 mockDb func() 264 wantErr error 265 wantAffectedRows int64 266 }{ 267 { 268 name: "跨表插入全部成功", 269 si: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{ 270 {UserId: 1, OrderId: 1, Content: "1", Account: 1.0}, 271 {UserId: 2, OrderId: 2, Content: "2", Account: 2.0}, 272 {UserId: 3, OrderId: 3, Content: "3", Account: 3.0}, 273 }), 274 mockDb: func() { 275 s.mock02.MatchExpectationsInOrder(false) 276 s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_1`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(1, int64(1), "1", 1.0).WillReturnResult(sqlmock.NewResult(1, 1)) 277 s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_0`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(3, int64(3), "3", 3.0).WillReturnResult(sqlmock.NewResult(1, 1)) 278 s.mock01.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_0`.`order_tab_2`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(2, int64(2), "2", 2.0).WillReturnResult(sqlmock.NewResult(1, 1)) 279 }, 280 wantAffectedRows: 3, 281 }, 282 { 283 name: "部分插入失败", 284 si: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{ 285 {UserId: 1, OrderId: 1, Content: "1", Account: 1.0}, 286 {UserId: 2, OrderId: 2, Content: "2", Account: 2.0}, 287 {UserId: 3, OrderId: 3, Content: "3", Account: 3.0}, 288 }), 289 mockDb: func() { 290 s.mock02.MatchExpectationsInOrder(false) 291 s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_1`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(1, int64(1), "1", 1.0).WillReturnError(newMockErr("db01")) 292 s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_0`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(3, int64(3), "3", 3.0).WillReturnResult(sqlmock.NewResult(1, 1)) 293 s.mock01.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_0`.`order_tab_2`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(2, int64(2), "2", 2.0).WillReturnResult(sqlmock.NewResult(1, 1)) 294 }, 295 wantErr: multierr.Combine(newMockErr("db01")), 296 }, 297 { 298 name: "全部插入失败", 299 si: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{ 300 {UserId: 1, OrderId: 1, Content: "1", Account: 1.0}, 301 {UserId: 2, OrderId: 2, Content: "2", Account: 2.0}, 302 {UserId: 3, OrderId: 3, Content: "3", Account: 3.0}, 303 }), 304 mockDb: func() { 305 s.mock02.MatchExpectationsInOrder(false) 306 s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_1`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(1, int64(1), "1", 1.0).WillReturnError(newMockErr("db")) 307 s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_0`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(3, int64(3), "3", 3.0).WillReturnError(newMockErr("db")) 308 s.mock01.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_0`.`order_tab_2`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(2, int64(2), "2", 2.0).WillReturnError(newMockErr("db")) 309 }, 310 wantErr: multierr.Combine(newMockErr("db"), newMockErr("db"), newMockErr("db")), 311 }, 312 } 313 for _, tc := range testcases { 314 s.T().Run(tc.name, func(t *testing.T) { 315 tc.mockDb() 316 res := tc.si.Exec(context.Background()) 317 require.Equal(t, tc.wantErr, res.Err()) 318 if res.Err() != nil { 319 return 320 } 321 322 affectRows, err := res.RowsAffected() 323 require.NoError(t, err) 324 assert.Equal(t, tc.wantAffectedRows, affectRows) 325 }) 326 } 327 } 328 329 func TestShardingInsertSuite(t *testing.T) { 330 suite.Run(t, &ShardingInsertSuite{}) 331 } 332 333 func MasterSlavesMockDB(db *sql.DB) *masterslave.MasterSlavesDB { 334 return masterslave.NewMasterSlavesDB(db) 335 }