github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/sortmerger/merger_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sortmerger 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "fmt" 22 "testing" 23 24 _ "github.com/mattn/go-sqlite3" 25 26 "github.com/ecodeclub/eorm/internal/rows" 27 28 "github.com/DATA-DOG/go-sqlmock" 29 "github.com/ecodeclub/eorm/internal/merger/internal/errs" 30 "github.com/stretchr/testify/assert" 31 "github.com/stretchr/testify/require" 32 "github.com/stretchr/testify/suite" 33 "go.uber.org/multierr" 34 ) 35 36 var ( 37 nextMockErr error = errors.New("rows: MockNextErr") 38 ) 39 40 func newCloseMockErr(dbName string) error { 41 return fmt.Errorf("rows: %s MockCloseErr", dbName) 42 } 43 44 type MergerSuite struct { 45 suite.Suite 46 mockDB01 *sql.DB 47 mock01 sqlmock.Sqlmock 48 mockDB02 *sql.DB 49 mock02 sqlmock.Sqlmock 50 mockDB03 *sql.DB 51 mock03 sqlmock.Sqlmock 52 mockDB04 *sql.DB 53 mock04 sqlmock.Sqlmock 54 } 55 56 func (ms *MergerSuite) SetupTest() { 57 t := ms.T() 58 ms.initMock(t) 59 } 60 61 func (ms *MergerSuite) TearDownTest() { 62 _ = ms.mockDB01.Close() 63 _ = ms.mockDB02.Close() 64 _ = ms.mockDB03.Close() 65 _ = ms.mockDB04.Close() 66 } 67 68 func (ms *MergerSuite) initMock(t *testing.T) { 69 var err error 70 ms.mockDB01, ms.mock01, err = sqlmock.New() 71 if err != nil { 72 t.Fatal(err) 73 } 74 ms.mockDB02, ms.mock02, err = sqlmock.New() 75 if err != nil { 76 t.Fatal(err) 77 } 78 ms.mockDB03, ms.mock03, err = sqlmock.New() 79 if err != nil { 80 t.Fatal(err) 81 } 82 ms.mockDB04, ms.mock04, err = sqlmock.New() 83 if err != nil { 84 t.Fatal(err) 85 } 86 } 87 88 func (ms *MergerSuite) TestMerger_New() { 89 testcases := []struct { 90 name string 91 wantErr error 92 wantSortColumn func() []SortColumn 93 sortCols []SortColumn 94 }{ 95 { 96 name: "正常案例", 97 wantSortColumn: func() []SortColumn { 98 sortCol := NewSortColumn("id", ASC) 99 return []SortColumn{sortCol} 100 }, 101 sortCols: []SortColumn{ 102 NewSortColumn("id", ASC), 103 }, 104 }, 105 { 106 name: "空的排序列表", 107 sortCols: []SortColumn{}, 108 wantErr: errs.ErrEmptySortColumns, 109 }, 110 { 111 name: "排序列重复", 112 sortCols: []SortColumn{ 113 NewSortColumn("id", ASC), 114 NewSortColumn("id", DESC), 115 }, 116 wantErr: errs.NewRepeatSortColumn("id"), 117 }, 118 } 119 for _, tc := range testcases { 120 ms.T().Run(tc.name, func(t *testing.T) { 121 mer, err := NewMerger(tc.sortCols...) 122 assert.Equal(t, tc.wantErr, err) 123 if err != nil { 124 return 125 } 126 assert.Equal(t, tc.wantSortColumn(), mer.sortColumns.columns) 127 }) 128 } 129 } 130 131 func (ms *MergerSuite) TestMerger_Merge() { 132 testcases := []struct { 133 name string 134 merger func() (*Merger, error) 135 ctx func() (context.Context, context.CancelFunc) 136 wantErr error 137 sqlRows func() []rows.Rows 138 }{ 139 { 140 name: "sqlRows字段不同", 141 merger: func() (*Merger, error) { 142 return NewMerger(NewSortColumn("id", ASC)) 143 }, 144 ctx: func() (context.Context, context.CancelFunc) { 145 return context.WithCancel(context.Background()) 146 }, 147 sqlRows: func() []rows.Rows { 148 query := "SELECT * FROM `t1`" 149 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) 150 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "email"}).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 151 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} 152 rowsList := make([]rows.Rows, 0, len(dbs)) 153 for _, db := range dbs { 154 row, err := db.QueryContext(context.Background(), query) 155 require.NoError(ms.T(), err) 156 rowsList = append(rowsList, row) 157 } 158 return rowsList 159 }, 160 wantErr: errs.ErrMergerRowsDiff, 161 }, 162 { 163 name: "sqlRows字段不同_少一个字段", 164 merger: func() (*Merger, error) { 165 return NewMerger(NewSortColumn("id", ASC)) 166 }, 167 ctx: func() (context.Context, context.CancelFunc) { 168 return context.WithCancel(context.Background()) 169 }, 170 sqlRows: func() []rows.Rows { 171 query := "SELECT * FROM `t1`" 172 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "address"}).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) 173 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(3, "alex").AddRow(4, "x")) 174 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02} 175 rowsList := make([]rows.Rows, 0, len(dbs)) 176 for _, db := range dbs { 177 row, err := db.QueryContext(context.Background(), query) 178 require.NoError(ms.T(), err) 179 rowsList = append(rowsList, row) 180 } 181 return rowsList 182 }, 183 wantErr: errs.ErrMergerRowsDiff, 184 }, 185 { 186 name: "超时", 187 merger: func() (*Merger, error) { 188 return NewMerger(NewSortColumn("id", ASC)) 189 }, 190 ctx: func() (context.Context, context.CancelFunc) { 191 ctx, cancel := context.WithTimeout(context.Background(), 0) 192 return ctx, cancel 193 }, 194 wantErr: context.DeadlineExceeded, 195 sqlRows: func() []rows.Rows { 196 query := "SELECT * FROM `t1`;" 197 cols := []string{"id"} 198 res := make([]rows.Rows, 0, 1) 199 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) 200 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 201 res = append(res, rows) 202 return res 203 }, 204 }, 205 { 206 name: "sqlRows列表为空", 207 ctx: func() (context.Context, context.CancelFunc) { 208 return context.WithCancel(context.Background()) 209 }, 210 merger: func() (*Merger, error) { 211 return NewMerger(NewSortColumn("id", ASC)) 212 }, 213 sqlRows: func() []rows.Rows { 214 return []rows.Rows{} 215 }, 216 wantErr: errs.ErrMergerEmptyRows, 217 }, 218 { 219 name: "sqlRows列表有nil", 220 merger: func() (*Merger, error) { 221 return NewMerger(NewSortColumn("id", ASC)) 222 }, 223 ctx: func() (context.Context, context.CancelFunc) { 224 return context.WithCancel(context.Background()) 225 }, 226 sqlRows: func() []rows.Rows { 227 return []rows.Rows{nil} 228 }, 229 wantErr: errs.ErrMergerRowsIsNull, 230 }, 231 { 232 name: "数据库列集: id;排序列集: age", 233 merger: func() (*Merger, error) { 234 return NewMerger(NewSortColumn("age", ASC)) 235 }, 236 sqlRows: func() []rows.Rows { 237 query := "SELECT * FROM `t1`;" 238 cols := []string{"id"} 239 res := make([]rows.Rows, 0, 1) 240 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) 241 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 242 res = append(res, rows) 243 return res 244 }, 245 ctx: func() (context.Context, context.CancelFunc) { 246 return context.WithCancel(context.Background()) 247 }, 248 wantErr: errs.NewInvalidSortColumn("age"), 249 }, 250 { 251 name: "数据库列集: id;排序列集: id,age", 252 merger: func() (*Merger, error) { 253 return NewMerger(NewSortColumn("id", ASC), NewSortColumn("age", ASC)) 254 }, 255 sqlRows: func() []rows.Rows { 256 query := "SELECT * FROM `t1`;" 257 cols := []string{"id"} 258 res := make([]rows.Rows, 0, 1) 259 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) 260 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 261 res = append(res, rows) 262 return res 263 }, 264 ctx: func() (context.Context, context.CancelFunc) { 265 return context.WithCancel(context.Background()) 266 }, 267 wantErr: errs.NewInvalidSortColumn("age"), 268 }, 269 { 270 name: "数据库列集: id,name,address;排序列集: age", 271 merger: func() (*Merger, error) { 272 return NewMerger(NewSortColumn("age", ASC)) 273 }, 274 sqlRows: func() []rows.Rows { 275 query := "SELECT * FROM `t1`;" 276 cols := []string{"id", "name", "address"} 277 res := make([]rows.Rows, 0, 1) 278 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) 279 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 280 res = append(res, rows) 281 return res 282 }, 283 ctx: func() (context.Context, context.CancelFunc) { 284 return context.WithCancel(context.Background()) 285 }, 286 wantErr: errs.NewInvalidSortColumn("age"), 287 }, 288 { 289 name: "数据库列集: id,name,address;排序列集: id,age,name", 290 merger: func() (*Merger, error) { 291 return NewMerger(NewSortColumn("id", ASC), NewSortColumn("age", ASC), NewSortColumn("name", ASC)) 292 }, 293 sqlRows: func() []rows.Rows { 294 query := "SELECT * FROM `t1`;" 295 cols := []string{"id", "name", "address"} 296 res := make([]rows.Rows, 0, 1) 297 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) 298 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 299 res = append(res, rows) 300 return res 301 }, 302 ctx: func() (context.Context, context.CancelFunc) { 303 return context.WithCancel(context.Background()) 304 }, 305 wantErr: errs.NewInvalidSortColumn("age"), 306 }, 307 { 308 name: "数据库列集: id,name,address;排序列集: id,name,age", 309 merger: func() (*Merger, error) { 310 return NewMerger(NewSortColumn("id", ASC), NewSortColumn("name", ASC), NewSortColumn("age", ASC)) 311 }, 312 sqlRows: func() []rows.Rows { 313 query := "SELECT * FROM `t1`;" 314 cols := []string{"id", "name", "address"} 315 res := make([]rows.Rows, 0, 1) 316 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) 317 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 318 res = append(res, rows) 319 return res 320 }, 321 ctx: func() (context.Context, context.CancelFunc) { 322 return context.WithCancel(context.Background()) 323 }, 324 wantErr: errs.NewInvalidSortColumn("age"), 325 }, 326 { 327 name: "数据库列集: id ;排序列集: id", 328 merger: func() (*Merger, error) { 329 return NewMerger(NewSortColumn("id", ASC)) 330 }, 331 sqlRows: func() []rows.Rows { 332 query := "SELECT * FROM `t1`;" 333 cols := []string{"id"} 334 res := make([]rows.Rows, 0, 1) 335 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) 336 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 337 res = append(res, rows) 338 return res 339 }, 340 ctx: func() (context.Context, context.CancelFunc) { 341 return context.WithCancel(context.Background()) 342 }, 343 }, 344 { 345 name: "数据库列集: id,age;排序列集: id,age", 346 merger: func() (*Merger, error) { 347 return NewMerger(NewSortColumn("id", ASC), NewSortColumn("age", ASC)) 348 }, 349 sqlRows: func() []rows.Rows { 350 query := "SELECT * FROM `t1`;" 351 cols := []string{"id", "age"} 352 res := make([]rows.Rows, 0, 1) 353 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 18)) 354 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 355 res = append(res, rows) 356 return res 357 }, 358 ctx: func() (context.Context, context.CancelFunc) { 359 return context.WithCancel(context.Background()) 360 }, 361 }, 362 { 363 name: "数据库列集: id,name,address;排序列集: id,name", 364 merger: func() (*Merger, error) { 365 return NewMerger(NewSortColumn("id", ASC), NewSortColumn("name", ASC)) 366 }, 367 sqlRows: func() []rows.Rows { 368 query := "SELECT * FROM `t1`;" 369 cols := []string{"id", "name", "address"} 370 res := make([]rows.Rows, 0, 1) 371 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh")) 372 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 373 res = append(res, rows) 374 return res 375 }, 376 ctx: func() (context.Context, context.CancelFunc) { 377 return context.WithCancel(context.Background()) 378 }, 379 }, 380 { 381 name: "初始化Rows错误", 382 merger: func() (*Merger, error) { 383 return NewMerger(NewSortColumn("id", ASC)) 384 }, 385 sqlRows: func() []rows.Rows { 386 query := "SELECT * FROM `t1`;" 387 cols := []string{"id", "name", "address"} 388 res := make([]rows.Rows, 0, 1) 389 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "zwl", "sh").RowError(0, nextMockErr)) 390 rows, _ := ms.mockDB01.QueryContext(context.Background(), query) 391 res = append(res, rows) 392 return res 393 }, 394 wantErr: nextMockErr, 395 ctx: func() (context.Context, context.CancelFunc) { 396 return context.WithCancel(context.Background()) 397 }, 398 }, 399 } 400 for _, tc := range testcases { 401 ms.T().Run(tc.name, func(t *testing.T) { 402 merger, err := tc.merger() 403 require.NoError(ms.T(), err) 404 ctx, cancel := tc.ctx() 405 rows, err := merger.Merge(ctx, tc.sqlRows()) 406 cancel() 407 assert.Equal(t, tc.wantErr, err) 408 if err != nil { 409 return 410 } 411 require.NotNil(t, rows) 412 }) 413 414 } 415 } 416 417 func (ms *MergerSuite) TestRows_NextAndScan() { 418 testCases := []struct { 419 name string 420 sqlRows func() []rows.Rows 421 wantVal []TestModel 422 sortColumns []SortColumn 423 wantErr error 424 }{ 425 { 426 name: "完全交叉读,sqlRows返回行数相同", 427 sqlRows: func() []rows.Rows { 428 cols := []string{"id", "name", "address"} 429 query := "SELECT * FROM `t1`" 430 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) 431 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 432 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) 433 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 434 rowsList := make([]rows.Rows, 0, len(dbs)) 435 for _, db := range dbs { 436 row, err := db.QueryContext(context.Background(), query) 437 require.NoError(ms.T(), err) 438 rowsList = append(rowsList, row) 439 } 440 return rowsList 441 }, 442 wantVal: []TestModel{ 443 { 444 Id: 1, 445 Name: "abex", 446 Address: "cn", 447 }, 448 { 449 Id: 2, 450 Name: "a", 451 Address: "cn", 452 }, 453 { 454 Id: 3, 455 Name: "alex", 456 Address: "cn", 457 }, 458 { 459 Id: 4, 460 Name: "x", 461 Address: "cn", 462 }, 463 { 464 Id: 5, 465 Name: "bruce", 466 Address: "cn", 467 }, 468 { 469 Id: 7, 470 Name: "b", 471 Address: "cn", 472 }, 473 }, 474 sortColumns: []SortColumn{ 475 NewSortColumn("id", ASC), 476 }, 477 }, 478 { 479 name: "完全交叉读,sqlRows返回行数部分不同", 480 sqlRows: func() []rows.Rows { 481 cols := []string{"id", "name", "address"} 482 query := "SELECT * FROM `t1`" 483 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(6, "x", "cn").AddRow(1, "x", "cn")) 484 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(8, "alex", "cn").AddRow(4, "bruce", "cn")) 485 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(9, "a", "cn").AddRow(5, "abex", "cn")) 486 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 487 rowsList := make([]rows.Rows, 0, len(dbs)) 488 for _, db := range dbs { 489 row, err := db.QueryContext(context.Background(), query) 490 require.NoError(ms.T(), err) 491 rowsList = append(rowsList, row) 492 } 493 return rowsList 494 }, 495 wantVal: []TestModel{ 496 { 497 Id: 9, 498 Name: "a", 499 Address: "cn", 500 }, 501 { 502 Id: 8, 503 Name: "alex", 504 Address: "cn", 505 }, 506 { 507 Id: 7, 508 Name: "b", 509 Address: "cn", 510 }, 511 { 512 Id: 6, 513 Name: "x", 514 Address: "cn", 515 }, 516 { 517 Id: 5, 518 Name: "abex", 519 Address: "cn", 520 }, 521 { 522 Id: 4, 523 Name: "bruce", 524 Address: "cn", 525 }, 526 { 527 Id: 1, 528 Name: "x", 529 Address: "cn", 530 }, 531 }, 532 sortColumns: []SortColumn{ 533 NewSortColumn("id", DESC), 534 }, 535 }, 536 { 537 // 包含一个sqlRows返回的行数为0,在前面 538 name: "完全交叉读,sqlRows返回行数完全不同", 539 sqlRows: func() []rows.Rows { 540 cols := []string{"id", "name", "address"} 541 query := "SELECT * FROM `t1`" 542 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "c", "cn").AddRow(2, "bruce", "cn").AddRow(2, "zwl", "cn")) 543 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "alex", "cn").AddRow(3, "x", "cn")) 544 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "c", "cn").AddRow(3, "b", "cn").AddRow(5, "c", "cn").AddRow(7, "c", "cn")) 545 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 546 dbs := []*sql.DB{ms.mockDB04, ms.mockDB01, ms.mockDB02, ms.mockDB03} 547 rowsList := make([]rows.Rows, 0, len(dbs)) 548 for _, db := range dbs { 549 row, err := db.QueryContext(context.Background(), query) 550 require.NoError(ms.T(), err) 551 rowsList = append(rowsList, row) 552 } 553 return rowsList 554 }, 555 wantVal: []TestModel{ 556 { 557 Id: 1, 558 Name: "alex", 559 Address: "cn", 560 }, 561 { 562 Id: 1, 563 Name: "c", 564 Address: "cn", 565 }, 566 { 567 Id: 2, 568 Name: "bruce", 569 Address: "cn", 570 }, 571 { 572 Id: 2, 573 Name: "c", 574 Address: "cn", 575 }, 576 { 577 Id: 2, 578 Name: "zwl", 579 Address: "cn", 580 }, 581 { 582 Id: 3, 583 Name: "b", 584 Address: "cn", 585 }, 586 { 587 Id: 3, 588 Name: "x", 589 Address: "cn", 590 }, 591 { 592 Id: 5, 593 Name: "c", 594 Address: "cn", 595 }, 596 { 597 Id: 7, 598 Name: "c", 599 Address: "cn", 600 }, 601 }, 602 sortColumns: []SortColumn{ 603 NewSortColumn("id", ASC), 604 NewSortColumn("name", ASC), 605 }, 606 }, 607 { 608 name: "部分交叉读,sqlRows返回行数相同", 609 sqlRows: func() []rows.Rows { 610 cols := []string{"id", "name", "address"} 611 query := "SELECT * FROM `t1`" 612 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) 613 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(5, "bruce", "cn")) 614 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(7, "b", "cn")) 615 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 616 rowsList := make([]rows.Rows, 0, len(dbs)) 617 for _, db := range dbs { 618 row, err := db.QueryContext(context.Background(), query) 619 require.NoError(ms.T(), err) 620 rowsList = append(rowsList, row) 621 } 622 return rowsList 623 }, 624 wantVal: []TestModel{ 625 { 626 Id: 1, 627 Name: "abex", 628 Address: "cn", 629 }, 630 { 631 Id: 2, 632 Name: "a", 633 Address: "cn", 634 }, 635 { 636 Id: 3, 637 Name: "alex", 638 Address: "cn", 639 }, 640 { 641 Id: 4, 642 Name: "x", 643 Address: "cn", 644 }, 645 { 646 Id: 5, 647 Name: "bruce", 648 Address: "cn", 649 }, 650 { 651 Id: 7, 652 Name: "b", 653 Address: "cn", 654 }, 655 }, 656 sortColumns: []SortColumn{ 657 NewSortColumn("id", ASC), 658 }, 659 }, 660 { 661 name: "部分交叉读,sqlRows返回行数部分相同", 662 sqlRows: func() []rows.Rows { 663 cols := []string{"id", "name", "address"} 664 query := "SELECT * FROM `t1`" 665 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn")) 666 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 667 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn").AddRow(8, "b", "cn")) 668 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 669 rowsList := make([]rows.Rows, 0, len(dbs)) 670 for _, db := range dbs { 671 row, err := db.QueryContext(context.Background(), query) 672 require.NoError(ms.T(), err) 673 rowsList = append(rowsList, row) 674 } 675 return rowsList 676 }, 677 wantVal: []TestModel{ 678 { 679 Id: 1, 680 Name: "abex", 681 Address: "cn", 682 }, 683 { 684 Id: 2, 685 Name: "a", 686 Address: "cn", 687 }, 688 { 689 Id: 3, 690 Name: "alex", 691 Address: "cn", 692 }, 693 { 694 Id: 4, 695 Name: "x", 696 Address: "cn", 697 }, 698 { 699 Id: 5, 700 Name: "bruce", 701 Address: "cn", 702 }, 703 { 704 Id: 7, 705 Name: "b", 706 Address: "cn", 707 }, 708 { 709 Id: 8, 710 Name: "b", 711 Address: "cn", 712 }, 713 }, 714 sortColumns: []SortColumn{ 715 NewSortColumn("id", ASC), 716 }, 717 }, 718 { 719 // 包含一个sqlRows返回的行数为0,在中间 720 name: "部分交叉读,sqlRows返回行数完全不同", 721 sqlRows: func() []rows.Rows { 722 cols := []string{"id", "name", "address"} 723 query := "SELECT * FROM `t1`" 724 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn").AddRow(5, "bruce", "cn")) 725 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 726 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(7, "b", "cn")) 727 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 728 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB04, ms.mockDB03} 729 rowsList := make([]rows.Rows, 0, len(dbs)) 730 for _, db := range dbs { 731 row, err := db.QueryContext(context.Background(), query) 732 require.NoError(ms.T(), err) 733 rowsList = append(rowsList, row) 734 } 735 return rowsList 736 }, 737 wantVal: []TestModel{ 738 { 739 Id: 1, 740 Name: "abex", 741 Address: "cn", 742 }, 743 { 744 Id: 2, 745 Name: "a", 746 Address: "cn", 747 }, 748 { 749 Id: 3, 750 Name: "alex", 751 Address: "cn", 752 }, 753 { 754 Id: 4, 755 Name: "x", 756 Address: "cn", 757 }, 758 { 759 Id: 5, 760 Name: "bruce", 761 Address: "cn", 762 }, 763 { 764 Id: 7, 765 Name: "b", 766 Address: "cn", 767 }, 768 }, 769 sortColumns: []SortColumn{ 770 NewSortColumn("id", ASC), 771 }, 772 }, 773 { 774 name: "顺序读,sqlRows返回行数相同", 775 sqlRows: func() []rows.Rows { 776 cols := []string{"id", "name", "address"} 777 query := "SELECT * FROM `t1`" 778 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) 779 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 780 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn").AddRow(7, "b", "cn")) 781 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 782 rowsList := make([]rows.Rows, 0, len(dbs)) 783 for _, db := range dbs { 784 row, err := db.QueryContext(context.Background(), query) 785 require.NoError(ms.T(), err) 786 rowsList = append(rowsList, row) 787 } 788 return rowsList 789 }, 790 wantVal: []TestModel{ 791 { 792 Id: 1, 793 Name: "abex", 794 Address: "cn", 795 }, 796 { 797 Id: 2, 798 Name: "a", 799 Address: "cn", 800 }, 801 { 802 Id: 3, 803 Name: "alex", 804 Address: "cn", 805 }, 806 { 807 Id: 4, 808 Name: "x", 809 Address: "cn", 810 }, 811 { 812 Id: 5, 813 Name: "bruce", 814 Address: "cn", 815 }, 816 { 817 Id: 7, 818 Name: "b", 819 Address: "cn", 820 }, 821 }, 822 sortColumns: []SortColumn{ 823 NewSortColumn("id", ASC), 824 }, 825 }, 826 { 827 name: "顺序读,sqlRows返回行数部分不同", 828 sqlRows: func() []rows.Rows { 829 cols := []string{"id", "name", "address"} 830 query := "SELECT * FROM `t1`" 831 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "a", "cn")) 832 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 833 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "bruce", "cn")) 834 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 835 rowsList := make([]rows.Rows, 0, len(dbs)) 836 for _, db := range dbs { 837 row, err := db.QueryContext(context.Background(), query) 838 require.NoError(ms.T(), err) 839 rowsList = append(rowsList, row) 840 } 841 return rowsList 842 }, 843 844 wantVal: []TestModel{ 845 { 846 Id: 1, 847 Name: "abex", 848 Address: "cn", 849 }, 850 { 851 Id: 2, 852 Name: "a", 853 Address: "cn", 854 }, 855 { 856 Id: 3, 857 Name: "alex", 858 Address: "cn", 859 }, 860 { 861 Id: 4, 862 Name: "x", 863 Address: "cn", 864 }, 865 { 866 Id: 5, 867 Name: "bruce", 868 Address: "cn", 869 }, 870 }, 871 sortColumns: []SortColumn{ 872 NewSortColumn("id", ASC), 873 }, 874 }, 875 { 876 // 包含一个sqlRows返回的行数为0,在后面 877 name: "顺序读,sqlRows返回行数完全不同", 878 sqlRows: func() []rows.Rows { 879 cols := []string{"id", "name", "address"} 880 query := "SELECT * FROM `t1`" 881 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn")) 882 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(3, "alex", "cn")) 883 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(4, "x", "cn").AddRow(5, "bruce", "cn").AddRow(7, "b", "cn")) 884 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 885 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} 886 rowsList := make([]rows.Rows, 0, len(dbs)) 887 for _, db := range dbs { 888 row, err := db.QueryContext(context.Background(), query) 889 require.NoError(ms.T(), err) 890 rowsList = append(rowsList, row) 891 } 892 return rowsList 893 }, 894 wantVal: []TestModel{ 895 { 896 Id: 1, 897 Name: "abex", 898 Address: "cn", 899 }, 900 { 901 Id: 2, 902 Name: "a", 903 Address: "cn", 904 }, 905 { 906 Id: 3, 907 Name: "alex", 908 Address: "cn", 909 }, 910 { 911 Id: 4, 912 Name: "x", 913 Address: "cn", 914 }, 915 { 916 Id: 5, 917 Name: "bruce", 918 Address: "cn", 919 }, 920 { 921 Id: 7, 922 Name: "b", 923 Address: "cn", 924 }, 925 }, 926 sortColumns: []SortColumn{ 927 NewSortColumn("id", ASC), 928 }, 929 }, 930 931 { 932 name: "所有sqlRows返回的行数均为空", 933 sqlRows: func() []rows.Rows { 934 cols := []string{"id", "name", "address"} 935 query := "SELECT * FROM `t1`" 936 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 937 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 938 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 939 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols)) 940 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} 941 rowsList := make([]rows.Rows, 0, len(dbs)) 942 for _, db := range dbs { 943 row, err := db.QueryContext(context.Background(), query) 944 require.NoError(ms.T(), err) 945 rowsList = append(rowsList, row) 946 } 947 return rowsList 948 }, 949 wantVal: []TestModel{}, 950 sortColumns: []SortColumn{ 951 NewSortColumn("id", ASC), 952 NewSortColumn("name", ASC), 953 }, 954 }, 955 { 956 name: "排序列返回的顺序和数据库里的字段顺序不一致", 957 sqlRows: func() []rows.Rows { 958 cols := []string{"id", "name", "address"} 959 query := "SELECT * FROM `t1`" 960 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "hz").AddRow(3, "b", "hz").AddRow(2, "b", "cs")) 961 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "a", "cs").AddRow(1, "a", "cs").AddRow(3, "e", "cn")) 962 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "d", "hm").AddRow(5, "k", "xx").AddRow(4, "k", "xz")) 963 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 964 rowsList := make([]rows.Rows, 0, len(dbs)) 965 for _, db := range dbs { 966 row, err := db.QueryContext(context.Background(), query) 967 require.NoError(ms.T(), err) 968 rowsList = append(rowsList, row) 969 } 970 return rowsList 971 }, 972 wantVal: []TestModel{ 973 { 974 Id: 3, 975 Name: "a", 976 Address: "cs", 977 }, 978 { 979 Id: 2, 980 Name: "a", 981 Address: "hz", 982 }, 983 { 984 Id: 1, 985 Name: "a", 986 Address: "cs", 987 }, 988 { 989 Id: 3, 990 Name: "b", 991 Address: "hz", 992 }, 993 { 994 Id: 2, 995 Name: "b", 996 Address: "cs", 997 }, 998 { 999 Id: 2, 1000 Name: "d", 1001 Address: "hm", 1002 }, 1003 { 1004 Id: 3, 1005 Name: "e", 1006 Address: "cn", 1007 }, 1008 { 1009 Id: 5, 1010 Name: "k", 1011 Address: "xx", 1012 }, 1013 { 1014 Id: 4, 1015 Name: "k", 1016 Address: "xz", 1017 }, 1018 }, 1019 sortColumns: []SortColumn{ 1020 NewSortColumn("name", ASC), 1021 NewSortColumn("id", DESC), 1022 }, 1023 }, 1024 } 1025 for _, tc := range testCases { 1026 ms.T().Run(tc.name, func(t *testing.T) { 1027 merger, err := NewMerger(tc.sortColumns...) 1028 require.NoError(t, err) 1029 rows, err := merger.Merge(context.Background(), tc.sqlRows()) 1030 require.NoError(t, err) 1031 res := make([]TestModel, 0, len(tc.wantVal)) 1032 for rows.Next() { 1033 t := TestModel{} 1034 err := rows.Scan(&t.Id, &t.Name, &t.Address) 1035 require.NoError(ms.T(), err) 1036 res = append(res, t) 1037 } 1038 require.True(t, rows.(*Rows).closed) 1039 assert.NoError(t, rows.Err()) 1040 assert.Equal(t, tc.wantVal, res) 1041 }) 1042 } 1043 1044 } 1045 1046 func (ms *MergerSuite) TestRows_Columns() { 1047 cols := []string{"id"} 1048 query := "SELECT * FROM `t1`" 1049 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) 1050 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) 1051 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) 1052 merger, err := NewMerger(NewSortColumn("id", DESC)) 1053 require.NoError(ms.T(), err) 1054 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 1055 rowsList := make([]rows.Rows, 0, len(dbs)) 1056 for _, db := range dbs { 1057 row, err := db.QueryContext(context.Background(), query) 1058 require.NoError(ms.T(), err) 1059 rowsList = append(rowsList, row) 1060 } 1061 1062 rows, err := merger.Merge(context.Background(), rowsList) 1063 require.NoError(ms.T(), err) 1064 ms.T().Run("Next没有迭代完", func(t *testing.T) { 1065 for rows.Next() { 1066 columns, err := rows.Columns() 1067 require.NoError(t, err) 1068 assert.Equal(t, cols, columns) 1069 } 1070 require.NoError(t, rows.Err()) 1071 }) 1072 ms.T().Run("Next迭代完", func(t *testing.T) { 1073 require.False(t, rows.Next()) 1074 require.NoError(t, rows.Err()) 1075 _, err := rows.Columns() 1076 assert.Equal(t, errs.ErrMergerRowsClosed, err) 1077 1078 }) 1079 1080 } 1081 1082 func (ms *MergerSuite) TestRows_Close() { 1083 cols := []string{"id"} 1084 query := "SELECT * FROM `t1`" 1085 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) 1086 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").CloseError(newCloseMockErr("db02"))) 1087 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) 1088 merger, err := NewMerger(NewSortColumn("id", DESC)) 1089 require.NoError(ms.T(), err) 1090 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 1091 rowsList := make([]rows.Rows, 0, len(dbs)) 1092 for _, db := range dbs { 1093 row, err := db.QueryContext(context.Background(), query) 1094 require.NoError(ms.T(), err) 1095 rowsList = append(rowsList, row) 1096 } 1097 rows, err := merger.Merge(context.Background(), rowsList) 1098 require.NoError(ms.T(), err) 1099 // 判断当前是可以正常读取的 1100 require.True(ms.T(), rows.Next()) 1101 var id int 1102 err = rows.Scan(&id) 1103 require.NoError(ms.T(), err) 1104 err = rows.Close() 1105 ms.T().Run("close返回multierror", func(t *testing.T) { 1106 assert.Equal(ms.T(), multierr.Combine(newCloseMockErr("db02"), newCloseMockErr("db03")), err) 1107 }) 1108 ms.T().Run("close之后Next返回false", func(t *testing.T) { 1109 for i := 0; i < len(rowsList); i++ { 1110 require.False(ms.T(), rowsList[i].Next()) 1111 } 1112 require.False(ms.T(), rows.Next()) 1113 }) 1114 ms.T().Run("close之后Scan返回迭代过程中的错误", func(t *testing.T) { 1115 var id int 1116 err := rows.Scan(&id) 1117 assert.Equal(t, errs.ErrMergerRowsClosed, err) 1118 }) 1119 ms.T().Run("close之后调用Columns方法返回错误", func(t *testing.T) { 1120 _, err := rows.Columns() 1121 require.Error(t, err) 1122 }) 1123 ms.T().Run("close多次是等效的", func(t *testing.T) { 1124 for i := 0; i < 4; i++ { 1125 err = rows.Close() 1126 require.NoError(t, err) 1127 } 1128 }) 1129 } 1130 1131 // 测试Next迭代过程中遇到错误 1132 func (ms *MergerSuite) TestRows_NextAndErr() { 1133 testcases := []struct { 1134 name string 1135 rowsList func() []rows.Rows 1136 wantErr error 1137 sortColumns []SortColumn 1138 }{ 1139 { 1140 name: "sqlRows列表中有一个返回error", 1141 rowsList: func() []rows.Rows { 1142 cols := []string{"id"} 1143 query := "SELECT * FROM `t1`" 1144 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) 1145 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) 1146 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").RowError(1, nextMockErr)) 1147 ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("5")) 1148 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} 1149 rowsList := make([]rows.Rows, 0, len(dbs)) 1150 for _, db := range dbs { 1151 row, err := db.QueryContext(context.Background(), query) 1152 require.NoError(ms.T(), err) 1153 rowsList = append(rowsList, row) 1154 } 1155 return rowsList 1156 }, 1157 sortColumns: []SortColumn{ 1158 NewSortColumn("id", ASC), 1159 }, 1160 wantErr: nextMockErr, 1161 }, 1162 } 1163 for _, tc := range testcases { 1164 ms.T().Run(tc.name, func(t *testing.T) { 1165 merger, err := NewMerger(tc.sortColumns...) 1166 require.NoError(t, err) 1167 rows, err := merger.Merge(context.Background(), tc.rowsList()) 1168 require.NoError(t, err) 1169 for rows.Next() { 1170 } 1171 assert.Equal(t, tc.wantErr, rows.Err()) 1172 }) 1173 } 1174 } 1175 1176 // Scan方法的一些边界情况的测试 1177 func (ms *MergerSuite) TestRows_ScanErr() { 1178 ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) { 1179 cols := []string{"id", "name", "address"} 1180 query := "SELECT * FROM `t1`" 1181 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) 1182 r, err := ms.mockDB01.QueryContext(context.Background(), query) 1183 require.NoError(t, err) 1184 rowsList := []rows.Rows{r} 1185 merger, err := NewMerger(NewSortColumn("id", DESC)) 1186 require.NoError(t, err) 1187 rows, err := merger.Merge(context.Background(), rowsList) 1188 require.NoError(t, err) 1189 model := TestModel{} 1190 err = rows.Scan(&model.Id, &model.Name, &model.Address) 1191 assert.Equal(t, errs.ErrMergerScanNotNext, err) 1192 }) 1193 ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) { 1194 cols := []string{"id", "name", "address"} 1195 query := "SELECT * FROM `t1`" 1196 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn").RowError(1, nextMockErr)) 1197 r, err := ms.mockDB01.QueryContext(context.Background(), query) 1198 require.NoError(t, err) 1199 rowsList := []rows.Rows{r} 1200 merger, err := NewMerger(NewSortColumn("id", DESC)) 1201 require.NoError(t, err) 1202 rows, err := merger.Merge(context.Background(), rowsList) 1203 require.NoError(t, err) 1204 for rows.Next() { 1205 } 1206 var model TestModel 1207 err = rows.Scan(&model.Id, &model.Name, &model.Address) 1208 assert.Equal(t, nextMockErr, err) 1209 }) 1210 1211 } 1212 1213 type TestModel struct { 1214 Id int 1215 Name string 1216 Address string 1217 } 1218 1219 func TestMerger(t *testing.T) { 1220 suite.Run(t, &MergerSuite{}) 1221 suite.Run(t, &NullableMergerSuite{}) 1222 } 1223 1224 type NullableMergerSuite struct { 1225 suite.Suite 1226 db01 *sql.DB 1227 db02 *sql.DB 1228 db03 *sql.DB 1229 } 1230 1231 func (ms *NullableMergerSuite) SetupSuite() { 1232 t := ms.T() 1233 query := "CREATE TABLE t1 (\n id int primary key,\n `age` int,\n \t`name` varchar(20)\n );\n" 1234 db01, err := sql.Open("sqlite3", "file:test01.db?cache=shared&mode=memory") 1235 if err != nil { 1236 t.Fatal(err) 1237 } 1238 ms.db01 = db01 1239 _, err = db01.ExecContext(context.Background(), query) 1240 if err != nil { 1241 t.Fatal(err) 1242 } 1243 db02, err := sql.Open("sqlite3", "file:test02.db?cache=shared&mode=memory") 1244 if err != nil { 1245 t.Fatal(err) 1246 } 1247 ms.db02 = db02 1248 _, err = db02.ExecContext(context.Background(), query) 1249 if err != nil { 1250 t.Fatal(err) 1251 } 1252 db03, err := sql.Open("sqlite3", "file:test03.db?cache=shared&mode=memory") 1253 if err != nil { 1254 t.Fatal(err) 1255 } 1256 ms.db03 = db03 1257 _, err = db03.ExecContext(context.Background(), query) 1258 if err != nil { 1259 t.Fatal(err) 1260 } 1261 } 1262 1263 func (ms *NullableMergerSuite) TearDownSuite() { 1264 _ = ms.db01.Close() 1265 _ = ms.db02.Close() 1266 _ = ms.db03.Close() 1267 } 1268 1269 func (ms *NullableMergerSuite) TestRows_Nullable() { 1270 testcases := []struct { 1271 name string 1272 rowsList func() []rows.Rows 1273 sortColumns []SortColumn 1274 wantErr error 1275 afterFunc func() 1276 wantVal []Nullable 1277 }{ 1278 { 1279 name: "多个nullable类型排序 age asc,name desc", 1280 rowsList: func() []rows.Rows { 1281 db1InsertSql := []string{ 1282 "insert into t1 (id, name) values (1, 'zwl')", 1283 "insert into t1 (id, age, name) values (2, 10, 'zwl')", 1284 "insert into t1 (id, age, name) values (3, 20, 'zwl')", 1285 "insert into t1 (id, age) values (4, 20)", 1286 } 1287 for _, sql := range db1InsertSql { 1288 _, err := ms.db01.ExecContext(context.Background(), sql) 1289 require.NoError(ms.T(), err) 1290 } 1291 db2InsertSql := []string{ 1292 "insert into t1 (id, age, name) values (5, 5, 'zwl')", 1293 "insert into t1 (id, age, name) values (6, 20, 'dm')", 1294 } 1295 for _, sql := range db2InsertSql { 1296 _, err := ms.db02.ExecContext(context.Background(), sql) 1297 require.NoError(ms.T(), err) 1298 } 1299 db3InsertSql := []string{ 1300 "insert into t1 (id, name) values (7, 'xq')", 1301 "insert into t1 (id, age) values (8, 5)", 1302 "insert into t1 (id, age,name) values (9, 10,'xx')", 1303 } 1304 for _, sql := range db3InsertSql { 1305 _, err := ms.db03.ExecContext(context.Background(), sql) 1306 require.NoError(ms.T(), err) 1307 } 1308 dbs := []*sql.DB{ms.db01, ms.db02, ms.db03} 1309 rowsList := make([]rows.Rows, 0, len(dbs)) 1310 query := "SELECT `id`, `age`,`name` FROM `t1` order by age asc,name desc" 1311 for _, db := range dbs { 1312 rows, err := db.QueryContext(context.Background(), query) 1313 require.NoError(ms.T(), err) 1314 rowsList = append(rowsList, rows) 1315 } 1316 return rowsList 1317 }, 1318 sortColumns: []SortColumn{ 1319 NewSortColumn("age", ASC), 1320 NewSortColumn("name", DESC), 1321 }, 1322 afterFunc: func() { 1323 dbs := []*sql.DB{ms.db01, ms.db02, ms.db03} 1324 for _, db := range dbs { 1325 _, err := db.Exec("DELETE FROM t1;") 1326 require.NoError(ms.T(), err) 1327 } 1328 }, 1329 wantVal: func() []Nullable { 1330 return []Nullable{ 1331 { 1332 Id: sql.NullInt64{Valid: true, Int64: 1}, 1333 Age: sql.NullInt64{Valid: false, Int64: 0}, 1334 Name: sql.NullString{Valid: true, String: "zwl"}, 1335 }, 1336 { 1337 Id: sql.NullInt64{Valid: true, Int64: 7}, 1338 Age: sql.NullInt64{Valid: false, Int64: 0}, 1339 Name: sql.NullString{Valid: true, String: "xq"}, 1340 }, 1341 { 1342 Id: sql.NullInt64{Valid: true, Int64: 5}, 1343 Age: sql.NullInt64{Valid: true, Int64: 5}, 1344 Name: sql.NullString{Valid: true, String: "zwl"}, 1345 }, 1346 { 1347 Id: sql.NullInt64{Valid: true, Int64: 8}, 1348 Age: sql.NullInt64{Valid: true, Int64: 5}, 1349 Name: sql.NullString{Valid: false, String: ""}, 1350 }, 1351 { 1352 Id: sql.NullInt64{Valid: true, Int64: 2}, 1353 Age: sql.NullInt64{Valid: true, Int64: 10}, 1354 Name: sql.NullString{Valid: true, String: "zwl"}, 1355 }, 1356 { 1357 Id: sql.NullInt64{Valid: true, Int64: 9}, 1358 Age: sql.NullInt64{Valid: true, Int64: 10}, 1359 Name: sql.NullString{Valid: true, String: "xx"}, 1360 }, 1361 { 1362 Id: sql.NullInt64{Valid: true, Int64: 3}, 1363 Age: sql.NullInt64{Valid: true, Int64: 20}, 1364 Name: sql.NullString{Valid: true, String: "zwl"}, 1365 }, 1366 { 1367 Id: sql.NullInt64{Valid: true, Int64: 6}, 1368 Age: sql.NullInt64{Valid: true, Int64: 20}, 1369 Name: sql.NullString{Valid: true, String: "dm"}, 1370 }, 1371 { 1372 Id: sql.NullInt64{Valid: true, Int64: 4}, 1373 Age: sql.NullInt64{Valid: true, Int64: 20}, 1374 Name: sql.NullString{Valid: false, String: ""}, 1375 }, 1376 } 1377 }(), 1378 }, 1379 } 1380 for _, tc := range testcases { 1381 ms.T().Run(tc.name, func(t *testing.T) { 1382 merger, err := NewMerger(tc.sortColumns...) 1383 require.NoError(t, err) 1384 rows, err := merger.Merge(context.Background(), tc.rowsList()) 1385 require.NoError(t, err) 1386 res := make([]Nullable, 0, len(tc.wantVal)) 1387 for rows.Next() { 1388 nullT := Nullable{} 1389 err := rows.Scan(&nullT.Id, &nullT.Age, &nullT.Name) 1390 require.NoError(ms.T(), err) 1391 res = append(res, nullT) 1392 } 1393 require.True(t, rows.(*Rows).closed) 1394 assert.NoError(t, rows.Err()) 1395 assert.Equal(t, tc.wantVal, res) 1396 tc.afterFunc() 1397 }) 1398 } 1399 } 1400 1401 type Nullable struct { 1402 Id sql.NullInt64 1403 Age sql.NullInt64 1404 Name sql.NullString 1405 } 1406 1407 func TestRows_NextResultSet(t *testing.T) { 1408 assert.False(t, (&Rows{}).NextResultSet()) 1409 }