github.com/dkishere/pop/v6@v6.103.1/finders.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"regexp"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/dkishere/pop/v6/associations"
    13  	"github.com/dkishere/pop/v6/logging"
    14  	"github.com/gofrs/uuid"
    15  )
    16  
    17  var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$")
    18  var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$")
    19  
    20  // Find the first record of the model in the database with a particular id.
    21  //
    22  //	c.Find(&User{}, 1)
    23  func (c *Connection) Find(model interface{}, id interface{}) error {
    24  	return Q(c).Find(model, id)
    25  }
    26  
    27  // Find the first record of the model in the database with a particular id.
    28  //
    29  //	q.Find(&User{}, 1)
    30  func (q *Query) Find(model interface{}, id interface{}) error {
    31  	m := NewModel(model, q.Connection.Context())
    32  	idq := m.WhereID()
    33  	switch t := id.(type) {
    34  	case uuid.UUID:
    35  		return q.Where(idq, t.String()).First(model)
    36  	default:
    37  		// Pick argument type based on column type. This is required for keeping backwards compatibility with:
    38  		//
    39  		// https://github.com/gobuffalo/buffalo/blob/master/genny/resource/templates/use_model/actions/resource-name.go.tmpl#L76
    40  		pkt, err := m.PrimaryKeyType()
    41  		if err != nil {
    42  			return err
    43  		}
    44  
    45  		switch pkt {
    46  		case "int32", "int64", "uint32", "uint64", "int8", "uint8", "int16", "uint16", "int":
    47  			if tid, ok := id.(string); ok {
    48  				if intID, err := strconv.Atoi(tid); err == nil {
    49  					return q.Where(idq, intID).First(model)
    50  				}
    51  			}
    52  		}
    53  	}
    54  
    55  	return q.Where(idq, id).First(model)
    56  }
    57  
    58  // First record of the model in the database that matches the query.
    59  //
    60  //	c.First(&User{})
    61  func (c *Connection) First(model interface{}) error {
    62  	return Q(c).First(model)
    63  }
    64  
    65  // First record of the model in the database that matches the query.
    66  //
    67  //	q.Where("name = ?", "mark").First(&User{})
    68  func (q *Query) First(model interface{}) error {
    69  	err := q.Connection.timeFunc("First", func() error {
    70  		q.Limit(1)
    71  		m := NewModel(model, q.Connection.Context())
    72  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
    73  			return err
    74  		}
    75  		return m.afterFind(q.Connection)
    76  	})
    77  
    78  	if err != nil {
    79  		return err
    80  	}
    81  
    82  	if q.eager {
    83  		err = q.eagerAssociations(model)
    84  		q.disableEager()
    85  		return err
    86  	}
    87  	return nil
    88  }
    89  
    90  // Last record of the model in the database that matches the query.
    91  //
    92  //	c.Last(&User{})
    93  func (c *Connection) Last(model interface{}) error {
    94  	return Q(c).Last(model)
    95  }
    96  
    97  // Last record of the model in the database that matches the query.
    98  //
    99  //	q.Where("name = ?", "mark").Last(&User{})
   100  func (q *Query) Last(model interface{}) error {
   101  	err := q.Connection.timeFunc("Last", func() error {
   102  		q.Limit(1)
   103  		q.Order("created_at DESC, id DESC")
   104  		m := NewModel(model, q.Connection.Context())
   105  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
   106  			return err
   107  		}
   108  		return m.afterFind(q.Connection)
   109  	})
   110  
   111  	if err != nil {
   112  		return err
   113  	}
   114  
   115  	if q.eager {
   116  		err = q.eagerAssociations(model)
   117  		q.disableEager()
   118  		return err
   119  	}
   120  
   121  	return nil
   122  }
   123  
   124  // All retrieves all of the records in the database that match the query.
   125  //
   126  //	c.All(&[]User{})
   127  func (c *Connection) All(models interface{}) error {
   128  	return Q(c).All(models)
   129  }
   130  
   131  // All retrieves all of the records in the database that match the query.
   132  //
   133  //	q.Where("name = ?", "mark").All(&[]User{})
   134  func (q *Query) All(models interface{}) error {
   135  	err := q.Connection.timeFunc("All", func() error {
   136  		m := NewModel(models, q.Connection.Context())
   137  		err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q)
   138  		if err != nil {
   139  			return err
   140  		}
   141  		err = q.paginateModel(models)
   142  		if err != nil {
   143  			return err
   144  		}
   145  		return m.afterFind(q.Connection)
   146  	})
   147  
   148  	if err != nil {
   149  		return fmt.Errorf("unable to fetch records: %w", err)
   150  	}
   151  
   152  	if q.eager {
   153  		err = q.eagerAssociations(models)
   154  		q.disableEager()
   155  		return err
   156  	}
   157  
   158  	return nil
   159  }
   160  
   161  func (q *Query) paginateModel(models interface{}) error {
   162  	if q.Paginator == nil {
   163  		return nil
   164  	}
   165  
   166  	ct, err := q.Count(models)
   167  	if err != nil {
   168  		return err
   169  	}
   170  
   171  	q.Paginator.TotalEntriesSize = ct
   172  	st := reflect.ValueOf(models).Elem()
   173  	q.Paginator.CurrentEntriesSize = st.Len()
   174  	q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage
   175  	if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 {
   176  		q.Paginator.TotalPages = q.Paginator.TotalPages + 1
   177  	}
   178  	return nil
   179  }
   180  
   181  // Load loads all association or the fields specified in params for
   182  // an already loaded model.
   183  //
   184  // tx.First(&u)
   185  // tx.Load(&u)
   186  func (c *Connection) Load(model interface{}, fields ...string) error {
   187  	q := Q(c)
   188  	q.eagerFields = fields
   189  	err := q.eagerAssociations(model)
   190  	q.disableEager()
   191  	return err
   192  }
   193  
   194  func (q *Query) eagerAssociations(model interface{}) error {
   195  	if q.eagerMode == eagerModeNil {
   196  		q.eagerMode = loadingAssociationsStrategy
   197  	}
   198  	if q.eagerMode == EagerPreload {
   199  		return preload(q.Connection, model, q.eagerFields...)
   200  	}
   201  
   202  	return q.eagerDefaultAssociations(model)
   203  }
   204  
   205  func (q *Query) eagerDefaultAssociations(model interface{}) error {
   206  	var err error
   207  
   208  	// eagerAssociations for a slice or array model passed as a param.
   209  	v := reflect.ValueOf(model)
   210  	kind := reflect.Indirect(v).Kind()
   211  	if kind == reflect.Slice || kind == reflect.Array {
   212  		v = v.Elem()
   213  		for i := 0; i < v.Len(); i++ {
   214  			e := v.Index(i)
   215  			if e.Type().Kind() == reflect.Ptr {
   216  				// Already a pointer
   217  				err = q.eagerAssociations(e.Interface())
   218  			} else {
   219  				err = q.eagerAssociations(e.Addr().Interface())
   220  			}
   221  			if err != nil {
   222  				return err
   223  			}
   224  		}
   225  		return nil
   226  	}
   227  
   228  	// eagerAssociations for a single element
   229  	assos, err := associations.ForStruct(model, q.eagerFields...)
   230  	if err != nil {
   231  		return fmt.Errorf("could not retrieve associations: %w", err)
   232  	}
   233  
   234  	// disable eager mode for current connection.
   235  	q.eager = false
   236  	q.Connection.eager = false
   237  
   238  	for _, association := range assos {
   239  		if association.Skipped() {
   240  			continue
   241  		}
   242  
   243  		query := Q(q.Connection)
   244  
   245  		whereCondition, args := association.Constraint()
   246  		query = query.Where(whereCondition, args...)
   247  
   248  		// validates if association is Sortable
   249  		sortable := (*associations.AssociationSortable)(nil)
   250  		t := reflect.TypeOf(association)
   251  		if t.Implements(reflect.TypeOf(sortable).Elem()) {
   252  			m := reflect.ValueOf(association).MethodByName("OrderBy")
   253  			out := m.Call([]reflect.Value{})
   254  			orderClause := out[0].String()
   255  			if orderClause != "" {
   256  				query = query.Order(orderClause)
   257  			}
   258  		}
   259  
   260  		sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context()))
   261  		query = query.RawQuery(sqlSentence, args...)
   262  
   263  		if association.Kind() == reflect.Slice || association.Kind() == reflect.Array {
   264  			err = query.All(association.Interface())
   265  		}
   266  
   267  		if association.Kind() == reflect.Struct {
   268  			err = query.First(association.Interface())
   269  		}
   270  
   271  		if err != nil && !errors.Is(err, sql.ErrNoRows) {
   272  			return err
   273  		}
   274  
   275  		if err == sql.ErrNoRows {
   276  			continue
   277  		}
   278  
   279  		// load all inner associations.
   280  		innerAssociations := association.InnerAssociations()
   281  		for _, inner := range innerAssociations {
   282  			v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name)
   283  			innerQuery := Q(query.Connection)
   284  			innerQuery.eagerFields = inner.Fields
   285  			err = innerQuery.eagerAssociations(v.Addr().Interface())
   286  			if err != nil {
   287  				return err
   288  			}
   289  		}
   290  	}
   291  	return nil
   292  }
   293  
   294  // Exists returns true/false if a record exists in the database that matches
   295  // the query.
   296  //
   297  // 	q.Where("name = ?", "mark").Exists(&User{})
   298  func (q *Query) Exists(model interface{}) (bool, error) {
   299  	tmpQuery := Q(q.Connection)
   300  	q.Clone(tmpQuery) // avoid meddling with original query
   301  
   302  	var res bool
   303  
   304  	err := tmpQuery.Connection.timeFunc("Exists", func() error {
   305  		tmpQuery.Paginator = nil
   306  		tmpQuery.orderClauses = clauses{}
   307  		tmpQuery.limitResults = 0
   308  		query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context()))
   309  
   310  		// when query contains custom selected fields / executed using RawQuery,
   311  		// sql may already contains limit and offset
   312  		if rLimitOffset.MatchString(query) {
   313  			foundLimit := rLimitOffset.FindString(query)
   314  			query = query[0 : len(query)-len(foundLimit)]
   315  		} else if rLimit.MatchString(query) {
   316  			foundLimit := rLimit.FindString(query)
   317  			query = query[0 : len(query)-len(foundLimit)]
   318  		}
   319  
   320  		existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query)
   321  		log(logging.SQL, existsQuery, args...)
   322  		return q.Connection.Store.Get(&res, existsQuery, args...)
   323  	})
   324  	return res, err
   325  }
   326  
   327  // Count the number of records in the database.
   328  //
   329  //	c.Count(&User{})
   330  func (c *Connection) Count(model interface{}) (int, error) {
   331  	return Q(c).Count(model)
   332  }
   333  
   334  // Count the number of records in the database.
   335  //
   336  //	q.Where("name = ?", "mark").Count(&User{})
   337  func (q Query) Count(model interface{}) (int, error) {
   338  	return q.CountByField(model, "*")
   339  }
   340  
   341  // CountByField counts the number of records in the database, for a given field.
   342  //
   343  //	q.Where("sex = ?", "f").Count(&User{}, "name")
   344  func (q Query) CountByField(model interface{}, field string) (int, error) {
   345  	tmpQuery := Q(q.Connection)
   346  	q.Clone(tmpQuery) // avoid meddling with original query
   347  
   348  	res := &rowCount{}
   349  
   350  	err := tmpQuery.Connection.timeFunc("CountByField", func() error {
   351  		tmpQuery.Paginator = nil
   352  		tmpQuery.orderClauses = clauses{}
   353  		tmpQuery.limitResults = 0
   354  		query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context()))
   355  		// when query contains custom selected fields / executed using RawQuery,
   356  		//	sql may already contains limit and offset
   357  
   358  		if rLimitOffset.MatchString(query) {
   359  			foundLimit := rLimitOffset.FindString(query)
   360  			query = query[0 : len(query)-len(foundLimit)]
   361  		} else if rLimit.MatchString(query) {
   362  			foundLimit := rLimit.FindString(query)
   363  			query = query[0 : len(query)-len(foundLimit)]
   364  		}
   365  
   366  		countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query)
   367  		log(logging.SQL, countQuery, args...)
   368  		return q.Connection.Store.Get(res, countQuery, args...)
   369  	})
   370  	return res.Count, err
   371  }
   372  
   373  type rowCount struct {
   374  	Count int `db:"row_count"`
   375  }
   376  
   377  // Select allows to query only fields passed as parameter.
   378  // c.Select("field1", "field2").All(&model)
   379  // => SELECT field1, field2 FROM models
   380  func (c *Connection) Select(fields ...string) *Query {
   381  	return c.Q().Select(fields...)
   382  }
   383  
   384  // Select allows to query only fields passed as parameter.
   385  // c.Select("field1", "field2").All(&model)
   386  // => SELECT field1, field2 FROM models
   387  func (q *Query) Select(fields ...string) *Query {
   388  	for _, f := range fields {
   389  		if strings.TrimSpace(f) != "" {
   390  			q.addColumns = append(q.addColumns, f)
   391  		}
   392  	}
   393  	return q
   394  }