github.com/go-playground/pkg/v5@v5.29.1/values/option/option_sql.go (about)

     1  //go:build go1.18 && !go1.22
     2  // +build go1.18,!go1.22
     3  
     4  package optionext
     5  
     6  import (
     7  	"database/sql"
     8  	"database/sql/driver"
     9  	"encoding/json"
    10  	"fmt"
    11  	"math"
    12  	"reflect"
    13  	"time"
    14  )
    15  
    16  var (
    17  	scanType      = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
    18  	byteSliceType = reflect.TypeOf(([]byte)(nil))
    19  	valuerType    = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
    20  	timeType      = reflect.TypeOf((*time.Time)(nil)).Elem()
    21  	stringType    = reflect.TypeOf((*string)(nil)).Elem()
    22  	int64Type     = reflect.TypeOf((*int64)(nil)).Elem()
    23  	float64Type   = reflect.TypeOf((*float64)(nil)).Elem()
    24  	boolType      = reflect.TypeOf((*bool)(nil)).Elem()
    25  )
    26  
    27  // Value implements the driver.Valuer interface.
    28  //
    29  // This honours the `driver.Valuer` interface if the value implements it.
    30  // It also supports custom types of the std types and treats all else as []byte
    31  func (o Option[T]) Value() (driver.Value, error) {
    32  	if o.IsNone() {
    33  		return nil, nil
    34  	}
    35  	val := reflect.ValueOf(o.value)
    36  
    37  	if val.Type().Implements(valuerType) {
    38  		return val.Interface().(driver.Valuer).Value()
    39  	}
    40  	switch val.Kind() {
    41  	case reflect.String:
    42  		return val.Convert(stringType).Interface(), nil
    43  	case reflect.Bool:
    44  		return val.Convert(boolType).Interface(), nil
    45  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    46  		return val.Convert(int64Type).Interface(), nil
    47  	case reflect.Float64:
    48  		return val.Convert(float64Type).Interface(), nil
    49  	case reflect.Slice, reflect.Array:
    50  		if val.Type().ConvertibleTo(byteSliceType) {
    51  			return val.Convert(byteSliceType).Interface(), nil
    52  		}
    53  		return json.Marshal(val.Interface())
    54  	case reflect.Struct:
    55  		if val.CanConvert(timeType) {
    56  			return val.Convert(timeType).Interface(), nil
    57  		}
    58  		return json.Marshal(val.Interface())
    59  	case reflect.Map:
    60  		return json.Marshal(val.Interface())
    61  	default:
    62  		return o.value, nil
    63  	}
    64  }
    65  
    66  // Scan implements the sql.Scanner interface.
    67  func (o *Option[T]) Scan(value any) error {
    68  
    69  	if value == nil {
    70  		*o = None[T]()
    71  		return nil
    72  	}
    73  
    74  	val := reflect.ValueOf(&o.value)
    75  
    76  	if val.Type().Implements(scanType) {
    77  		err := val.Interface().(sql.Scanner).Scan(value)
    78  		if err != nil {
    79  			return err
    80  		}
    81  		o.isSome = true
    82  		return nil
    83  	}
    84  
    85  	val = val.Elem()
    86  
    87  	switch val.Kind() {
    88  	case reflect.String:
    89  		var v sql.NullString
    90  		if err := v.Scan(value); err != nil {
    91  			return err
    92  		}
    93  		*o = Some(reflect.ValueOf(v.String).Convert(val.Type()).Interface().(T))
    94  	case reflect.Bool:
    95  		var v sql.NullBool
    96  		if err := v.Scan(value); err != nil {
    97  			return err
    98  		}
    99  		*o = Some(reflect.ValueOf(v.Bool).Convert(val.Type()).Interface().(T))
   100  	case reflect.Uint8:
   101  		var v sql.NullByte
   102  		if err := v.Scan(value); err != nil {
   103  			return err
   104  		}
   105  		*o = Some(reflect.ValueOf(v.Byte).Convert(val.Type()).Interface().(T))
   106  	case reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
   107  		v := reflect.ValueOf(value)
   108  		if v.Type().ConvertibleTo(val.Type()) {
   109  			*o = Some(reflect.ValueOf(v.Convert(val.Type()).Interface()).Interface().(T))
   110  		} else {
   111  			return fmt.Errorf("value %T not convertable to %T", value, o.value)
   112  		}
   113  	case reflect.Float32:
   114  		var v sql.NullFloat64
   115  		if err := v.Scan(value); err != nil {
   116  			return err
   117  		}
   118  		*o = Some(reflect.ValueOf(v.Float64).Convert(val.Type()).Interface().(T))
   119  	case reflect.Float64:
   120  		var v sql.NullFloat64
   121  		if err := v.Scan(value); err != nil {
   122  			return err
   123  		}
   124  		*o = Some(reflect.ValueOf(v.Float64).Convert(val.Type()).Interface().(T))
   125  	case reflect.Int:
   126  		var v sql.NullInt64
   127  		if err := v.Scan(value); err != nil {
   128  			return err
   129  		}
   130  		if v.Int64 > math.MaxInt || v.Int64 < math.MinInt {
   131  			return fmt.Errorf("value %d out of range for int", v.Int64)
   132  		}
   133  		*o = Some(reflect.ValueOf(v.Int64).Convert(val.Type()).Interface().(T))
   134  	case reflect.Int8:
   135  		var v sql.NullInt64
   136  		if err := v.Scan(value); err != nil {
   137  			return err
   138  		}
   139  		if v.Int64 > math.MaxInt8 || v.Int64 < math.MinInt8 {
   140  			return fmt.Errorf("value %d out of range for int8", v.Int64)
   141  		}
   142  		*o = Some(reflect.ValueOf(v.Int64).Convert(val.Type()).Interface().(T))
   143  	case reflect.Int16:
   144  		var v sql.NullInt16
   145  		if err := v.Scan(value); err != nil {
   146  			return err
   147  		}
   148  		*o = Some(reflect.ValueOf(v.Int16).Convert(val.Type()).Interface().(T))
   149  	case reflect.Int32:
   150  		var v sql.NullInt32
   151  		if err := v.Scan(value); err != nil {
   152  			return err
   153  		}
   154  		*o = Some(reflect.ValueOf(v.Int32).Convert(val.Type()).Interface().(T))
   155  	case reflect.Int64:
   156  		var v sql.NullInt64
   157  		if err := v.Scan(value); err != nil {
   158  			return err
   159  		}
   160  		*o = Some(reflect.ValueOf(v.Int64).Convert(val.Type()).Interface().(T))
   161  	case reflect.Interface:
   162  		*o = Some(reflect.ValueOf(value).Convert(val.Type()).Interface().(T))
   163  	case reflect.Struct:
   164  		if val.CanConvert(timeType) {
   165  			switch t := value.(type) {
   166  			case string:
   167  				tm, err := time.Parse(time.RFC3339Nano, t)
   168  				if err != nil {
   169  					return err
   170  				}
   171  				*o = Some(reflect.ValueOf(tm).Convert(val.Type()).Interface().(T))
   172  
   173  			case []byte:
   174  				tm, err := time.Parse(time.RFC3339Nano, string(t))
   175  				if err != nil {
   176  					return err
   177  				}
   178  				*o = Some(reflect.ValueOf(tm).Convert(val.Type()).Interface().(T))
   179  
   180  			default:
   181  				var v sql.NullTime
   182  				if err := v.Scan(value); err != nil {
   183  					return err
   184  				}
   185  				*o = Some(reflect.ValueOf(v.Time).Convert(val.Type()).Interface().(T))
   186  			}
   187  			return nil
   188  		}
   189  		fallthrough
   190  
   191  	default:
   192  		switch val.Kind() {
   193  		case reflect.Struct, reflect.Slice, reflect.Map:
   194  			v := reflect.ValueOf(value)
   195  
   196  			if v.Type().ConvertibleTo(byteSliceType) {
   197  				if val.Kind() == reflect.Slice && val.Type().Elem().Kind() == reflect.Uint8 {
   198  					*o = Some(reflect.ValueOf(v.Convert(val.Type()).Interface()).Interface().(T))
   199  				} else {
   200  					if err := json.Unmarshal(v.Convert(byteSliceType).Interface().([]byte), &o.value); err != nil {
   201  						return err
   202  					}
   203  				}
   204  				o.isSome = true
   205  				return nil
   206  			}
   207  		}
   208  		return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", value, o.value)
   209  	}
   210  	return nil
   211  }