github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/sharding_insert_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"fmt"
    22  	"regexp"
    23  	"testing"
    24  
    25  	"github.com/DATA-DOG/go-sqlmock"
    26  	"github.com/ecodeclub/eorm/internal/datasource"
    27  	"github.com/ecodeclub/eorm/internal/datasource/cluster"
    28  	"github.com/ecodeclub/eorm/internal/datasource/masterslave"
    29  	"github.com/ecodeclub/eorm/internal/datasource/shardingsource"
    30  	"github.com/ecodeclub/eorm/internal/errs"
    31  	"github.com/ecodeclub/eorm/internal/model"
    32  	"github.com/ecodeclub/eorm/internal/sharding"
    33  	"github.com/ecodeclub/eorm/internal/sharding/hash"
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  	"github.com/stretchr/testify/suite"
    37  	"go.uber.org/multierr"
    38  )
    39  
    40  func newMockErr(dbName string) error {
    41  	return fmt.Errorf("mock error for %s", dbName)
    42  }
    43  
    44  type OrderInsert struct {
    45  	UserId  int `eorm:"primary_key"`
    46  	OrderId int64
    47  	Content string
    48  	Account float64
    49  }
    50  
    51  func TestShardingInsert_Build(t *testing.T) {
    52  	r := model.NewMetaRegistry()
    53  	dbBase, tableBase, dsBase := 2, 3, 2
    54  	dbPattern, tablePattern, dsPattern := "order_db_%d", "order_tab_%d", "%d.db.cluster.company.com:3306"
    55  	_, err := r.Register(&OrderInsert{},
    56  		model.WithTableShardingAlgorithm(&hash.Hash{
    57  			ShardingKey:  "UserId",
    58  			DBPattern:    &hash.Pattern{Name: dbPattern, Base: dbBase},
    59  			TablePattern: &hash.Pattern{Name: tablePattern, Base: tableBase},
    60  			DsPattern:    &hash.Pattern{Name: dsPattern, Base: dsBase},
    61  		}))
    62  	require.NoError(t, err)
    63  	m := map[string]*masterslave.MasterSlavesDB{
    64  		"order_db_0": MasterSlavesMemoryDB(),
    65  		"order_db_1": MasterSlavesMemoryDB(),
    66  		"order_db_2": MasterSlavesMemoryDB(),
    67  	}
    68  	clusterDB := cluster.NewClusterDB(m)
    69  	ds := map[string]datasource.DataSource{
    70  		"0.db.cluster.company.com:3306": clusterDB,
    71  		"1.db.cluster.company.com:3306": clusterDB,
    72  	}
    73  	shardingDB, err := OpenDS("sqlite3",
    74  		shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r))
    75  	require.NoError(t, err)
    76  	testCases := []struct {
    77  		name    string
    78  		builder sharding.QueryBuilder
    79  		wantQs  []sharding.Query
    80  		wantErr error
    81  	}{
    82  		{
    83  			name: "插入一个元素",
    84  			builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{
    85  				{UserId: 1, OrderId: 1, Content: "1", Account: 1.0},
    86  			}),
    87  			wantQs: []sharding.Query{
    88  				{
    89  					SQL:        fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_1`", "`order_tab_1`"),
    90  					Args:       []any{1, int64(1), "1", 1.0},
    91  					DB:         "order_db_1",
    92  					Datasource: "1.db.cluster.company.com:3306",
    93  				},
    94  			},
    95  		},
    96  		{
    97  			name: "插入多个元素",
    98  			builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{
    99  				{UserId: 1, OrderId: 1, Content: "1", Account: 1.0},
   100  				{UserId: 2, OrderId: 2, Content: "2", Account: 2.0},
   101  				{UserId: 3, OrderId: 3, Content: "3", Account: 3.0},
   102  				{UserId: 4, OrderId: 4, Content: "4", Account: 4.0},
   103  			}),
   104  			wantQs: []sharding.Query{
   105  				{
   106  					SQL:        fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_0`", "`order_tab_1`"),
   107  					Args:       []any{4, int64(4), "4", 4.0},
   108  					DB:         "order_db_0",
   109  					Datasource: "0.db.cluster.company.com:3306",
   110  				},
   111  				{
   112  					SQL:        fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_0`", "`order_tab_2`"),
   113  					Args:       []any{2, int64(2), "2", 2.0},
   114  					DB:         "order_db_0",
   115  					Datasource: "0.db.cluster.company.com:3306",
   116  				},
   117  				{
   118  					SQL:        fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_1`", "`order_tab_0`"),
   119  					Args:       []any{3, int64(3), "3", 3.0},
   120  					DB:         "order_db_1",
   121  					Datasource: "1.db.cluster.company.com:3306",
   122  				},
   123  				{
   124  					SQL:        fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_1`", "`order_tab_1`"),
   125  					Args:       []any{1, int64(1), "1", 1.0},
   126  					DB:         "order_db_1",
   127  					Datasource: "1.db.cluster.company.com:3306",
   128  				},
   129  			},
   130  		},
   131  		{
   132  			name: "插入多个元素, 但是不同的元素会被分配到同一个库",
   133  			builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{
   134  				{UserId: 1, OrderId: 1, Content: "1", Account: 1.0},
   135  				{UserId: 7, OrderId: 7, Content: "7", Account: 7.0},
   136  			}),
   137  			wantQs: []sharding.Query{
   138  				{
   139  					SQL:        fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?),(?,?,?,?);", "`order_db_1`", "`order_tab_1`"),
   140  					Args:       []any{1, int64(1), "1", 1.0, 7, int64(7), "7", 7.0},
   141  					DB:         "order_db_1",
   142  					Datasource: "1.db.cluster.company.com:3306",
   143  				},
   144  			},
   145  		},
   146  		{
   147  			name: "插入多个元素, 有不同的元素会被分配到同一个库表",
   148  			builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{
   149  				{UserId: 1, OrderId: 1, Content: "1", Account: 1.0},
   150  				{UserId: 7, OrderId: 7, Content: "7", Account: 7.0},
   151  				{UserId: 2, OrderId: 2, Content: "2", Account: 2.0},
   152  				{UserId: 8, OrderId: 8, Content: "8", Account: 8.0},
   153  				{UserId: 3, OrderId: 3, Content: "3", Account: 3.0},
   154  			}),
   155  			wantQs: []sharding.Query{
   156  
   157  				{
   158  					SQL:        fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?),(?,?,?,?);", "`order_db_0`", "`order_tab_2`"),
   159  					Args:       []any{2, int64(2), "2", 2.0, 8, int64(8), "8", 8.0},
   160  					DB:         "order_db_0",
   161  					Datasource: "0.db.cluster.company.com:3306",
   162  				},
   163  				{
   164  					SQL:        fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);", "`order_db_1`", "`order_tab_0`"),
   165  					Args:       []any{3, int64(3), "3", 3.0},
   166  					DB:         "order_db_1",
   167  					Datasource: "1.db.cluster.company.com:3306",
   168  				},
   169  				{
   170  					SQL:        fmt.Sprintf("INSERT INTO %s.%s(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?),(?,?,?,?);", "`order_db_1`", "`order_tab_1`"),
   171  					Args:       []any{1, int64(1), "1", 1.0, 7, int64(7), "7", 7.0},
   172  					DB:         "order_db_1",
   173  					Datasource: "1.db.cluster.company.com:3306",
   174  				},
   175  			},
   176  		},
   177  		{
   178  			name: "插入时,插入的列没有包含分库分表的列",
   179  			builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{
   180  				{OrderId: 1, Content: "1", Account: 1.0},
   181  			}).Columns([]string{"OrderId", "Content", "Account"}),
   182  			wantErr: errs.ErrInsertShardingKeyNotFound,
   183  		},
   184  		{
   185  			name: "插入时,忽略主键,但主键为shardingKey报错",
   186  			builder: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{
   187  				{OrderId: 1, Content: "1", Account: 1.0},
   188  			}).IgnorePK(),
   189  			wantErr: errs.ErrInsertShardingKeyNotFound,
   190  		},
   191  		{
   192  			name:    "values中没有元素报错",
   193  			builder: NewShardingInsert[OrderInsert](shardingDB),
   194  			wantErr: errors.New("插入0行"),
   195  		},
   196  	}
   197  	for _, tc := range testCases {
   198  		t.Run(tc.name, func(t *testing.T) {
   199  			qs, err := tc.builder.Build(context.Background())
   200  			require.Equal(t, tc.wantErr, err)
   201  			if err != nil {
   202  				return
   203  			}
   204  			assert.ElementsMatch(t, tc.wantQs, qs)
   205  		})
   206  	}
   207  }
   208  
   209  type ShardingInsertSuite struct {
   210  	suite.Suite
   211  	mock01   sqlmock.Sqlmock
   212  	mockDB01 *sql.DB
   213  	mock02   sqlmock.Sqlmock
   214  	mockDB02 *sql.DB
   215  }
   216  
   217  func (s *ShardingInsertSuite) SetupSuite() {
   218  	t := s.T()
   219  	var err error
   220  	s.mockDB01, s.mock01, err = sqlmock.New()
   221  	if err != nil {
   222  		t.Fatal(err)
   223  	}
   224  	s.mockDB02, s.mock02, err = sqlmock.New()
   225  	if err != nil {
   226  		t.Fatal(err)
   227  	}
   228  
   229  }
   230  
   231  func (s *ShardingInsertSuite) TearDownTest() {
   232  	_ = s.mockDB01.Close()
   233  	_ = s.mockDB02.Close()
   234  }
   235  
   236  func (s *ShardingInsertSuite) TestShardingInsert_Exec() {
   237  	r := model.NewMetaRegistry()
   238  	dbBase, tableBase := 2, 3
   239  	dbPattern, tablePattern, dsPattern := "order_db_%d", "order_tab_%d", "0.db.cluster.company.com:3306"
   240  	_, err := r.Register(&OrderInsert{},
   241  		model.WithTableShardingAlgorithm(&hash.Hash{
   242  			ShardingKey:  "UserId",
   243  			DBPattern:    &hash.Pattern{Name: dbPattern, Base: dbBase},
   244  			TablePattern: &hash.Pattern{Name: tablePattern, Base: tableBase},
   245  			DsPattern:    &hash.Pattern{Name: dsPattern, NotSharding: true},
   246  		}))
   247  	require.NoError(s.T(), err)
   248  
   249  	m := map[string]*masterslave.MasterSlavesDB{
   250  		"order_db_0": MasterSlavesMockDB(s.mockDB01),
   251  		"order_db_1": MasterSlavesMockDB(s.mockDB02),
   252  	}
   253  	clusterDB := cluster.NewClusterDB(m)
   254  	ds := map[string]datasource.DataSource{
   255  		"0.db.cluster.company.com:3306": clusterDB,
   256  	}
   257  	shardingDB, err := OpenDS("mysql",
   258  		shardingsource.NewShardingDataSource(ds), DBWithMetaRegistry(r))
   259  	require.NoError(s.T(), err)
   260  	testcases := []struct {
   261  		name             string
   262  		si               *ShardingInserter[OrderInsert]
   263  		mockDb           func()
   264  		wantErr          error
   265  		wantAffectedRows int64
   266  	}{
   267  		{
   268  			name: "跨表插入全部成功",
   269  			si: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{
   270  				{UserId: 1, OrderId: 1, Content: "1", Account: 1.0},
   271  				{UserId: 2, OrderId: 2, Content: "2", Account: 2.0},
   272  				{UserId: 3, OrderId: 3, Content: "3", Account: 3.0},
   273  			}),
   274  			mockDb: func() {
   275  				s.mock02.MatchExpectationsInOrder(false)
   276  				s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_1`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(1, int64(1), "1", 1.0).WillReturnResult(sqlmock.NewResult(1, 1))
   277  				s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_0`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(3, int64(3), "3", 3.0).WillReturnResult(sqlmock.NewResult(1, 1))
   278  				s.mock01.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_0`.`order_tab_2`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(2, int64(2), "2", 2.0).WillReturnResult(sqlmock.NewResult(1, 1))
   279  			},
   280  			wantAffectedRows: 3,
   281  		},
   282  		{
   283  			name: "部分插入失败",
   284  			si: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{
   285  				{UserId: 1, OrderId: 1, Content: "1", Account: 1.0},
   286  				{UserId: 2, OrderId: 2, Content: "2", Account: 2.0},
   287  				{UserId: 3, OrderId: 3, Content: "3", Account: 3.0},
   288  			}),
   289  			mockDb: func() {
   290  				s.mock02.MatchExpectationsInOrder(false)
   291  				s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_1`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(1, int64(1), "1", 1.0).WillReturnError(newMockErr("db01"))
   292  				s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_0`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(3, int64(3), "3", 3.0).WillReturnResult(sqlmock.NewResult(1, 1))
   293  				s.mock01.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_0`.`order_tab_2`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(2, int64(2), "2", 2.0).WillReturnResult(sqlmock.NewResult(1, 1))
   294  			},
   295  			wantErr: multierr.Combine(newMockErr("db01")),
   296  		},
   297  		{
   298  			name: "全部插入失败",
   299  			si: NewShardingInsert[OrderInsert](shardingDB).Values([]*OrderInsert{
   300  				{UserId: 1, OrderId: 1, Content: "1", Account: 1.0},
   301  				{UserId: 2, OrderId: 2, Content: "2", Account: 2.0},
   302  				{UserId: 3, OrderId: 3, Content: "3", Account: 3.0},
   303  			}),
   304  			mockDb: func() {
   305  				s.mock02.MatchExpectationsInOrder(false)
   306  				s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_1`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(1, int64(1), "1", 1.0).WillReturnError(newMockErr("db"))
   307  				s.mock02.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_1`.`order_tab_0`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(3, int64(3), "3", 3.0).WillReturnError(newMockErr("db"))
   308  				s.mock01.ExpectExec(regexp.QuoteMeta("INSERT INTO `order_db_0`.`order_tab_2`(`user_id`,`order_id`,`content`,`account`) VALUES(?,?,?,?);")).WithArgs(2, int64(2), "2", 2.0).WillReturnError(newMockErr("db"))
   309  			},
   310  			wantErr: multierr.Combine(newMockErr("db"), newMockErr("db"), newMockErr("db")),
   311  		},
   312  	}
   313  	for _, tc := range testcases {
   314  		s.T().Run(tc.name, func(t *testing.T) {
   315  			tc.mockDb()
   316  			res := tc.si.Exec(context.Background())
   317  			require.Equal(t, tc.wantErr, res.Err())
   318  			if res.Err() != nil {
   319  				return
   320  			}
   321  
   322  			affectRows, err := res.RowsAffected()
   323  			require.NoError(t, err)
   324  			assert.Equal(t, tc.wantAffectedRows, affectRows)
   325  		})
   326  	}
   327  }
   328  
   329  func TestShardingInsertSuite(t *testing.T) {
   330  	suite.Run(t, &ShardingInsertSuite{})
   331  }
   332  
   333  func MasterSlavesMockDB(db *sql.DB) *masterslave.MasterSlavesDB {
   334  	return masterslave.NewMasterSlavesDB(db)
   335  }