github.com/friesencr/pop/v6@v6.1.6/finders.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"regexp"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/friesencr/pop/v6/associations"
    13  	"github.com/friesencr/pop/v6/logging"
    14  	"github.com/gofrs/uuid/v5"
    15  )
    16  
    17  var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$")
    18  var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$")
    19  
    20  // Find the first record of the model in the database with a particular id.
    21  //
    22  //	c.Find(&User{}, 1)
    23  func (c *Connection) Find(model interface{}, id interface{}) error {
    24  	return Q(c).Find(model, id)
    25  }
    26  
    27  // Find the first record of the model in the database with a particular id.
    28  //
    29  //	q.Find(&User{}, 1)
    30  func (q *Query) Find(model interface{}, id interface{}) error {
    31  	m := NewModel(model, q.Connection.Context())
    32  	idq := m.WhereID()
    33  	switch t := id.(type) {
    34  	case uuid.UUID:
    35  		return q.Where(idq, t.String()).First(model)
    36  	default:
    37  		// Pick argument type based on column type. This is required for keeping backwards compatibility with:
    38  		//
    39  		// https://github.com/gobuffalo/buffalo/blob/master/genny/resource/templates/use_model/actions/resource-name.go.tmpl#L76
    40  		pkt, err := m.PrimaryKeyType()
    41  		if err != nil {
    42  			return err
    43  		}
    44  
    45  		switch pkt {
    46  		case "int32", "int64", "uint32", "uint64", "int8", "uint8", "int16", "uint16", "int":
    47  			if tid, ok := id.(string); ok {
    48  				if intID, err := strconv.Atoi(tid); err == nil {
    49  					return q.Where(idq, intID).First(model)
    50  				}
    51  			}
    52  		}
    53  	}
    54  
    55  	return q.Where(idq, id).First(model)
    56  }
    57  
    58  // First record of the model in the database that matches the query.
    59  //
    60  //	c.First(&User{})
    61  func (c *Connection) First(model interface{}) error {
    62  	return Q(c).First(model)
    63  }
    64  
    65  // First record of the model in the database that matches the query.
    66  //
    67  //	q.Where("name = ?", "mark").First(&User{})
    68  func (q *Query) First(model interface{}) error {
    69  	var m *Model
    70  	err := q.Connection.timeFunc("First", func() error {
    71  		q.Limit(1)
    72  		m = NewModel(model, q.Connection.Context())
    73  		if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil {
    74  			return err
    75  		}
    76  		return m.afterFind(q.Connection, false)
    77  	})
    78  
    79  	if err != nil {
    80  		return err
    81  	}
    82  
    83  	if q.eager {
    84  		err := q.eagerAssociations(model)
    85  		q.disableEager()
    86  		if err != nil {
    87  			return err
    88  		}
    89  		return m.afterFind(q.Connection, true)
    90  	}
    91  
    92  	return nil
    93  }
    94  
    95  // Last record of the model in the database that matches the query.
    96  //
    97  //	c.Last(&User{})
    98  func (c *Connection) Last(model interface{}) error {
    99  	return Q(c).Last(model)
   100  }
   101  
   102  // Last record of the model in the database that matches the query.
   103  //
   104  //	q.Where("name = ?", "mark").Last(&User{})
   105  func (q *Query) Last(model interface{}) error {
   106  	var m *Model
   107  	err := q.Connection.timeFunc("Last", func() error {
   108  		q.Limit(1)
   109  		q.Order("created_at DESC, id DESC")
   110  		m = NewModel(model, q.Connection.Context())
   111  		if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil {
   112  			return err
   113  		}
   114  		return m.afterFind(q.Connection, false)
   115  	})
   116  
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	if q.eager {
   122  		err = q.eagerAssociations(model)
   123  		q.disableEager()
   124  		if err != nil {
   125  			return err
   126  		}
   127  		return m.afterFind(q.Connection, true)
   128  	}
   129  
   130  	return nil
   131  }
   132  
   133  // All retrieves all of the records in the database that match the query.
   134  //
   135  //	c.All(&[]User{})
   136  func (c *Connection) All(models interface{}) error {
   137  	return Q(c).All(models)
   138  }
   139  
   140  // All retrieves all of the records in the database that match the query.
   141  //
   142  //	q.Where("name = ?", "mark").All(&[]User{})
   143  func (q *Query) All(models interface{}) error {
   144  	var m *Model
   145  	err := q.Connection.timeFunc("All", func() error {
   146  		m = NewModel(models, q.Connection.Context())
   147  		err := q.Connection.Dialect.SelectMany(q.Connection, m, *q)
   148  		if err != nil {
   149  			return err
   150  		}
   151  
   152  		err = q.paginateModel(models)
   153  		if err != nil {
   154  			return err
   155  		}
   156  
   157  		return m.afterFind(q.Connection, false)
   158  	})
   159  
   160  	if err != nil {
   161  		return fmt.Errorf("unable to fetch records: %w", err)
   162  	}
   163  
   164  	if q.eager {
   165  		err = q.eagerAssociations(models)
   166  		q.disableEager()
   167  		if err != nil {
   168  			return err
   169  		}
   170  		return m.afterFind(q.Connection, true)
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  func (q *Query) paginateModel(models interface{}) error {
   177  	if q.Paginator == nil {
   178  		return nil
   179  	}
   180  
   181  	ct, err := q.Count(models)
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	q.Paginator.TotalEntriesSize = ct
   187  	st := reflect.ValueOf(models).Elem()
   188  	q.Paginator.CurrentEntriesSize = st.Len()
   189  	q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage
   190  	if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 {
   191  		q.Paginator.TotalPages = q.Paginator.TotalPages + 1
   192  	}
   193  	return nil
   194  }
   195  
   196  // Load loads all association or the fields specified in params for
   197  // an already loaded model.
   198  //
   199  // tx.First(&u)
   200  // tx.Load(&u)
   201  func (c *Connection) Load(model interface{}, fields ...string) error {
   202  	q := Q(c)
   203  	q.eagerFields = fields
   204  	err := q.eagerAssociations(model)
   205  	q.disableEager()
   206  	return err
   207  }
   208  
   209  func (q *Query) eagerAssociations(model interface{}) error {
   210  	if q.eagerMode == eagerModeNil {
   211  		q.eagerMode = loadingAssociationsStrategy
   212  	}
   213  	if q.eagerMode == EagerPreload {
   214  		return preload(q.Connection, model, q.eagerFields...)
   215  	}
   216  
   217  	return q.eagerDefaultAssociations(model)
   218  }
   219  
   220  func (q *Query) eagerDefaultAssociations(model interface{}) error {
   221  	var err error
   222  
   223  	// eagerAssociations for a slice or array model passed as a param.
   224  	v := reflect.ValueOf(model)
   225  	kind := reflect.Indirect(v).Kind()
   226  	if kind == reflect.Slice || kind == reflect.Array {
   227  		v = v.Elem()
   228  		for i := 0; i < v.Len(); i++ {
   229  			e := v.Index(i)
   230  			if e.Type().Kind() == reflect.Ptr {
   231  				// Already a pointer
   232  				err = q.eagerAssociations(e.Interface())
   233  			} else {
   234  				err = q.eagerAssociations(e.Addr().Interface())
   235  			}
   236  			if err != nil {
   237  				return err
   238  			}
   239  		}
   240  		return nil
   241  	}
   242  
   243  	// eagerAssociations for a single element
   244  	assos, err := associations.ForStruct(model, q.eagerFields...)
   245  	if err != nil {
   246  		return fmt.Errorf("could not retrieve associations: %w", err)
   247  	}
   248  
   249  	// disable eager mode for current connection.
   250  	q.eager = false
   251  	q.Connection.eager = false
   252  
   253  	for _, association := range assos {
   254  		if association.Skipped() {
   255  			continue
   256  		}
   257  
   258  		query := Q(q.Connection)
   259  
   260  		whereCondition, args := association.Constraint()
   261  		query = query.Where(whereCondition, args...)
   262  
   263  		// validates if association is Sortable
   264  		sortable := (*associations.AssociationSortable)(nil)
   265  		t := reflect.TypeOf(association)
   266  		if t.Implements(reflect.TypeOf(sortable).Elem()) {
   267  			m := reflect.ValueOf(association).MethodByName("OrderBy")
   268  			out := m.Call([]reflect.Value{})
   269  			orderClause := out[0].String()
   270  			if orderClause != "" {
   271  				query = query.Order(orderClause)
   272  			}
   273  		}
   274  
   275  		sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context()))
   276  		query = query.RawQuery(sqlSentence, args...)
   277  
   278  		if association.Kind() == reflect.Slice || association.Kind() == reflect.Array {
   279  			err = query.All(association.Interface())
   280  		}
   281  
   282  		if association.Kind() == reflect.Struct {
   283  			err = query.First(association.Interface())
   284  		}
   285  
   286  		if err != nil && !errors.Is(err, sql.ErrNoRows) {
   287  			return err
   288  		}
   289  
   290  		if err == sql.ErrNoRows {
   291  			continue
   292  		}
   293  
   294  		// load all inner associations.
   295  		innerAssociations := association.InnerAssociations()
   296  		for _, inner := range innerAssociations {
   297  			v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name)
   298  			innerQuery := Q(query.Connection)
   299  			innerQuery.eagerFields = inner.Fields
   300  
   301  			switch v.Kind() {
   302  			case reflect.Ptr:
   303  				err = innerQuery.eagerAssociations(v.Interface())
   304  			default:
   305  				err = innerQuery.eagerAssociations(v.Addr().Interface())
   306  			}
   307  
   308  			if err != nil {
   309  				return err
   310  			}
   311  		}
   312  	}
   313  	return nil
   314  }
   315  
   316  // Exists returns true/false if a record exists in the database that matches
   317  // the query.
   318  //
   319  //	q.Where("name = ?", "mark").Exists(&User{})
   320  func (q *Query) Exists(model interface{}) (bool, error) {
   321  	tmpQuery := Q(q.Connection)
   322  	q.Clone(tmpQuery) // avoid meddling with original query
   323  
   324  	var res bool
   325  
   326  	err := tmpQuery.Connection.timeFunc("Exists", func() error {
   327  		tmpQuery.Paginator = nil
   328  		tmpQuery.orderClauses = clauses{}
   329  		tmpQuery.limitResults = 0
   330  		query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context()))
   331  
   332  		// when query contains custom selected fields / executed using RawQuery,
   333  		// sql may already contains limit and offset
   334  		if rLimitOffset.MatchString(query) {
   335  			foundLimit := rLimitOffset.FindString(query)
   336  			query = query[0 : len(query)-len(foundLimit)]
   337  		} else if rLimit.MatchString(query) {
   338  			foundLimit := rLimit.FindString(query)
   339  			query = query[0 : len(query)-len(foundLimit)]
   340  		}
   341  
   342  		existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query)
   343  		txlog(logging.SQL, q.Connection, existsQuery, args...)
   344  		return q.Connection.Store.Get(&res, existsQuery, args...)
   345  	})
   346  	return res, err
   347  }
   348  
   349  // Count the number of records in the database.
   350  //
   351  //	c.Count(&User{})
   352  func (c *Connection) Count(model interface{}) (int, error) {
   353  	return Q(c).Count(model)
   354  }
   355  
   356  // Count the number of records in the database.
   357  //
   358  //	q.Where("name = ?", "mark").Count(&User{})
   359  func (q Query) Count(model interface{}) (int, error) {
   360  	return q.CountByField(model, "*")
   361  }
   362  
   363  // CountByField counts the number of records in the database, for a given field.
   364  //
   365  //	q.Where("sex = ?", "f").Count(&User{}, "name")
   366  func (q Query) CountByField(model interface{}, field string) (int, error) {
   367  	tmpQuery := Q(q.Connection)
   368  	q.Clone(tmpQuery) // avoid meddling with original query
   369  
   370  	res := &rowCount{}
   371  
   372  	err := tmpQuery.Connection.timeFunc("CountByField", func() error {
   373  		tmpQuery.Paginator = nil
   374  		tmpQuery.orderClauses = clauses{}
   375  		tmpQuery.limitResults = 0
   376  		query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context()))
   377  		// when query contains custom selected fields / executed using RawQuery,
   378  		//	sql may already contains limit and offset
   379  
   380  		if rLimitOffset.MatchString(query) {
   381  			foundLimit := rLimitOffset.FindString(query)
   382  			query = query[0 : len(query)-len(foundLimit)]
   383  		} else if rLimit.MatchString(query) {
   384  			foundLimit := rLimit.FindString(query)
   385  			query = query[0 : len(query)-len(foundLimit)]
   386  		}
   387  
   388  		countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query)
   389  		txlog(logging.SQL, q.Connection, countQuery, args...)
   390  		return q.Connection.Store.Get(res, countQuery, args...)
   391  	})
   392  	return res.Count, err
   393  }
   394  
   395  type rowCount struct {
   396  	Count int `db:"row_count"`
   397  }
   398  
   399  // Select allows to query only fields passed as parameter.
   400  // c.Select("field1", "field2").All(&model)
   401  // => SELECT field1, field2 FROM models
   402  func (c *Connection) Select(fields ...string) *Query {
   403  	return c.Q().Select(fields...)
   404  }
   405  
   406  // Select allows to query only fields passed as parameter.
   407  // c.Select("field1", "field2").All(&model)
   408  // => SELECT field1, field2 FROM models
   409  func (q *Query) Select(fields ...string) *Query {
   410  	for _, f := range fields {
   411  		if strings.TrimSpace(f) != "" {
   412  			q.addColumns = append(q.addColumns, f)
   413  		}
   414  	}
   415  	return q
   416  }