github.com/dkishere/pop/v6@v6.103.1/model.go (about)

     1  package pop
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/dkishere/pop/v6/columns"
    11  	"github.com/gobuffalo/flect"
    12  	nflect "github.com/gobuffalo/flect/name"
    13  	"github.com/gofrs/uuid"
    14  )
    15  
    16  var nowFunc = time.Now
    17  
    18  // Value is the contents of a `Model`.
    19  type Value interface{}
    20  
    21  type modelIterable func(*Model) error
    22  
    23  // Model is used throughout Pop to wrap the end user interface
    24  // that is passed in to many functions.
    25  type Model struct {
    26  	Value
    27  	ctx context.Context
    28  	As  string
    29  }
    30  
    31  // NewModel returns a new model with the specified value and context.
    32  func NewModel(v Value, ctx context.Context) *Model {
    33  	return &Model{Value: v, ctx: ctx}
    34  }
    35  
    36  // ID returns the ID of the Model. All models must have an `ID` field this is
    37  // of type `int`,`int64` or of type `uuid.UUID`.
    38  func (m *Model) ID() interface{} {
    39  	fbn, err := m.field(0)
    40  	if err != nil {
    41  		return nil
    42  	}
    43  	if pkt, _ := m.PrimaryKeyType(); pkt == "UUID" {
    44  		return fbn.Interface().(uuid.UUID).String()
    45  	}
    46  	return fbn.Interface()
    47  }
    48  
    49  // IDField returns the name of the DB field used for the ID.
    50  // By default, it will return "id".
    51  func (m *Model) IDField() string {
    52  	modelType := reflect.TypeOf(m.Value)
    53  
    54  	// remove all indirections
    55  	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Array {
    56  		modelType = modelType.Elem()
    57  	}
    58  
    59  	if modelType.Kind() == reflect.String {
    60  		return "id"
    61  	}
    62  
    63  	field := modelType.Field(0)
    64  	dbField := field.Tag.Get("db")
    65  	if dbField == "" {
    66  		return "id"
    67  	}
    68  	return dbField
    69  }
    70  
    71  // PrimaryKeyType gives the primary key type of the `Model`.
    72  func (m *Model) PrimaryKeyType() (string, error) {
    73  	fbn, err := m.field(0)
    74  	if err != nil {
    75  		return "", fmt.Errorf("model %T is missing required field ID", m.Value)
    76  	}
    77  	return fbn.Type().Name(), nil
    78  }
    79  
    80  // TableNameAble interface allows for the customize table mapping
    81  // between a name and the database. For example the value
    82  // `User{}` will automatically map to "users". Implementing `TableNameAble`
    83  // would allow this to change to be changed to whatever you would like.
    84  type TableNameAble interface {
    85  	TableName() string
    86  }
    87  
    88  // TableNameAbleWithContext is equal to TableNameAble but will
    89  // be passed the queries' context. Useful in cases where the
    90  // table name depends on e.g.
    91  type TableNameAbleWithContext interface {
    92  	TableName(ctx context.Context) string
    93  }
    94  
    95  // TableName returns the corresponding name of the underlying database table
    96  // for a given `Model`. See also `TableNameAble` to change the default name of the table.
    97  func (m *Model) TableName() string {
    98  	if s, ok := m.Value.(string); ok {
    99  		return s
   100  	}
   101  
   102  	if n, ok := m.Value.(TableNameAble); ok {
   103  		return n.TableName()
   104  	}
   105  
   106  	if n, ok := m.Value.(TableNameAbleWithContext); ok {
   107  		if m.ctx == nil {
   108  			m.ctx = context.TODO()
   109  		}
   110  		return n.TableName(m.ctx)
   111  	}
   112  
   113  	return m.typeName(reflect.TypeOf(m.Value))
   114  }
   115  
   116  func (m *Model) Columns() columns.Columns {
   117  	return columns.ForStructWithAlias(m.Value, m.TableName(), m.As, m.IDField())
   118  }
   119  
   120  func (m *Model) cacheKey(t reflect.Type) string {
   121  	return t.PkgPath() + "." + t.Name()
   122  }
   123  
   124  func (m *Model) typeName(t reflect.Type) (name string) {
   125  	if t.Kind() == reflect.Ptr {
   126  		t = t.Elem()
   127  	}
   128  	switch t.Kind() {
   129  	case reflect.Slice, reflect.Array:
   130  		el := t.Elem()
   131  		if el.Kind() == reflect.Ptr {
   132  			el = el.Elem()
   133  		}
   134  
   135  		// validates if the elem of slice or array implements TableNameAble interface.
   136  		var tableNameAble *TableNameAble
   137  		if el.Implements(reflect.TypeOf(tableNameAble).Elem()) {
   138  			v := reflect.New(el)
   139  			out := v.MethodByName("TableName").Call([]reflect.Value{})
   140  			return out[0].String()
   141  		}
   142  
   143  		// validates if the elem of slice or array implements TableNameAbleWithContext interface.
   144  		var tableNameAbleWithContext *TableNameAbleWithContext
   145  		if el.Implements(reflect.TypeOf(tableNameAbleWithContext).Elem()) {
   146  			v := reflect.New(el)
   147  			out := v.MethodByName("TableName").Call([]reflect.Value{reflect.ValueOf(m.ctx)})
   148  			return out[0].String()
   149  
   150  			// We do not want to cache contextualized TableNames because that would break
   151  			// the contextualization.
   152  		}
   153  		return nflect.Tableize(el.Name())
   154  	default:
   155  		return nflect.Tableize(t.Name())
   156  	}
   157  }
   158  
   159  func (m *Model) field(i int) (reflect.Value, error) {
   160  	el := reflect.ValueOf(m.Value).Elem()
   161  	fbn := el.Field(i)
   162  	if !fbn.IsValid() {
   163  		return fbn, fmt.Errorf("model does not have a field %s", i)
   164  	}
   165  	return fbn, nil
   166  }
   167  
   168  func (m *Model) fieldByName(s string) (reflect.Value, error) {
   169  	el := reflect.ValueOf(m.Value).Elem()
   170  	fbn := el.FieldByName(s)
   171  	if !fbn.IsValid() {
   172  		return fbn, fmt.Errorf("model does not have a field named %s", s)
   173  	}
   174  	return fbn, nil
   175  }
   176  
   177  func (m *Model) associationName() string {
   178  	tn := flect.Singularize(m.TableName())
   179  	return fmt.Sprintf("%s_id", tn)
   180  }
   181  
   182  func (m *Model) setID(i interface{}) {
   183  	fbn, err := m.field(0)
   184  	if err == nil {
   185  		v := reflect.ValueOf(i)
   186  		switch fbn.Kind() {
   187  		case reflect.Int, reflect.Int64:
   188  			fbn.SetInt(v.Int())
   189  		default:
   190  			fbn.Set(reflect.ValueOf(i))
   191  		}
   192  	}
   193  }
   194  
   195  func (m *Model) setCreatedAt(now time.Time) {
   196  	fbn, err := m.fieldByName("CreatedAt")
   197  	if err == nil {
   198  		v := fbn.Interface()
   199  		if !IsZeroOfUnderlyingType(v) {
   200  			// Do not override already set CreatedAt
   201  			return
   202  		}
   203  		switch v.(type) {
   204  		case int, int64:
   205  			fbn.SetInt(now.Unix())
   206  		default:
   207  			fbn.Set(reflect.ValueOf(now))
   208  		}
   209  	}
   210  }
   211  
   212  func (m *Model) setUpdatedAt(now time.Time) {
   213  	fbn, err := m.fieldByName("UpdatedAt")
   214  	if err == nil {
   215  		v := fbn.Interface()
   216  		switch v.(type) {
   217  		case int, int64:
   218  			fbn.SetInt(now.Unix())
   219  		default:
   220  			fbn.Set(reflect.ValueOf(now))
   221  		}
   222  	}
   223  }
   224  
   225  func (m *Model) WhereID() string {
   226  	return fmt.Sprintf("%s.%s = ?", m.Alias(), m.IDField())
   227  }
   228  
   229  func (m *Model) Alias() string {
   230  	as := m.As
   231  	if as == "" {
   232  		as = strings.ReplaceAll(m.TableName(), ".", "_")
   233  	}
   234  	return as
   235  }
   236  
   237  func (m *Model) WhereNamedID() string {
   238  	return fmt.Sprintf("%s.%s = :%s", m.Alias(), m.IDField(), m.IDField())
   239  }
   240  
   241  func (m *Model) isSlice() bool {
   242  	v := reflect.Indirect(reflect.ValueOf(m.Value))
   243  	return v.Kind() == reflect.Slice || v.Kind() == reflect.Array
   244  }
   245  
   246  func (m *Model) iterate(fn modelIterable) error {
   247  	if m.isSlice() {
   248  		v := reflect.Indirect(reflect.ValueOf(m.Value))
   249  		for i := 0; i < v.Len(); i++ {
   250  			val := v.Index(i)
   251  			newModel := &Model{
   252  				Value: val.Addr().Interface(),
   253  				ctx:   m.ctx,
   254  			}
   255  			err := fn(newModel)
   256  
   257  			if err != nil {
   258  				return err
   259  			}
   260  		}
   261  		return nil
   262  	}
   263  
   264  	return fn(m)
   265  }