github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/sqlx/convert.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"errors"
     7  	"fmt"
     8  	"reflect"
     9  	"strconv"
    10  	"time"
    11  )
    12  
    13  // copy from sql/convert.go
    14  
    15  var errNilPtr = errors.New("destination pointer is nil")
    16  
    17  func convertAssign(dest, src interface{}) error {
    18  	// Common cases, without reflect.
    19  	switch s := src.(type) {
    20  	case string:
    21  		switch d := dest.(type) {
    22  		case *string:
    23  			if d == nil {
    24  				return errNilPtr
    25  			}
    26  			*d = s
    27  			return nil
    28  		case *[]byte:
    29  			if d == nil {
    30  				return errNilPtr
    31  			}
    32  			*d = []byte(s)
    33  			return nil
    34  		}
    35  	case []byte:
    36  		switch d := dest.(type) {
    37  		case *string:
    38  			if d == nil {
    39  				return errNilPtr
    40  			}
    41  			*d = string(s)
    42  			return nil
    43  		case *interface{}:
    44  			if d == nil {
    45  				return errNilPtr
    46  			}
    47  			*d = cloneBytes(s)
    48  			return nil
    49  		case *[]byte:
    50  			if d == nil {
    51  				return errNilPtr
    52  			}
    53  			*d = cloneBytes(s)
    54  			return nil
    55  		case *sql.RawBytes:
    56  			if d == nil {
    57  				return errNilPtr
    58  			}
    59  			*d = s
    60  			return nil
    61  		}
    62  	case time.Time:
    63  		switch d := dest.(type) {
    64  		case *string:
    65  			*d = s.Format(time.RFC3339Nano)
    66  			return nil
    67  		case *[]byte:
    68  			if d == nil {
    69  				return errNilPtr
    70  			}
    71  			*d = []byte(s.Format(time.RFC3339Nano))
    72  			return nil
    73  		}
    74  	case nil:
    75  		switch d := dest.(type) {
    76  		case *interface{}:
    77  			if d == nil {
    78  				return errNilPtr
    79  			}
    80  			*d = nil
    81  			return nil
    82  		case *[]byte:
    83  			if d == nil {
    84  				return errNilPtr
    85  			}
    86  			*d = nil
    87  			return nil
    88  		case *sql.RawBytes:
    89  			if d == nil {
    90  				return errNilPtr
    91  			}
    92  			*d = nil
    93  			return nil
    94  		}
    95  	}
    96  
    97  	var sv reflect.Value
    98  
    99  	switch d := dest.(type) {
   100  	case *string:
   101  		sv = reflect.ValueOf(src)
   102  		switch sv.Kind() {
   103  		case reflect.Bool,
   104  			reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   105  			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   106  			reflect.Float32, reflect.Float64:
   107  			*d = asString(src)
   108  			return nil
   109  		}
   110  	case *[]byte:
   111  		sv = reflect.ValueOf(src)
   112  		if b, ok := asBytes(nil, sv); ok {
   113  			*d = b
   114  			return nil
   115  		}
   116  	case *sql.RawBytes:
   117  		sv = reflect.ValueOf(src)
   118  		if b, ok := asBytes([]byte(*d)[:0], sv); ok {
   119  			*d = sql.RawBytes(b)
   120  			return nil
   121  		}
   122  	case *bool:
   123  		bv, err := driver.Bool.ConvertValue(src)
   124  		if err == nil {
   125  			*d = bv.(bool)
   126  		}
   127  		return err
   128  	case *interface{}:
   129  		*d = src
   130  		return nil
   131  	}
   132  
   133  	if scanner, ok := dest.(sql.Scanner); ok {
   134  		return scanner.Scan(src)
   135  	}
   136  
   137  	dpv := reflect.ValueOf(dest)
   138  	if dpv.Kind() != reflect.Ptr {
   139  		return errors.New("destination not a pointer")
   140  	}
   141  	if dpv.IsNil() {
   142  		return errNilPtr
   143  	}
   144  
   145  	if !sv.IsValid() {
   146  		sv = reflect.ValueOf(src)
   147  	}
   148  
   149  	dv := reflect.Indirect(dpv)
   150  	if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
   151  		switch b := src.(type) {
   152  		case []byte:
   153  			dv.Set(reflect.ValueOf(cloneBytes(b)))
   154  		default:
   155  			dv.Set(sv)
   156  		}
   157  		return nil
   158  	}
   159  
   160  	if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
   161  		dv.Set(sv.Convert(dv.Type()))
   162  		return nil
   163  	}
   164  
   165  	// The following conversions use a string value as an intermediate representation
   166  	// to convert between various numeric types.
   167  	//
   168  	// This also allows scanning into user defined types such as "type Int int64".
   169  	// For symmetry, also check for string destination types.
   170  	switch dv.Kind() {
   171  	case reflect.Ptr:
   172  		if src == nil {
   173  			dv.Set(reflect.Zero(dv.Type()))
   174  			return nil
   175  		} else {
   176  			dv.Set(reflect.New(dv.Type().Elem()))
   177  			return convertAssign(dv.Interface(), src)
   178  		}
   179  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   180  		s := asString(src)
   181  		i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
   182  		if err != nil {
   183  			err = strconvErr(err)
   184  			return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
   185  		}
   186  		dv.SetInt(i64)
   187  		return nil
   188  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   189  		s := asString(src)
   190  		u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
   191  		if err != nil {
   192  			err = strconvErr(err)
   193  			return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
   194  		}
   195  		dv.SetUint(u64)
   196  		return nil
   197  	case reflect.Float32, reflect.Float64:
   198  		s := asString(src)
   199  		f64, err := strconv.ParseFloat(s, dv.Type().Bits())
   200  		if err != nil {
   201  			err = strconvErr(err)
   202  			return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
   203  		}
   204  		dv.SetFloat(f64)
   205  		return nil
   206  	case reflect.String:
   207  		switch v := src.(type) {
   208  		case string:
   209  			dv.SetString(v)
   210  			return nil
   211  		case []byte:
   212  			dv.SetString(string(v))
   213  			return nil
   214  		}
   215  	}
   216  
   217  	return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
   218  }
   219  
   220  func strconvErr(err error) error {
   221  	if ne, ok := err.(*strconv.NumError); ok {
   222  		return ne.Err
   223  	}
   224  	return err
   225  }
   226  
   227  func cloneBytes(b []byte) []byte {
   228  	if b == nil {
   229  		return nil
   230  	} else {
   231  		c := make([]byte, len(b))
   232  		copy(c, b)
   233  		return c
   234  	}
   235  }
   236  
   237  func asString(src interface{}) string {
   238  	switch v := src.(type) {
   239  	case string:
   240  		return v
   241  	case []byte:
   242  		return string(v)
   243  	}
   244  	rv := reflect.ValueOf(src)
   245  	switch rv.Kind() {
   246  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   247  		return strconv.FormatInt(rv.Int(), 10)
   248  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   249  		return strconv.FormatUint(rv.Uint(), 10)
   250  	case reflect.Float64:
   251  		return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
   252  	case reflect.Float32:
   253  		return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
   254  	case reflect.Bool:
   255  		return strconv.FormatBool(rv.Bool())
   256  	}
   257  	return fmt.Sprintf("%v", src)
   258  }
   259  
   260  func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
   261  	switch rv.Kind() {
   262  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   263  		return strconv.AppendInt(buf, rv.Int(), 10), true
   264  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   265  		return strconv.AppendUint(buf, rv.Uint(), 10), true
   266  	case reflect.Float32:
   267  		return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
   268  	case reflect.Float64:
   269  		return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
   270  	case reflect.Bool:
   271  		return strconv.AppendBool(buf, rv.Bool()), true
   272  	case reflect.String:
   273  		s := rv.String()
   274  		return append(buf, s...), true
   275  	}
   276  	return
   277  }