github.com/rjgonzale/pop/v5@v5.1.3-dev/finders.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"reflect"
     7  	"regexp"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"github.com/gobuffalo/pop/v5/associations"
    12  	"github.com/gobuffalo/pop/v5/logging"
    13  	"github.com/gofrs/uuid"
    14  	"github.com/pkg/errors"
    15  )
    16  
    17  var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$")
    18  var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$")
    19  
    20  // Find the first record of the model in the database with a particular id.
    21  //
    22  //	c.Find(&User{}, 1)
    23  func (c *Connection) Find(model interface{}, id interface{}) error {
    24  	return Q(c).Find(model, id)
    25  }
    26  
    27  // Find the first record of the model in the database with a particular id.
    28  //
    29  //	q.Find(&User{}, 1)
    30  func (q *Query) Find(model interface{}, id interface{}) error {
    31  	m := &Model{Value: model}
    32  	idq := m.whereID()
    33  	switch t := id.(type) {
    34  	case uuid.UUID:
    35  		return q.Where(idq, t.String()).First(model)
    36  	case string:
    37  		l := len(t)
    38  		if l > 0 {
    39  			// Handle leading '0':
    40  			// if the string have a leading '0' and is not "0", prevent parsing to int
    41  			if t[0] != '0' || l == 1 {
    42  				var err error
    43  				id, err = strconv.Atoi(t)
    44  				if err != nil {
    45  					return q.Where(idq, t).First(model)
    46  				}
    47  			}
    48  		}
    49  	}
    50  
    51  	return q.Where(idq, id).First(model)
    52  }
    53  
    54  // First record of the model in the database that matches the query.
    55  //
    56  //	c.First(&User{})
    57  func (c *Connection) First(model interface{}) error {
    58  	return Q(c).First(model)
    59  }
    60  
    61  // First record of the model in the database that matches the query.
    62  //
    63  //	q.Where("name = ?", "mark").First(&User{})
    64  func (q *Query) First(model interface{}) error {
    65  	err := q.Connection.timeFunc("First", func() error {
    66  		q.Limit(1)
    67  		m := &Model{Value: model}
    68  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
    69  			return err
    70  		}
    71  		return m.afterFind(q.Connection)
    72  	})
    73  
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	if q.eager {
    79  		err = q.eagerAssociations(model)
    80  		q.disableEager()
    81  		return err
    82  	}
    83  	return nil
    84  }
    85  
    86  // Last record of the model in the database that matches the query.
    87  //
    88  //	c.Last(&User{})
    89  func (c *Connection) Last(model interface{}) error {
    90  	return Q(c).Last(model)
    91  }
    92  
    93  // Last record of the model in the database that matches the query.
    94  //
    95  //	q.Where("name = ?", "mark").Last(&User{})
    96  func (q *Query) Last(model interface{}) error {
    97  	err := q.Connection.timeFunc("Last", func() error {
    98  		q.Limit(1)
    99  		q.Order("created_at DESC, id DESC")
   100  		m := &Model{Value: model}
   101  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
   102  			return err
   103  		}
   104  		return m.afterFind(q.Connection)
   105  	})
   106  
   107  	if err != nil {
   108  		return err
   109  	}
   110  
   111  	if q.eager {
   112  		err = q.eagerAssociations(model)
   113  		q.disableEager()
   114  		return err
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  // All retrieves all of the records in the database that match the query.
   121  //
   122  //	c.All(&[]User{})
   123  func (c *Connection) All(models interface{}) error {
   124  	return Q(c).All(models)
   125  }
   126  
   127  // All retrieves all of the records in the database that match the query.
   128  //
   129  //	q.Where("name = ?", "mark").All(&[]User{})
   130  func (q *Query) All(models interface{}) error {
   131  	err := q.Connection.timeFunc("All", func() error {
   132  		m := &Model{Value: models}
   133  		err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q)
   134  		if err != nil {
   135  			return err
   136  		}
   137  		err = q.paginateModel(models)
   138  		if err != nil {
   139  			return err
   140  		}
   141  		return m.afterFind(q.Connection)
   142  	})
   143  
   144  	if err != nil {
   145  		return errors.Wrap(err, "unable to fetch records")
   146  	}
   147  
   148  	if q.eager {
   149  		err = q.eagerAssociations(models)
   150  		q.disableEager()
   151  		return err
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func (q *Query) paginateModel(models interface{}) error {
   158  	if q.Paginator == nil {
   159  		return nil
   160  	}
   161  
   162  	ct, err := q.Count(models)
   163  	if err != nil {
   164  		return err
   165  	}
   166  
   167  	q.Paginator.TotalEntriesSize = ct
   168  	st := reflect.ValueOf(models).Elem()
   169  	q.Paginator.CurrentEntriesSize = st.Len()
   170  	q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage
   171  	if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 {
   172  		q.Paginator.TotalPages = q.Paginator.TotalPages + 1
   173  	}
   174  	return nil
   175  }
   176  
   177  // Load loads all association or the fields specified in params for
   178  // an already loaded model.
   179  //
   180  // tx.First(&u)
   181  // tx.Load(&u)
   182  func (c *Connection) Load(model interface{}, fields ...string) error {
   183  	q := Q(c)
   184  	q.eagerFields = fields
   185  	err := q.eagerAssociations(model)
   186  	q.disableEager()
   187  	return err
   188  }
   189  
   190  func (q *Query) eagerAssociations(model interface{}) error {
   191  	if q.eagerMode == eagerModeNil {
   192  		q.eagerMode = loadingAssociationsStrategy
   193  	}
   194  	if q.eagerMode == EagerPreload {
   195  		return preload(q.Connection, model, q.eagerFields...)
   196  	}
   197  
   198  	return q.eagerDefaultAssociations(model)
   199  }
   200  
   201  func (q *Query) eagerDefaultAssociations(model interface{}) error {
   202  	var err error
   203  
   204  	// eagerAssociations for a slice or array model passed as a param.
   205  	v := reflect.ValueOf(model)
   206  	kind := reflect.Indirect(v).Kind()
   207  	if kind == reflect.Slice || kind == reflect.Array {
   208  		v = v.Elem()
   209  		for i := 0; i < v.Len(); i++ {
   210  			e := v.Index(i)
   211  			if e.Type().Kind() == reflect.Ptr {
   212  				// Already a pointer
   213  				err = q.eagerAssociations(e.Interface())
   214  			} else {
   215  				err = q.eagerAssociations(e.Addr().Interface())
   216  			}
   217  			if err != nil {
   218  				return err
   219  			}
   220  		}
   221  		return nil
   222  	}
   223  
   224  	// eagerAssociations for a single element
   225  	assos, err := associations.ForStruct(model, q.eagerFields...)
   226  	if err != nil {
   227  		return errors.Wrap(err, "could not retrieve associations")
   228  	}
   229  
   230  	// disable eager mode for current connection.
   231  	q.eager = false
   232  	q.Connection.eager = false
   233  
   234  	for _, association := range assos {
   235  		if association.Skipped() {
   236  			continue
   237  		}
   238  
   239  		query := Q(q.Connection)
   240  
   241  		whereCondition, args := association.Constraint()
   242  		query = query.Where(whereCondition, args...)
   243  
   244  		// validates if association is Sortable
   245  		sortable := (*associations.AssociationSortable)(nil)
   246  		t := reflect.TypeOf(association)
   247  		if t.Implements(reflect.TypeOf(sortable).Elem()) {
   248  			m := reflect.ValueOf(association).MethodByName("OrderBy")
   249  			out := m.Call([]reflect.Value{})
   250  			orderClause := out[0].String()
   251  			if orderClause != "" {
   252  				query = query.Order(orderClause)
   253  			}
   254  		}
   255  
   256  		sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()})
   257  		query = query.RawQuery(sqlSentence, args...)
   258  
   259  		if association.Kind() == reflect.Slice || association.Kind() == reflect.Array {
   260  			err = query.All(association.Interface())
   261  		}
   262  
   263  		if association.Kind() == reflect.Struct {
   264  			err = query.First(association.Interface())
   265  		}
   266  
   267  		if err != nil && errors.Cause(err) != sql.ErrNoRows {
   268  			return err
   269  		}
   270  
   271  		// load all inner associations.
   272  		innerAssociations := association.InnerAssociations()
   273  		for _, inner := range innerAssociations {
   274  			v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name)
   275  			innerQuery := Q(query.Connection)
   276  			innerQuery.eagerFields = []string{inner.Fields}
   277  			err = innerQuery.eagerAssociations(v.Addr().Interface())
   278  			if err != nil {
   279  				return err
   280  			}
   281  		}
   282  	}
   283  	return nil
   284  }
   285  
   286  // Exists returns true/false if a record exists in the database that matches
   287  // the query.
   288  //
   289  // 	q.Where("name = ?", "mark").Exists(&User{})
   290  func (q *Query) Exists(model interface{}) (bool, error) {
   291  	tmpQuery := Q(q.Connection)
   292  	q.Clone(tmpQuery) //avoid meddling with original query
   293  
   294  	var res bool
   295  
   296  	err := tmpQuery.Connection.timeFunc("Exists", func() error {
   297  		tmpQuery.Paginator = nil
   298  		tmpQuery.orderClauses = clauses{}
   299  		tmpQuery.limitResults = 0
   300  		query, args := tmpQuery.ToSQL(&Model{Value: model})
   301  
   302  		// when query contains custom selected fields / executed using RawQuery,
   303  		// sql may already contains limit and offset
   304  		if rLimitOffset.MatchString(query) {
   305  			foundLimit := rLimitOffset.FindString(query)
   306  			query = query[0 : len(query)-len(foundLimit)]
   307  		} else if rLimit.MatchString(query) {
   308  			foundLimit := rLimit.FindString(query)
   309  			query = query[0 : len(query)-len(foundLimit)]
   310  		}
   311  
   312  		existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query)
   313  		log(logging.SQL, existsQuery, args...)
   314  		return q.Connection.Store.Get(&res, existsQuery, args...)
   315  	})
   316  	return res, err
   317  }
   318  
   319  // Count the number of records in the database.
   320  //
   321  //	c.Count(&User{})
   322  func (c *Connection) Count(model interface{}) (int, error) {
   323  	return Q(c).Count(model)
   324  }
   325  
   326  // Count the number of records in the database.
   327  //
   328  //	q.Where("name = ?", "mark").Count(&User{})
   329  func (q Query) Count(model interface{}) (int, error) {
   330  	return q.CountByField(model, "*")
   331  }
   332  
   333  // CountByField counts the number of records in the database, for a given field.
   334  //
   335  //	q.Where("sex = ?", "f").Count(&User{}, "name")
   336  func (q Query) CountByField(model interface{}, field string) (int, error) {
   337  	tmpQuery := Q(q.Connection)
   338  	q.Clone(tmpQuery) //avoid meddling with original query
   339  
   340  	res := &rowCount{}
   341  
   342  	err := tmpQuery.Connection.timeFunc("CountByField", func() error {
   343  		tmpQuery.Paginator = nil
   344  		tmpQuery.orderClauses = clauses{}
   345  		tmpQuery.limitResults = 0
   346  		query, args := tmpQuery.ToSQL(&Model{Value: model})
   347  		//when query contains custom selected fields / executed using RawQuery,
   348  		//	sql may already contains limit and offset
   349  
   350  		if rLimitOffset.MatchString(query) {
   351  			foundLimit := rLimitOffset.FindString(query)
   352  			query = query[0 : len(query)-len(foundLimit)]
   353  		} else if rLimit.MatchString(query) {
   354  			foundLimit := rLimit.FindString(query)
   355  			query = query[0 : len(query)-len(foundLimit)]
   356  		}
   357  
   358  		countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query)
   359  		log(logging.SQL, countQuery, args...)
   360  		return q.Connection.Store.Get(res, countQuery, args...)
   361  	})
   362  	return res.Count, err
   363  }
   364  
   365  type rowCount struct {
   366  	Count int `db:"row_count"`
   367  }
   368  
   369  // Select allows to query only fields passed as parameter.
   370  // c.Select("field1", "field2").All(&model)
   371  // => SELECT field1, field2 FROM models
   372  func (c *Connection) Select(fields ...string) *Query {
   373  	return c.Q().Select(fields...)
   374  }
   375  
   376  // Select allows to query only fields passed as parameter.
   377  // c.Select("field1", "field2").All(&model)
   378  // => SELECT field1, field2 FROM models
   379  func (q *Query) Select(fields ...string) *Query {
   380  	for _, f := range fields {
   381  		if strings.TrimSpace(f) != "" {
   382  			q.addColumns = append(q.addColumns, f)
   383  		}
   384  	}
   385  	return q
   386  }