gitee.com/go-genie/sqlx@v1.0.3/scanner/struct.go (about)

     1  package scanner
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"reflect"
     8  	"strings"
     9  
    10  	"gitee.com/go-genie/sqlx/builder"
    11  	"gitee.com/go-genie/sqlx/scanner/nullable"
    12  	reflectx "gitee.com/go-genie/xx/reflect"
    13  )
    14  
    15  type RowScanner interface {
    16  	Scan(values ...interface{}) error
    17  }
    18  
    19  type WithColumnReceivers interface {
    20  	ColumnReceivers() map[string]interface{}
    21  }
    22  
    23  func scanTo(ctx context.Context, rows *sql.Rows, v interface{}) error {
    24  	tpe := reflect.TypeOf(v)
    25  
    26  	if tpe.Kind() != reflect.Ptr {
    27  		return fmt.Errorf("scanTo target must be a ptr value, but got %T", v)
    28  	}
    29  
    30  	if s, ok := v.(sql.Scanner); ok {
    31  		return rows.Scan(s)
    32  	}
    33  
    34  	tpe = reflectx.Deref(tpe)
    35  
    36  	switch tpe.Kind() {
    37  	case reflect.Struct:
    38  		columns, err := rows.Columns()
    39  		if err != nil {
    40  			return err
    41  		}
    42  
    43  		n := len(columns)
    44  		if n < 1 {
    45  			return nil
    46  		}
    47  
    48  		dest := make([]interface{}, n)
    49  		holder := placeholder()
    50  
    51  		if withColumnReceivers, ok := v.(WithColumnReceivers); ok {
    52  			columnReceivers := withColumnReceivers.ColumnReceivers()
    53  
    54  			for i, columnName := range columns {
    55  				if cr, ok := columnReceivers[strings.ToLower(columnName)]; ok {
    56  					dest[i] = nullable.NewNullIgnoreScanner(cr)
    57  				} else {
    58  					dest[i] = holder
    59  				}
    60  			}
    61  
    62  			return rows.Scan(dest...)
    63  		}
    64  
    65  		columnIndexes := map[string]int{}
    66  
    67  		for i, columnName := range columns {
    68  			columnIndexes[strings.ToLower(columnName)] = i
    69  			dest[i] = holder
    70  		}
    71  
    72  		builder.ForEachStructFieldValue(ctx, v, func(sf *builder.StructFieldValue) {
    73  			if sf.TableName != "" {
    74  				if i, ok := columnIndexes[sf.TableName+"__"+sf.Field.Name]; ok && i > -1 {
    75  					dest[i] = nullable.NewNullIgnoreScanner(sf.Value.Addr().Interface())
    76  				}
    77  			}
    78  
    79  			if i, ok := columnIndexes[sf.Field.Name]; ok && i > -1 {
    80  				dest[i] = nullable.NewNullIgnoreScanner(sf.Value.Addr().Interface())
    81  			}
    82  		})
    83  
    84  		return rows.Scan(dest...)
    85  	default:
    86  		return rows.Scan(nullable.NewNullIgnoreScanner(v))
    87  	}
    88  }
    89  
    90  func placeholder() sql.Scanner {
    91  	p := emptyScanner(0)
    92  	return &p
    93  }
    94  
    95  type emptyScanner int
    96  
    97  func (e *emptyScanner) Scan(value interface{}) error {
    98  	return nil
    99  }