github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/groupby_merger/aggregator_merger_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package groupby_merger 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "testing" 22 23 "github.com/ecodeclub/eorm/internal/rows" 24 25 "github.com/ecodeclub/eorm/internal/merger" 26 27 "github.com/DATA-DOG/go-sqlmock" 28 "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" 29 "github.com/ecodeclub/eorm/internal/merger/internal/errs" 30 "github.com/stretchr/testify/assert" 31 "github.com/stretchr/testify/require" 32 "github.com/stretchr/testify/suite" 33 ) 34 35 var ( 36 nextMockErr error = errors.New("rows: MockNextErr") 37 aggregatorErr error = errors.New("aggregator: MockAggregatorErr") 38 ) 39 40 type MergerSuite struct { 41 suite.Suite 42 mockDB01 *sql.DB 43 mock01 sqlmock.Sqlmock 44 mockDB02 *sql.DB 45 mock02 sqlmock.Sqlmock 46 mockDB03 *sql.DB 47 mock03 sqlmock.Sqlmock 48 mockDB04 *sql.DB 49 mock04 sqlmock.Sqlmock 50 } 51 52 func (ms *MergerSuite) SetupTest() { 53 t := ms.T() 54 ms.initMock(t) 55 } 56 57 func (ms *MergerSuite) TearDownTest() { 58 _ = ms.mockDB01.Close() 59 _ = ms.mockDB02.Close() 60 _ = ms.mockDB03.Close() 61 _ = ms.mockDB04.Close() 62 } 63 64 func (ms *MergerSuite) initMock(t *testing.T) { 65 var err error 66 ms.mockDB01, ms.mock01, err = sqlmock.New() 67 if err != nil { 68 t.Fatal(err) 69 } 70 ms.mockDB02, ms.mock02, err = sqlmock.New() 71 if err != nil { 72 t.Fatal(err) 73 } 74 ms.mockDB03, ms.mock03, err = sqlmock.New() 75 if err != nil { 76 t.Fatal(err) 77 } 78 ms.mockDB04, ms.mock04, err = sqlmock.New() 79 if err != nil { 80 t.Fatal(err) 81 } 82 } 83 84 func TestMerger(t *testing.T) { 85 suite.Run(t, &MergerSuite{}) 86 } 87 88 func (ms *MergerSuite) TestAggregatorMerger_Merge() { 89 testcases := []struct { 90 name string 91 aggregators []aggregator.Aggregator 92 rowsList []rows.Rows 93 GroupByColumns []merger.ColumnInfo 94 wantErr error 95 ctx func() (context.Context, context.CancelFunc) 96 }{ 97 { 98 name: "正常案例", 99 aggregators: []aggregator.Aggregator{ 100 aggregator.NewCount(merger.NewColumnInfo(2, "id")), 101 }, 102 GroupByColumns: []merger.ColumnInfo{ 103 merger.NewColumnInfo(0, "county"), 104 merger.NewColumnInfo(1, "gender"), 105 }, 106 rowsList: func() []rows.Rows { 107 query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" 108 cols := []string{"county", "gender", "SUM(id)"} 109 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) 110 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) 111 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) 112 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 113 rowsList := make([]rows.Rows, 0, len(dbs)) 114 for _, db := range dbs { 115 row, err := db.QueryContext(context.Background(), query) 116 require.NoError(ms.T(), err) 117 rowsList = append(rowsList, row) 118 } 119 return rowsList 120 }(), 121 122 ctx: func() (context.Context, context.CancelFunc) { 123 ctx, cancel := context.WithCancel(context.Background()) 124 return ctx, cancel 125 }, 126 }, 127 { 128 name: "超时", 129 aggregators: []aggregator.Aggregator{ 130 aggregator.NewCount(merger.NewColumnInfo(1, "id")), 131 }, 132 GroupByColumns: []merger.ColumnInfo{ 133 merger.NewColumnInfo(0, "user_name"), 134 }, 135 rowsList: func() []rows.Rows { 136 query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" 137 cols := []string{"user_name", "SUM(id)"} 138 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) 139 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) 140 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) 141 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 142 rowsList := make([]rows.Rows, 0, len(dbs)) 143 for _, db := range dbs { 144 row, err := db.QueryContext(context.Background(), query) 145 require.NoError(ms.T(), err) 146 rowsList = append(rowsList, row) 147 } 148 return rowsList 149 }(), 150 ctx: func() (context.Context, context.CancelFunc) { 151 ctx, cancel := context.WithTimeout(context.Background(), 0) 152 return ctx, cancel 153 }, 154 wantErr: context.DeadlineExceeded, 155 }, 156 { 157 name: "rowsList为空", 158 aggregators: []aggregator.Aggregator{ 159 aggregator.NewCount(merger.NewColumnInfo(1, "id")), 160 }, 161 GroupByColumns: []merger.ColumnInfo{ 162 merger.NewColumnInfo(0, "user_name"), 163 }, 164 rowsList: func() []rows.Rows { 165 return []rows.Rows{} 166 }(), 167 ctx: func() (context.Context, context.CancelFunc) { 168 ctx, cancel := context.WithCancel(context.Background()) 169 return ctx, cancel 170 }, 171 wantErr: errs.ErrMergerEmptyRows, 172 }, 173 { 174 name: "rowsList中有nil", 175 aggregators: []aggregator.Aggregator{ 176 aggregator.NewCount(merger.NewColumnInfo(1, "id")), 177 }, 178 GroupByColumns: []merger.ColumnInfo{ 179 merger.NewColumnInfo(0, "user_name"), 180 }, 181 rowsList: func() []rows.Rows { 182 return []rows.Rows{nil} 183 }(), 184 ctx: func() (context.Context, context.CancelFunc) { 185 ctx, cancel := context.WithCancel(context.Background()) 186 return ctx, cancel 187 }, 188 wantErr: errs.ErrMergerRowsIsNull, 189 }, 190 { 191 name: "rowsList中有sql.Rows返回错误", 192 aggregators: []aggregator.Aggregator{ 193 aggregator.NewCount(merger.NewColumnInfo(1, "id")), 194 }, 195 GroupByColumns: []merger.ColumnInfo{ 196 merger.NewColumnInfo(0, "user_name"), 197 }, 198 rowsList: func() []rows.Rows { 199 query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" 200 cols := []string{"user_name", "SUM(id)"} 201 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20).RowError(1, nextMockErr)) 202 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) 203 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) 204 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 205 rowsList := make([]rows.Rows, 0, len(dbs)) 206 for _, db := range dbs { 207 row, err := db.QueryContext(context.Background(), query) 208 require.NoError(ms.T(), err) 209 rowsList = append(rowsList, row) 210 } 211 return rowsList 212 }(), 213 ctx: func() (context.Context, context.CancelFunc) { 214 ctx, cancel := context.WithCancel(context.Background()) 215 return ctx, cancel 216 }, 217 wantErr: nextMockErr, 218 }, 219 } 220 for _, tc := range testcases { 221 ms.T().Run(tc.name, func(t *testing.T) { 222 merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) 223 ctx, cancel := tc.ctx() 224 groupByRows, err := merger.Merge(ctx, tc.rowsList) 225 cancel() 226 assert.Equal(t, tc.wantErr, err) 227 if err != nil { 228 return 229 } 230 require.NotNil(t, groupByRows) 231 }) 232 } 233 } 234 235 func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { 236 testcases := []struct { 237 name string 238 aggregators []aggregator.Aggregator 239 rowsList []rows.Rows 240 wantVal [][]any 241 gotVal [][]any 242 GroupByColumns []merger.ColumnInfo 243 wantErr error 244 }{ 245 { 246 name: "同一组数据在不同的sql.Rows中", 247 aggregators: []aggregator.Aggregator{ 248 aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")), 249 }, 250 GroupByColumns: []merger.ColumnInfo{ 251 merger.NewColumnInfo(0, "user_name"), 252 }, 253 rowsList: func() []rows.Rows { 254 query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" 255 cols := []string{"user_name", "SUM(id)"} 256 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) 257 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) 258 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) 259 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 260 rowsList := make([]rows.Rows, 0, len(dbs)) 261 for _, db := range dbs { 262 row, err := db.QueryContext(context.Background(), query) 263 require.NoError(ms.T(), err) 264 rowsList = append(rowsList, row) 265 } 266 return rowsList 267 }(), 268 wantVal: [][]any{ 269 {"zwl", int64(30)}, 270 {"dm", int64(40)}, 271 {"xz", int64(10)}, 272 }, 273 gotVal: [][]any{ 274 {"", int64(0)}, 275 {"", int64(0)}, 276 {"", int64(0)}, 277 }, 278 }, 279 { 280 name: "同一组数据在同一个sql.Rows中", 281 aggregators: []aggregator.Aggregator{ 282 aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")), 283 }, 284 GroupByColumns: []merger.ColumnInfo{ 285 merger.NewColumnInfo(0, "user_name"), 286 }, 287 rowsList: func() []rows.Rows { 288 query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" 289 cols := []string{"user_name", "SUM(id)"} 290 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("xm", 20)) 291 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("xx", 20)) 292 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) 293 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 294 rowsList := make([]rows.Rows, 0, len(dbs)) 295 for _, db := range dbs { 296 row, err := db.QueryContext(context.Background(), query) 297 require.NoError(ms.T(), err) 298 rowsList = append(rowsList, row) 299 } 300 return rowsList 301 }(), 302 wantVal: [][]any{ 303 {"zwl", int64(10)}, 304 {"xm", int64(20)}, 305 {"xz", int64(10)}, 306 {"xx", int64(20)}, 307 {"dm", int64(20)}, 308 }, 309 gotVal: [][]any{ 310 {"", int64(0)}, 311 {"", int64(0)}, 312 {"", int64(0)}, 313 {"", int64(0)}, 314 {"", int64(0)}, 315 }, 316 }, 317 { 318 name: "多个分组列", 319 aggregators: []aggregator.Aggregator{ 320 aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), 321 }, 322 GroupByColumns: []merger.ColumnInfo{ 323 merger.NewColumnInfo(0, "county"), 324 merger.NewColumnInfo(1, "gender"), 325 }, 326 rowsList: func() []rows.Rows { 327 query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" 328 cols := []string{"county", "gender", "SUM(id)"} 329 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) 330 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) 331 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) 332 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 333 rowsList := make([]rows.Rows, 0, len(dbs)) 334 for _, db := range dbs { 335 row, err := db.QueryContext(context.Background(), query) 336 require.NoError(ms.T(), err) 337 rowsList = append(rowsList, row) 338 } 339 return rowsList 340 }(), 341 wantVal: [][]any{ 342 { 343 "hangzhou", 344 "male", 345 int64(10), 346 }, 347 { 348 "hangzhou", 349 "female", 350 int64(80), 351 }, 352 { 353 "shanghai", 354 "female", 355 int64(160), 356 }, 357 { 358 "shanghai", 359 "male", 360 int64(110), 361 }, 362 }, 363 gotVal: [][]any{ 364 {"", "", int64(0)}, 365 {"", "", int64(0)}, 366 {"", "", int64(0)}, 367 {"", "", int64(0)}, 368 }, 369 }, 370 { 371 name: "多个聚合函数", 372 aggregators: []aggregator.Aggregator{ 373 aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), 374 aggregator.NewAVG(merger.NewColumnInfo(3, "SUM(age)"), merger.NewColumnInfo(4, "COUNT(age)"), "AVG(age)"), 375 }, 376 GroupByColumns: []merger.ColumnInfo{ 377 merger.NewColumnInfo(0, "county"), 378 merger.NewColumnInfo(1, "gender"), 379 }, 380 381 rowsList: func() []rows.Rows { 382 query := "SELECT `county`,`gender`,SUM(`id`),SUM(`age`),COUNT(`age`) FROM `t1` GROUP BY `country`,`gender`" 383 cols := []string{"county", "gender", "SUM(id)", "SUM(age)", "COUNT(age)"} 384 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10, 100, 2).AddRow("hangzhou", "female", 20, 120, 3).AddRow("shanghai", "female", 30, 90, 3)) 385 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40, 120, 5).AddRow("shanghai", "female", 50, 120, 4).AddRow("hangzhou", "female", 60, 150, 3)) 386 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70, 100, 5).AddRow("shanghai", "female", 80, 150, 5)) 387 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 388 rowsList := make([]rows.Rows, 0, len(dbs)) 389 for _, db := range dbs { 390 row, err := db.QueryContext(context.Background(), query) 391 require.NoError(ms.T(), err) 392 rowsList = append(rowsList, row) 393 } 394 return rowsList 395 }(), 396 wantVal: [][]any{ 397 { 398 "hangzhou", 399 "male", 400 int64(10), 401 float64(50), 402 }, 403 { 404 "hangzhou", 405 "female", 406 int64(80), 407 float64(45), 408 }, 409 { 410 "shanghai", 411 "female", 412 int64(160), 413 float64(30), 414 }, 415 { 416 "shanghai", 417 "male", 418 int64(110), 419 float64(22), 420 }, 421 }, 422 gotVal: [][]any{ 423 {"", "", int64(0), float64(0)}, 424 {"", "", int64(0), float64(0)}, 425 {"", "", int64(0), float64(0)}, 426 {"", "", int64(0), float64(0)}, 427 }, 428 }, 429 } 430 for _, tc := range testcases { 431 ms.T().Run(tc.name, func(t *testing.T) { 432 merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) 433 groupByRows, err := merger.Merge(context.Background(), tc.rowsList) 434 require.NoError(t, err) 435 436 idx := 0 437 for groupByRows.Next() { 438 if idx >= len(tc.gotVal) { 439 break 440 } 441 tmp := make([]any, 0, len(tc.gotVal[0])) 442 for i := range tc.gotVal[idx] { 443 tmp = append(tmp, &tc.gotVal[idx][i]) 444 } 445 err := groupByRows.Scan(tmp...) 446 require.NoError(t, err) 447 idx++ 448 } 449 require.NoError(t, groupByRows.Err()) 450 assert.Equal(t, tc.wantVal, tc.gotVal) 451 }) 452 } 453 } 454 455 func (ms *MergerSuite) TestAggregatorRows_ScanAndErr() { 456 ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) { 457 cols := []string{"userid", "SUM(id)"} 458 query := "SELECT userid,SUM(id) FROM `t1`" 459 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) 460 r, err := ms.mockDB01.QueryContext(context.Background(), query) 461 require.NoError(t, err) 462 rowsList := []rows.Rows{r} 463 merger := NewAggregatorMerger([]aggregator.Aggregator{aggregator.NewSum(merger.NewColumnInfo(1, "SUM(id)"))}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) 464 rows, err := merger.Merge(context.Background(), rowsList) 465 require.NoError(t, err) 466 userid := 0 467 sumId := 0 468 err = rows.Scan(&userid, &sumId) 469 assert.Equal(t, errs.ErrMergerScanNotNext, err) 470 }) 471 ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) { 472 cols := []string{"userid", "SUM(id)"} 473 query := "SELECT userid,SUM(id) FROM `t1`" 474 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) 475 r, err := ms.mockDB01.QueryContext(context.Background(), query) 476 require.NoError(t, err) 477 rowsList := []rows.Rows{r} 478 merger := NewAggregatorMerger([]aggregator.Aggregator{&mockAggregate{}}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) 479 rows, err := merger.Merge(context.Background(), rowsList) 480 require.NoError(t, err) 481 userid := 0 482 sumId := 0 483 rows.Next() 484 err = rows.Scan(&userid, &sumId) 485 assert.Equal(t, aggregatorErr, err) 486 }) 487 488 } 489 490 func (ms *MergerSuite) TestAggregatorRows_NextAndErr() { 491 testcases := []struct { 492 name string 493 rowsList func() []rows.Rows 494 wantErr error 495 aggregators []aggregator.Aggregator 496 GroupByColumns []merger.ColumnInfo 497 }{ 498 { 499 name: "有一个aggregator返回error", 500 rowsList: func() []rows.Rows { 501 cols := []string{"username", "COUNT(id)"} 502 query := "SELECT username,COUNT(`id`) FROM `t1`" 503 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 1)) 504 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("daming", 2)) 505 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("wu", 4)) 506 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("ming", 5)) 507 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} 508 rowsList := make([]rows.Rows, 0, len(dbs)) 509 for _, db := range dbs { 510 row, err := db.QueryContext(context.Background(), query) 511 require.NoError(ms.T(), err) 512 rowsList = append(rowsList, row) 513 } 514 return rowsList 515 }, 516 aggregators: func() []aggregator.Aggregator { 517 return []aggregator.Aggregator{ 518 &mockAggregate{}, 519 } 520 }(), 521 GroupByColumns: []merger.ColumnInfo{ 522 merger.NewColumnInfo(0, "username"), 523 }, 524 wantErr: aggregatorErr, 525 }, 526 } 527 for _, tc := range testcases { 528 ms.T().Run(tc.name, func(t *testing.T) { 529 merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) 530 rows, err := merger.Merge(context.Background(), tc.rowsList()) 531 require.NoError(t, err) 532 for rows.Next() { 533 } 534 count := int64(0) 535 name := "" 536 err = rows.Scan(&name, &count) 537 assert.Equal(t, tc.wantErr, err) 538 assert.Equal(t, tc.wantErr, rows.Err()) 539 }) 540 } 541 } 542 543 func (ms *MergerSuite) TestAggregatorRows_Columns() { 544 cols := []string{"userid", "SUM(grade)", "COUNT(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} 545 query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`id`),MIN(`id`),MAX(`id`),COUNT(`id`),`userid` FROM `t1`" 546 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 1, 2, 1, 3, 10, "zwl")) 547 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11, "dm")) 548 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12, "xm")) 549 aggregators := []aggregator.Aggregator{ 550 aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), 551 aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), 552 aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")), 553 aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")), 554 aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")), 555 } 556 groupbyColumns := []merger.ColumnInfo{ 557 merger.NewColumnInfo(6, "userid"), 558 } 559 merger := NewAggregatorMerger(aggregators, groupbyColumns) 560 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 561 rowsList := make([]rows.Rows, 0, len(dbs)) 562 for _, db := range dbs { 563 row, err := db.QueryContext(context.Background(), query) 564 require.NoError(ms.T(), err) 565 rowsList = append(rowsList, row) 566 } 567 568 rows, err := merger.Merge(context.Background(), rowsList) 569 require.NoError(ms.T(), err) 570 wantCols := []string{"userid", "AVG(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} 571 ms.T().Run("Next没有迭代完", func(t *testing.T) { 572 for rows.Next() { 573 columns, err := rows.Columns() 574 require.NoError(t, err) 575 assert.Equal(t, wantCols, columns) 576 } 577 require.NoError(t, rows.Err()) 578 }) 579 ms.T().Run("Next迭代完", func(t *testing.T) { 580 require.False(t, rows.Next()) 581 require.NoError(t, rows.Err()) 582 _, err := rows.Columns() 583 assert.Equal(t, errs.ErrMergerRowsClosed, err) 584 }) 585 } 586 587 type mockAggregate struct { 588 cols [][]any 589 } 590 591 func (m *mockAggregate) Aggregate(cols [][]any) (any, error) { 592 m.cols = cols 593 return nil, aggregatorErr 594 } 595 596 func (*mockAggregate) ColumnName() string { 597 return "mockAggregate" 598 } 599 600 func TestAggregatorRows_NextResultSet(t *testing.T) { 601 assert.False(t, (&AggregatorRows{}).NextResultSet()) 602 }