github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/pagedmerger/merger_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package pagedmerger 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "fmt" 22 "testing" 23 24 "github.com/ecodeclub/eorm/internal/rows" 25 26 "github.com/DATA-DOG/go-sqlmock" 27 "github.com/ecodeclub/eorm/internal/merger" 28 "github.com/ecodeclub/eorm/internal/merger/internal/errs" 29 "github.com/ecodeclub/eorm/internal/merger/sortmerger" 30 "github.com/stretchr/testify/assert" 31 "github.com/stretchr/testify/require" 32 "github.com/stretchr/testify/suite" 33 "go.uber.org/multierr" 34 ) 35 36 var ( 37 offsetMockErr error = errors.New("rows: MockOffsetErr") 38 limitMockErr error = errors.New("rows: MockLimitErr") 39 ) 40 41 func newCloseMockErr(dbName string) error { 42 return fmt.Errorf("rows: %s MockCloseErr", dbName) 43 } 44 45 type MergerSuite struct { 46 suite.Suite 47 mockDB01 *sql.DB 48 mock01 sqlmock.Sqlmock 49 mockDB02 *sql.DB 50 mock02 sqlmock.Sqlmock 51 mockDB03 *sql.DB 52 mock03 sqlmock.Sqlmock 53 mockDB04 *sql.DB 54 mock04 sqlmock.Sqlmock 55 } 56 57 func (ms *MergerSuite) SetupTest() { 58 t := ms.T() 59 ms.initMock(t) 60 } 61 62 func (ms *MergerSuite) TearDownTest() { 63 _ = ms.mockDB01.Close() 64 _ = ms.mockDB02.Close() 65 _ = ms.mockDB03.Close() 66 _ = ms.mockDB04.Close() 67 } 68 69 func (ms *MergerSuite) initMock(t *testing.T) { 70 var err error 71 ms.mockDB01, ms.mock01, err = sqlmock.New() 72 if err != nil { 73 t.Fatal(err) 74 } 75 ms.mockDB02, ms.mock02, err = sqlmock.New() 76 if err != nil { 77 t.Fatal(err) 78 } 79 ms.mockDB03, ms.mock03, err = sqlmock.New() 80 if err != nil { 81 t.Fatal(err) 82 } 83 ms.mockDB04, ms.mock04, err = sqlmock.New() 84 if err != nil { 85 t.Fatal(err) 86 } 87 } 88 func (ms *MergerSuite) TestMerger_New() { 89 testcases := []struct { 90 name string 91 limit int 92 offset int 93 wantErr error 94 }{ 95 { 96 name: "limit 小于0", 97 limit: -1, 98 offset: 10, 99 wantErr: errs.ErrMergerInvalidLimitOrOffset, 100 }, 101 { 102 name: "limit 等于0", 103 limit: 0, 104 offset: 10, 105 wantErr: errs.ErrMergerInvalidLimitOrOffset, 106 }, 107 { 108 name: "offset 小于0", 109 limit: 0, 110 offset: -1, 111 wantErr: errs.ErrMergerInvalidLimitOrOffset, 112 }, 113 { 114 name: "limit 大于等于0,offset大于等于0", 115 limit: 10, 116 offset: 10, 117 }, 118 } 119 for _, tc := range testcases { 120 ms.T().Run(tc.name, func(t *testing.T) { 121 m, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 122 require.NoError(t, err) 123 limitMerger, err := NewMerger(m, tc.offset, tc.limit) 124 assert.Equal(t, tc.wantErr, err) 125 if err != nil { 126 return 127 } 128 require.NotNil(t, limitMerger) 129 }) 130 } 131 } 132 133 func (ms *MergerSuite) TestMerger_Merge() { 134 testcases := []struct { 135 name string 136 getMerger func() (merger.Merger, error) 137 GetRowsList func() []rows.Rows 138 wantErr error 139 ctx func() (context.Context, context.CancelFunc) 140 limit int 141 offset int 142 }{ 143 { 144 name: "limitMerger里的Merger的Merge出错", 145 getMerger: func() (merger.Merger, error) { 146 return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 147 }, 148 GetRowsList: func() []rows.Rows { 149 return []rows.Rows{} 150 }, 151 wantErr: errs.ErrMergerEmptyRows, 152 ctx: func() (context.Context, context.CancelFunc) { 153 return context.WithCancel(context.Background()) 154 }, 155 limit: 1, 156 offset: 0, 157 }, 158 { 159 name: "初始化游标出错", 160 getMerger: func() (merger.Merger, error) { 161 return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 162 }, 163 GetRowsList: func() []rows.Rows { 164 cols := []string{"id", "name", "address"} 165 query := "SELECT * FROM `t1`" 166 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn").RowError(1, offsetMockErr)) 167 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 168 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) 169 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 170 rowsList := make([]rows.Rows, 0, len(dbs)) 171 for _, db := range dbs { 172 row, err := db.QueryContext(context.Background(), query) 173 require.NoError(ms.T(), err) 174 rowsList = append(rowsList, row) 175 } 176 return rowsList 177 }, 178 wantErr: offsetMockErr, 179 ctx: func() (context.Context, context.CancelFunc) { 180 return context.WithCancel(context.Background()) 181 }, 182 limit: 10, 183 offset: 5, 184 }, 185 { 186 name: "offset的值超过返回的数据行数", 187 getMerger: func() (merger.Merger, error) { 188 return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 189 }, 190 GetRowsList: func() []rows.Rows { 191 cols := []string{"id", "name", "address"} 192 query := "SELECT * FROM `t1`" 193 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) 194 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 195 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) 196 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 197 rowsList := make([]rows.Rows, 0, len(dbs)) 198 for _, db := range dbs { 199 row, err := db.QueryContext(context.Background(), query) 200 require.NoError(ms.T(), err) 201 rowsList = append(rowsList, row) 202 } 203 return rowsList 204 }, 205 ctx: func() (context.Context, context.CancelFunc) { 206 return context.WithCancel(context.Background()) 207 }, 208 limit: 10, 209 offset: 10, 210 }, 211 { 212 name: "超时", 213 getMerger: func() (merger.Merger, error) { 214 return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 215 }, 216 GetRowsList: func() []rows.Rows { 217 cols := []string{"id", "name", "address"} 218 query := "SELECT * FROM `t1`" 219 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) 220 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 221 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) 222 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 223 rowsList := make([]rows.Rows, 0, len(dbs)) 224 for _, db := range dbs { 225 row, err := db.QueryContext(context.Background(), query) 226 require.NoError(ms.T(), err) 227 rowsList = append(rowsList, row) 228 } 229 return rowsList 230 }, 231 ctx: func() (context.Context, context.CancelFunc) { 232 return context.WithTimeout(context.Background(), 0) 233 }, 234 wantErr: context.DeadlineExceeded, 235 limit: 5, 236 offset: 0, 237 }, 238 } 239 for _, tc := range testcases { 240 ms.T().Run(tc.name, func(t *testing.T) { 241 merger, err := tc.getMerger() 242 require.NoError(t, err) 243 limitMerger, err := NewMerger(merger, tc.offset, tc.limit) 244 require.NoError(t, err) 245 require.NoError(t, err) 246 ctx, cancel := tc.ctx() 247 rows, err := limitMerger.Merge(ctx, tc.GetRowsList()) 248 cancel() 249 assert.Equal(t, tc.wantErr, err) 250 if err != nil { 251 return 252 } 253 require.NotNil(t, rows) 254 255 }) 256 } 257 } 258 259 func (ms *MergerSuite) TestMerger_NextAndScan() { 260 testcases := []struct { 261 name string 262 getMerger func() (merger.Merger, error) 263 GetRowsList func() []rows.Rows 264 wantVal []TestModel 265 limit int 266 offset int 267 }{ 268 { 269 name: "limit的行数超过了返回的总行数,", 270 getMerger: func() (merger.Merger, error) { 271 return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 272 }, 273 GetRowsList: func() []rows.Rows { 274 cols := []string{"id", "name", "address"} 275 query := "SELECT * FROM `t1`" 276 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) 277 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 278 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) 279 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 280 rowsList := make([]rows.Rows, 0, len(dbs)) 281 for _, db := range dbs { 282 row, err := db.QueryContext(context.Background(), query) 283 require.NoError(ms.T(), err) 284 rowsList = append(rowsList, row) 285 } 286 return rowsList 287 }, 288 wantVal: []TestModel{ 289 { 290 Id: 2, 291 Name: "a", 292 Address: "cn", 293 }, 294 { 295 Id: 3, 296 Name: "alex", 297 Address: "cn", 298 }, 299 { 300 Id: 4, 301 Name: "x", 302 Address: "cn", 303 }, 304 { 305 Id: 5, 306 Name: "bruce", 307 Address: "cn", 308 }, 309 { 310 Id: 7, 311 Name: "b", 312 Address: "cn", 313 }, 314 }, 315 limit: 100, 316 offset: 1, 317 }, 318 { 319 name: "limit 行数小于返回的总行数", 320 getMerger: func() (merger.Merger, error) { 321 return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 322 }, 323 GetRowsList: func() []rows.Rows { 324 cols := []string{"id", "name", "address"} 325 query := "SELECT * FROM `t1`" 326 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) 327 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 328 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) 329 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 330 rowsList := make([]rows.Rows, 0, len(dbs)) 331 for _, db := range dbs { 332 row, err := db.QueryContext(context.Background(), query) 333 require.NoError(ms.T(), err) 334 rowsList = append(rowsList, row) 335 } 336 return rowsList 337 }, 338 wantVal: []TestModel{ 339 { 340 Id: 2, 341 Name: "a", 342 Address: "cn", 343 }, 344 { 345 Id: 3, 346 Name: "alex", 347 Address: "cn", 348 }, 349 }, 350 limit: 2, 351 offset: 1, 352 }, 353 { 354 name: "offset超过sqlRows列表返回的总行数", 355 getMerger: func() (merger.Merger, error) { 356 return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 357 }, 358 GetRowsList: func() []rows.Rows { 359 cols := []string{"id", "name", "address"} 360 query := "SELECT * FROM `t1`" 361 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) 362 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 363 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) 364 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 365 rowsList := make([]rows.Rows, 0, len(dbs)) 366 for _, db := range dbs { 367 row, err := db.QueryContext(context.Background(), query) 368 require.NoError(ms.T(), err) 369 rowsList = append(rowsList, row) 370 } 371 return rowsList 372 }, 373 wantVal: []TestModel{}, 374 limit: 2, 375 offset: 100, 376 }, 377 { 378 name: "offset 的值为0", 379 getMerger: func() (merger.Merger, error) { 380 return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 381 }, 382 GetRowsList: func() []rows.Rows { 383 cols := []string{"id", "name", "address"} 384 query := "SELECT * FROM `t1`" 385 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) 386 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 387 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) 388 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 389 rowsList := make([]rows.Rows, 0, len(dbs)) 390 for _, db := range dbs { 391 row, err := db.QueryContext(context.Background(), query) 392 require.NoError(ms.T(), err) 393 rowsList = append(rowsList, row) 394 } 395 return rowsList 396 }, 397 wantVal: []TestModel{ 398 { 399 Id: 1, 400 Name: "abex", 401 Address: "cn", 402 }, 403 { 404 Id: 2, 405 Name: "a", 406 Address: "cn", 407 }, 408 { 409 Id: 3, 410 Name: "alex", 411 Address: "cn", 412 }, 413 { 414 Id: 4, 415 Name: "x", 416 Address: "cn", 417 }, 418 { 419 Id: 5, 420 Name: "bruce", 421 Address: "cn", 422 }, 423 { 424 Id: 7, 425 Name: "b", 426 Address: "cn", 427 }, 428 }, 429 limit: 10, 430 offset: 0, 431 }, 432 } 433 for _, tc := range testcases { 434 ms.T().Run(tc.name, func(t *testing.T) { 435 merger, err := tc.getMerger() 436 require.NoError(t, err) 437 limitMerger, err := NewMerger(merger, tc.offset, tc.limit) 438 require.NoError(t, err) 439 rows, err := limitMerger.Merge(context.Background(), tc.GetRowsList()) 440 require.NoError(t, err) 441 res := make([]TestModel, 0, len(tc.wantVal)) 442 for rows.Next() { 443 var model TestModel 444 err = rows.Scan(&model.Id, &model.Name, &model.Address) 445 require.NoError(t, err) 446 res = append(res, model) 447 } 448 require.True(t, rows.(*Rows).closed) 449 require.NoError(t, rows.Err()) 450 assert.Equal(t, tc.wantVal, res) 451 }) 452 } 453 } 454 455 func (ms *MergerSuite) TestRows_NextAndErr() { 456 testcases := []struct { 457 name string 458 getMerger func() (merger.Merger, error) 459 GetRowsList func() []rows.Rows 460 wantErr error 461 limit int 462 offset int 463 }{ 464 { 465 name: "有sql.Rows返回错误", 466 getMerger: func() (merger.Merger, error) { 467 return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 468 }, 469 GetRowsList: func() []rows.Rows { 470 cols := []string{"id", "name", "address"} 471 query := "SELECT * FROM `t1`" 472 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) 473 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) 474 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn").RowError(1, limitMockErr)) 475 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 476 rowsList := make([]rows.Rows, 0, len(dbs)) 477 for _, db := range dbs { 478 row, err := db.QueryContext(context.Background(), query) 479 require.NoError(ms.T(), err) 480 rowsList = append(rowsList, row) 481 } 482 return rowsList 483 }, 484 limit: 10, 485 offset: 1, 486 wantErr: limitMockErr, 487 }, 488 } 489 for _, tc := range testcases { 490 ms.T().Run(tc.name, func(t *testing.T) { 491 merger, err := tc.getMerger() 492 require.NoError(t, err) 493 limitMerger, err := NewMerger(merger, tc.offset, tc.limit) 494 require.NoError(t, err) 495 rows, err := limitMerger.Merge(context.Background(), tc.GetRowsList()) 496 require.NoError(t, err) 497 for rows.Next() { 498 } 499 require.True(t, rows.(*Rows).closed) 500 assert.Equal(t, tc.wantErr, rows.Err()) 501 }) 502 } 503 } 504 505 func (ms *MergerSuite) TestRows_ScanAndErr() { 506 ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) { 507 cols := []string{"id"} 508 query := "SELECT * FROM `t1`" 509 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) 510 r, err := ms.mockDB01.QueryContext(context.Background(), query) 511 require.NoError(t, err) 512 rowsList := []rows.Rows{r} 513 merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 514 require.NoError(t, err) 515 limitMerger, err := NewMerger(merger, 0, 1) 516 require.NoError(t, err) 517 rows, err := limitMerger.Merge(context.Background(), rowsList) 518 require.NoError(t, err) 519 id := 0 520 err = rows.Scan(&id) 521 require.Error(t, err) 522 }) 523 ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) { 524 cols := []string{"id"} 525 query := "SELECT * FROM `t1`" 526 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(2).RowError(1, limitMockErr)) 527 r, err := ms.mockDB01.QueryContext(context.Background(), query) 528 require.NoError(t, err) 529 rowsList := []rows.Rows{r} 530 merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 531 require.NoError(t, err) 532 limitMerger, err := NewMerger(merger, 0, 1) 533 require.NoError(t, err) 534 rows, err := limitMerger.Merge(context.Background(), rowsList) 535 require.NoError(t, err) 536 for rows.Next() { 537 } 538 id := 0 539 err = rows.Scan(&id) 540 assert.Equal(t, limitMockErr, err) 541 }) 542 } 543 544 func (ms *MergerSuite) TestRows_Close() { 545 cols := []string{"id"} 546 query := "SELECT * FROM `t1`" 547 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) 548 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").AddRow("5").CloseError(newCloseMockErr("db02"))) 549 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) 550 merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 551 require.NoError(ms.T(), err) 552 limitMerger, err := NewMerger(merger, 1, 6) 553 require.NoError(ms.T(), err) 554 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 555 rowsList := make([]rows.Rows, 0, len(dbs)) 556 for _, db := range dbs { 557 row, err := db.QueryContext(context.Background(), query) 558 require.NoError(ms.T(), err) 559 rowsList = append(rowsList, row) 560 } 561 rows, err := limitMerger.Merge(context.Background(), rowsList) 562 require.NoError(ms.T(), err) 563 // 判断当前是可以正常读取的 564 require.True(ms.T(), rows.Next()) 565 var id int 566 err = rows.Scan(&id) 567 require.NoError(ms.T(), err) 568 err = rows.Close() 569 ms.T().Run("close返回error", func(t *testing.T) { 570 assert.Equal(ms.T(), multierr.Combine(newCloseMockErr("db02"), newCloseMockErr("db03")), err) 571 }) 572 ms.T().Run("close之后Next返回false", func(t *testing.T) { 573 for i := 0; i < len(rowsList); i++ { 574 require.False(ms.T(), rowsList[i].Next()) 575 } 576 require.False(ms.T(), rows.Next()) 577 }) 578 ms.T().Run("close之后Scan返回迭代过程中的错误", func(t *testing.T) { 579 var id int 580 err := rows.Scan(&id) 581 assert.Equal(t, errs.ErrMergerRowsClosed, err) 582 }) 583 ms.T().Run("close之后调用Columns方法返回错误", func(t *testing.T) { 584 _, err := rows.Columns() 585 require.Error(t, err) 586 }) 587 ms.T().Run("close多次是等效的", func(t *testing.T) { 588 for i := 0; i < 4; i++ { 589 err = rows.Close() 590 require.NoError(t, err) 591 } 592 }) 593 } 594 595 func (ms *MergerSuite) TestRows_Columns() { 596 cols := []string{"id"} 597 query := "SELECT * FROM `t1`" 598 ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) 599 ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) 600 ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) 601 merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC)) 602 require.NoError(ms.T(), err) 603 limitMerger, err := NewMerger(merger, 0, 10) 604 require.NoError(ms.T(), err) 605 dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} 606 rowsList := make([]rows.Rows, 0, len(dbs)) 607 for _, db := range dbs { 608 row, err := db.QueryContext(context.Background(), query) 609 require.NoError(ms.T(), err) 610 rowsList = append(rowsList, row) 611 } 612 rows, err := limitMerger.Merge(context.Background(), rowsList) 613 require.NoError(ms.T(), err) 614 columns, err := rows.Columns() 615 require.NoError(ms.T(), err) 616 assert.Equal(ms.T(), cols, columns) 617 } 618 619 func TestMerger(t *testing.T) { 620 suite.Run(t, &MergerSuite{}) 621 } 622 623 type TestModel struct { 624 Id int 625 Name string 626 Address string 627 } 628 629 func TestRows_NextResultSet(t *testing.T) { 630 assert.False(t, (&Rows{}).NextResultSet()) 631 }