github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/groupby_merger/aggregator_merger_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package groupby_merger
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"testing"
    22  
    23  	"github.com/ecodeclub/eorm/internal/rows"
    24  
    25  	"github.com/ecodeclub/eorm/internal/merger"
    26  
    27  	"github.com/DATA-DOG/go-sqlmock"
    28  	"github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator"
    29  	"github.com/ecodeclub/eorm/internal/merger/internal/errs"
    30  	"github.com/stretchr/testify/assert"
    31  	"github.com/stretchr/testify/require"
    32  	"github.com/stretchr/testify/suite"
    33  )
    34  
    35  var (
    36  	nextMockErr   error = errors.New("rows: MockNextErr")
    37  	aggregatorErr error = errors.New("aggregator: MockAggregatorErr")
    38  )
    39  
    40  type MergerSuite struct {
    41  	suite.Suite
    42  	mockDB01 *sql.DB
    43  	mock01   sqlmock.Sqlmock
    44  	mockDB02 *sql.DB
    45  	mock02   sqlmock.Sqlmock
    46  	mockDB03 *sql.DB
    47  	mock03   sqlmock.Sqlmock
    48  	mockDB04 *sql.DB
    49  	mock04   sqlmock.Sqlmock
    50  }
    51  
    52  func (ms *MergerSuite) SetupTest() {
    53  	t := ms.T()
    54  	ms.initMock(t)
    55  }
    56  
    57  func (ms *MergerSuite) TearDownTest() {
    58  	_ = ms.mockDB01.Close()
    59  	_ = ms.mockDB02.Close()
    60  	_ = ms.mockDB03.Close()
    61  	_ = ms.mockDB04.Close()
    62  }
    63  
    64  func (ms *MergerSuite) initMock(t *testing.T) {
    65  	var err error
    66  	ms.mockDB01, ms.mock01, err = sqlmock.New()
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	ms.mockDB02, ms.mock02, err = sqlmock.New()
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	ms.mockDB03, ms.mock03, err = sqlmock.New()
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	ms.mockDB04, ms.mock04, err = sqlmock.New()
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  }
    83  
    84  func TestMerger(t *testing.T) {
    85  	suite.Run(t, &MergerSuite{})
    86  }
    87  
    88  func (ms *MergerSuite) TestAggregatorMerger_Merge() {
    89  	testcases := []struct {
    90  		name           string
    91  		aggregators    []aggregator.Aggregator
    92  		rowsList       []rows.Rows
    93  		GroupByColumns []merger.ColumnInfo
    94  		wantErr        error
    95  		ctx            func() (context.Context, context.CancelFunc)
    96  	}{
    97  		{
    98  			name: "正常案例",
    99  			aggregators: []aggregator.Aggregator{
   100  				aggregator.NewCount(merger.NewColumnInfo(2, "id")),
   101  			},
   102  			GroupByColumns: []merger.ColumnInfo{
   103  				merger.NewColumnInfo(0, "county"),
   104  				merger.NewColumnInfo(1, "gender"),
   105  			},
   106  			rowsList: func() []rows.Rows {
   107  				query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`"
   108  				cols := []string{"county", "gender", "SUM(id)"}
   109  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30))
   110  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60))
   111  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80))
   112  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   113  				rowsList := make([]rows.Rows, 0, len(dbs))
   114  				for _, db := range dbs {
   115  					row, err := db.QueryContext(context.Background(), query)
   116  					require.NoError(ms.T(), err)
   117  					rowsList = append(rowsList, row)
   118  				}
   119  				return rowsList
   120  			}(),
   121  
   122  			ctx: func() (context.Context, context.CancelFunc) {
   123  				ctx, cancel := context.WithCancel(context.Background())
   124  				return ctx, cancel
   125  			},
   126  		},
   127  		{
   128  			name: "超时",
   129  			aggregators: []aggregator.Aggregator{
   130  				aggregator.NewCount(merger.NewColumnInfo(1, "id")),
   131  			},
   132  			GroupByColumns: []merger.ColumnInfo{
   133  				merger.NewColumnInfo(0, "user_name"),
   134  			},
   135  			rowsList: func() []rows.Rows {
   136  				query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`"
   137  				cols := []string{"user_name", "SUM(id)"}
   138  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20))
   139  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20))
   140  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20))
   141  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   142  				rowsList := make([]rows.Rows, 0, len(dbs))
   143  				for _, db := range dbs {
   144  					row, err := db.QueryContext(context.Background(), query)
   145  					require.NoError(ms.T(), err)
   146  					rowsList = append(rowsList, row)
   147  				}
   148  				return rowsList
   149  			}(),
   150  			ctx: func() (context.Context, context.CancelFunc) {
   151  				ctx, cancel := context.WithTimeout(context.Background(), 0)
   152  				return ctx, cancel
   153  			},
   154  			wantErr: context.DeadlineExceeded,
   155  		},
   156  		{
   157  			name: "rowsList为空",
   158  			aggregators: []aggregator.Aggregator{
   159  				aggregator.NewCount(merger.NewColumnInfo(1, "id")),
   160  			},
   161  			GroupByColumns: []merger.ColumnInfo{
   162  				merger.NewColumnInfo(0, "user_name"),
   163  			},
   164  			rowsList: func() []rows.Rows {
   165  				return []rows.Rows{}
   166  			}(),
   167  			ctx: func() (context.Context, context.CancelFunc) {
   168  				ctx, cancel := context.WithCancel(context.Background())
   169  				return ctx, cancel
   170  			},
   171  			wantErr: errs.ErrMergerEmptyRows,
   172  		},
   173  		{
   174  			name: "rowsList中有nil",
   175  			aggregators: []aggregator.Aggregator{
   176  				aggregator.NewCount(merger.NewColumnInfo(1, "id")),
   177  			},
   178  			GroupByColumns: []merger.ColumnInfo{
   179  				merger.NewColumnInfo(0, "user_name"),
   180  			},
   181  			rowsList: func() []rows.Rows {
   182  				return []rows.Rows{nil}
   183  			}(),
   184  			ctx: func() (context.Context, context.CancelFunc) {
   185  				ctx, cancel := context.WithCancel(context.Background())
   186  				return ctx, cancel
   187  			},
   188  			wantErr: errs.ErrMergerRowsIsNull,
   189  		},
   190  		{
   191  			name: "rowsList中有sql.Rows返回错误",
   192  			aggregators: []aggregator.Aggregator{
   193  				aggregator.NewCount(merger.NewColumnInfo(1, "id")),
   194  			},
   195  			GroupByColumns: []merger.ColumnInfo{
   196  				merger.NewColumnInfo(0, "user_name"),
   197  			},
   198  			rowsList: func() []rows.Rows {
   199  				query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`"
   200  				cols := []string{"user_name", "SUM(id)"}
   201  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20).RowError(1, nextMockErr))
   202  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20))
   203  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20))
   204  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   205  				rowsList := make([]rows.Rows, 0, len(dbs))
   206  				for _, db := range dbs {
   207  					row, err := db.QueryContext(context.Background(), query)
   208  					require.NoError(ms.T(), err)
   209  					rowsList = append(rowsList, row)
   210  				}
   211  				return rowsList
   212  			}(),
   213  			ctx: func() (context.Context, context.CancelFunc) {
   214  				ctx, cancel := context.WithCancel(context.Background())
   215  				return ctx, cancel
   216  			},
   217  			wantErr: nextMockErr,
   218  		},
   219  	}
   220  	for _, tc := range testcases {
   221  		ms.T().Run(tc.name, func(t *testing.T) {
   222  			merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns)
   223  			ctx, cancel := tc.ctx()
   224  			groupByRows, err := merger.Merge(ctx, tc.rowsList)
   225  			cancel()
   226  			assert.Equal(t, tc.wantErr, err)
   227  			if err != nil {
   228  				return
   229  			}
   230  			require.NotNil(t, groupByRows)
   231  		})
   232  	}
   233  }
   234  
   235  func (ms *MergerSuite) TestAggregatorRows_NextAndScan() {
   236  	testcases := []struct {
   237  		name           string
   238  		aggregators    []aggregator.Aggregator
   239  		rowsList       []rows.Rows
   240  		wantVal        [][]any
   241  		gotVal         [][]any
   242  		GroupByColumns []merger.ColumnInfo
   243  		wantErr        error
   244  	}{
   245  		{
   246  			name: "同一组数据在不同的sql.Rows中",
   247  			aggregators: []aggregator.Aggregator{
   248  				aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")),
   249  			},
   250  			GroupByColumns: []merger.ColumnInfo{
   251  				merger.NewColumnInfo(0, "user_name"),
   252  			},
   253  			rowsList: func() []rows.Rows {
   254  				query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`"
   255  				cols := []string{"user_name", "SUM(id)"}
   256  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20))
   257  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20))
   258  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20))
   259  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   260  				rowsList := make([]rows.Rows, 0, len(dbs))
   261  				for _, db := range dbs {
   262  					row, err := db.QueryContext(context.Background(), query)
   263  					require.NoError(ms.T(), err)
   264  					rowsList = append(rowsList, row)
   265  				}
   266  				return rowsList
   267  			}(),
   268  			wantVal: [][]any{
   269  				{"zwl", int64(30)},
   270  				{"dm", int64(40)},
   271  				{"xz", int64(10)},
   272  			},
   273  			gotVal: [][]any{
   274  				{"", int64(0)},
   275  				{"", int64(0)},
   276  				{"", int64(0)},
   277  			},
   278  		},
   279  		{
   280  			name: "同一组数据在同一个sql.Rows中",
   281  			aggregators: []aggregator.Aggregator{
   282  				aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")),
   283  			},
   284  			GroupByColumns: []merger.ColumnInfo{
   285  				merger.NewColumnInfo(0, "user_name"),
   286  			},
   287  			rowsList: func() []rows.Rows {
   288  				query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`"
   289  				cols := []string{"user_name", "SUM(id)"}
   290  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("xm", 20))
   291  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("xx", 20))
   292  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20))
   293  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   294  				rowsList := make([]rows.Rows, 0, len(dbs))
   295  				for _, db := range dbs {
   296  					row, err := db.QueryContext(context.Background(), query)
   297  					require.NoError(ms.T(), err)
   298  					rowsList = append(rowsList, row)
   299  				}
   300  				return rowsList
   301  			}(),
   302  			wantVal: [][]any{
   303  				{"zwl", int64(10)},
   304  				{"xm", int64(20)},
   305  				{"xz", int64(10)},
   306  				{"xx", int64(20)},
   307  				{"dm", int64(20)},
   308  			},
   309  			gotVal: [][]any{
   310  				{"", int64(0)},
   311  				{"", int64(0)},
   312  				{"", int64(0)},
   313  				{"", int64(0)},
   314  				{"", int64(0)},
   315  			},
   316  		},
   317  		{
   318  			name: "多个分组列",
   319  			aggregators: []aggregator.Aggregator{
   320  				aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")),
   321  			},
   322  			GroupByColumns: []merger.ColumnInfo{
   323  				merger.NewColumnInfo(0, "county"),
   324  				merger.NewColumnInfo(1, "gender"),
   325  			},
   326  			rowsList: func() []rows.Rows {
   327  				query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`"
   328  				cols := []string{"county", "gender", "SUM(id)"}
   329  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30))
   330  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60))
   331  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80))
   332  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   333  				rowsList := make([]rows.Rows, 0, len(dbs))
   334  				for _, db := range dbs {
   335  					row, err := db.QueryContext(context.Background(), query)
   336  					require.NoError(ms.T(), err)
   337  					rowsList = append(rowsList, row)
   338  				}
   339  				return rowsList
   340  			}(),
   341  			wantVal: [][]any{
   342  				{
   343  					"hangzhou",
   344  					"male",
   345  					int64(10),
   346  				},
   347  				{
   348  					"hangzhou",
   349  					"female",
   350  					int64(80),
   351  				},
   352  				{
   353  					"shanghai",
   354  					"female",
   355  					int64(160),
   356  				},
   357  				{
   358  					"shanghai",
   359  					"male",
   360  					int64(110),
   361  				},
   362  			},
   363  			gotVal: [][]any{
   364  				{"", "", int64(0)},
   365  				{"", "", int64(0)},
   366  				{"", "", int64(0)},
   367  				{"", "", int64(0)},
   368  			},
   369  		},
   370  		{
   371  			name: "多个聚合函数",
   372  			aggregators: []aggregator.Aggregator{
   373  				aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")),
   374  				aggregator.NewAVG(merger.NewColumnInfo(3, "SUM(age)"), merger.NewColumnInfo(4, "COUNT(age)"), "AVG(age)"),
   375  			},
   376  			GroupByColumns: []merger.ColumnInfo{
   377  				merger.NewColumnInfo(0, "county"),
   378  				merger.NewColumnInfo(1, "gender"),
   379  			},
   380  
   381  			rowsList: func() []rows.Rows {
   382  				query := "SELECT `county`,`gender`,SUM(`id`),SUM(`age`),COUNT(`age`) FROM `t1` GROUP BY `country`,`gender`"
   383  				cols := []string{"county", "gender", "SUM(id)", "SUM(age)", "COUNT(age)"}
   384  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10, 100, 2).AddRow("hangzhou", "female", 20, 120, 3).AddRow("shanghai", "female", 30, 90, 3))
   385  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40, 120, 5).AddRow("shanghai", "female", 50, 120, 4).AddRow("hangzhou", "female", 60, 150, 3))
   386  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70, 100, 5).AddRow("shanghai", "female", 80, 150, 5))
   387  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   388  				rowsList := make([]rows.Rows, 0, len(dbs))
   389  				for _, db := range dbs {
   390  					row, err := db.QueryContext(context.Background(), query)
   391  					require.NoError(ms.T(), err)
   392  					rowsList = append(rowsList, row)
   393  				}
   394  				return rowsList
   395  			}(),
   396  			wantVal: [][]any{
   397  				{
   398  					"hangzhou",
   399  					"male",
   400  					int64(10),
   401  					float64(50),
   402  				},
   403  				{
   404  					"hangzhou",
   405  					"female",
   406  					int64(80),
   407  					float64(45),
   408  				},
   409  				{
   410  					"shanghai",
   411  					"female",
   412  					int64(160),
   413  					float64(30),
   414  				},
   415  				{
   416  					"shanghai",
   417  					"male",
   418  					int64(110),
   419  					float64(22),
   420  				},
   421  			},
   422  			gotVal: [][]any{
   423  				{"", "", int64(0), float64(0)},
   424  				{"", "", int64(0), float64(0)},
   425  				{"", "", int64(0), float64(0)},
   426  				{"", "", int64(0), float64(0)},
   427  			},
   428  		},
   429  	}
   430  	for _, tc := range testcases {
   431  		ms.T().Run(tc.name, func(t *testing.T) {
   432  			merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns)
   433  			groupByRows, err := merger.Merge(context.Background(), tc.rowsList)
   434  			require.NoError(t, err)
   435  
   436  			idx := 0
   437  			for groupByRows.Next() {
   438  				if idx >= len(tc.gotVal) {
   439  					break
   440  				}
   441  				tmp := make([]any, 0, len(tc.gotVal[0]))
   442  				for i := range tc.gotVal[idx] {
   443  					tmp = append(tmp, &tc.gotVal[idx][i])
   444  				}
   445  				err := groupByRows.Scan(tmp...)
   446  				require.NoError(t, err)
   447  				idx++
   448  			}
   449  			require.NoError(t, groupByRows.Err())
   450  			assert.Equal(t, tc.wantVal, tc.gotVal)
   451  		})
   452  	}
   453  }
   454  
   455  func (ms *MergerSuite) TestAggregatorRows_ScanAndErr() {
   456  	ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) {
   457  		cols := []string{"userid", "SUM(id)"}
   458  		query := "SELECT userid,SUM(id) FROM `t1`"
   459  		ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20))
   460  		r, err := ms.mockDB01.QueryContext(context.Background(), query)
   461  		require.NoError(t, err)
   462  		rowsList := []rows.Rows{r}
   463  		merger := NewAggregatorMerger([]aggregator.Aggregator{aggregator.NewSum(merger.NewColumnInfo(1, "SUM(id)"))}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")})
   464  		rows, err := merger.Merge(context.Background(), rowsList)
   465  		require.NoError(t, err)
   466  		userid := 0
   467  		sumId := 0
   468  		err = rows.Scan(&userid, &sumId)
   469  		assert.Equal(t, errs.ErrMergerScanNotNext, err)
   470  	})
   471  	ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) {
   472  		cols := []string{"userid", "SUM(id)"}
   473  		query := "SELECT userid,SUM(id) FROM `t1`"
   474  		ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20))
   475  		r, err := ms.mockDB01.QueryContext(context.Background(), query)
   476  		require.NoError(t, err)
   477  		rowsList := []rows.Rows{r}
   478  		merger := NewAggregatorMerger([]aggregator.Aggregator{&mockAggregate{}}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")})
   479  		rows, err := merger.Merge(context.Background(), rowsList)
   480  		require.NoError(t, err)
   481  		userid := 0
   482  		sumId := 0
   483  		rows.Next()
   484  		err = rows.Scan(&userid, &sumId)
   485  		assert.Equal(t, aggregatorErr, err)
   486  	})
   487  
   488  }
   489  
   490  func (ms *MergerSuite) TestAggregatorRows_NextAndErr() {
   491  	testcases := []struct {
   492  		name           string
   493  		rowsList       func() []rows.Rows
   494  		wantErr        error
   495  		aggregators    []aggregator.Aggregator
   496  		GroupByColumns []merger.ColumnInfo
   497  	}{
   498  		{
   499  			name: "有一个aggregator返回error",
   500  			rowsList: func() []rows.Rows {
   501  				cols := []string{"username", "COUNT(id)"}
   502  				query := "SELECT username,COUNT(`id`) FROM `t1`"
   503  				ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 1))
   504  				ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("daming", 2))
   505  				ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("wu", 4))
   506  				ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("ming", 5))
   507  				dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04}
   508  				rowsList := make([]rows.Rows, 0, len(dbs))
   509  				for _, db := range dbs {
   510  					row, err := db.QueryContext(context.Background(), query)
   511  					require.NoError(ms.T(), err)
   512  					rowsList = append(rowsList, row)
   513  				}
   514  				return rowsList
   515  			},
   516  			aggregators: func() []aggregator.Aggregator {
   517  				return []aggregator.Aggregator{
   518  					&mockAggregate{},
   519  				}
   520  			}(),
   521  			GroupByColumns: []merger.ColumnInfo{
   522  				merger.NewColumnInfo(0, "username"),
   523  			},
   524  			wantErr: aggregatorErr,
   525  		},
   526  	}
   527  	for _, tc := range testcases {
   528  		ms.T().Run(tc.name, func(t *testing.T) {
   529  			merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns)
   530  			rows, err := merger.Merge(context.Background(), tc.rowsList())
   531  			require.NoError(t, err)
   532  			for rows.Next() {
   533  			}
   534  			count := int64(0)
   535  			name := ""
   536  			err = rows.Scan(&name, &count)
   537  			assert.Equal(t, tc.wantErr, err)
   538  			assert.Equal(t, tc.wantErr, rows.Err())
   539  		})
   540  	}
   541  }
   542  
   543  func (ms *MergerSuite) TestAggregatorRows_Columns() {
   544  	cols := []string{"userid", "SUM(grade)", "COUNT(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"}
   545  	query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`id`),MIN(`id`),MAX(`id`),COUNT(`id`),`userid` FROM `t1`"
   546  	ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 1, 2, 1, 3, 10, "zwl"))
   547  	ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11, "dm"))
   548  	ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12, "xm"))
   549  	aggregators := []aggregator.Aggregator{
   550  		aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"),
   551  		aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")),
   552  		aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")),
   553  		aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")),
   554  		aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")),
   555  	}
   556  	groupbyColumns := []merger.ColumnInfo{
   557  		merger.NewColumnInfo(6, "userid"),
   558  	}
   559  	merger := NewAggregatorMerger(aggregators, groupbyColumns)
   560  	dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
   561  	rowsList := make([]rows.Rows, 0, len(dbs))
   562  	for _, db := range dbs {
   563  		row, err := db.QueryContext(context.Background(), query)
   564  		require.NoError(ms.T(), err)
   565  		rowsList = append(rowsList, row)
   566  	}
   567  
   568  	rows, err := merger.Merge(context.Background(), rowsList)
   569  	require.NoError(ms.T(), err)
   570  	wantCols := []string{"userid", "AVG(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"}
   571  	ms.T().Run("Next没有迭代完", func(t *testing.T) {
   572  		for rows.Next() {
   573  			columns, err := rows.Columns()
   574  			require.NoError(t, err)
   575  			assert.Equal(t, wantCols, columns)
   576  		}
   577  		require.NoError(t, rows.Err())
   578  	})
   579  	ms.T().Run("Next迭代完", func(t *testing.T) {
   580  		require.False(t, rows.Next())
   581  		require.NoError(t, rows.Err())
   582  		_, err := rows.Columns()
   583  		assert.Equal(t, errs.ErrMergerRowsClosed, err)
   584  	})
   585  }
   586  
   587  type mockAggregate struct {
   588  	cols [][]any
   589  }
   590  
   591  func (m *mockAggregate) Aggregate(cols [][]any) (any, error) {
   592  	m.cols = cols
   593  	return nil, aggregatorErr
   594  }
   595  
   596  func (*mockAggregate) ColumnName() string {
   597  	return "mockAggregate"
   598  }
   599  
   600  func TestAggregatorRows_NextResultSet(t *testing.T) {
   601  	assert.False(t, (&AggregatorRows{}).NextResultSet())
   602  }