github.com/jackc/pgx/v5@v5.5.5/rows.go (about) 1 package pgx 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "reflect" 8 "strings" 9 "time" 10 11 "github.com/jackc/pgx/v5/pgconn" 12 "github.com/jackc/pgx/v5/pgtype" 13 ) 14 15 // Rows is the result set returned from *Conn.Query. Rows must be closed before 16 // the *Conn can be used again. Rows are closed by explicitly calling Close(), 17 // calling Next() until it returns false, or when a fatal error occurs. 18 // 19 // Once a Rows is closed the only methods that may be called are Close(), Err(), 20 // and CommandTag(). 21 // 22 // Rows is an interface instead of a struct to allow tests to mock Query. However, 23 // adding a method to an interface is technically a breaking change. Because of this 24 // the Rows interface is partially excluded from semantic version requirements. 25 // Methods will not be removed or changed, but new methods may be added. 26 type Rows interface { 27 // Close closes the rows, making the connection ready for use again. It is safe 28 // to call Close after rows is already closed. 29 Close() 30 31 // Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by 32 // calling Close or by Next returning false). If it is called early it may return nil even if there was an error 33 // executing the query. 34 Err() error 35 36 // CommandTag returns the command tag from this query. It is only available after Rows is closed. 37 CommandTag() pgconn.CommandTag 38 39 // FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur 40 // when there was an error executing the query. 41 FieldDescriptions() []pgconn.FieldDescription 42 43 // Next prepares the next row for reading. It returns true if there is another 44 // row and false if no more rows are available or a fatal error has occurred. 45 // It automatically closes rows when all rows are read. 46 // 47 // Callers should check rows.Err() after rows.Next() returns false to detect 48 // whether result-set reading ended prematurely due to an error. See 49 // Conn.Query for details. 50 // 51 // For simpler error handling, consider using the higher-level pgx v5 52 // CollectRows() and ForEachRow() helpers instead. 53 Next() bool 54 55 // Scan reads the values from the current row into dest values positionally. 56 // dest can include pointers to core types, values implementing the Scanner 57 // interface, and nil. nil will skip the value entirely. It is an error to 58 // call Scan without first calling Next() and checking that it returned true. 59 Scan(dest ...any) error 60 61 // Values returns the decoded row values. As with Scan(), it is an error to 62 // call Values without first calling Next() and checking that it returned 63 // true. 64 Values() ([]any, error) 65 66 // RawValues returns the unparsed bytes of the row values. The returned data is only valid until the next Next 67 // call or the Rows is closed. 68 RawValues() [][]byte 69 70 // Conn returns the underlying *Conn on which the query was executed. This may return nil if Rows did not come from a 71 // *Conn (e.g. if it was created by RowsFromResultReader) 72 Conn() *Conn 73 } 74 75 // Row is a convenience wrapper over Rows that is returned by QueryRow. 76 // 77 // Row is an interface instead of a struct to allow tests to mock QueryRow. However, 78 // adding a method to an interface is technically a breaking change. Because of this 79 // the Row interface is partially excluded from semantic version requirements. 80 // Methods will not be removed or changed, but new methods may be added. 81 type Row interface { 82 // Scan works the same as Rows. with the following exceptions. If no 83 // rows were found it returns ErrNoRows. If multiple rows are returned it 84 // ignores all but the first. 85 Scan(dest ...any) error 86 } 87 88 // RowScanner scans an entire row at a time into the RowScanner. 89 type RowScanner interface { 90 // ScanRows scans the row. 91 ScanRow(rows Rows) error 92 } 93 94 // connRow implements the Row interface for Conn.QueryRow. 95 type connRow baseRows 96 97 func (r *connRow) Scan(dest ...any) (err error) { 98 rows := (*baseRows)(r) 99 100 if rows.Err() != nil { 101 return rows.Err() 102 } 103 104 for _, d := range dest { 105 if _, ok := d.(*pgtype.DriverBytes); ok { 106 rows.Close() 107 return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow") 108 } 109 } 110 111 if !rows.Next() { 112 if rows.Err() == nil { 113 return ErrNoRows 114 } 115 return rows.Err() 116 } 117 118 rows.Scan(dest...) 119 rows.Close() 120 return rows.Err() 121 } 122 123 // baseRows implements the Rows interface for Conn.Query. 124 type baseRows struct { 125 typeMap *pgtype.Map 126 resultReader *pgconn.ResultReader 127 128 values [][]byte 129 130 commandTag pgconn.CommandTag 131 err error 132 closed bool 133 134 scanPlans []pgtype.ScanPlan 135 scanTypes []reflect.Type 136 137 conn *Conn 138 multiResultReader *pgconn.MultiResultReader 139 140 queryTracer QueryTracer 141 batchTracer BatchTracer 142 ctx context.Context 143 startTime time.Time 144 sql string 145 args []any 146 rowCount int 147 } 148 149 func (rows *baseRows) FieldDescriptions() []pgconn.FieldDescription { 150 return rows.resultReader.FieldDescriptions() 151 } 152 153 func (rows *baseRows) Close() { 154 if rows.closed { 155 return 156 } 157 158 rows.closed = true 159 160 if rows.resultReader != nil { 161 var closeErr error 162 rows.commandTag, closeErr = rows.resultReader.Close() 163 if rows.err == nil { 164 rows.err = closeErr 165 } 166 } 167 168 if rows.multiResultReader != nil { 169 closeErr := rows.multiResultReader.Close() 170 if rows.err == nil { 171 rows.err = closeErr 172 } 173 } 174 175 if rows.err != nil && rows.conn != nil && rows.sql != "" { 176 if sc := rows.conn.statementCache; sc != nil { 177 sc.Invalidate(rows.sql) 178 } 179 180 if sc := rows.conn.descriptionCache; sc != nil { 181 sc.Invalidate(rows.sql) 182 } 183 } 184 185 if rows.batchTracer != nil { 186 rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err}) 187 } else if rows.queryTracer != nil { 188 rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) 189 } 190 } 191 192 func (rows *baseRows) CommandTag() pgconn.CommandTag { 193 return rows.commandTag 194 } 195 196 func (rows *baseRows) Err() error { 197 return rows.err 198 } 199 200 // fatal signals an error occurred after the query was sent to the server. It 201 // closes the rows automatically. 202 func (rows *baseRows) fatal(err error) { 203 if rows.err != nil { 204 return 205 } 206 207 rows.err = err 208 rows.Close() 209 } 210 211 func (rows *baseRows) Next() bool { 212 if rows.closed { 213 return false 214 } 215 216 if rows.resultReader.NextRow() { 217 rows.rowCount++ 218 rows.values = rows.resultReader.Values() 219 return true 220 } else { 221 rows.Close() 222 return false 223 } 224 } 225 226 func (rows *baseRows) Scan(dest ...any) error { 227 m := rows.typeMap 228 fieldDescriptions := rows.FieldDescriptions() 229 values := rows.values 230 231 if len(fieldDescriptions) != len(values) { 232 err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) 233 rows.fatal(err) 234 return err 235 } 236 237 if len(dest) == 1 { 238 if rc, ok := dest[0].(RowScanner); ok { 239 err := rc.ScanRow(rows) 240 if err != nil { 241 rows.fatal(err) 242 } 243 return err 244 } 245 } 246 247 if len(fieldDescriptions) != len(dest) { 248 err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) 249 rows.fatal(err) 250 return err 251 } 252 253 if rows.scanPlans == nil { 254 rows.scanPlans = make([]pgtype.ScanPlan, len(values)) 255 rows.scanTypes = make([]reflect.Type, len(values)) 256 for i := range dest { 257 rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) 258 rows.scanTypes[i] = reflect.TypeOf(dest[i]) 259 } 260 } 261 262 for i, dst := range dest { 263 if dst == nil { 264 continue 265 } 266 267 if rows.scanTypes[i] != reflect.TypeOf(dst) { 268 rows.scanPlans[i] = m.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) 269 rows.scanTypes[i] = reflect.TypeOf(dest[i]) 270 } 271 272 err := rows.scanPlans[i].Scan(values[i], dst) 273 if err != nil { 274 err = ScanArgError{ColumnIndex: i, Err: err} 275 rows.fatal(err) 276 return err 277 } 278 } 279 280 return nil 281 } 282 283 func (rows *baseRows) Values() ([]any, error) { 284 if rows.closed { 285 return nil, errors.New("rows is closed") 286 } 287 288 values := make([]any, 0, len(rows.FieldDescriptions())) 289 290 for i := range rows.FieldDescriptions() { 291 buf := rows.values[i] 292 fd := &rows.FieldDescriptions()[i] 293 294 if buf == nil { 295 values = append(values, nil) 296 continue 297 } 298 299 if dt, ok := rows.typeMap.TypeForOID(fd.DataTypeOID); ok { 300 value, err := dt.Codec.DecodeValue(rows.typeMap, fd.DataTypeOID, fd.Format, buf) 301 if err != nil { 302 rows.fatal(err) 303 } 304 values = append(values, value) 305 } else { 306 switch fd.Format { 307 case TextFormatCode: 308 values = append(values, string(buf)) 309 case BinaryFormatCode: 310 newBuf := make([]byte, len(buf)) 311 copy(newBuf, buf) 312 values = append(values, newBuf) 313 default: 314 rows.fatal(errors.New("unknown format code")) 315 } 316 } 317 318 if rows.Err() != nil { 319 return nil, rows.Err() 320 } 321 } 322 323 return values, rows.Err() 324 } 325 326 func (rows *baseRows) RawValues() [][]byte { 327 return rows.values 328 } 329 330 func (rows *baseRows) Conn() *Conn { 331 return rows.conn 332 } 333 334 type ScanArgError struct { 335 ColumnIndex int 336 Err error 337 } 338 339 func (e ScanArgError) Error() string { 340 return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) 341 } 342 343 func (e ScanArgError) Unwrap() error { 344 return e.Err 345 } 346 347 // ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. 348 // 349 // typeMap - OID to Go type mapping. 350 // fieldDescriptions - OID and format of values 351 // values - the raw data as returned from the PostgreSQL server 352 // dest - the destination that values will be decoded into 353 func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, values [][]byte, dest ...any) error { 354 if len(fieldDescriptions) != len(values) { 355 return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) 356 } 357 if len(fieldDescriptions) != len(dest) { 358 return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) 359 } 360 361 for i, d := range dest { 362 if d == nil { 363 continue 364 } 365 366 err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) 367 if err != nil { 368 return ScanArgError{ColumnIndex: i, Err: err} 369 } 370 } 371 372 return nil 373 } 374 375 // RowsFromResultReader returns a Rows that will read from values resultReader and decode with typeMap. It can be used 376 // to read from the lower level pgconn interface. 377 func RowsFromResultReader(typeMap *pgtype.Map, resultReader *pgconn.ResultReader) Rows { 378 return &baseRows{ 379 typeMap: typeMap, 380 resultReader: resultReader, 381 } 382 } 383 384 // ForEachRow iterates through rows. For each row it scans into the elements of scans and calls fn. If any row 385 // fails to scan or fn returns an error the query will be aborted and the error will be returned. Rows will be closed 386 // when ForEachRow returns. 387 func ForEachRow(rows Rows, scans []any, fn func() error) (pgconn.CommandTag, error) { 388 defer rows.Close() 389 390 for rows.Next() { 391 err := rows.Scan(scans...) 392 if err != nil { 393 return pgconn.CommandTag{}, err 394 } 395 396 err = fn() 397 if err != nil { 398 return pgconn.CommandTag{}, err 399 } 400 } 401 402 if err := rows.Err(); err != nil { 403 return pgconn.CommandTag{}, err 404 } 405 406 return rows.CommandTag(), nil 407 } 408 409 // CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call. 410 type CollectableRow interface { 411 FieldDescriptions() []pgconn.FieldDescription 412 Scan(dest ...any) error 413 Values() ([]any, error) 414 RawValues() [][]byte 415 } 416 417 // RowToFunc is a function that scans or otherwise converts row to a T. 418 type RowToFunc[T any] func(row CollectableRow) (T, error) 419 420 // AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T. 421 func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) { 422 defer rows.Close() 423 424 for rows.Next() { 425 value, err := fn(rows) 426 if err != nil { 427 return nil, err 428 } 429 slice = append(slice, value) 430 } 431 432 if err := rows.Err(); err != nil { 433 return nil, err 434 } 435 436 return slice, nil 437 } 438 439 // CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. 440 func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { 441 return AppendRows([]T{}, rows, fn) 442 } 443 444 // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. 445 // CollectOneRow is to CollectRows as QueryRow is to Query. 446 func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { 447 defer rows.Close() 448 449 var value T 450 var err error 451 452 if !rows.Next() { 453 if err = rows.Err(); err != nil { 454 return value, err 455 } 456 return value, ErrNoRows 457 } 458 459 value, err = fn(rows) 460 if err != nil { 461 return value, err 462 } 463 464 rows.Close() 465 return value, rows.Err() 466 } 467 468 // CollectExactlyOneRow calls fn for the first row in rows and returns the result. 469 // - If no rows are found returns an error where errors.Is(ErrNoRows) is true. 470 // - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true. 471 func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { 472 defer rows.Close() 473 474 var ( 475 err error 476 value T 477 ) 478 479 if !rows.Next() { 480 if err = rows.Err(); err != nil { 481 return value, err 482 } 483 484 return value, ErrNoRows 485 } 486 487 value, err = fn(rows) 488 if err != nil { 489 return value, err 490 } 491 492 if rows.Next() { 493 var zero T 494 495 return zero, ErrTooManyRows 496 } 497 498 return value, rows.Err() 499 } 500 501 // RowTo returns a T scanned from row. 502 func RowTo[T any](row CollectableRow) (T, error) { 503 var value T 504 err := row.Scan(&value) 505 return value, err 506 } 507 508 // RowTo returns a the address of a T scanned from row. 509 func RowToAddrOf[T any](row CollectableRow) (*T, error) { 510 var value T 511 err := row.Scan(&value) 512 return &value, err 513 } 514 515 // RowToMap returns a map scanned from row. 516 func RowToMap(row CollectableRow) (map[string]any, error) { 517 var value map[string]any 518 err := row.Scan((*mapRowScanner)(&value)) 519 return value, err 520 } 521 522 type mapRowScanner map[string]any 523 524 func (rs *mapRowScanner) ScanRow(rows Rows) error { 525 values, err := rows.Values() 526 if err != nil { 527 return err 528 } 529 530 *rs = make(mapRowScanner, len(values)) 531 532 for i := range values { 533 (*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i] 534 } 535 536 return nil 537 } 538 539 // RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row 540 // has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be 541 // ignored. 542 func RowToStructByPos[T any](row CollectableRow) (T, error) { 543 var value T 544 err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) 545 return value, err 546 } 547 548 // RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a 549 // public fields as row has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then 550 // the field will be ignored. 551 func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { 552 var value T 553 err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) 554 return &value, err 555 } 556 557 type positionalStructRowScanner struct { 558 ptrToStruct any 559 } 560 561 func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { 562 dst := rs.ptrToStruct 563 dstValue := reflect.ValueOf(dst) 564 if dstValue.Kind() != reflect.Ptr { 565 return fmt.Errorf("dst not a pointer") 566 } 567 568 dstElemValue := dstValue.Elem() 569 scanTargets := rs.appendScanTargets(dstElemValue, nil) 570 571 if len(rows.RawValues()) > len(scanTargets) { 572 return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets)) 573 } 574 575 return rows.Scan(scanTargets...) 576 } 577 578 func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { 579 dstElemType := dstElemValue.Type() 580 581 if scanTargets == nil { 582 scanTargets = make([]any, 0, dstElemType.NumField()) 583 } 584 585 for i := 0; i < dstElemType.NumField(); i++ { 586 sf := dstElemType.Field(i) 587 // Handle anonymous struct embedding, but do not try to handle embedded pointers. 588 if sf.Anonymous && sf.Type.Kind() == reflect.Struct { 589 scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) 590 } else if sf.PkgPath == "" { 591 dbTag, _ := sf.Tag.Lookup(structTagKey) 592 if dbTag == "-" { 593 // Field is ignored, skip it. 594 continue 595 } 596 scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) 597 } 598 } 599 600 return scanTargets 601 } 602 603 // RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public 604 // fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database 605 // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. 606 func RowToStructByName[T any](row CollectableRow) (T, error) { 607 var value T 608 err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) 609 return value, err 610 } 611 612 // RowToAddrOfStructByName returns the address of a T scanned from row. T must be a struct. T must have the same number 613 // of named public fields as row has fields. The row and T fields will be matched by name. The match is 614 // case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" 615 // then the field will be ignored. 616 func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { 617 var value T 618 err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) 619 return &value, err 620 } 621 622 // RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public 623 // fields as row has fields. The row and T fields will be matched by name. The match is case-insensitive. The database 624 // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. 625 func RowToStructByNameLax[T any](row CollectableRow) (T, error) { 626 var value T 627 err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) 628 return value, err 629 } 630 631 // RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or 632 // equal number of named public fields as row has fields. The row and T fields will be matched by name. The match is 633 // case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" 634 // then the field will be ignored. 635 func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { 636 var value T 637 err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) 638 return &value, err 639 } 640 641 type namedStructRowScanner struct { 642 ptrToStruct any 643 lax bool 644 } 645 646 func (rs *namedStructRowScanner) ScanRow(rows Rows) error { 647 dst := rs.ptrToStruct 648 dstValue := reflect.ValueOf(dst) 649 if dstValue.Kind() != reflect.Ptr { 650 return fmt.Errorf("dst not a pointer") 651 } 652 653 dstElemValue := dstValue.Elem() 654 scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) 655 if err != nil { 656 return err 657 } 658 659 for i, t := range scanTargets { 660 if t == nil { 661 return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name) 662 } 663 } 664 665 return rows.Scan(scanTargets...) 666 } 667 668 const structTagKey = "db" 669 670 func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { 671 i = -1 672 for i, desc := range fldDescs { 673 674 // Snake case support. 675 field = strings.ReplaceAll(field, "_", "") 676 descName := strings.ReplaceAll(desc.Name, "_", "") 677 678 if strings.EqualFold(descName, field) { 679 return i 680 } 681 } 682 return 683 } 684 685 func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) { 686 var err error 687 dstElemType := dstElemValue.Type() 688 689 if scanTargets == nil { 690 scanTargets = make([]any, len(fldDescs)) 691 } 692 693 for i := 0; i < dstElemType.NumField(); i++ { 694 sf := dstElemType.Field(i) 695 if sf.PkgPath != "" && !sf.Anonymous { 696 // Field is unexported, skip it. 697 continue 698 } 699 // Handle anonymous struct embedding, but do not try to handle embedded pointers. 700 if sf.Anonymous && sf.Type.Kind() == reflect.Struct { 701 scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs) 702 if err != nil { 703 return nil, err 704 } 705 } else { 706 dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) 707 if dbTagPresent { 708 dbTag, _, _ = strings.Cut(dbTag, ",") 709 } 710 if dbTag == "-" { 711 // Field is ignored, skip it. 712 continue 713 } 714 colName := dbTag 715 if !dbTagPresent { 716 colName = sf.Name 717 } 718 fpos := fieldPosByName(fldDescs, colName) 719 if fpos == -1 { 720 if rs.lax { 721 continue 722 } 723 return nil, fmt.Errorf("cannot find field %s in returned row", colName) 724 } 725 if fpos >= len(scanTargets) && !rs.lax { 726 return nil, fmt.Errorf("cannot find field %s in returned row", colName) 727 } 728 scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() 729 } 730 } 731 732 return scanTargets, err 733 }