github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/aggregatemerger/merger_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggregatemerger
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"fmt"
    22  	"testing"
    23  
    24  	"github.com/ecodeclub/eorm/internal/rows"
    25  	_ "github.com/mattn/go-sqlite3"
    26  
    27  	"github.com/ecodeclub/eorm/internal/merger"
    28  
    29  	"github.com/DATA-DOG/go-sqlmock"
    30  	"github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator"
    31  	"github.com/ecodeclub/eorm/internal/merger/internal/errs"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"github.com/stretchr/testify/suite"
    35  	"go.uber.org/multierr"
    36  )
    37  
    38  var (
    39  	nextMockErr   = errors.New("rows: MockNextErr")
    40  	aggregatorErr = errors.New("aggregator: MockAggregatorErr")
    41  )
    42  
    43  func newCloseMockErr(dbName string) error {
    44  	return fmt.Errorf("rows: %s MockCloseErr", dbName)
    45  }
    46  
    47  type MergerSuite struct {
    48  	suite.Suite
    49  	mockDB01 *sql.DB
    50  	mock01   sqlmock.Sqlmock
    51  	mockDB02 *sql.DB
    52  	mock02   sqlmock.Sqlmock
    53  	mockDB03 *sql.DB
    54  	mock03   sqlmock.Sqlmock
    55  	mockDB04 *sql.DB
    56  	mock04   sqlmock.Sqlmock
    57  	db05     *sql.DB
    58  }
    59  
    60  func (ms *MergerSuite) SetupTest() {
    61  	t := ms.T()
    62  	ms.initMock(t)
    63  }
    64  
    65  func (ms *MergerSuite) TearDownTest() {
    66  	_ = ms.mockDB01.Close()
    67  	_ = ms.mockDB02.Close()
    68  	_ = ms.mockDB03.Close()
    69  	_ = ms.mockDB04.Close()
    70  	_ = ms.db05.Close()
    71  }
    72  
    73  func (ms *MergerSuite) initMock(t *testing.T) {
    74  	var err error
    75  	query := "CREATE TABLE t1" +
    76  		"(" +
    77  		"   id INT PRIMARY KEY     NOT NULL," +
    78  		"   grade            INT     NOT NULL" +
    79  		");"
    80  	ms.mockDB01, ms.mock01, err = sqlmock.New()
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	ms.mockDB02, ms.mock02, err = sqlmock.New()
    85  	if err != nil {
    86  		t.Fatal(err)
    87  	}
    88  	ms.mockDB03, ms.mock03, err = sqlmock.New()
    89  	if err != nil {
    90  		t.Fatal(err)
    91  	}
    92  	ms.mockDB04, ms.mock04, err = sqlmock.New()
    93  	if err != nil {
    94  		t.Fatal(err)
    95  	}
    96  	db05, err := sql.Open("sqlite3", "file:test01.db?cache=shared&mode=memory")
    97  	if err != nil {
    98  		t.Fatal(err)
    99  	}
   100  	ms.db05 = db05
   101  	_, err = db05.ExecContext(context.Background(), query)
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  }
   106  
   107  func TestMerger(t *testing.T) {
   108  	suite.Run(t, &MergerSuite{})
   109  }
   110  
   111  func (ms *MergerSuite) TestRows_NextAndScan() {
   112  	testcases := []struct {
   113  		name        string
   114  		sqlRows     func() []rows.Rows
   115  		wantVal     []any
   116  		aggregators func() []aggregator.Aggregator
   117  		gotVal      []any
   118  		wantErr     error
   119  	}{
   120  		{
   121  			name: "sqlite的ColumnType 使用了多级指针",
   122  			sqlRows: func() []rows.Rows {
   123  				query1 := "insert into `t1` values (1,10),(2,20),(3,30)"
   124  				_, err := ms.db05.ExecContext(context.Background(), query1)
   125  				require.NoError(ms.T(), err)
   126  				cols := []string{"SUM(id)"}
   127  				query := "SELECT SUM(`id`) FROM `t1`"
   128  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10))
   129  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20))
   130  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30))
   131  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.db05}
   132  				rowsList := make([]rows.Rows, 0, len(dbs))
   133  				for _, db := range dbs {
   134  					row, err := db.QueryContext(context.Background(), query)
   135  					require.NoError(ms.T(), err)
   136  					rowsList = append(rowsList, row)
   137  				}
   138  				return rowsList
   139  			},
   140  			wantVal: []any{int64(66)},
   141  			gotVal: func() []any {
   142  				return []any{
   143  					0,
   144  				}
   145  			}(),
   146  			aggregators: func() []aggregator.Aggregator {
   147  				return []aggregator.Aggregator{
   148  					aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
   149  				}
   150  			},
   151  		},
   152  		{
   153  			name: "SUM(id)",
   154  			sqlRows: func() []rows.Rows {
   155  				cols := []string{"SUM(id)"}
   156  				query := "SELECT SUM(`id`) FROM `t1`"
   157  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10))
   158  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20))
   159  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30))
   160  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   161  				rowsList := make([]rows.Rows, 0, len(dbs))
   162  				for _, db := range dbs {
   163  					row, err := db.QueryContext(context.Background(), query)
   164  					require.NoError(ms.T(), err)
   165  					rowsList = append(rowsList, row)
   166  				}
   167  				return rowsList
   168  			},
   169  			wantVal: []any{int64(60)},
   170  			gotVal: func() []any {
   171  				return []any{
   172  					0,
   173  				}
   174  			}(),
   175  			aggregators: func() []aggregator.Aggregator {
   176  				return []aggregator.Aggregator{
   177  					aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
   178  				}
   179  			},
   180  		},
   181  
   182  		{
   183  			name: "MAX(id)",
   184  			sqlRows: func() []rows.Rows {
   185  				cols := []string{"MAX(id)"}
   186  				query := "SELECT MAX(`id`) FROM `t1`"
   187  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10))
   188  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20))
   189  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30))
   190  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   191  				rowsList := make([]rows.Rows, 0, len(dbs))
   192  				for _, db := range dbs {
   193  					row, err := db.QueryContext(context.Background(), query)
   194  					require.NoError(ms.T(), err)
   195  					rowsList = append(rowsList, row)
   196  				}
   197  				return rowsList
   198  			},
   199  			wantVal: []any{int64(30)},
   200  			gotVal: func() []any {
   201  				return []any{
   202  					0,
   203  				}
   204  			}(),
   205  			aggregators: func() []aggregator.Aggregator {
   206  				return []aggregator.Aggregator{
   207  					aggregator.NewMax(merger.NewColumnInfo(0, "MAX(id)")),
   208  				}
   209  			},
   210  		},
   211  		{
   212  			name: "MIN(id)",
   213  			sqlRows: func() []rows.Rows {
   214  				cols := []string{"MIN(id)"}
   215  				query := "SELECT MIN(`id`) FROM `t1`"
   216  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10))
   217  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20))
   218  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30))
   219  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   220  				rowsList := make([]rows.Rows, 0, len(dbs))
   221  				for _, db := range dbs {
   222  					row, err := db.QueryContext(context.Background(), query)
   223  					require.NoError(ms.T(), err)
   224  					rowsList = append(rowsList, row)
   225  				}
   226  				return rowsList
   227  			},
   228  			wantVal: []any{int64(10)},
   229  			gotVal: func() []any {
   230  				return []any{
   231  					0,
   232  				}
   233  			}(),
   234  			aggregators: func() []aggregator.Aggregator {
   235  				return []aggregator.Aggregator{
   236  					aggregator.NewMin(merger.NewColumnInfo(0, "MIN(id)")),
   237  				}
   238  			},
   239  		},
   240  		{
   241  			name: "COUNT(id)",
   242  			sqlRows: func() []rows.Rows {
   243  				cols := []string{"COUNT(id)"}
   244  				query := "SELECT COUNT(`id`) FROM `t1`"
   245  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10))
   246  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20))
   247  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10))
   248  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   249  				rowsList := make([]rows.Rows, 0, len(dbs))
   250  				for _, db := range dbs {
   251  					row, err := db.QueryContext(context.Background(), query)
   252  					require.NoError(ms.T(), err)
   253  					rowsList = append(rowsList, row)
   254  				}
   255  				return rowsList
   256  			},
   257  			wantVal: []any{int64(40)},
   258  			gotVal: func() []any {
   259  				return []any{
   260  					0,
   261  				}
   262  			}(),
   263  			aggregators: func() []aggregator.Aggregator {
   264  				return []aggregator.Aggregator{
   265  					aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")),
   266  				}
   267  			},
   268  		},
   269  		{
   270  			name: "AVG(grade)",
   271  			sqlRows: func() []rows.Rows {
   272  				cols := []string{"SUM(grade)", "COUNT(grade)"}
   273  				query := "SELECT SUM(`grade`),COUNT(`grade`) FROM `t1`"
   274  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 10))
   275  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 20))
   276  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 10))
   277  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   278  				rowsList := make([]rows.Rows, 0, len(dbs))
   279  				for _, db := range dbs {
   280  					row, err := db.QueryContext(context.Background(), query)
   281  					require.NoError(ms.T(), err)
   282  					rowsList = append(rowsList, row)
   283  				}
   284  				return rowsList
   285  			},
   286  			wantVal: []any{
   287  				float64(150),
   288  			},
   289  			gotVal: func() []any {
   290  				return []any{
   291  					float64(0),
   292  				}
   293  			}(),
   294  			aggregators: func() []aggregator.Aggregator {
   295  				return []aggregator.Aggregator{
   296  					aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"),
   297  				}
   298  			},
   299  		},
   300  		// 下面为多个聚合函数组合的情况
   301  
   302  		// 1.每种聚合函数出现一次
   303  		{
   304  			name: "COUNT(id),MAX(id),MIN(id),SUM(id),AVG(grade)",
   305  			sqlRows: func() []rows.Rows {
   306  				cols := []string{"COUNT(id)", "MAX(id)", "MIN(id)", "SUM(id)", "SUM(grade)", "COUNT(grade)"}
   307  				query := "SELECT COUNT(`id`),MAX(`id`),MIN(`id`),SUM(`id`),SUM(`grade`),COUNT(`student`) FROM `t1`"
   308  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 20, 1, 100, 2000, 20))
   309  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20, 30, 0, 200, 800, 10))
   310  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 40, 2, 300, 1800, 20))
   311  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   312  				rowsList := make([]rows.Rows, 0, len(dbs))
   313  				for _, db := range dbs {
   314  					row, err := db.QueryContext(context.Background(), query)
   315  					require.NoError(ms.T(), err)
   316  					rowsList = append(rowsList, row)
   317  				}
   318  				return rowsList
   319  			},
   320  			wantVal: []any{
   321  				int64(40), int64(40), int64(0), int64(600), float64(4600) / float64(50),
   322  			},
   323  			gotVal: func() []any {
   324  				return []any{
   325  					0, 0, 0, 0, float64(0),
   326  				}
   327  			}(),
   328  			aggregators: func() []aggregator.Aggregator {
   329  				return []aggregator.Aggregator{
   330  					aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")),
   331  					aggregator.NewMax(merger.NewColumnInfo(1, "MAX(id)")),
   332  					aggregator.NewMin(merger.NewColumnInfo(2, "MIN(id)")),
   333  					aggregator.NewSum(merger.NewColumnInfo(3, "SUM(id)")),
   334  					aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(5, "COUNT(grade)"), "AVG(grade)"),
   335  				}
   336  			},
   337  		},
   338  		// 2. 聚合函数出现一次或多次,会有相同的聚合函数类型,且相同的聚合函数类型会有连续出现,和不连续出现。
   339  		// 两个avg会包含sum列在前,和sum列在后的状态。并且有完全相同的列出现
   340  		{
   341  			name: "AVG(grade),SUM(grade),AVG(grade),MIN(id),MIN(userid),MAX(id),COUNT(id)",
   342  			sqlRows: func() []rows.Rows {
   343  				cols := []string{"SUM(grade)", "COUNT(grade)", "SUM(grade)", "COUNT(grade)", "SUM(grade)", "MIN(id)", "MIN(userid)", "MAX(id)", "COUNT(id)"}
   344  				query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`grade`),COUNT(`grade`),SUM(`grade`),MIN(`id`),MIN(`userid`),MAX(`id`),COUNT(`id`) FROM `t1`"
   345  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 20, 2000, 20, 2000, 10, 20, 200, 200))
   346  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1000, 10, 1000, 10, 1000, 20, 30, 300, 300))
   347  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(800, 10, 800, 10, 800, 5, 6, 100, 200))
   348  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   349  				rowsList := make([]rows.Rows, 0, len(dbs))
   350  				for _, db := range dbs {
   351  					row, err := db.QueryContext(context.Background(), query)
   352  					require.NoError(ms.T(), err)
   353  					rowsList = append(rowsList, row)
   354  				}
   355  				return rowsList
   356  			},
   357  			wantVal: []any{
   358  				float64(3800) / float64(40), int64(3800), float64(3800) / float64(40), int64(5), int64(6), int64(300), int64(700),
   359  			},
   360  			gotVal: func() []any {
   361  				return []any{
   362  					float64(0), 0, float64(0), 0, 0, 0, 0,
   363  				}
   364  			}(),
   365  			aggregators: func() []aggregator.Aggregator {
   366  				return []aggregator.Aggregator{
   367  					aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"),
   368  					aggregator.NewSum(merger.NewColumnInfo(2, "SUM(grade)")),
   369  					aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(3, "COUNT(grade)"), "AVG(grade)"),
   370  					aggregator.NewMin(merger.NewColumnInfo(5, "MIN(id)")),
   371  					aggregator.NewMin(merger.NewColumnInfo(6, "MIN(userid)")),
   372  					aggregator.NewMax(merger.NewColumnInfo(7, "MAX(id)")),
   373  					aggregator.NewCount(merger.NewColumnInfo(8, "COUNT(id)")),
   374  				}
   375  			},
   376  		},
   377  
   378  		// 下面为RowList为有元素返回的行数为空
   379  
   380  		// 1. Rows 列表中有一个Rows返回行数为空,在前面会返回错误
   381  		{
   382  			name: "RowsList有一个Rows为空,在前面",
   383  			sqlRows: func() []rows.Rows {
   384  				cols := []string{"SUM(id)"}
   385  				query := "SELECT SUM(`id`) FROM `t1`"
   386  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10))
   387  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20))
   388  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30))
   389  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   390  				dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03}
   391  				rowsList := make([]rows.Rows, 0, len(dbs))
   392  				for _, db := range dbs {
   393  					row, err := db.QueryContext(context.Background(), query)
   394  					require.NoError(ms.T(), err)
   395  					rowsList = append(rowsList, row)
   396  				}
   397  				return rowsList
   398  			},
   399  			wantVal: []any{60},
   400  			gotVal: func() []any {
   401  				return []any{
   402  					0,
   403  				}
   404  			}(),
   405  			wantErr: errs.ErrMergerAggregateHasEmptyRows,
   406  			aggregators: func() []aggregator.Aggregator {
   407  				return []aggregator.Aggregator{
   408  					aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
   409  				}
   410  			},
   411  		},
   412  		// 2. Rows 列表中有一个Rows返回行数为空,在中间会返回错误
   413  		{
   414  			name: "RowsList有一个Rows为空,在中间",
   415  			sqlRows: func() []rows.Rows {
   416  				cols := []string{"SUM(id)"}
   417  				query := "SELECT SUM(`id`) FROM `t1`"
   418  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10))
   419  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20))
   420  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30))
   421  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   422  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB04, ms.mockDB02, ms.mockDB03}
   423  				rowsList := make([]rows.Rows, 0, len(dbs))
   424  				for _, db := range dbs {
   425  					row, err := db.QueryContext(context.Background(), query)
   426  					require.NoError(ms.T(), err)
   427  					rowsList = append(rowsList, row)
   428  				}
   429  				return rowsList
   430  			},
   431  			wantVal: []any{60},
   432  			gotVal: func() []any {
   433  				return []any{
   434  					0,
   435  				}
   436  			}(),
   437  			wantErr: errs.ErrMergerAggregateHasEmptyRows,
   438  			aggregators: func() []aggregator.Aggregator {
   439  				return []aggregator.Aggregator{
   440  					aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
   441  				}
   442  			},
   443  		},
   444  		// 3. Rows 列表中有一个Rows返回行数为空,在后面会返回错误
   445  		{
   446  			name: "RowsList有一个Rows为空,在最后",
   447  			sqlRows: func() []rows.Rows {
   448  				cols := []string{"SUM(id)"}
   449  				query := "SELECT SUM(`id`) FROM `t1`"
   450  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10))
   451  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20))
   452  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30))
   453  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   454  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04}
   455  				rowsList := make([]rows.Rows, 0, len(dbs))
   456  				for _, db := range dbs {
   457  					row, err := db.QueryContext(context.Background(), query)
   458  					require.NoError(ms.T(), err)
   459  					rowsList = append(rowsList, row)
   460  				}
   461  				return rowsList
   462  			},
   463  			wantVal: []any{60},
   464  			gotVal: func() []any {
   465  				return []any{
   466  					0,
   467  				}
   468  			}(),
   469  			wantErr: errs.ErrMergerAggregateHasEmptyRows,
   470  			aggregators: func() []aggregator.Aggregator {
   471  				return []aggregator.Aggregator{
   472  					aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
   473  				}
   474  			},
   475  		},
   476  		// 4. Rows 列表中全部Rows返回的行数为空,不会返回错误
   477  		{
   478  			name: "RowsList全部为空",
   479  			sqlRows: func() []rows.Rows {
   480  				cols := []string{"SUM(id)"}
   481  				query := "SELECT SUM(`id`) FROM `t1`"
   482  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   483  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   484  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   485  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols))
   486  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04}
   487  				rowsList := make([]rows.Rows, 0, len(dbs))
   488  				for _, db := range dbs {
   489  					row, err := db.QueryContext(context.Background(), query)
   490  					require.NoError(ms.T(), err)
   491  					rowsList = append(rowsList, row)
   492  				}
   493  				return rowsList
   494  			},
   495  			wantErr: errs.ErrMergerAggregateHasEmptyRows,
   496  			aggregators: func() []aggregator.Aggregator {
   497  				return []aggregator.Aggregator{
   498  					aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
   499  				}
   500  			},
   501  		},
   502  	}
   503  	for _, tc := range testcases {
   504  		ms.T().Run(tc.name, func(t *testing.T) {
   505  			m := NewMerger(tc.aggregators()...)
   506  			rows, err := m.Merge(context.Background(), tc.sqlRows())
   507  			require.NoError(t, err)
   508  			for rows.Next() {
   509  				kk := make([]any, 0, len(tc.gotVal))
   510  				for i := 0; i < len(tc.gotVal); i++ {
   511  					kk = append(kk, &tc.gotVal[i])
   512  				}
   513  				err = rows.Scan(kk...)
   514  				require.NoError(t, err)
   515  			}
   516  			assert.Equal(t, tc.wantErr, rows.Err())
   517  			if rows.Err() != nil {
   518  				return
   519  			}
   520  			assert.Equal(t, tc.wantVal, tc.gotVal)
   521  		})
   522  	}
   523  }
   524  
   525  func (ms *MergerSuite) TestRows_NextAndErr() {
   526  	testcases := []struct {
   527  		name        string
   528  		rowsList    func() []rows.Rows
   529  		wantErr     error
   530  		aggregators []aggregator.Aggregator
   531  	}{
   532  		{
   533  			name: "sqlRows列表中有一个返回error",
   534  			rowsList: func() []rows.Rows {
   535  				cols := []string{"COUNT(id)"}
   536  				query := "SELECT COUNT(`id`) FROM `t1`"
   537  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1))
   538  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2))
   539  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4).RowError(0, nextMockErr))
   540  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5))
   541  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04}
   542  				rowsList := make([]rows.Rows, 0, len(dbs))
   543  				for _, db := range dbs {
   544  					row, err := db.QueryContext(context.Background(), query)
   545  					require.NoError(ms.T(), err)
   546  					rowsList = append(rowsList, row)
   547  				}
   548  				return rowsList
   549  			},
   550  			aggregators: func() []aggregator.Aggregator {
   551  				return []aggregator.Aggregator{
   552  					aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")),
   553  				}
   554  			}(),
   555  			wantErr: nextMockErr,
   556  		},
   557  		{
   558  			name: "有一个aggregator返回error",
   559  			rowsList: func() []rows.Rows {
   560  				cols := []string{"COUNT(id)"}
   561  				query := "SELECT COUNT(`id`) FROM `t1`"
   562  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1))
   563  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2))
   564  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4))
   565  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5))
   566  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04}
   567  				rowsList := make([]rows.Rows, 0, len(dbs))
   568  				for _, db := range dbs {
   569  					row, err := db.QueryContext(context.Background(), query)
   570  					require.NoError(ms.T(), err)
   571  					rowsList = append(rowsList, row)
   572  				}
   573  				return rowsList
   574  			},
   575  			aggregators: func() []aggregator.Aggregator {
   576  				return []aggregator.Aggregator{
   577  					&mockAggregate{},
   578  				}
   579  			}(),
   580  			wantErr: aggregatorErr,
   581  		},
   582  	}
   583  	for _, tc := range testcases {
   584  		ms.T().Run(tc.name, func(t *testing.T) {
   585  			merger := NewMerger(tc.aggregators...)
   586  			rows, err := merger.Merge(context.Background(), tc.rowsList())
   587  			require.NoError(t, err)
   588  			for rows.Next() {
   589  			}
   590  			count := int64(0)
   591  			err = rows.Scan(&count)
   592  			assert.Equal(t, tc.wantErr, err)
   593  			assert.Equal(t, tc.wantErr, rows.Err())
   594  		})
   595  	}
   596  }
   597  
   598  func (ms *MergerSuite) TestRows_Close() {
   599  	cols := []string{"SUM(id)"}
   600  	query := "SELECT SUM(`id`) FROM `t1`"
   601  	ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1))
   602  	ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2).CloseError(newCloseMockErr("db02")))
   603  	ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3).CloseError(newCloseMockErr("db03")))
   604  	merger := NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")))
   605  	dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   606  	rowsList := make([]rows.Rows, 0, len(dbs))
   607  	for _, db := range dbs {
   608  		row, err := db.QueryContext(context.Background(), query)
   609  		require.NoError(ms.T(), err)
   610  		rowsList = append(rowsList, row)
   611  	}
   612  	rows, err := merger.Merge(context.Background(), rowsList)
   613  	require.NoError(ms.T(), err)
   614  	// 判断当前是可以正常读取的
   615  	require.True(ms.T(), rows.Next())
   616  	var id int
   617  	err = rows.Scan(&id)
   618  	require.NoError(ms.T(), err)
   619  	err = rows.Close()
   620  	ms.T().Run("close返回multierror", func(t *testing.T) {
   621  		assert.Equal(ms.T(), multierr.Combine(newCloseMockErr("db02"), newCloseMockErr("db03")), err)
   622  	})
   623  	ms.T().Run("close之后Next返回false", func(t *testing.T) {
   624  		for i := 0; i < len(rowsList); i++ {
   625  			require.False(ms.T(), rowsList[i].Next())
   626  		}
   627  		require.False(ms.T(), rows.Next())
   628  	})
   629  	ms.T().Run("close之后Scan返回迭代过程中的错误", func(t *testing.T) {
   630  		var id int
   631  		err := rows.Scan(&id)
   632  		assert.Equal(t, errs.ErrMergerRowsClosed, err)
   633  	})
   634  	ms.T().Run("close之后调用Columns方法返回错误", func(t *testing.T) {
   635  		_, err := rows.Columns()
   636  		require.Error(t, err)
   637  	})
   638  	ms.T().Run("close多次是等效的", func(t *testing.T) {
   639  		for i := 0; i < 4; i++ {
   640  			err = rows.Close()
   641  			require.NoError(t, err)
   642  		}
   643  	})
   644  }
   645  
   646  func (ms *MergerSuite) TestRows_Columns() {
   647  	cols := []string{"SUM(grade)", "COUNT(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"}
   648  	query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`id`),MIN(`id`),MAX(`id`),COUNT(`id`) FROM `t1`"
   649  	ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 1, 2, 1, 3, 10))
   650  	ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11))
   651  	ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12))
   652  	aggregators := []aggregator.Aggregator{
   653  		aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"),
   654  		aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")),
   655  		aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")),
   656  		aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")),
   657  		aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")),
   658  	}
   659  	merger := NewMerger(aggregators...)
   660  	dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   661  	rowsList := make([]rows.Rows, 0, len(dbs))
   662  	for _, db := range dbs {
   663  		row, err := db.QueryContext(context.Background(), query)
   664  		require.NoError(ms.T(), err)
   665  		rowsList = append(rowsList, row)
   666  	}
   667  
   668  	rows, err := merger.Merge(context.Background(), rowsList)
   669  	require.NoError(ms.T(), err)
   670  	wantCols := []string{"AVG(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"}
   671  	ms.T().Run("Next没有迭代完", func(t *testing.T) {
   672  		for rows.Next() {
   673  			columns, err := rows.Columns()
   674  			require.NoError(t, err)
   675  			assert.Equal(t, wantCols, columns)
   676  		}
   677  		require.NoError(t, rows.Err())
   678  	})
   679  	ms.T().Run("Next迭代完", func(t *testing.T) {
   680  		require.False(t, rows.Next())
   681  		require.NoError(t, rows.Err())
   682  		_, err := rows.Columns()
   683  		assert.Equal(t, errs.ErrMergerRowsClosed, err)
   684  	})
   685  }
   686  
   687  func (ms *MergerSuite) TestMerger_Merge() {
   688  	testcases := []struct {
   689  		name    string
   690  		merger  func() *Merger
   691  		ctx     func() (context.Context, context.CancelFunc)
   692  		wantErr error
   693  		sqlRows func() []rows.Rows
   694  	}{
   695  		{
   696  			name: "超时",
   697  			merger: func() *Merger {
   698  				return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")))
   699  			},
   700  			ctx: func() (context.Context, context.CancelFunc) {
   701  				ctx, cancel := context.WithTimeout(context.Background(), 0)
   702  				return ctx, cancel
   703  			},
   704  			wantErr: context.DeadlineExceeded,
   705  			sqlRows: func() []rows.Rows {
   706  				query := "SELECT  SUM(`id`) FROM `t1`;"
   707  				cols := []string{"SUM(id)"}
   708  				res := make([]rows.Rows, 0, 1)
   709  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1))
   710  				rows, _ := ms.mockDB01.QueryContext(context.Background(), query)
   711  				res = append(res, rows)
   712  				return res
   713  			},
   714  		},
   715  		{
   716  			name: "sqlRows列表元素个数为0",
   717  			merger: func() *Merger {
   718  				return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")))
   719  			},
   720  			ctx: func() (context.Context, context.CancelFunc) {
   721  				ctx, cancel := context.WithCancel(context.Background())
   722  				return ctx, cancel
   723  			},
   724  			wantErr: errs.ErrMergerEmptyRows,
   725  			sqlRows: func() []rows.Rows {
   726  				return []rows.Rows{}
   727  			},
   728  		},
   729  		{
   730  			name: "sqlRows列表有nil",
   731  			merger: func() *Merger {
   732  				return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")))
   733  			},
   734  			ctx: func() (context.Context, context.CancelFunc) {
   735  				ctx, cancel := context.WithCancel(context.Background())
   736  				return ctx, cancel
   737  			},
   738  			wantErr: errs.ErrMergerRowsIsNull,
   739  			sqlRows: func() []rows.Rows {
   740  				return []rows.Rows{nil}
   741  			},
   742  		},
   743  	}
   744  	for _, tc := range testcases {
   745  		ms.T().Run(tc.name, func(t *testing.T) {
   746  			ctx, cancel := tc.ctx()
   747  			m := tc.merger()
   748  			r, err := m.Merge(ctx, tc.sqlRows())
   749  			cancel()
   750  			assert.Equal(t, tc.wantErr, err)
   751  			if err != nil {
   752  				return
   753  			}
   754  			require.NotNil(t, r)
   755  		})
   756  	}
   757  }
   758  
   759  type mockAggregate struct {
   760  	cols [][]any
   761  }
   762  
   763  func (m *mockAggregate) Aggregate(cols [][]any) (any, error) {
   764  	m.cols = cols
   765  	return nil, aggregatorErr
   766  }
   767  
   768  func (*mockAggregate) ColumnName() string {
   769  	return "mockAggregate"
   770  }
   771  
   772  func TestRows_NextResultSet(t *testing.T) {
   773  	assert.False(t, (&Rows{}).NextResultSet())
   774  }