github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/model.go (about)

     1  package pop
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/gobuffalo/flect"
    11  	nflect "github.com/gobuffalo/flect/name"
    12  	"github.com/gobuffalo/pop/v6/columns"
    13  	"github.com/gofrs/uuid"
    14  )
    15  
    16  var nowFunc = time.Now
    17  
    18  // SetNowFunc allows an override of time.Now for customizing CreatedAt/UpdatedAt
    19  func SetNowFunc(f func() time.Time) {
    20  	nowFunc = f
    21  }
    22  
    23  // Value is the contents of a `Model`.
    24  type Value interface{}
    25  
    26  type modelIterable func(*Model) error
    27  
    28  // Model is used throughout Pop to wrap the end user interface
    29  // that is passed in to many functions.
    30  type Model struct {
    31  	Value
    32  	ctx context.Context
    33  	As  string
    34  }
    35  
    36  // NewModel returns a new model with the specified value and context.
    37  func NewModel(v Value, ctx context.Context) *Model {
    38  	return &Model{Value: v, ctx: ctx}
    39  }
    40  
    41  // ID returns the ID of the Model. All models must have an `ID` field this is
    42  // of type `int`,`int64` or of type `uuid.UUID`.
    43  func (m *Model) ID() interface{} {
    44  	fbn, err := m.fieldByName("ID")
    45  	if err != nil {
    46  		return nil
    47  	}
    48  	if pkt, _ := m.PrimaryKeyType(); pkt == "UUID" {
    49  		return fbn.Interface().(uuid.UUID).String()
    50  	}
    51  	return fbn.Interface()
    52  }
    53  
    54  // IDField returns the name of the DB field used for the ID.
    55  // By default, it will return "id".
    56  func (m *Model) IDField() string {
    57  	modelType := reflect.TypeOf(m.Value)
    58  
    59  	// remove all indirections
    60  	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Array {
    61  		modelType = modelType.Elem()
    62  	}
    63  
    64  	if modelType.Kind() == reflect.String {
    65  		return "id"
    66  	}
    67  
    68  	field, ok := modelType.FieldByName("ID")
    69  	if !ok {
    70  		return "id"
    71  	}
    72  	dbField := field.Tag.Get("db")
    73  	if dbField == "" {
    74  		return "id"
    75  	}
    76  	return dbField
    77  }
    78  
    79  // PrimaryKeyType gives the primary key type of the `Model`.
    80  func (m *Model) PrimaryKeyType() (string, error) {
    81  	fbn, err := m.fieldByName("ID")
    82  	if err != nil {
    83  		return "", fmt.Errorf("model %T is missing required field ID", m.Value)
    84  	}
    85  	return fbn.Type().Name(), nil
    86  }
    87  
    88  // TableNameAble interface allows for the customize table mapping
    89  // between a name and the database. For example the value
    90  // `User{}` will automatically map to "users". Implementing `TableNameAble`
    91  // would allow this to change to be changed to whatever you would like.
    92  type TableNameAble interface {
    93  	TableName() string
    94  }
    95  
    96  // TableNameAbleWithContext is equal to TableNameAble but will
    97  // be passed the queries' context. Useful in cases where the
    98  // table name depends on e.g.
    99  type TableNameAbleWithContext interface {
   100  	TableName(ctx context.Context) string
   101  }
   102  
   103  // TableName returns the corresponding name of the underlying database table
   104  // for a given `Model`. See also `TableNameAble` to change the default name of the table.
   105  func (m *Model) TableName() string {
   106  	if s, ok := m.Value.(string); ok {
   107  		return s
   108  	}
   109  
   110  	if n, ok := m.Value.(TableNameAble); ok {
   111  		return n.TableName()
   112  	}
   113  
   114  	if n, ok := m.Value.(TableNameAbleWithContext); ok {
   115  		if m.ctx == nil {
   116  			m.ctx = context.TODO()
   117  		}
   118  		return n.TableName(m.ctx)
   119  	}
   120  
   121  	return m.typeName(reflect.TypeOf(m.Value))
   122  }
   123  
   124  func (m *Model) Columns() columns.Columns {
   125  	return columns.ForStructWithAlias(m.Value, m.TableName(), m.As, m.IDField())
   126  }
   127  
   128  func (m *Model) cacheKey(t reflect.Type) string {
   129  	return t.PkgPath() + "." + t.Name()
   130  }
   131  
   132  func (m *Model) typeName(t reflect.Type) (name string) {
   133  	if t.Kind() == reflect.Ptr {
   134  		t = t.Elem()
   135  	}
   136  	switch t.Kind() {
   137  	case reflect.Slice, reflect.Array:
   138  		el := t.Elem()
   139  		if el.Kind() == reflect.Ptr {
   140  			el = el.Elem()
   141  		}
   142  
   143  		// validates if the elem of slice or array implements TableNameAble interface.
   144  		var tableNameAble *TableNameAble
   145  		if el.Implements(reflect.TypeOf(tableNameAble).Elem()) {
   146  			v := reflect.New(el)
   147  			out := v.MethodByName("TableName").Call([]reflect.Value{})
   148  			return out[0].String()
   149  		}
   150  
   151  		// validates if the elem of slice or array implements TableNameAbleWithContext interface.
   152  		var tableNameAbleWithContext *TableNameAbleWithContext
   153  		if el.Implements(reflect.TypeOf(tableNameAbleWithContext).Elem()) {
   154  			v := reflect.New(el)
   155  			out := v.MethodByName("TableName").Call([]reflect.Value{reflect.ValueOf(m.ctx)})
   156  			return out[0].String()
   157  
   158  			// We do not want to cache contextualized TableNames because that would break
   159  			// the contextualization.
   160  		}
   161  		return nflect.Tableize(el.Name())
   162  	default:
   163  		return nflect.Tableize(t.Name())
   164  	}
   165  }
   166  
   167  func (m *Model) fieldByName(s string) (reflect.Value, error) {
   168  	el := reflect.ValueOf(m.Value).Elem()
   169  	fbn := el.FieldByName(s)
   170  	if !fbn.IsValid() {
   171  		return fbn, fmt.Errorf("model does not have a field named %s", s)
   172  	}
   173  	return fbn, nil
   174  }
   175  
   176  func (m *Model) associationName() string {
   177  	tn := flect.Singularize(m.TableName())
   178  	return fmt.Sprintf("%s_id", tn)
   179  }
   180  
   181  func (m *Model) setID(i interface{}) {
   182  	fbn, err := m.fieldByName("ID")
   183  	if err == nil {
   184  		v := reflect.ValueOf(i)
   185  		switch fbn.Kind() {
   186  		case reflect.Int, reflect.Int64:
   187  			fbn.SetInt(v.Int())
   188  		default:
   189  			fbn.Set(reflect.ValueOf(i))
   190  		}
   191  	}
   192  }
   193  
   194  func (m *Model) setCreatedAt(now time.Time) {
   195  	fbn, err := m.fieldByName("CreatedAt")
   196  	if err == nil {
   197  		v := fbn.Interface()
   198  		if !IsZeroOfUnderlyingType(v) {
   199  			// Do not override already set CreatedAt
   200  			return
   201  		}
   202  		switch v.(type) {
   203  		case int, int64:
   204  			fbn.SetInt(now.Unix())
   205  		default:
   206  			fbn.Set(reflect.ValueOf(now))
   207  		}
   208  	}
   209  }
   210  
   211  func (m *Model) setUpdatedAt(now time.Time) {
   212  	fbn, err := m.fieldByName("UpdatedAt")
   213  	if err == nil {
   214  		v := fbn.Interface()
   215  		switch v.(type) {
   216  		case int, int64:
   217  			fbn.SetInt(now.Unix())
   218  		default:
   219  			fbn.Set(reflect.ValueOf(now))
   220  		}
   221  	}
   222  }
   223  
   224  func (m *Model) WhereID() string {
   225  	return fmt.Sprintf("%s.%s = ?", m.Alias(), m.IDField())
   226  }
   227  
   228  func (m *Model) Alias() string {
   229  	as := m.As
   230  	if as == "" {
   231  		as = strings.ReplaceAll(m.TableName(), ".", "_")
   232  	}
   233  	return as
   234  }
   235  
   236  func (m *Model) WhereNamedID() string {
   237  	return fmt.Sprintf("%s.%s = :%s", m.Alias(), m.IDField(), m.IDField())
   238  }
   239  
   240  func (m *Model) isSlice() bool {
   241  	v := reflect.Indirect(reflect.ValueOf(m.Value))
   242  	return v.Kind() == reflect.Slice || v.Kind() == reflect.Array
   243  }
   244  
   245  func (m *Model) iterate(fn modelIterable) error {
   246  	if m.isSlice() {
   247  		v := reflect.Indirect(reflect.ValueOf(m.Value))
   248  		for i := 0; i < v.Len(); i++ {
   249  			val := v.Index(i)
   250  			newModel := &Model{
   251  				Value: val.Addr().Interface(),
   252  				ctx:   m.ctx,
   253  			}
   254  			err := fn(newModel)
   255  
   256  			if err != nil {
   257  				return err
   258  			}
   259  		}
   260  		return nil
   261  	}
   262  
   263  	return fn(m)
   264  }