github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/aggregatemerger/merger_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregatemerger 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "fmt" 22 "testing" 23 24 "github.com/ecodeclub/eorm/internal/rows" 25 _ "github.com/mattn/go-sqlite3" 26 27 "github.com/ecodeclub/eorm/internal/merger" 28 29 "github.com/DATA-DOG/go-sqlmock" 30 "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" 31 "github.com/ecodeclub/eorm/internal/merger/internal/errs" 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 "github.com/stretchr/testify/suite" 35 "go.uber.org/multierr" 36 ) 37 38 var ( 39 nextMockErr = errors.New("rows: MockNextErr") 40 aggregatorErr = errors.New("aggregator: MockAggregatorErr") 41 ) 42 43 func newCloseMockErr(dbName string) error { 44 return fmt.Errorf("rows: %s MockCloseErr", dbName) 45 } 46 47 type MergerSuite struct { 48 suite.Suite 49 mockDB01 *sql.DB 50 mock01 sqlmock.Sqlmock 51 mockDB02 *sql.DB 52 mock02 sqlmock.Sqlmock 53 mockDB03 *sql.DB 54 mock03 sqlmock.Sqlmock 55 mockDB04 *sql.DB 56 mock04 sqlmock.Sqlmock 57 db05 *sql.DB 58 } 59 60 func (ms *MergerSuite) SetupTest() { 61 t := ms.T() 62 ms.initMock(t) 63 } 64 65 func (ms *MergerSuite) TearDownTest() { 66 _ = ms.mockDB01.Close() 67 _ = ms.mockDB02.Close() 68 _ = ms.mockDB03.Close() 69 _ = ms.mockDB04.Close() 70 _ = ms.db05.Close() 71 } 72 73 func (ms *MergerSuite) initMock(t *testing.T) { 74 var err error 75 query := "CREATE TABLE t1" + 76 "(" + 77 " id INT PRIMARY KEY NOT NULL," + 78 " grade INT NOT NULL" + 79 ");" 80 ms.mockDB01, ms.mock01, err = sqlmock.New() 81 if err != nil { 82 t.Fatal(err) 83 } 84 ms.mockDB02, ms.mock02, err = sqlmock.New() 85 if err != nil { 86 t.Fatal(err) 87 } 88 ms.mockDB03, ms.mock03, err = sqlmock.New() 89 if err != nil { 90 t.Fatal(err) 91 } 92 ms.mockDB04, ms.mock04, err = sqlmock.New() 93 if err != nil { 94 t.Fatal(err) 95 } 96 db05, err := sql.Open("sqlite3", "file:test01.db?cache=shared&mode=memory") 97 if err != nil { 98 t.Fatal(err) 99 } 100 ms.db05 = db05 101 _, err = db05.ExecContext(context.Background(), query) 102 if err != nil { 103 t.Fatal(err) 104 } 105 } 106 107 func TestMerger(t *testing.T) { 108 suite.Run(t, &MergerSuite{}) 109 } 110 111 func (ms *MergerSuite) TestRows_NextAndScan() { 112 testcases := []struct { 113 name string 114 sqlRows func() []rows.Rows 115 wantVal []any 116 aggregators func() []aggregator.Aggregator 117 gotVal []any 118 wantErr error 119 }{ 120 { 121 name: "sqlite的ColumnType 使用了多级指针", 122 sqlRows: func() []rows.Rows { 123 query1 := "insert into `t1` values (1,10),(2,20),(3,30)" 124 _, err := ms.db05.ExecContext(context.Background(), query1) 125 require.NoError(ms.T(), err) 126 cols := []string{"SUM(id)"} 127 query := "SELECT SUM(`id`) FROM `t1`" 128 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) 129 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) 130 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) 131 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.db05} 132 rowsList := make([]rows.Rows, 0, len(dbs)) 133 for _, db := range dbs { 134 row, err := db.QueryContext(context.Background(), query) 135 require.NoError(ms.T(), err) 136 rowsList = append(rowsList, row) 137 } 138 return rowsList 139 }, 140 wantVal: []any{int64(66)}, 141 gotVal: func() []any { 142 return []any{ 143 0, 144 } 145 }(), 146 aggregators: func() []aggregator.Aggregator { 147 return []aggregator.Aggregator{ 148 aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), 149 } 150 }, 151 }, 152 { 153 name: "SUM(id)", 154 sqlRows: func() []rows.Rows { 155 cols := []string{"SUM(id)"} 156 query := "SELECT SUM(`id`) FROM `t1`" 157 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) 158 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) 159 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) 160 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 161 rowsList := make([]rows.Rows, 0, len(dbs)) 162 for _, db := range dbs { 163 row, err := db.QueryContext(context.Background(), query) 164 require.NoError(ms.T(), err) 165 rowsList = append(rowsList, row) 166 } 167 return rowsList 168 }, 169 wantVal: []any{int64(60)}, 170 gotVal: func() []any { 171 return []any{ 172 0, 173 } 174 }(), 175 aggregators: func() []aggregator.Aggregator { 176 return []aggregator.Aggregator{ 177 aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), 178 } 179 }, 180 }, 181 182 { 183 name: "MAX(id)", 184 sqlRows: func() []rows.Rows { 185 cols := []string{"MAX(id)"} 186 query := "SELECT MAX(`id`) FROM `t1`" 187 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) 188 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) 189 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) 190 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 191 rowsList := make([]rows.Rows, 0, len(dbs)) 192 for _, db := range dbs { 193 row, err := db.QueryContext(context.Background(), query) 194 require.NoError(ms.T(), err) 195 rowsList = append(rowsList, row) 196 } 197 return rowsList 198 }, 199 wantVal: []any{int64(30)}, 200 gotVal: func() []any { 201 return []any{ 202 0, 203 } 204 }(), 205 aggregators: func() []aggregator.Aggregator { 206 return []aggregator.Aggregator{ 207 aggregator.NewMax(merger.NewColumnInfo(0, "MAX(id)")), 208 } 209 }, 210 }, 211 { 212 name: "MIN(id)", 213 sqlRows: func() []rows.Rows { 214 cols := []string{"MIN(id)"} 215 query := "SELECT MIN(`id`) FROM `t1`" 216 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) 217 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) 218 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) 219 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 220 rowsList := make([]rows.Rows, 0, len(dbs)) 221 for _, db := range dbs { 222 row, err := db.QueryContext(context.Background(), query) 223 require.NoError(ms.T(), err) 224 rowsList = append(rowsList, row) 225 } 226 return rowsList 227 }, 228 wantVal: []any{int64(10)}, 229 gotVal: func() []any { 230 return []any{ 231 0, 232 } 233 }(), 234 aggregators: func() []aggregator.Aggregator { 235 return []aggregator.Aggregator{ 236 aggregator.NewMin(merger.NewColumnInfo(0, "MIN(id)")), 237 } 238 }, 239 }, 240 { 241 name: "COUNT(id)", 242 sqlRows: func() []rows.Rows { 243 cols := []string{"COUNT(id)"} 244 query := "SELECT COUNT(`id`) FROM `t1`" 245 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) 246 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) 247 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) 248 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 249 rowsList := make([]rows.Rows, 0, len(dbs)) 250 for _, db := range dbs { 251 row, err := db.QueryContext(context.Background(), query) 252 require.NoError(ms.T(), err) 253 rowsList = append(rowsList, row) 254 } 255 return rowsList 256 }, 257 wantVal: []any{int64(40)}, 258 gotVal: func() []any { 259 return []any{ 260 0, 261 } 262 }(), 263 aggregators: func() []aggregator.Aggregator { 264 return []aggregator.Aggregator{ 265 aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), 266 } 267 }, 268 }, 269 { 270 name: "AVG(grade)", 271 sqlRows: func() []rows.Rows { 272 cols := []string{"SUM(grade)", "COUNT(grade)"} 273 query := "SELECT SUM(`grade`),COUNT(`grade`) FROM `t1`" 274 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 10)) 275 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 20)) 276 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 10)) 277 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 278 rowsList := make([]rows.Rows, 0, len(dbs)) 279 for _, db := range dbs { 280 row, err := db.QueryContext(context.Background(), query) 281 require.NoError(ms.T(), err) 282 rowsList = append(rowsList, row) 283 } 284 return rowsList 285 }, 286 wantVal: []any{ 287 float64(150), 288 }, 289 gotVal: func() []any { 290 return []any{ 291 float64(0), 292 } 293 }(), 294 aggregators: func() []aggregator.Aggregator { 295 return []aggregator.Aggregator{ 296 aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), 297 } 298 }, 299 }, 300 // 下面为多个聚合函数组合的情况 301 302 // 1.每种聚合函数出现一次 303 { 304 name: "COUNT(id),MAX(id),MIN(id),SUM(id),AVG(grade)", 305 sqlRows: func() []rows.Rows { 306 cols := []string{"COUNT(id)", "MAX(id)", "MIN(id)", "SUM(id)", "SUM(grade)", "COUNT(grade)"} 307 query := "SELECT COUNT(`id`),MAX(`id`),MIN(`id`),SUM(`id`),SUM(`grade`),COUNT(`student`) FROM `t1`" 308 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 20, 1, 100, 2000, 20)) 309 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20, 30, 0, 200, 800, 10)) 310 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10, 40, 2, 300, 1800, 20)) 311 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 312 rowsList := make([]rows.Rows, 0, len(dbs)) 313 for _, db := range dbs { 314 row, err := db.QueryContext(context.Background(), query) 315 require.NoError(ms.T(), err) 316 rowsList = append(rowsList, row) 317 } 318 return rowsList 319 }, 320 wantVal: []any{ 321 int64(40), int64(40), int64(0), int64(600), float64(4600) / float64(50), 322 }, 323 gotVal: func() []any { 324 return []any{ 325 0, 0, 0, 0, float64(0), 326 } 327 }(), 328 aggregators: func() []aggregator.Aggregator { 329 return []aggregator.Aggregator{ 330 aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), 331 aggregator.NewMax(merger.NewColumnInfo(1, "MAX(id)")), 332 aggregator.NewMin(merger.NewColumnInfo(2, "MIN(id)")), 333 aggregator.NewSum(merger.NewColumnInfo(3, "SUM(id)")), 334 aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(5, "COUNT(grade)"), "AVG(grade)"), 335 } 336 }, 337 }, 338 // 2. 聚合函数出现一次或多次,会有相同的聚合函数类型,且相同的聚合函数类型会有连续出现,和不连续出现。 339 // 两个avg会包含sum列在前,和sum列在后的状态。并且有完全相同的列出现 340 { 341 name: "AVG(grade),SUM(grade),AVG(grade),MIN(id),MIN(userid),MAX(id),COUNT(id)", 342 sqlRows: func() []rows.Rows { 343 cols := []string{"SUM(grade)", "COUNT(grade)", "SUM(grade)", "COUNT(grade)", "SUM(grade)", "MIN(id)", "MIN(userid)", "MAX(id)", "COUNT(id)"} 344 query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`grade`),COUNT(`grade`),SUM(`grade`),MIN(`id`),MIN(`userid`),MAX(`id`),COUNT(`id`) FROM `t1`" 345 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2000, 20, 2000, 20, 2000, 10, 20, 200, 200)) 346 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1000, 10, 1000, 10, 1000, 20, 30, 300, 300)) 347 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(800, 10, 800, 10, 800, 5, 6, 100, 200)) 348 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 349 rowsList := make([]rows.Rows, 0, len(dbs)) 350 for _, db := range dbs { 351 row, err := db.QueryContext(context.Background(), query) 352 require.NoError(ms.T(), err) 353 rowsList = append(rowsList, row) 354 } 355 return rowsList 356 }, 357 wantVal: []any{ 358 float64(3800) / float64(40), int64(3800), float64(3800) / float64(40), int64(5), int64(6), int64(300), int64(700), 359 }, 360 gotVal: func() []any { 361 return []any{ 362 float64(0), 0, float64(0), 0, 0, 0, 0, 363 } 364 }(), 365 aggregators: func() []aggregator.Aggregator { 366 return []aggregator.Aggregator{ 367 aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), 368 aggregator.NewSum(merger.NewColumnInfo(2, "SUM(grade)")), 369 aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(3, "COUNT(grade)"), "AVG(grade)"), 370 aggregator.NewMin(merger.NewColumnInfo(5, "MIN(id)")), 371 aggregator.NewMin(merger.NewColumnInfo(6, "MIN(userid)")), 372 aggregator.NewMax(merger.NewColumnInfo(7, "MAX(id)")), 373 aggregator.NewCount(merger.NewColumnInfo(8, "COUNT(id)")), 374 } 375 }, 376 }, 377 378 // 下面为RowList为有元素返回的行数为空 379 380 // 1. Rows 列表中有一个Rows返回行数为空,在前面会返回错误 381 { 382 name: "RowsList有一个Rows为空,在前面", 383 sqlRows: func() []rows.Rows { 384 cols := []string{"SUM(id)"} 385 query := "SELECT SUM(`id`) FROM `t1`" 386 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) 387 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) 388 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) 389 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 390 dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} 391 rowsList := make([]rows.Rows, 0, len(dbs)) 392 for _, db := range dbs { 393 row, err := db.QueryContext(context.Background(), query) 394 require.NoError(ms.T(), err) 395 rowsList = append(rowsList, row) 396 } 397 return rowsList 398 }, 399 wantVal: []any{60}, 400 gotVal: func() []any { 401 return []any{ 402 0, 403 } 404 }(), 405 wantErr: errs.ErrMergerAggregateHasEmptyRows, 406 aggregators: func() []aggregator.Aggregator { 407 return []aggregator.Aggregator{ 408 aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), 409 } 410 }, 411 }, 412 // 2. Rows 列表中有一个Rows返回行数为空,在中间会返回错误 413 { 414 name: "RowsList有一个Rows为空,在中间", 415 sqlRows: func() []rows.Rows { 416 cols := []string{"SUM(id)"} 417 query := "SELECT SUM(`id`) FROM `t1`" 418 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) 419 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) 420 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) 421 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 422 dbs := []*sql.DB{ms.mockDB01, ms.mockDB04, ms.mockDB02, ms.mockDB03} 423 rowsList := make([]rows.Rows, 0, len(dbs)) 424 for _, db := range dbs { 425 row, err := db.QueryContext(context.Background(), query) 426 require.NoError(ms.T(), err) 427 rowsList = append(rowsList, row) 428 } 429 return rowsList 430 }, 431 wantVal: []any{60}, 432 gotVal: func() []any { 433 return []any{ 434 0, 435 } 436 }(), 437 wantErr: errs.ErrMergerAggregateHasEmptyRows, 438 aggregators: func() []aggregator.Aggregator { 439 return []aggregator.Aggregator{ 440 aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), 441 } 442 }, 443 }, 444 // 3. Rows 列表中有一个Rows返回行数为空,在后面会返回错误 445 { 446 name: "RowsList有一个Rows为空,在最后", 447 sqlRows: func() []rows.Rows { 448 cols := []string{"SUM(id)"} 449 query := "SELECT SUM(`id`) FROM `t1`" 450 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(10)) 451 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(20)) 452 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(30)) 453 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 454 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} 455 rowsList := make([]rows.Rows, 0, len(dbs)) 456 for _, db := range dbs { 457 row, err := db.QueryContext(context.Background(), query) 458 require.NoError(ms.T(), err) 459 rowsList = append(rowsList, row) 460 } 461 return rowsList 462 }, 463 wantVal: []any{60}, 464 gotVal: func() []any { 465 return []any{ 466 0, 467 } 468 }(), 469 wantErr: errs.ErrMergerAggregateHasEmptyRows, 470 aggregators: func() []aggregator.Aggregator { 471 return []aggregator.Aggregator{ 472 aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), 473 } 474 }, 475 }, 476 // 4. Rows 列表中全部Rows返回的行数为空,不会返回错误 477 { 478 name: "RowsList全部为空", 479 sqlRows: func() []rows.Rows { 480 cols := []string{"SUM(id)"} 481 query := "SELECT SUM(`id`) FROM `t1`" 482 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 483 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 484 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 485 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 486 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} 487 rowsList := make([]rows.Rows, 0, len(dbs)) 488 for _, db := range dbs { 489 row, err := db.QueryContext(context.Background(), query) 490 require.NoError(ms.T(), err) 491 rowsList = append(rowsList, row) 492 } 493 return rowsList 494 }, 495 wantErr: errs.ErrMergerAggregateHasEmptyRows, 496 aggregators: func() []aggregator.Aggregator { 497 return []aggregator.Aggregator{ 498 aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), 499 } 500 }, 501 }, 502 } 503 for _, tc := range testcases { 504 ms.T().Run(tc.name, func(t *testing.T) { 505 m := NewMerger(tc.aggregators()...) 506 rows, err := m.Merge(context.Background(), tc.sqlRows()) 507 require.NoError(t, err) 508 for rows.Next() { 509 kk := make([]any, 0, len(tc.gotVal)) 510 for i := 0; i < len(tc.gotVal); i++ { 511 kk = append(kk, &tc.gotVal[i]) 512 } 513 err = rows.Scan(kk...) 514 require.NoError(t, err) 515 } 516 assert.Equal(t, tc.wantErr, rows.Err()) 517 if rows.Err() != nil { 518 return 519 } 520 assert.Equal(t, tc.wantVal, tc.gotVal) 521 }) 522 } 523 } 524 525 func (ms *MergerSuite) TestRows_NextAndErr() { 526 testcases := []struct { 527 name string 528 rowsList func() []rows.Rows 529 wantErr error 530 aggregators []aggregator.Aggregator 531 }{ 532 { 533 name: "sqlRows列表中有一个返回error", 534 rowsList: func() []rows.Rows { 535 cols := []string{"COUNT(id)"} 536 query := "SELECT COUNT(`id`) FROM `t1`" 537 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) 538 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2)) 539 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4).RowError(0, nextMockErr)) 540 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5)) 541 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} 542 rowsList := make([]rows.Rows, 0, len(dbs)) 543 for _, db := range dbs { 544 row, err := db.QueryContext(context.Background(), query) 545 require.NoError(ms.T(), err) 546 rowsList = append(rowsList, row) 547 } 548 return rowsList 549 }, 550 aggregators: func() []aggregator.Aggregator { 551 return []aggregator.Aggregator{ 552 aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), 553 } 554 }(), 555 wantErr: nextMockErr, 556 }, 557 { 558 name: "有一个aggregator返回error", 559 rowsList: func() []rows.Rows { 560 cols := []string{"COUNT(id)"} 561 query := "SELECT COUNT(`id`) FROM `t1`" 562 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) 563 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2)) 564 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4)) 565 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5)) 566 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} 567 rowsList := make([]rows.Rows, 0, len(dbs)) 568 for _, db := range dbs { 569 row, err := db.QueryContext(context.Background(), query) 570 require.NoError(ms.T(), err) 571 rowsList = append(rowsList, row) 572 } 573 return rowsList 574 }, 575 aggregators: func() []aggregator.Aggregator { 576 return []aggregator.Aggregator{ 577 &mockAggregate{}, 578 } 579 }(), 580 wantErr: aggregatorErr, 581 }, 582 } 583 for _, tc := range testcases { 584 ms.T().Run(tc.name, func(t *testing.T) { 585 merger := NewMerger(tc.aggregators...) 586 rows, err := merger.Merge(context.Background(), tc.rowsList()) 587 require.NoError(t, err) 588 for rows.Next() { 589 } 590 count := int64(0) 591 err = rows.Scan(&count) 592 assert.Equal(t, tc.wantErr, err) 593 assert.Equal(t, tc.wantErr, rows.Err()) 594 }) 595 } 596 } 597 598 func (ms *MergerSuite) TestRows_Close() { 599 cols := []string{"SUM(id)"} 600 query := "SELECT SUM(`id`) FROM `t1`" 601 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) 602 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2).CloseError(newCloseMockErr("db02"))) 603 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3).CloseError(newCloseMockErr("db03"))) 604 merger := NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) 605 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 606 rowsList := make([]rows.Rows, 0, len(dbs)) 607 for _, db := range dbs { 608 row, err := db.QueryContext(context.Background(), query) 609 require.NoError(ms.T(), err) 610 rowsList = append(rowsList, row) 611 } 612 rows, err := merger.Merge(context.Background(), rowsList) 613 require.NoError(ms.T(), err) 614 // 判断当前是可以正常读取的 615 require.True(ms.T(), rows.Next()) 616 var id int 617 err = rows.Scan(&id) 618 require.NoError(ms.T(), err) 619 err = rows.Close() 620 ms.T().Run("close返回multierror", func(t *testing.T) { 621 assert.Equal(ms.T(), multierr.Combine(newCloseMockErr("db02"), newCloseMockErr("db03")), err) 622 }) 623 ms.T().Run("close之后Next返回false", func(t *testing.T) { 624 for i := 0; i < len(rowsList); i++ { 625 require.False(ms.T(), rowsList[i].Next()) 626 } 627 require.False(ms.T(), rows.Next()) 628 }) 629 ms.T().Run("close之后Scan返回迭代过程中的错误", func(t *testing.T) { 630 var id int 631 err := rows.Scan(&id) 632 assert.Equal(t, errs.ErrMergerRowsClosed, err) 633 }) 634 ms.T().Run("close之后调用Columns方法返回错误", func(t *testing.T) { 635 _, err := rows.Columns() 636 require.Error(t, err) 637 }) 638 ms.T().Run("close多次是等效的", func(t *testing.T) { 639 for i := 0; i < 4; i++ { 640 err = rows.Close() 641 require.NoError(t, err) 642 } 643 }) 644 } 645 646 func (ms *MergerSuite) TestRows_Columns() { 647 cols := []string{"SUM(grade)", "COUNT(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} 648 query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`id`),MIN(`id`),MAX(`id`),COUNT(`id`) FROM `t1`" 649 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 1, 2, 1, 3, 10)) 650 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11)) 651 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12)) 652 aggregators := []aggregator.Aggregator{ 653 aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), 654 aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), 655 aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")), 656 aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")), 657 aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")), 658 } 659 merger := NewMerger(aggregators...) 660 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 661 rowsList := make([]rows.Rows, 0, len(dbs)) 662 for _, db := range dbs { 663 row, err := db.QueryContext(context.Background(), query) 664 require.NoError(ms.T(), err) 665 rowsList = append(rowsList, row) 666 } 667 668 rows, err := merger.Merge(context.Background(), rowsList) 669 require.NoError(ms.T(), err) 670 wantCols := []string{"AVG(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} 671 ms.T().Run("Next没有迭代完", func(t *testing.T) { 672 for rows.Next() { 673 columns, err := rows.Columns() 674 require.NoError(t, err) 675 assert.Equal(t, wantCols, columns) 676 } 677 require.NoError(t, rows.Err()) 678 }) 679 ms.T().Run("Next迭代完", func(t *testing.T) { 680 require.False(t, rows.Next()) 681 require.NoError(t, rows.Err()) 682 _, err := rows.Columns() 683 assert.Equal(t, errs.ErrMergerRowsClosed, err) 684 }) 685 } 686 687 func (ms *MergerSuite) TestMerger_Merge() { 688 testcases := []struct { 689 name string 690 merger func() *Merger 691 ctx func() (context.Context, context.CancelFunc) 692 wantErr error 693 sqlRows func() []rows.Rows 694 }{ 695 { 696 name: "超时", 697 merger: func() *Merger { 698 return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) 699 }, 700 ctx: func() (context.Context, context.CancelFunc) { 701 ctx, cancel := context.WithTimeout(context.Background(), 0) 702 return ctx, cancel 703 }, 704 wantErr: context.DeadlineExceeded, 705 sqlRows: func() []rows.Rows { 706 query := "SELECT SUM(`id`) FROM `t1`;" 707 cols := []string{"SUM(id)"} 708 res := make([]rows.Rows, 0, 1) 709 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) 710 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 711 res = append(res, rows) 712 return res 713 }, 714 }, 715 { 716 name: "sqlRows列表元素个数为0", 717 merger: func() *Merger { 718 return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) 719 }, 720 ctx: func() (context.Context, context.CancelFunc) { 721 ctx, cancel := context.WithCancel(context.Background()) 722 return ctx, cancel 723 }, 724 wantErr: errs.ErrMergerEmptyRows, 725 sqlRows: func() []rows.Rows { 726 return []rows.Rows{} 727 }, 728 }, 729 { 730 name: "sqlRows列表有nil", 731 merger: func() *Merger { 732 return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) 733 }, 734 ctx: func() (context.Context, context.CancelFunc) { 735 ctx, cancel := context.WithCancel(context.Background()) 736 return ctx, cancel 737 }, 738 wantErr: errs.ErrMergerRowsIsNull, 739 sqlRows: func() []rows.Rows { 740 return []rows.Rows{nil} 741 }, 742 }, 743 } 744 for _, tc := range testcases { 745 ms.T().Run(tc.name, func(t *testing.T) { 746 ctx, cancel := tc.ctx() 747 m := tc.merger() 748 r, err := m.Merge(ctx, tc.sqlRows()) 749 cancel() 750 assert.Equal(t, tc.wantErr, err) 751 if err != nil { 752 return 753 } 754 require.NotNil(t, r) 755 }) 756 } 757 } 758 759 type mockAggregate struct { 760 cols [][]any 761 } 762 763 func (m *mockAggregate) Aggregate(cols [][]any) (any, error) { 764 m.cols = cols 765 return nil, aggregatorErr 766 } 767 768 func (*mockAggregate) ColumnName() string { 769 return "mockAggregate" 770 } 771 772 func TestRows_NextResultSet(t *testing.T) { 773 assert.False(t, (&Rows{}).NextResultSet()) 774 }