github.com/doug-martin/goqu/v9@v9.19.0/internal/util/reflect.go (about)

     1  package util
     2  
     3  import (
     4  	"database/sql"
     5  	"reflect"
     6  	"strings"
     7  	"sync"
     8  
     9  	"github.com/doug-martin/goqu/v9/internal/errors"
    10  )
    11  
    12  const (
    13  	skipUpdateTagName     = "skipupdate"
    14  	skipInsertTagName     = "skipinsert"
    15  	defaultIfEmptyTagName = "defaultifempty"
    16  	omitNilTagName        = "omitnil"
    17  	omitEmptyTagName      = "omitempty"
    18  )
    19  
    20  var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
    21  
    22  func IsUint(k reflect.Kind) bool {
    23  	return (k == reflect.Uint) ||
    24  		(k == reflect.Uint8) ||
    25  		(k == reflect.Uint16) ||
    26  		(k == reflect.Uint32) ||
    27  		(k == reflect.Uint64)
    28  }
    29  
    30  func IsInt(k reflect.Kind) bool {
    31  	return (k == reflect.Int) ||
    32  		(k == reflect.Int8) ||
    33  		(k == reflect.Int16) ||
    34  		(k == reflect.Int32) ||
    35  		(k == reflect.Int64)
    36  }
    37  
    38  func IsFloat(k reflect.Kind) bool {
    39  	return (k == reflect.Float32) ||
    40  		(k == reflect.Float64)
    41  }
    42  
    43  func IsString(k reflect.Kind) bool {
    44  	return k == reflect.String
    45  }
    46  
    47  func IsBool(k reflect.Kind) bool {
    48  	return k == reflect.Bool
    49  }
    50  
    51  func IsSlice(k reflect.Kind) bool {
    52  	return k == reflect.Slice
    53  }
    54  
    55  func IsStruct(k reflect.Kind) bool {
    56  	return k == reflect.Struct
    57  }
    58  
    59  func IsInvalid(k reflect.Kind) bool {
    60  	return k == reflect.Invalid
    61  }
    62  
    63  func IsPointer(k reflect.Kind) bool {
    64  	return k == reflect.Ptr
    65  }
    66  
    67  func IsNil(v reflect.Value) bool {
    68  	if !v.IsValid() {
    69  		return true
    70  	}
    71  	switch v.Kind() {
    72  	case reflect.Ptr, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func:
    73  		return v.IsNil()
    74  	default:
    75  		return false
    76  	}
    77  }
    78  
    79  func IsEmptyValue(v reflect.Value) bool {
    80  	return !v.IsValid() || v.IsZero()
    81  }
    82  
    83  var (
    84  	structMapCache     = make(map[interface{}]ColumnMap)
    85  	structMapCacheLock = sync.Mutex{}
    86  )
    87  
    88  var (
    89  	DefaultColumnRenameFunction = strings.ToLower
    90  	columnRenameFunction        = DefaultColumnRenameFunction
    91  	ignoreUntaggedFields        = false
    92  )
    93  
    94  func SetIgnoreUntaggedFields(ignore bool) {
    95  	// If the value here is changing, reset the struct map cache
    96  	if ignore != ignoreUntaggedFields {
    97  		ignoreUntaggedFields = ignore
    98  
    99  		structMapCacheLock.Lock()
   100  		defer structMapCacheLock.Unlock()
   101  
   102  		structMapCache = make(map[interface{}]ColumnMap)
   103  	}
   104  }
   105  
   106  func SetColumnRenameFunction(newFunction func(string) string) {
   107  	columnRenameFunction = newFunction
   108  }
   109  
   110  // GetSliceElementType returns the type for a slices elements.
   111  func GetSliceElementType(val reflect.Value) reflect.Type {
   112  	elemType := val.Type().Elem()
   113  	if elemType.Kind() == reflect.Ptr {
   114  		elemType = elemType.Elem()
   115  	}
   116  
   117  	return elemType
   118  }
   119  
   120  // AppendSliceElement will append val to slice. Handles slice of pointers and
   121  // not pointers. Val needs to be a pointer.
   122  func AppendSliceElement(slice, val reflect.Value) {
   123  	if slice.Type().Elem().Kind() == reflect.Ptr {
   124  		slice.Set(reflect.Append(slice, val))
   125  	} else {
   126  		slice.Set(reflect.Append(slice, reflect.Indirect(val)))
   127  	}
   128  }
   129  
   130  func GetTypeInfo(i interface{}, val reflect.Value) (reflect.Type, reflect.Kind) {
   131  	var t reflect.Type
   132  	valKind := val.Kind()
   133  	if valKind == reflect.Slice {
   134  		if reflect.ValueOf(i).Kind() == reflect.Ptr {
   135  			t = reflect.TypeOf(i).Elem().Elem()
   136  		} else {
   137  			t = reflect.TypeOf(i).Elem()
   138  		}
   139  		if t.Kind() == reflect.Ptr {
   140  			t = t.Elem()
   141  		}
   142  		valKind = t.Kind()
   143  	} else {
   144  		t = val.Type()
   145  	}
   146  	return t, valKind
   147  }
   148  
   149  func SafeGetFieldByIndex(v reflect.Value, fieldIndex []int) (result reflect.Value, isAvailable bool) {
   150  	switch len(fieldIndex) {
   151  	case 0:
   152  		return v, true
   153  	case 1:
   154  		return v.FieldByIndex(fieldIndex), true
   155  	default:
   156  		if f := reflect.Indirect(v.Field(fieldIndex[0])); f.IsValid() {
   157  			return SafeGetFieldByIndex(f, fieldIndex[1:])
   158  		}
   159  	}
   160  	return reflect.ValueOf(nil), false
   161  }
   162  
   163  func SafeSetFieldByIndex(v reflect.Value, fieldIndex []int, src interface{}) (result reflect.Value) {
   164  	v = reflect.Indirect(v)
   165  	switch len(fieldIndex) {
   166  	case 0:
   167  		return v
   168  	case 1:
   169  		f := v.FieldByIndex(fieldIndex)
   170  		srcVal := reflect.ValueOf(src)
   171  		f.Set(reflect.Indirect(srcVal))
   172  	default:
   173  		f := v.Field(fieldIndex[0])
   174  		switch f.Kind() {
   175  		case reflect.Ptr:
   176  			s := f
   177  			if f.IsNil() || !f.IsValid() {
   178  				s = reflect.New(f.Type().Elem())
   179  				f.Set(s)
   180  			}
   181  			SafeSetFieldByIndex(reflect.Indirect(s), fieldIndex[1:], src)
   182  		case reflect.Struct:
   183  			SafeSetFieldByIndex(f, fieldIndex[1:], src)
   184  		default: // use the original value
   185  		}
   186  	}
   187  	return v
   188  }
   189  
   190  type rowData = map[string]interface{}
   191  
   192  // AssignStructVals will assign the data from rd to i.
   193  func AssignStructVals(i interface{}, rd rowData, cm ColumnMap) {
   194  	val := reflect.Indirect(reflect.ValueOf(i))
   195  
   196  	for name, data := range cm {
   197  		src, ok := rd[name]
   198  		if ok {
   199  			SafeSetFieldByIndex(val, data.FieldIndex, src)
   200  		}
   201  	}
   202  }
   203  
   204  func GetColumnMap(i interface{}) (ColumnMap, error) {
   205  	val := reflect.Indirect(reflect.ValueOf(i))
   206  	t, valKind := GetTypeInfo(i, val)
   207  	if valKind != reflect.Struct {
   208  		return nil, errors.New("cannot scan into this type: %v", t) // #nosec
   209  	}
   210  
   211  	structMapCacheLock.Lock()
   212  	defer structMapCacheLock.Unlock()
   213  	if _, ok := structMapCache[t]; !ok {
   214  		structMapCache[t] = newColumnMap(t, []int{}, []string{})
   215  	}
   216  	return structMapCache[t], nil
   217  }