go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/db/column.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package db
     9  
    10  import (
    11  	"database/sql"
    12  	"encoding/json"
    13  	"fmt"
    14  	"reflect"
    15  	"strings"
    16  )
    17  
    18  // NewColumnFromFieldTag reads the contents of a field tag, ex: `json:"foo" db:"bar,isprimarykey,isserial"
    19  func NewColumnFromFieldTag(field reflect.StructField) *Column {
    20  	db := field.Tag.Get("db")
    21  	if db != "-" {
    22  		col := Column{}
    23  		col.FieldName = field.Name
    24  		col.ColumnName = field.Name
    25  		col.FieldType = field.Type
    26  		if db != "" {
    27  			pieces := strings.Split(db, ",")
    28  			if !strings.HasPrefix(db, ",") {
    29  				// note: split will return the original string
    30  				// if the string _does not contain_ the split string
    31  				// so this index is safe to do, but we also
    32  				// generally want the first token in the csv
    33  				// regardless (why we use pieces at all)
    34  				col.ColumnName = pieces[0]
    35  			}
    36  			if len(pieces) > 1 {
    37  				for _, p := range pieces[1:] {
    38  					switch strings.TrimSpace(strings.ToLower(p)) {
    39  					case "pk":
    40  						col.IsPrimaryKey = true
    41  					case "uk":
    42  						col.IsUniqueKey = true
    43  					case "auto", "serial":
    44  						col.IsAuto = true
    45  					case "readonly":
    46  						col.IsReadOnly = true
    47  					case "inline":
    48  						col.Inline = true
    49  					case "json":
    50  						col.IsJSON = true
    51  					default:
    52  						panic("invalid struct tag key; " + p)
    53  					}
    54  				}
    55  			}
    56  		}
    57  		return &col
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  // Column represents a single field on a struct that is mapped to the database.
    64  type Column struct {
    65  	Parent       *Column
    66  	TableName    string
    67  	FieldName    string
    68  	FieldType    reflect.Type
    69  	ColumnName   string
    70  	Index        int
    71  	IsPrimaryKey bool
    72  	IsUniqueKey  bool
    73  	IsAuto       bool
    74  	IsReadOnly   bool
    75  	IsJSON       bool
    76  	Inline       bool
    77  }
    78  
    79  // SetValue sets the field on a database mapped object to the instance of `value`.
    80  func (c Column) SetValue(reference, value any) error {
    81  	return c.SetValueReflected(reflectValue(reference), value)
    82  }
    83  
    84  // SetValueReflected sets the field on a reflect value object to the instance of `value`.
    85  func (c Column) SetValueReflected(reference reflect.Value, value any) error {
    86  	objectField := reference.FieldByName(c.FieldName)
    87  
    88  	// check if we've been passed a reference for the target object
    89  	if !objectField.CanSet() {
    90  		return fmt.Errorf("hit a field we can't set; did you forget to pass the object as a reference? field: %s", c.FieldName)
    91  	}
    92  
    93  	// special case for `db:"...,json"` fields.
    94  	if c.IsJSON {
    95  		var deserialized interface{}
    96  		if objectField.Kind() == reflect.Ptr {
    97  			deserialized = reflect.New(objectField.Type().Elem()).Interface()
    98  		} else {
    99  			deserialized = objectField.Addr().Interface()
   100  		}
   101  
   102  		switch valueContents := value.(type) {
   103  		case *sql.NullString:
   104  			if !valueContents.Valid {
   105  				objectField.Set(reflect.Zero(objectField.Type()))
   106  				return nil
   107  			}
   108  			if err := json.Unmarshal([]byte(valueContents.String), deserialized); err != nil {
   109  				return err
   110  			}
   111  		case sql.NullString:
   112  			if !valueContents.Valid {
   113  				objectField.Set(reflect.Zero(objectField.Type()))
   114  				return nil
   115  			}
   116  			if err := json.Unmarshal([]byte(valueContents.String), deserialized); err != nil {
   117  				return err
   118  			}
   119  		case *string:
   120  			if err := json.Unmarshal([]byte(*valueContents), deserialized); err != nil {
   121  				return err
   122  			}
   123  		case string:
   124  			if err := json.Unmarshal([]byte(valueContents), deserialized); err != nil {
   125  				return err
   126  			}
   127  		case *[]byte:
   128  			if err := json.Unmarshal(*valueContents, deserialized); err != nil {
   129  				return err
   130  			}
   131  		case []byte:
   132  			if err := json.Unmarshal(valueContents, deserialized); err != nil {
   133  				return err
   134  			}
   135  		default:
   136  			return fmt.Errorf("set value; invalid type for assignment to json field; field %s", c.FieldName)
   137  		}
   138  
   139  		if rv := reflect.ValueOf(deserialized); !rv.IsValid() {
   140  			objectField.Set(reflect.Zero(objectField.Type()))
   141  		} else {
   142  			if objectField.Kind() == reflect.Ptr {
   143  				objectField.Set(rv)
   144  			} else {
   145  				objectField.Set(rv.Elem())
   146  			}
   147  		}
   148  		return nil
   149  	}
   150  
   151  	valueReflected := reflectValue(value)
   152  	if !valueReflected.IsValid() { // if the value is nil
   153  		objectField.Set(reflect.Zero(objectField.Type())) // zero the field
   154  		return nil
   155  	}
   156  
   157  	// if we can direct assign the value to the field
   158  	if valueReflected.Type().AssignableTo(objectField.Type()) {
   159  		objectField.Set(valueReflected)
   160  		return nil
   161  	}
   162  
   163  	// convert and assign
   164  	if valueReflected.Type().ConvertibleTo(objectField.Type()) ||
   165  		haveSameUnderlyingTypes(objectField, valueReflected) {
   166  		objectField.Set(valueReflected.Convert(objectField.Type()))
   167  		return nil
   168  	}
   169  
   170  	if objectField.Kind() == reflect.Ptr && valueReflected.CanAddr() {
   171  		if valueReflected.Addr().Type().AssignableTo(objectField.Type()) {
   172  			objectField.Set(valueReflected.Addr())
   173  			return nil
   174  		}
   175  		if valueReflected.Addr().Type().ConvertibleTo(objectField.Type()) {
   176  			objectField.Set(valueReflected.Convert(objectField.Elem().Type()).Addr())
   177  			return nil
   178  		}
   179  		return fmt.Errorf("set value; can addr value but can't figure out how to assign or convert; field: %s", c.FieldName)
   180  	}
   181  
   182  	return fmt.Errorf("set value; ran out of ways to set the field; field: %s", c.FieldName)
   183  }
   184  
   185  // GetValue returns the value for a column on a given database mapped object.
   186  func (c Column) GetValue(object any) interface{} {
   187  	value := reflectValue(object)
   188  	if c.Parent != nil {
   189  		embedded := value.Field(c.Parent.Index)
   190  		valueField := embedded.Field(c.Index)
   191  		return valueField.Interface()
   192  	}
   193  	valueField := value.Field(c.Index)
   194  	return valueField.Interface()
   195  }