github.com/solongordon/pop@v4.10.0+incompatible/finders.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"reflect"
     7  	"regexp"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"github.com/gobuffalo/pop/associations"
    12  	"github.com/gobuffalo/pop/logging"
    13  	"github.com/gofrs/uuid"
    14  	"github.com/pkg/errors"
    15  )
    16  
    17  var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$")
    18  var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$")
    19  
    20  // Find the first record of the model in the database with a particular id.
    21  //
    22  //	c.Find(&User{}, 1)
    23  func (c *Connection) Find(model interface{}, id interface{}) error {
    24  	return Q(c).Find(model, id)
    25  }
    26  
    27  // Find the first record of the model in the database with a particular id.
    28  //
    29  //	q.Find(&User{}, 1)
    30  func (q *Query) Find(model interface{}, id interface{}) error {
    31  	m := &Model{Value: model}
    32  	idq := m.whereID()
    33  	switch t := id.(type) {
    34  	case uuid.UUID:
    35  		return q.Where(idq, t.String()).First(model)
    36  	case string:
    37  		l := len(t)
    38  		if l > 0 {
    39  			// Handle leading '0':
    40  			// if the string have a leading '0' and is not "0", prevent parsing to int
    41  			if t[0] != '0' || l == 1 {
    42  				var err error
    43  				id, err = strconv.Atoi(t)
    44  				if err != nil {
    45  					return q.Where(idq, t).First(model)
    46  				}
    47  			}
    48  		}
    49  	}
    50  
    51  	return q.Where(idq, id).First(model)
    52  }
    53  
    54  // First record of the model in the database that matches the query.
    55  //
    56  //	c.First(&User{})
    57  func (c *Connection) First(model interface{}) error {
    58  	return Q(c).First(model)
    59  }
    60  
    61  // First record of the model in the database that matches the query.
    62  //
    63  //	q.Where("name = ?", "mark").First(&User{})
    64  func (q *Query) First(model interface{}) error {
    65  	err := q.Connection.timeFunc("First", func() error {
    66  		q.Limit(1)
    67  		m := &Model{Value: model}
    68  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
    69  			return err
    70  		}
    71  		return m.afterFind(q.Connection)
    72  	})
    73  
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	if q.eager {
    79  		err = q.eagerAssociations(model)
    80  		q.disableEager()
    81  		return err
    82  	}
    83  	return nil
    84  }
    85  
    86  // Last record of the model in the database that matches the query.
    87  //
    88  //	c.Last(&User{})
    89  func (c *Connection) Last(model interface{}) error {
    90  	return Q(c).Last(model)
    91  }
    92  
    93  // Last record of the model in the database that matches the query.
    94  //
    95  //	q.Where("name = ?", "mark").Last(&User{})
    96  func (q *Query) Last(model interface{}) error {
    97  	err := q.Connection.timeFunc("Last", func() error {
    98  		q.Limit(1)
    99  		q.Order("created_at DESC, id DESC")
   100  		m := &Model{Value: model}
   101  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
   102  			return err
   103  		}
   104  		return m.afterFind(q.Connection)
   105  	})
   106  
   107  	if err != nil {
   108  		return err
   109  	}
   110  
   111  	if q.eager {
   112  		err = q.eagerAssociations(model)
   113  		q.disableEager()
   114  		return err
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  // All retrieves all of the records in the database that match the query.
   121  //
   122  //	c.All(&[]User{})
   123  func (c *Connection) All(models interface{}) error {
   124  	return Q(c).All(models)
   125  }
   126  
   127  // All retrieves all of the records in the database that match the query.
   128  //
   129  //	q.Where("name = ?", "mark").All(&[]User{})
   130  func (q *Query) All(models interface{}) error {
   131  	err := q.Connection.timeFunc("All", func() error {
   132  		m := &Model{Value: models}
   133  		err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q)
   134  		if err != nil {
   135  			return err
   136  		}
   137  		err = q.paginateModel(models)
   138  		if err != nil {
   139  			return err
   140  		}
   141  		return m.afterFind(q.Connection)
   142  	})
   143  
   144  	if err != nil {
   145  		return errors.Wrap(err, "unable to fetch records")
   146  	}
   147  
   148  	if q.eager {
   149  		err = q.eagerAssociations(models)
   150  		q.disableEager()
   151  		return err
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func (q *Query) paginateModel(models interface{}) error {
   158  	if q.Paginator == nil {
   159  		return nil
   160  	}
   161  
   162  	ct, err := q.Count(models)
   163  	if err != nil {
   164  		return err
   165  	}
   166  
   167  	q.Paginator.TotalEntriesSize = ct
   168  	st := reflect.ValueOf(models).Elem()
   169  	q.Paginator.CurrentEntriesSize = st.Len()
   170  	q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage
   171  	if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 {
   172  		q.Paginator.TotalPages = q.Paginator.TotalPages + 1
   173  	}
   174  	return nil
   175  }
   176  
   177  // Load loads all association or the fields specified in params for
   178  // an already loaded model.
   179  //
   180  // tx.First(&u)
   181  // tx.Load(&u)
   182  func (c *Connection) Load(model interface{}, fields ...string) error {
   183  	q := Q(c)
   184  	q.eagerFields = fields
   185  	err := q.eagerAssociations(model)
   186  	q.disableEager()
   187  	return err
   188  }
   189  
   190  func (q *Query) eagerAssociations(model interface{}) error {
   191  	var err error
   192  
   193  	// eagerAssociations for a slice or array model passed as a param.
   194  	v := reflect.ValueOf(model)
   195  	if reflect.Indirect(v).Kind() == reflect.Slice ||
   196  		reflect.Indirect(v).Kind() == reflect.Array {
   197  		v = v.Elem()
   198  		for i := 0; i < v.Len(); i++ {
   199  			err = q.eagerAssociations(v.Index(i).Addr().Interface())
   200  			if err != nil {
   201  				return err
   202  			}
   203  		}
   204  		return err
   205  	}
   206  
   207  	assos, err := associations.ForStruct(model, q.eagerFields...)
   208  	if err != nil {
   209  		return err
   210  	}
   211  
   212  	// disable eager mode for current connection.
   213  	q.eager = false
   214  	q.Connection.eager = false
   215  
   216  	for _, association := range assos {
   217  		if association.Skipped() {
   218  			continue
   219  		}
   220  
   221  		query := Q(q.Connection)
   222  
   223  		whereCondition, args := association.Constraint()
   224  		query = query.Where(whereCondition, args...)
   225  
   226  		// validates if association is Sortable
   227  		sortable := (*associations.AssociationSortable)(nil)
   228  		t := reflect.TypeOf(association)
   229  		if t.Implements(reflect.TypeOf(sortable).Elem()) {
   230  			m := reflect.ValueOf(association).MethodByName("OrderBy")
   231  			out := m.Call([]reflect.Value{})
   232  			orderClause := out[0].String()
   233  			if orderClause != "" {
   234  				query = query.Order(orderClause)
   235  			}
   236  		}
   237  
   238  		sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()})
   239  		query = query.RawQuery(sqlSentence, args...)
   240  
   241  		if association.Kind() == reflect.Slice || association.Kind() == reflect.Array {
   242  			err = query.All(association.Interface())
   243  		}
   244  
   245  		if association.Kind() == reflect.Struct {
   246  			err = query.First(association.Interface())
   247  		}
   248  
   249  		if err != nil && errors.Cause(err) != sql.ErrNoRows {
   250  			return err
   251  		}
   252  
   253  		// load all inner associations.
   254  		innerAssociations := association.InnerAssociations()
   255  		for _, inner := range innerAssociations {
   256  			v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name)
   257  			innerQuery := Q(query.Connection)
   258  			innerQuery.eagerFields = []string{inner.Fields}
   259  			err = innerQuery.eagerAssociations(v.Addr().Interface())
   260  			if err != nil {
   261  				return err
   262  			}
   263  		}
   264  	}
   265  	return nil
   266  }
   267  
   268  // Exists returns true/false if a record exists in the database that matches
   269  // the query.
   270  //
   271  // 	q.Where("name = ?", "mark").Exists(&User{})
   272  func (q *Query) Exists(model interface{}) (bool, error) {
   273  	tmpQuery := Q(q.Connection)
   274  	q.Clone(tmpQuery) //avoid meddling with original query
   275  
   276  	var res bool
   277  
   278  	err := tmpQuery.Connection.timeFunc("Exists", func() error {
   279  		tmpQuery.Paginator = nil
   280  		tmpQuery.orderClauses = clauses{}
   281  		tmpQuery.limitResults = 0
   282  		query, args := tmpQuery.ToSQL(&Model{Value: model})
   283  
   284  		// when query contains custom selected fields / executed using RawQuery,
   285  		// sql may already contains limit and offset
   286  		if rLimitOffset.MatchString(query) {
   287  			foundLimit := rLimitOffset.FindString(query)
   288  			query = query[0 : len(query)-len(foundLimit)]
   289  		} else if rLimit.MatchString(query) {
   290  			foundLimit := rLimit.FindString(query)
   291  			query = query[0 : len(query)-len(foundLimit)]
   292  		}
   293  
   294  		existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query)
   295  		log(logging.SQL, existsQuery, args...)
   296  		return q.Connection.Store.Get(&res, existsQuery, args...)
   297  	})
   298  	return res, err
   299  }
   300  
   301  // Count the number of records in the database.
   302  //
   303  //	c.Count(&User{})
   304  func (c *Connection) Count(model interface{}) (int, error) {
   305  	return Q(c).Count(model)
   306  }
   307  
   308  // Count the number of records in the database.
   309  //
   310  //	q.Where("name = ?", "mark").Count(&User{})
   311  func (q Query) Count(model interface{}) (int, error) {
   312  	return q.CountByField(model, "*")
   313  }
   314  
   315  // CountByField counts the number of records in the database, for a given field.
   316  //
   317  //	q.Where("sex = ?", "f").Count(&User{}, "name")
   318  func (q Query) CountByField(model interface{}, field string) (int, error) {
   319  	tmpQuery := Q(q.Connection)
   320  	q.Clone(tmpQuery) //avoid meddling with original query
   321  
   322  	res := &rowCount{}
   323  
   324  	err := tmpQuery.Connection.timeFunc("CountByField", func() error {
   325  		tmpQuery.Paginator = nil
   326  		tmpQuery.orderClauses = clauses{}
   327  		tmpQuery.limitResults = 0
   328  		query, args := tmpQuery.ToSQL(&Model{Value: model})
   329  		//when query contains custom selected fields / executed using RawQuery,
   330  		//	sql may already contains limit and offset
   331  
   332  		if rLimitOffset.MatchString(query) {
   333  			foundLimit := rLimitOffset.FindString(query)
   334  			query = query[0 : len(query)-len(foundLimit)]
   335  		} else if rLimit.MatchString(query) {
   336  			foundLimit := rLimit.FindString(query)
   337  			query = query[0 : len(query)-len(foundLimit)]
   338  		}
   339  
   340  		countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query)
   341  		log(logging.SQL, countQuery, args...)
   342  		return q.Connection.Store.Get(res, countQuery, args...)
   343  	})
   344  	return res.Count, err
   345  }
   346  
   347  type rowCount struct {
   348  	Count int `db:"row_count"`
   349  }
   350  
   351  // Select allows to query only fields passed as parameter.
   352  // c.Select("field1", "field2").All(&model)
   353  // => SELECT field1, field2 FROM models
   354  func (c *Connection) Select(fields ...string) *Query {
   355  	return c.Q().Select(fields...)
   356  }
   357  
   358  // Select allows to query only fields passed as parameter.
   359  // c.Select("field1", "field2").All(&model)
   360  // => SELECT field1, field2 FROM models
   361  func (q *Query) Select(fields ...string) *Query {
   362  	for _, f := range fields {
   363  		if strings.TrimSpace(f) != "" {
   364  			q.addColumns = append(q.addColumns, f)
   365  		}
   366  	}
   367  	return q
   368  }