github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/sortmerger/merger_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sortmerger
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"fmt"
    22  	"testing"
    23  
    24  	_ "github.com/mattn/go-sqlite3"
    25  
    26  	"github.com/ecodeclub/eorm/internal/rows"
    27  
    28  	"github.com/DATA-DOG/go-sqlmock"
    29  	"github.com/ecodeclub/eorm/internal/merger/internal/errs"
    30  	"github.com/stretchr/testify/assert"
    31  	"github.com/stretchr/testify/require"
    32  	"github.com/stretchr/testify/suite"
    33  	"go.uber.org/multierr"
    34  )
    35  
    36  var (
    37  	nextMockErr error = errors.New("rows: MockNextErr")
    38  )
    39  
    40  func newCloseMockErr(dbName string) error {
    41  	return fmt.Errorf("rows: %s MockCloseErr", dbName)
    42  }
    43  
    44  type MergerSuite struct {
    45  	suite.Suite
    46  	mockDB01 *sql.DB
    47  	mock01   sqlmock.Sqlmock
    48  	mockDB02 *sql.DB
    49  	mock02   sqlmock.Sqlmock
    50  	mockDB03 *sql.DB
    51  	mock03   sqlmock.Sqlmock
    52  	mockDB04 *sql.DB
    53  	mock04   sqlmock.Sqlmock
    54  }
    55  
    56  func (ms *MergerSuite) SetupTest() {
    57  	t := ms.T()
    58  	ms.initMock(t)
    59  }
    60  
    61  func (ms *MergerSuite) TearDownTest() {
    62  	_ = ms.mockDB01.Close()
    63  	_ = ms.mockDB02.Close()
    64  	_ = ms.mockDB03.Close()
    65  	_ = ms.mockDB04.Close()
    66  }
    67  
    68  func (ms *MergerSuite) initMock(t *testing.T) {
    69  	var err error
    70  	ms.mockDB01, ms.mock01, err = sqlmock.New()
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	ms.mockDB02, ms.mock02, err = sqlmock.New()
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	ms.mockDB03, ms.mock03, err = sqlmock.New()
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  	ms.mockDB04, ms.mock04, err = sqlmock.New()
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  }
    87  
    88  func (ms *MergerSuite) TestMerger_New() {
    89  	testcases := []struct {
    90  		name           string
    91  		wantErr        error
    92  		wantSortColumn func() []SortColumn
    93  		sortCols       []SortColumn
    94  	}{
    95  		{
    96  			name: "正常案例",
    97  			wantSortColumn: func() []SortColumn {
    98  				sortCol := NewSortColumn("id", ASC)
    99  				return []SortColumn{sortCol}
   100  			},
   101  			sortCols: []SortColumn{
   102  				NewSortColumn("id", ASC),
   103  			},
   104  		},
   105  		{
   106  			name:     "空的排序列表",
   107  			sortCols: []SortColumn{},
   108  			wantErr:  errs.ErrEmptySortColumns,
   109  		},
   110  		{
   111  			name: "排序列重复",
   112  			sortCols: []SortColumn{
   113  				NewSortColumn("id", ASC),
   114  				NewSortColumn("id", DESC),
   115  			},
   116  			wantErr: errs.NewRepeatSortColumn("id"),
   117  		},
   118  	}
   119  	for _, tc := range testcases {
   120  		ms.T().Run(tc.name, func(t *testing.T) {
   121  			mer, err := NewMerger(tc.sortCols...)
   122  			assert.Equal(t, tc.wantErr, err)
   123  			if err != nil {
   124  				return
   125  			}
   126  			assert.Equal(t, tc.wantSortColumn(), mer.sortColumns.columns)
   127  		})
   128  	}
   129  }
   130  
   131  func (ms *MergerSuite) TestMerger_Merge() {
   132  	testcases := []struct {
   133  		name    string
   134  		merger  func() (*Merger, error)
   135  		ctx     func() (context.Context, context.CancelFunc)
   136  		wantErr error
   137  		sqlRows func() []rows.Rows
   138  	}{
   139  		{
   140  			name: "sqlRows字段不同",
   141  			merger: func() (*Merger, error) {
   142  				return NewMerger(NewSortColumn("id", ASC))
   143  			},
   144  			ctx: func() (context.Context, context.CancelFunc) {
   145  				return context.WithCancel(context.Background())
   146  			},
   147  			sqlRows: func() []rows.Rows {
   148  				query := "SELECT * FROM `t1`"
   149  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn"))
   150  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "email"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn"))
   151  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02}
   152  				rowsList := make([]rows.Rows, 0, len(dbs))
   153  				for _, db := range dbs {
   154  					row, err := db.QueryContext(context.Background(), query)
   155  					require.NoError(ms.T(), err)
   156  					rowsList = append(rowsList, row)
   157  				}
   158  				return rowsList
   159  			},
   160  			wantErr: errs.ErrMergerRowsDiff,
   161  		},
   162  		{
   163  			name: "sqlRows字段不同_少一个字段",
   164  			merger: func() (*Merger, error) {
   165  				return NewMerger(NewSortColumn("id", ASC))
   166  			},
   167  			ctx: func() (context.Context, context.CancelFunc) {
   168  				return context.WithCancel(context.Background())
   169  			},
   170  			sqlRows: func() []rows.Rows {
   171  				query := "SELECT * FROM `t1`"
   172  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn"))
   173  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(3, "alex").AddRow(4, "x"))
   174  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02}
   175  				rowsList := make([]rows.Rows, 0, len(dbs))
   176  				for _, db := range dbs {
   177  					row, err := db.QueryContext(context.Background(), query)
   178  					require.NoError(ms.T(), err)
   179  					rowsList = append(rowsList, row)
   180  				}
   181  				return rowsList
   182  			},
   183  			wantErr: errs.ErrMergerRowsDiff,
   184  		},
   185  		{
   186  			name: "超时",
   187  			merger: func() (*Merger, error) {
   188  				return NewMerger(NewSortColumn("id", ASC))
   189  			},
   190  			ctx: func() (context.Context, context.CancelFunc) {
   191  				ctx, cancel := context.WithTimeout(context.Background(), 0)
   192  				return ctx, cancel
   193  			},
   194  			wantErr: context.DeadlineExceeded,
   195  			sqlRows: func() []rows.Rows {
   196  				query := "SELECT * FROM `t1`;"
   197  				cols := []string{"id"}
   198  				res := make([]rows.Rows, 0, 1)
   199  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1))
   200  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   201  				res = append(res, rows)
   202  				return res
   203  			},
   204  		},
   205  		{
   206  			name: "sqlRows列表为空",
   207  			ctx: func() (context.Context, context.CancelFunc) {
   208  				return context.WithCancel(context.Background())
   209  			},
   210  			merger: func() (*Merger, error) {
   211  				return NewMerger(NewSortColumn("id", ASC))
   212  			},
   213  			sqlRows: func() []rows.Rows {
   214  				return []rows.Rows{}
   215  			},
   216  			wantErr: errs.ErrMergerEmptyRows,
   217  		},
   218  		{
   219  			name: "sqlRows列表有nil",
   220  			merger: func() (*Merger, error) {
   221  				return NewMerger(NewSortColumn("id", ASC))
   222  			},
   223  			ctx: func() (context.Context, context.CancelFunc) {
   224  				return context.WithCancel(context.Background())
   225  			},
   226  			sqlRows: func() []rows.Rows {
   227  				return []rows.Rows{nil}
   228  			},
   229  			wantErr: errs.ErrMergerRowsIsNull,
   230  		},
   231  		{
   232  			name: "数据库列集: id;排序列集: age",
   233  			merger: func() (*Merger, error) {
   234  				return NewMerger(NewSortColumn("age", ASC))
   235  			},
   236  			sqlRows: func() []rows.Rows {
   237  				query := "SELECT * FROM `t1`;"
   238  				cols := []string{"id"}
   239  				res := make([]rows.Rows, 0, 1)
   240  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1))
   241  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   242  				res = append(res, rows)
   243  				return res
   244  			},
   245  			ctx: func() (context.Context, context.CancelFunc) {
   246  				return context.WithCancel(context.Background())
   247  			},
   248  			wantErr: errs.NewInvalidSortColumn("age"),
   249  		},
   250  		{
   251  			name: "数据库列集: id;排序列集: id,age",
   252  			merger: func() (*Merger, error) {
   253  				return NewMerger(NewSortColumn("id", ASC), NewSortColumn("age", ASC))
   254  			},
   255  			sqlRows: func() []rows.Rows {
   256  				query := "SELECT * FROM `t1`;"
   257  				cols := []string{"id"}
   258  				res := make([]rows.Rows, 0, 1)
   259  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1))
   260  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   261  				res = append(res, rows)
   262  				return res
   263  			},
   264  			ctx: func() (context.Context, context.CancelFunc) {
   265  				return context.WithCancel(context.Background())
   266  			},
   267  			wantErr: errs.NewInvalidSortColumn("age"),
   268  		},
   269  		{
   270  			name: "数据库列集: id,name,address;排序列集: age",
   271  			merger: func() (*Merger, error) {
   272  				return NewMerger(NewSortColumn("age", ASC))
   273  			},
   274  			sqlRows: func() []rows.Rows {
   275  				query := "SELECT * FROM `t1`;"
   276  				cols := []string{"id", "name", "address"}
   277  				res := make([]rows.Rows, 0, 1)
   278  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh"))
   279  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   280  				res = append(res, rows)
   281  				return res
   282  			},
   283  			ctx: func() (context.Context, context.CancelFunc) {
   284  				return context.WithCancel(context.Background())
   285  			},
   286  			wantErr: errs.NewInvalidSortColumn("age"),
   287  		},
   288  		{
   289  			name: "数据库列集: id,name,address;排序列集: id,age,name",
   290  			merger: func() (*Merger, error) {
   291  				return NewMerger(NewSortColumn("id", ASC), NewSortColumn("age", ASC), NewSortColumn("name", ASC))
   292  			},
   293  			sqlRows: func() []rows.Rows {
   294  				query := "SELECT * FROM `t1`;"
   295  				cols := []string{"id", "name", "address"}
   296  				res := make([]rows.Rows, 0, 1)
   297  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh"))
   298  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   299  				res = append(res, rows)
   300  				return res
   301  			},
   302  			ctx: func() (context.Context, context.CancelFunc) {
   303  				return context.WithCancel(context.Background())
   304  			},
   305  			wantErr: errs.NewInvalidSortColumn("age"),
   306  		},
   307  		{
   308  			name: "数据库列集: id,name,address;排序列集: id,name,age",
   309  			merger: func() (*Merger, error) {
   310  				return NewMerger(NewSortColumn("id", ASC), NewSortColumn("name", ASC), NewSortColumn("age", ASC))
   311  			},
   312  			sqlRows: func() []rows.Rows {
   313  				query := "SELECT * FROM `t1`;"
   314  				cols := []string{"id", "name", "address"}
   315  				res := make([]rows.Rows, 0, 1)
   316  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh"))
   317  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   318  				res = append(res, rows)
   319  				return res
   320  			},
   321  			ctx: func() (context.Context, context.CancelFunc) {
   322  				return context.WithCancel(context.Background())
   323  			},
   324  			wantErr: errs.NewInvalidSortColumn("age"),
   325  		},
   326  		{
   327  			name: "数据库列集: id ;排序列集: id",
   328  			merger: func() (*Merger, error) {
   329  				return NewMerger(NewSortColumn("id", ASC))
   330  			},
   331  			sqlRows: func() []rows.Rows {
   332  				query := "SELECT * FROM `t1`;"
   333  				cols := []string{"id"}
   334  				res := make([]rows.Rows, 0, 1)
   335  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1))
   336  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   337  				res = append(res, rows)
   338  				return res
   339  			},
   340  			ctx: func() (context.Context, context.CancelFunc) {
   341  				return context.WithCancel(context.Background())
   342  			},
   343  		},
   344  		{
   345  			name: "数据库列集: id,age;排序列集: id,age",
   346  			merger: func() (*Merger, error) {
   347  				return NewMerger(NewSortColumn("id", ASC), NewSortColumn("age", ASC))
   348  			},
   349  			sqlRows: func() []rows.Rows {
   350  				query := "SELECT * FROM `t1`;"
   351  				cols := []string{"id", "age"}
   352  				res := make([]rows.Rows, 0, 1)
   353  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 18))
   354  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   355  				res = append(res, rows)
   356  				return res
   357  			},
   358  			ctx: func() (context.Context, context.CancelFunc) {
   359  				return context.WithCancel(context.Background())
   360  			},
   361  		},
   362  		{
   363  			name: "数据库列集: id,name,address;排序列集: id,name",
   364  			merger: func() (*Merger, error) {
   365  				return NewMerger(NewSortColumn("id", ASC), NewSortColumn("name", ASC))
   366  			},
   367  			sqlRows: func() []rows.Rows {
   368  				query := "SELECT * FROM `t1`;"
   369  				cols := []string{"id", "name", "address"}
   370  				res := make([]rows.Rows, 0, 1)
   371  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh"))
   372  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   373  				res = append(res, rows)
   374  				return res
   375  			},
   376  			ctx: func() (context.Context, context.CancelFunc) {
   377  				return context.WithCancel(context.Background())
   378  			},
   379  		},
   380  		{
   381  			name: "初始化Rows错误",
   382  			merger: func() (*Merger, error) {
   383  				return NewMerger(NewSortColumn("id", ASC))
   384  			},
   385  			sqlRows: func() []rows.Rows {
   386  				query := "SELECT * FROM `t1`;"
   387  				cols := []string{"id", "name", "address"}
   388  				res := make([]rows.Rows, 0, 1)
   389  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh").RowError(0, nextMockErr))
   390  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   391  				res = append(res, rows)
   392  				return res
   393  			},
   394  			wantErr: nextMockErr,
   395  			ctx: func() (context.Context, context.CancelFunc) {
   396  				return context.WithCancel(context.Background())
   397  			},
   398  		},
   399  	}
   400  	for _, tc := range testcases {
   401  		ms.T().Run(tc.name, func(t *testing.T) {
   402  			merger, err := tc.merger()
   403  			require.NoError(ms.T(), err)
   404  			ctx, cancel := tc.ctx()
   405  			rows, err := merger.Merge(ctx, tc.sqlRows())
   406  			cancel()
   407  			assert.Equal(t, tc.wantErr, err)
   408  			if err != nil {
   409  				return
   410  			}
   411  			require.NotNil(t, rows)
   412  		})
   413  
   414  	}
   415  }
   416  
   417  func (ms *MergerSuite) TestRows_NextAndScan() {
   418  	testCases := []struct {
   419  		name        string
   420  		sqlRows     func() []rows.Rows
   421  		wantVal     []TestModel
   422  		sortColumns []SortColumn
   423  		wantErr     error
   424  	}{
   425  		{
   426  			name: "完全交叉读,sqlRows返回行数相同",
   427  			sqlRows: func() []rows.Rows {
   428  				cols := []string{"id", "name", "address"}
   429  				query := "SELECT * FROM `t1`"
   430  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn"))
   431  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn"))
   432  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn"))
   433  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   434  				rowsList := make([]rows.Rows, 0, len(dbs))
   435  				for _, db := range dbs {
   436  					row, err := db.QueryContext(context.Background(), query)
   437  					require.NoError(ms.T(), err)
   438  					rowsList = append(rowsList, row)
   439  				}
   440  				return rowsList
   441  			},
   442  			wantVal: []TestModel{
   443  				{
   444  					Id:      1,
   445  					Name:    "abex",
   446  					Address: "cn",
   447  				},
   448  				{
   449  					Id:      2,
   450  					Name:    "a",
   451  					Address: "cn",
   452  				},
   453  				{
   454  					Id:      3,
   455  					Name:    "alex",
   456  					Address: "cn",
   457  				},
   458  				{
   459  					Id:      4,
   460  					Name:    "x",
   461  					Address: "cn",
   462  				},
   463  				{
   464  					Id:      5,
   465  					Name:    "bruce",
   466  					Address: "cn",
   467  				},
   468  				{
   469  					Id:      7,
   470  					Name:    "b",
   471  					Address: "cn",
   472  				},
   473  			},
   474  			sortColumns: []SortColumn{
   475  				NewSortColumn("id", ASC),
   476  			},
   477  		},
   478  		{
   479  			name: "完全交叉读,sqlRows返回行数部分不同",
   480  			sqlRows: func() []rows.Rows {
   481  				cols := []string{"id", "name", "address"}
   482  				query := "SELECT * FROM `t1`"
   483  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(6, "x", "cn").AddRow(1, "x", "cn"))
   484  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(8, "alex", "cn").AddRow(4, "bruce", "cn"))
   485  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(9, "a", "cn").AddRow(5, "abex", "cn"))
   486  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   487  				rowsList := make([]rows.Rows, 0, len(dbs))
   488  				for _, db := range dbs {
   489  					row, err := db.QueryContext(context.Background(), query)
   490  					require.NoError(ms.T(), err)
   491  					rowsList = append(rowsList, row)
   492  				}
   493  				return rowsList
   494  			},
   495  			wantVal: []TestModel{
   496  				{
   497  					Id:      9,
   498  					Name:    "a",
   499  					Address: "cn",
   500  				},
   501  				{
   502  					Id:      8,
   503  					Name:    "alex",
   504  					Address: "cn",
   505  				},
   506  				{
   507  					Id:      7,
   508  					Name:    "b",
   509  					Address: "cn",
   510  				},
   511  				{
   512  					Id:      6,
   513  					Name:    "x",
   514  					Address: "cn",
   515  				},
   516  				{
   517  					Id:      5,
   518  					Name:    "abex",
   519  					Address: "cn",
   520  				},
   521  				{
   522  					Id:      4,
   523  					Name:    "bruce",
   524  					Address: "cn",
   525  				},
   526  				{
   527  					Id:      1,
   528  					Name:    "x",
   529  					Address: "cn",
   530  				},
   531  			},
   532  			sortColumns: []SortColumn{
   533  				NewSortColumn("id", DESC),
   534  			},
   535  		},
   536  		{
   537  			// 包含一个sqlRows返回的行数为0,在前面
   538  			name: "完全交叉读,sqlRows返回行数完全不同",
   539  			sqlRows: func() []rows.Rows {
   540  				cols := []string{"id", "name", "address"}
   541  				query := "SELECT * FROM `t1`"
   542  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "c", "cn").AddRow(2, "bruce", "cn").AddRow(2, "zwl", "cn"))
   543  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "alex", "cn").AddRow(3, "x", "cn"))
   544  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "c", "cn").AddRow(3, "b", "cn").AddRow(5, "c", "cn").AddRow(7, "c", "cn"))
   545  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   546  				dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03}
   547  				rowsList := make([]rows.Rows, 0, len(dbs))
   548  				for _, db := range dbs {
   549  					row, err := db.QueryContext(context.Background(), query)
   550  					require.NoError(ms.T(), err)
   551  					rowsList = append(rowsList, row)
   552  				}
   553  				return rowsList
   554  			},
   555  			wantVal: []TestModel{
   556  				{
   557  					Id:      1,
   558  					Name:    "alex",
   559  					Address: "cn",
   560  				},
   561  				{
   562  					Id:      1,
   563  					Name:    "c",
   564  					Address: "cn",
   565  				},
   566  				{
   567  					Id:      2,
   568  					Name:    "bruce",
   569  					Address: "cn",
   570  				},
   571  				{
   572  					Id:      2,
   573  					Name:    "c",
   574  					Address: "cn",
   575  				},
   576  				{
   577  					Id:      2,
   578  					Name:    "zwl",
   579  					Address: "cn",
   580  				},
   581  				{
   582  					Id:      3,
   583  					Name:    "b",
   584  					Address: "cn",
   585  				},
   586  				{
   587  					Id:      3,
   588  					Name:    "x",
   589  					Address: "cn",
   590  				},
   591  				{
   592  					Id:      5,
   593  					Name:    "c",
   594  					Address: "cn",
   595  				},
   596  				{
   597  					Id:      7,
   598  					Name:    "c",
   599  					Address: "cn",
   600  				},
   601  			},
   602  			sortColumns: []SortColumn{
   603  				NewSortColumn("id", ASC),
   604  				NewSortColumn("name", ASC),
   605  			},
   606  		},
   607  		{
   608  			name: "部分交叉读,sqlRows返回行数相同",
   609  			sqlRows: func() []rows.Rows {
   610  				cols := []string{"id", "name", "address"}
   611  				query := "SELECT * FROM `t1`"
   612  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn"))
   613  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(5, "bruce", "cn"))
   614  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(7, "b", "cn"))
   615  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   616  				rowsList := make([]rows.Rows, 0, len(dbs))
   617  				for _, db := range dbs {
   618  					row, err := db.QueryContext(context.Background(), query)
   619  					require.NoError(ms.T(), err)
   620  					rowsList = append(rowsList, row)
   621  				}
   622  				return rowsList
   623  			},
   624  			wantVal: []TestModel{
   625  				{
   626  					Id:      1,
   627  					Name:    "abex",
   628  					Address: "cn",
   629  				},
   630  				{
   631  					Id:      2,
   632  					Name:    "a",
   633  					Address: "cn",
   634  				},
   635  				{
   636  					Id:      3,
   637  					Name:    "alex",
   638  					Address: "cn",
   639  				},
   640  				{
   641  					Id:      4,
   642  					Name:    "x",
   643  					Address: "cn",
   644  				},
   645  				{
   646  					Id:      5,
   647  					Name:    "bruce",
   648  					Address: "cn",
   649  				},
   650  				{
   651  					Id:      7,
   652  					Name:    "b",
   653  					Address: "cn",
   654  				},
   655  			},
   656  			sortColumns: []SortColumn{
   657  				NewSortColumn("id", ASC),
   658  			},
   659  		},
   660  		{
   661  			name: "部分交叉读,sqlRows返回行数部分相同",
   662  			sqlRows: func() []rows.Rows {
   663  				cols := []string{"id", "name", "address"}
   664  				query := "SELECT * FROM `t1`"
   665  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn"))
   666  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn"))
   667  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(8, "b", "cn"))
   668  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   669  				rowsList := make([]rows.Rows, 0, len(dbs))
   670  				for _, db := range dbs {
   671  					row, err := db.QueryContext(context.Background(), query)
   672  					require.NoError(ms.T(), err)
   673  					rowsList = append(rowsList, row)
   674  				}
   675  				return rowsList
   676  			},
   677  			wantVal: []TestModel{
   678  				{
   679  					Id:      1,
   680  					Name:    "abex",
   681  					Address: "cn",
   682  				},
   683  				{
   684  					Id:      2,
   685  					Name:    "a",
   686  					Address: "cn",
   687  				},
   688  				{
   689  					Id:      3,
   690  					Name:    "alex",
   691  					Address: "cn",
   692  				},
   693  				{
   694  					Id:      4,
   695  					Name:    "x",
   696  					Address: "cn",
   697  				},
   698  				{
   699  					Id:      5,
   700  					Name:    "bruce",
   701  					Address: "cn",
   702  				},
   703  				{
   704  					Id:      7,
   705  					Name:    "b",
   706  					Address: "cn",
   707  				},
   708  				{
   709  					Id:      8,
   710  					Name:    "b",
   711  					Address: "cn",
   712  				},
   713  			},
   714  			sortColumns: []SortColumn{
   715  				NewSortColumn("id", ASC),
   716  			},
   717  		},
   718  		{
   719  			// 包含一个sqlRows返回的行数为0,在中间
   720  			name: "部分交叉读,sqlRows返回行数完全不同",
   721  			sqlRows: func() []rows.Rows {
   722  				cols := []string{"id", "name", "address"}
   723  				query := "SELECT * FROM `t1`"
   724  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn"))
   725  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn"))
   726  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn"))
   727  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   728  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB04, ms.mockDB03}
   729  				rowsList := make([]rows.Rows, 0, len(dbs))
   730  				for _, db := range dbs {
   731  					row, err := db.QueryContext(context.Background(), query)
   732  					require.NoError(ms.T(), err)
   733  					rowsList = append(rowsList, row)
   734  				}
   735  				return rowsList
   736  			},
   737  			wantVal: []TestModel{
   738  				{
   739  					Id:      1,
   740  					Name:    "abex",
   741  					Address: "cn",
   742  				},
   743  				{
   744  					Id:      2,
   745  					Name:    "a",
   746  					Address: "cn",
   747  				},
   748  				{
   749  					Id:      3,
   750  					Name:    "alex",
   751  					Address: "cn",
   752  				},
   753  				{
   754  					Id:      4,
   755  					Name:    "x",
   756  					Address: "cn",
   757  				},
   758  				{
   759  					Id:      5,
   760  					Name:    "bruce",
   761  					Address: "cn",
   762  				},
   763  				{
   764  					Id:      7,
   765  					Name:    "b",
   766  					Address: "cn",
   767  				},
   768  			},
   769  			sortColumns: []SortColumn{
   770  				NewSortColumn("id", ASC),
   771  			},
   772  		},
   773  		{
   774  			name: "顺序读,sqlRows返回行数相同",
   775  			sqlRows: func() []rows.Rows {
   776  				cols := []string{"id", "name", "address"}
   777  				query := "SELECT * FROM `t1`"
   778  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn"))
   779  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn"))
   780  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn").AddRow(7, "b", "cn"))
   781  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   782  				rowsList := make([]rows.Rows, 0, len(dbs))
   783  				for _, db := range dbs {
   784  					row, err := db.QueryContext(context.Background(), query)
   785  					require.NoError(ms.T(), err)
   786  					rowsList = append(rowsList, row)
   787  				}
   788  				return rowsList
   789  			},
   790  			wantVal: []TestModel{
   791  				{
   792  					Id:      1,
   793  					Name:    "abex",
   794  					Address: "cn",
   795  				},
   796  				{
   797  					Id:      2,
   798  					Name:    "a",
   799  					Address: "cn",
   800  				},
   801  				{
   802  					Id:      3,
   803  					Name:    "alex",
   804  					Address: "cn",
   805  				},
   806  				{
   807  					Id:      4,
   808  					Name:    "x",
   809  					Address: "cn",
   810  				},
   811  				{
   812  					Id:      5,
   813  					Name:    "bruce",
   814  					Address: "cn",
   815  				},
   816  				{
   817  					Id:      7,
   818  					Name:    "b",
   819  					Address: "cn",
   820  				},
   821  			},
   822  			sortColumns: []SortColumn{
   823  				NewSortColumn("id", ASC),
   824  			},
   825  		},
   826  		{
   827  			name: "顺序读,sqlRows返回行数部分不同",
   828  			sqlRows: func() []rows.Rows {
   829  				cols := []string{"id", "name", "address"}
   830  				query := "SELECT * FROM `t1`"
   831  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn"))
   832  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn"))
   833  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn"))
   834  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   835  				rowsList := make([]rows.Rows, 0, len(dbs))
   836  				for _, db := range dbs {
   837  					row, err := db.QueryContext(context.Background(), query)
   838  					require.NoError(ms.T(), err)
   839  					rowsList = append(rowsList, row)
   840  				}
   841  				return rowsList
   842  			},
   843  
   844  			wantVal: []TestModel{
   845  				{
   846  					Id:      1,
   847  					Name:    "abex",
   848  					Address: "cn",
   849  				},
   850  				{
   851  					Id:      2,
   852  					Name:    "a",
   853  					Address: "cn",
   854  				},
   855  				{
   856  					Id:      3,
   857  					Name:    "alex",
   858  					Address: "cn",
   859  				},
   860  				{
   861  					Id:      4,
   862  					Name:    "x",
   863  					Address: "cn",
   864  				},
   865  				{
   866  					Id:      5,
   867  					Name:    "bruce",
   868  					Address: "cn",
   869  				},
   870  			},
   871  			sortColumns: []SortColumn{
   872  				NewSortColumn("id", ASC),
   873  			},
   874  		},
   875  		{
   876  			// 包含一个sqlRows返回的行数为0,在后面
   877  			name: "顺序读,sqlRows返回行数完全不同",
   878  			sqlRows: func() []rows.Rows {
   879  				cols := []string{"id", "name", "address"}
   880  				query := "SELECT * FROM `t1`"
   881  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn"))
   882  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(3, "alex", "cn"))
   883  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(5, "bruce", "cn").AddRow(7, "b", "cn"))
   884  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   885  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04}
   886  				rowsList := make([]rows.Rows, 0, len(dbs))
   887  				for _, db := range dbs {
   888  					row, err := db.QueryContext(context.Background(), query)
   889  					require.NoError(ms.T(), err)
   890  					rowsList = append(rowsList, row)
   891  				}
   892  				return rowsList
   893  			},
   894  			wantVal: []TestModel{
   895  				{
   896  					Id:      1,
   897  					Name:    "abex",
   898  					Address: "cn",
   899  				},
   900  				{
   901  					Id:      2,
   902  					Name:    "a",
   903  					Address: "cn",
   904  				},
   905  				{
   906  					Id:      3,
   907  					Name:    "alex",
   908  					Address: "cn",
   909  				},
   910  				{
   911  					Id:      4,
   912  					Name:    "x",
   913  					Address: "cn",
   914  				},
   915  				{
   916  					Id:      5,
   917  					Name:    "bruce",
   918  					Address: "cn",
   919  				},
   920  				{
   921  					Id:      7,
   922  					Name:    "b",
   923  					Address: "cn",
   924  				},
   925  			},
   926  			sortColumns: []SortColumn{
   927  				NewSortColumn("id", ASC),
   928  			},
   929  		},
   930  
   931  		{
   932  			name: "所有sqlRows返回的行数均为空",
   933  			sqlRows: func() []rows.Rows {
   934  				cols := []string{"id", "name", "address"}
   935  				query := "SELECT * FROM `t1`"
   936  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   937  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   938  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   939  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   940  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04}
   941  				rowsList := make([]rows.Rows, 0, len(dbs))
   942  				for _, db := range dbs {
   943  					row, err := db.QueryContext(context.Background(), query)
   944  					require.NoError(ms.T(), err)
   945  					rowsList = append(rowsList, row)
   946  				}
   947  				return rowsList
   948  			},
   949  			wantVal: []TestModel{},
   950  			sortColumns: []SortColumn{
   951  				NewSortColumn("id", ASC),
   952  				NewSortColumn("name", ASC),
   953  			},
   954  		},
   955  		{
   956  			name: "排序列返回的顺序和数据库里的字段顺序不一致",
   957  			sqlRows: func() []rows.Rows {
   958  				cols := []string{"id", "name", "address"}
   959  				query := "SELECT * FROM `t1`"
   960  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "hz").AddRow(3, "b", "hz").AddRow(2, "b", "cs"))
   961  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "a", "cs").AddRow(1, "a", "cs").AddRow(3, "e", "cn"))
   962  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "d", "hm").AddRow(5, "k", "xx").AddRow(4, "k", "xz"))
   963  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   964  				rowsList := make([]rows.Rows, 0, len(dbs))
   965  				for _, db := range dbs {
   966  					row, err := db.QueryContext(context.Background(), query)
   967  					require.NoError(ms.T(), err)
   968  					rowsList = append(rowsList, row)
   969  				}
   970  				return rowsList
   971  			},
   972  			wantVal: []TestModel{
   973  				{
   974  					Id:      3,
   975  					Name:    "a",
   976  					Address: "cs",
   977  				},
   978  				{
   979  					Id:      2,
   980  					Name:    "a",
   981  					Address: "hz",
   982  				},
   983  				{
   984  					Id:      1,
   985  					Name:    "a",
   986  					Address: "cs",
   987  				},
   988  				{
   989  					Id:      3,
   990  					Name:    "b",
   991  					Address: "hz",
   992  				},
   993  				{
   994  					Id:      2,
   995  					Name:    "b",
   996  					Address: "cs",
   997  				},
   998  				{
   999  					Id:      2,
  1000  					Name:    "d",
  1001  					Address: "hm",
  1002  				},
  1003  				{
  1004  					Id:      3,
  1005  					Name:    "e",
  1006  					Address: "cn",
  1007  				},
  1008  				{
  1009  					Id:      5,
  1010  					Name:    "k",
  1011  					Address: "xx",
  1012  				},
  1013  				{
  1014  					Id:      4,
  1015  					Name:    "k",
  1016  					Address: "xz",
  1017  				},
  1018  			},
  1019  			sortColumns: []SortColumn{
  1020  				NewSortColumn("name", ASC),
  1021  				NewSortColumn("id", DESC),
  1022  			},
  1023  		},
  1024  	}
  1025  	for _, tc := range testCases {
  1026  		ms.T().Run(tc.name, func(t *testing.T) {
  1027  			merger, err := NewMerger(tc.sortColumns...)
  1028  			require.NoError(t, err)
  1029  			rows, err := merger.Merge(context.Background(), tc.sqlRows())
  1030  			require.NoError(t, err)
  1031  			res := make([]TestModel, 0, len(tc.wantVal))
  1032  			for rows.Next() {
  1033  				t := TestModel{}
  1034  				err := rows.Scan(&t.Id, &t.Name, &t.Address)
  1035  				require.NoError(ms.T(), err)
  1036  				res = append(res, t)
  1037  			}
  1038  			require.True(t, rows.(*Rows).closed)
  1039  			assert.NoError(t, rows.Err())
  1040  			assert.Equal(t, tc.wantVal, res)
  1041  		})
  1042  	}
  1043  
  1044  }
  1045  
  1046  func (ms *MergerSuite) TestRows_Columns() {
  1047  	cols := []string{"id"}
  1048  	query := "SELECT * FROM `t1`"
  1049  	ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1"))
  1050  	ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2"))
  1051  	ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4"))
  1052  	merger, err := NewMerger(NewSortColumn("id", DESC))
  1053  	require.NoError(ms.T(), err)
  1054  	dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
  1055  	rowsList := make([]rows.Rows, 0, len(dbs))
  1056  	for _, db := range dbs {
  1057  		row, err := db.QueryContext(context.Background(), query)
  1058  		require.NoError(ms.T(), err)
  1059  		rowsList = append(rowsList, row)
  1060  	}
  1061  
  1062  	rows, err := merger.Merge(context.Background(), rowsList)
  1063  	require.NoError(ms.T(), err)
  1064  	ms.T().Run("Next没有迭代完", func(t *testing.T) {
  1065  		for rows.Next() {
  1066  			columns, err := rows.Columns()
  1067  			require.NoError(t, err)
  1068  			assert.Equal(t, cols, columns)
  1069  		}
  1070  		require.NoError(t, rows.Err())
  1071  	})
  1072  	ms.T().Run("Next迭代完", func(t *testing.T) {
  1073  		require.False(t, rows.Next())
  1074  		require.NoError(t, rows.Err())
  1075  		_, err := rows.Columns()
  1076  		assert.Equal(t, errs.ErrMergerRowsClosed, err)
  1077  
  1078  	})
  1079  
  1080  }
  1081  
  1082  func (ms *MergerSuite) TestRows_Close() {
  1083  	cols := []string{"id"}
  1084  	query := "SELECT * FROM `t1`"
  1085  	ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1"))
  1086  	ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").CloseError(newCloseMockErr("db02")))
  1087  	ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03")))
  1088  	merger, err := NewMerger(NewSortColumn("id", DESC))
  1089  	require.NoError(ms.T(), err)
  1090  	dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
  1091  	rowsList := make([]rows.Rows, 0, len(dbs))
  1092  	for _, db := range dbs {
  1093  		row, err := db.QueryContext(context.Background(), query)
  1094  		require.NoError(ms.T(), err)
  1095  		rowsList = append(rowsList, row)
  1096  	}
  1097  	rows, err := merger.Merge(context.Background(), rowsList)
  1098  	require.NoError(ms.T(), err)
  1099  	// 判断当前是可以正常读取的
  1100  	require.True(ms.T(), rows.Next())
  1101  	var id int
  1102  	err = rows.Scan(&id)
  1103  	require.NoError(ms.T(), err)
  1104  	err = rows.Close()
  1105  	ms.T().Run("close返回multierror", func(t *testing.T) {
  1106  		assert.Equal(ms.T(), multierr.Combine(newCloseMockErr("db02"), newCloseMockErr("db03")), err)
  1107  	})
  1108  	ms.T().Run("close之后Next返回false", func(t *testing.T) {
  1109  		for i := 0; i < len(rowsList); i++ {
  1110  			require.False(ms.T(), rowsList[i].Next())
  1111  		}
  1112  		require.False(ms.T(), rows.Next())
  1113  	})
  1114  	ms.T().Run("close之后Scan返回迭代过程中的错误", func(t *testing.T) {
  1115  		var id int
  1116  		err := rows.Scan(&id)
  1117  		assert.Equal(t, errs.ErrMergerRowsClosed, err)
  1118  	})
  1119  	ms.T().Run("close之后调用Columns方法返回错误", func(t *testing.T) {
  1120  		_, err := rows.Columns()
  1121  		require.Error(t, err)
  1122  	})
  1123  	ms.T().Run("close多次是等效的", func(t *testing.T) {
  1124  		for i := 0; i < 4; i++ {
  1125  			err = rows.Close()
  1126  			require.NoError(t, err)
  1127  		}
  1128  	})
  1129  }
  1130  
  1131  // 测试Next迭代过程中遇到错误
  1132  func (ms *MergerSuite) TestRows_NextAndErr() {
  1133  	testcases := []struct {
  1134  		name        string
  1135  		rowsList    func() []rows.Rows
  1136  		wantErr     error
  1137  		sortColumns []SortColumn
  1138  	}{
  1139  		{
  1140  			name: "sqlRows列表中有一个返回error",
  1141  			rowsList: func() []rows.Rows {
  1142  				cols := []string{"id"}
  1143  				query := "SELECT * FROM `t1`"
  1144  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1"))
  1145  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2"))
  1146  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").RowError(1, nextMockErr))
  1147  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("5"))
  1148  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04}
  1149  				rowsList := make([]rows.Rows, 0, len(dbs))
  1150  				for _, db := range dbs {
  1151  					row, err := db.QueryContext(context.Background(), query)
  1152  					require.NoError(ms.T(), err)
  1153  					rowsList = append(rowsList, row)
  1154  				}
  1155  				return rowsList
  1156  			},
  1157  			sortColumns: []SortColumn{
  1158  				NewSortColumn("id", ASC),
  1159  			},
  1160  			wantErr: nextMockErr,
  1161  		},
  1162  	}
  1163  	for _, tc := range testcases {
  1164  		ms.T().Run(tc.name, func(t *testing.T) {
  1165  			merger, err := NewMerger(tc.sortColumns...)
  1166  			require.NoError(t, err)
  1167  			rows, err := merger.Merge(context.Background(), tc.rowsList())
  1168  			require.NoError(t, err)
  1169  			for rows.Next() {
  1170  			}
  1171  			assert.Equal(t, tc.wantErr, rows.Err())
  1172  		})
  1173  	}
  1174  }
  1175  
  1176  // Scan方法的一些边界情况的测试
  1177  func (ms *MergerSuite) TestRows_ScanErr() {
  1178  	ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) {
  1179  		cols := []string{"id", "name", "address"}
  1180  		query := "SELECT * FROM `t1`"
  1181  		ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn"))
  1182  		r, err := ms.mockDB01.QueryContext(context.Background(), query)
  1183  		require.NoError(t, err)
  1184  		rowsList := []rows.Rows{r}
  1185  		merger, err := NewMerger(NewSortColumn("id", DESC))
  1186  		require.NoError(t, err)
  1187  		rows, err := merger.Merge(context.Background(), rowsList)
  1188  		require.NoError(t, err)
  1189  		model := TestModel{}
  1190  		err = rows.Scan(&model.Id, &model.Name, &model.Address)
  1191  		assert.Equal(t, errs.ErrMergerScanNotNext, err)
  1192  	})
  1193  	ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) {
  1194  		cols := []string{"id", "name", "address"}
  1195  		query := "SELECT * FROM `t1`"
  1196  		ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn").RowError(1, nextMockErr))
  1197  		r, err := ms.mockDB01.QueryContext(context.Background(), query)
  1198  		require.NoError(t, err)
  1199  		rowsList := []rows.Rows{r}
  1200  		merger, err := NewMerger(NewSortColumn("id", DESC))
  1201  		require.NoError(t, err)
  1202  		rows, err := merger.Merge(context.Background(), rowsList)
  1203  		require.NoError(t, err)
  1204  		for rows.Next() {
  1205  		}
  1206  		var model TestModel
  1207  		err = rows.Scan(&model.Id, &model.Name, &model.Address)
  1208  		assert.Equal(t, nextMockErr, err)
  1209  	})
  1210  
  1211  }
  1212  
  1213  type TestModel struct {
  1214  	Id      int
  1215  	Name    string
  1216  	Address string
  1217  }
  1218  
  1219  func TestMerger(t *testing.T) {
  1220  	suite.Run(t, &MergerSuite{})
  1221  	suite.Run(t, &NullableMergerSuite{})
  1222  }
  1223  
  1224  type NullableMergerSuite struct {
  1225  	suite.Suite
  1226  	db01 *sql.DB
  1227  	db02 *sql.DB
  1228  	db03 *sql.DB
  1229  }
  1230  
  1231  func (ms *NullableMergerSuite) SetupSuite() {
  1232  	t := ms.T()
  1233  	query := "CREATE TABLE t1 (\n      id int primary key,\n      `age`  int,\n    \t`name` varchar(20)\n  );\n"
  1234  	db01, err := sql.Open("sqlite3", "file:test01.db?cache=shared&mode=memory")
  1235  	if err != nil {
  1236  		t.Fatal(err)
  1237  	}
  1238  	ms.db01 = db01
  1239  	_, err = db01.ExecContext(context.Background(), query)
  1240  	if err != nil {
  1241  		t.Fatal(err)
  1242  	}
  1243  	db02, err := sql.Open("sqlite3", "file:test02.db?cache=shared&mode=memory")
  1244  	if err != nil {
  1245  		t.Fatal(err)
  1246  	}
  1247  	ms.db02 = db02
  1248  	_, err = db02.ExecContext(context.Background(), query)
  1249  	if err != nil {
  1250  		t.Fatal(err)
  1251  	}
  1252  	db03, err := sql.Open("sqlite3", "file:test03.db?cache=shared&mode=memory")
  1253  	if err != nil {
  1254  		t.Fatal(err)
  1255  	}
  1256  	ms.db03 = db03
  1257  	_, err = db03.ExecContext(context.Background(), query)
  1258  	if err != nil {
  1259  		t.Fatal(err)
  1260  	}
  1261  }
  1262  
  1263  func (ms *NullableMergerSuite) TearDownSuite() {
  1264  	_ = ms.db01.Close()
  1265  	_ = ms.db02.Close()
  1266  	_ = ms.db03.Close()
  1267  }
  1268  
  1269  func (ms *NullableMergerSuite) TestRows_Nullable() {
  1270  	testcases := []struct {
  1271  		name        string
  1272  		rowsList    func() []rows.Rows
  1273  		sortColumns []SortColumn
  1274  		wantErr     error
  1275  		afterFunc   func()
  1276  		wantVal     []Nullable
  1277  	}{
  1278  		{
  1279  			name: "多个nullable类型排序 age asc,name desc",
  1280  			rowsList: func() []rows.Rows {
  1281  				db1InsertSql := []string{
  1282  					"insert into t1  (id,  name) values (1,  'zwl')",
  1283  					"insert into t1  (id, age, name) values (2, 10, 'zwl')",
  1284  					"insert into t1  (id, age, name) values (3, 20, 'zwl')",
  1285  					"insert into t1  (id, age) values (4, 20)",
  1286  				}
  1287  				for _, sql := range db1InsertSql {
  1288  					_, err := ms.db01.ExecContext(context.Background(), sql)
  1289  					require.NoError(ms.T(), err)
  1290  				}
  1291  				db2InsertSql := []string{
  1292  					"insert into t1  (id, age, name) values (5, 5, 'zwl')",
  1293  					"insert into t1  (id, age, name) values (6, 20, 'dm')",
  1294  				}
  1295  				for _, sql := range db2InsertSql {
  1296  					_, err := ms.db02.ExecContext(context.Background(), sql)
  1297  					require.NoError(ms.T(), err)
  1298  				}
  1299  				db3InsertSql := []string{
  1300  					"insert into t1  (id, name) values (7, 'xq')",
  1301  					"insert into t1  (id, age) values (8, 5)",
  1302  					"insert into t1  (id, age,name) values (9, 10,'xx')",
  1303  				}
  1304  				for _, sql := range db3InsertSql {
  1305  					_, err := ms.db03.ExecContext(context.Background(), sql)
  1306  					require.NoError(ms.T(), err)
  1307  				}
  1308  				dbs := []*sql.DB{ms.db01, ms.db02, ms.db03}
  1309  				rowsList := make([]rows.Rows, 0, len(dbs))
  1310  				query := "SELECT `id`, `age`,`name` FROM `t1` order by age asc,name desc"
  1311  				for _, db := range dbs {
  1312  					rows, err := db.QueryContext(context.Background(), query)
  1313  					require.NoError(ms.T(), err)
  1314  					rowsList = append(rowsList, rows)
  1315  				}
  1316  				return rowsList
  1317  			},
  1318  			sortColumns: []SortColumn{
  1319  				NewSortColumn("age", ASC),
  1320  				NewSortColumn("name", DESC),
  1321  			},
  1322  			afterFunc: func() {
  1323  				dbs := []*sql.DB{ms.db01, ms.db02, ms.db03}
  1324  				for _, db := range dbs {
  1325  					_, err := db.Exec("DELETE FROM t1;")
  1326  					require.NoError(ms.T(), err)
  1327  				}
  1328  			},
  1329  			wantVal: func() []Nullable {
  1330  				return []Nullable{
  1331  					{
  1332  						Id:   sql.NullInt64{Valid: true, Int64: 1},
  1333  						Age:  sql.NullInt64{Valid: false, Int64: 0},
  1334  						Name: sql.NullString{Valid: true, String: "zwl"},
  1335  					},
  1336  					{
  1337  						Id:   sql.NullInt64{Valid: true, Int64: 7},
  1338  						Age:  sql.NullInt64{Valid: false, Int64: 0},
  1339  						Name: sql.NullString{Valid: true, String: "xq"},
  1340  					},
  1341  					{
  1342  						Id:   sql.NullInt64{Valid: true, Int64: 5},
  1343  						Age:  sql.NullInt64{Valid: true, Int64: 5},
  1344  						Name: sql.NullString{Valid: true, String: "zwl"},
  1345  					},
  1346  					{
  1347  						Id:   sql.NullInt64{Valid: true, Int64: 8},
  1348  						Age:  sql.NullInt64{Valid: true, Int64: 5},
  1349  						Name: sql.NullString{Valid: false, String: ""},
  1350  					},
  1351  					{
  1352  						Id:   sql.NullInt64{Valid: true, Int64: 2},
  1353  						Age:  sql.NullInt64{Valid: true, Int64: 10},
  1354  						Name: sql.NullString{Valid: true, String: "zwl"},
  1355  					},
  1356  					{
  1357  						Id:   sql.NullInt64{Valid: true, Int64: 9},
  1358  						Age:  sql.NullInt64{Valid: true, Int64: 10},
  1359  						Name: sql.NullString{Valid: true, String: "xx"},
  1360  					},
  1361  					{
  1362  						Id:   sql.NullInt64{Valid: true, Int64: 3},
  1363  						Age:  sql.NullInt64{Valid: true, Int64: 20},
  1364  						Name: sql.NullString{Valid: true, String: "zwl"},
  1365  					},
  1366  					{
  1367  						Id:   sql.NullInt64{Valid: true, Int64: 6},
  1368  						Age:  sql.NullInt64{Valid: true, Int64: 20},
  1369  						Name: sql.NullString{Valid: true, String: "dm"},
  1370  					},
  1371  					{
  1372  						Id:   sql.NullInt64{Valid: true, Int64: 4},
  1373  						Age:  sql.NullInt64{Valid: true, Int64: 20},
  1374  						Name: sql.NullString{Valid: false, String: ""},
  1375  					},
  1376  				}
  1377  			}(),
  1378  		},
  1379  	}
  1380  	for _, tc := range testcases {
  1381  		ms.T().Run(tc.name, func(t *testing.T) {
  1382  			merger, err := NewMerger(tc.sortColumns...)
  1383  			require.NoError(t, err)
  1384  			rows, err := merger.Merge(context.Background(), tc.rowsList())
  1385  			require.NoError(t, err)
  1386  			res := make([]Nullable, 0, len(tc.wantVal))
  1387  			for rows.Next() {
  1388  				nullT := Nullable{}
  1389  				err := rows.Scan(&nullT.Id, &nullT.Age, &nullT.Name)
  1390  				require.NoError(ms.T(), err)
  1391  				res = append(res, nullT)
  1392  			}
  1393  			require.True(t, rows.(*Rows).closed)
  1394  			assert.NoError(t, rows.Err())
  1395  			assert.Equal(t, tc.wantVal, res)
  1396  			tc.afterFunc()
  1397  		})
  1398  	}
  1399  }
  1400  
  1401  type Nullable struct {
  1402  	Id   sql.NullInt64
  1403  	Age  sql.NullInt64
  1404  	Name sql.NullString
  1405  }
  1406  
  1407  func TestRows_NextResultSet(t *testing.T) {
  1408  	assert.False(t, (&Rows{}).NextResultSet())
  1409  }