github.com/tufanbarisyildirim/pop@v4.13.1+incompatible/finders.go (about)

     1  package pop
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"reflect"
     7  	"regexp"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"github.com/gobuffalo/pop/associations"
    12  	"github.com/gobuffalo/pop/logging"
    13  	"github.com/gofrs/uuid"
    14  	"github.com/pkg/errors"
    15  )
    16  
    17  var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$")
    18  var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$")
    19  
    20  // Find the first record of the model in the database with a particular id.
    21  //
    22  //	c.Find(&User{}, 1)
    23  func (c *Connection) Find(model interface{}, id interface{}) error {
    24  	return Q(c).Find(model, id)
    25  }
    26  
    27  // Find the first record of the model in the database with a particular id.
    28  //
    29  //	q.Find(&User{}, 1)
    30  func (q *Query) Find(model interface{}, id interface{}) error {
    31  	m := &Model{Value: model}
    32  	idq := m.whereID()
    33  	switch t := id.(type) {
    34  	case uuid.UUID:
    35  		return q.Where(idq, t.String()).First(model)
    36  	case string:
    37  		l := len(t)
    38  		if l > 0 {
    39  			// Handle leading '0':
    40  			// if the string have a leading '0' and is not "0", prevent parsing to int
    41  			if t[0] != '0' || l == 1 {
    42  				var err error
    43  				id, err = strconv.Atoi(t)
    44  				if err != nil {
    45  					return q.Where(idq, t).First(model)
    46  				}
    47  			}
    48  		}
    49  	}
    50  
    51  	return q.Where(idq, id).First(model)
    52  }
    53  
    54  // First record of the model in the database that matches the query.
    55  //
    56  //	c.First(&User{})
    57  func (c *Connection) First(model interface{}) error {
    58  	return Q(c).First(model)
    59  }
    60  
    61  // First record of the model in the database that matches the query.
    62  //
    63  //	q.Where("name = ?", "mark").First(&User{})
    64  func (q *Query) First(model interface{}) error {
    65  	err := q.Connection.timeFunc("First", func() error {
    66  		q.Limit(1)
    67  		m := &Model{Value: model}
    68  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
    69  			return err
    70  		}
    71  		return m.afterFind(q.Connection)
    72  	})
    73  
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	if q.eager {
    79  		err = q.eagerAssociations(model)
    80  		q.disableEager()
    81  		return err
    82  	}
    83  	return nil
    84  }
    85  
    86  // Last record of the model in the database that matches the query.
    87  //
    88  //	c.Last(&User{})
    89  func (c *Connection) Last(model interface{}) error {
    90  	return Q(c).Last(model)
    91  }
    92  
    93  // Last record of the model in the database that matches the query.
    94  //
    95  //	q.Where("name = ?", "mark").Last(&User{})
    96  func (q *Query) Last(model interface{}) error {
    97  	err := q.Connection.timeFunc("Last", func() error {
    98  		q.Limit(1)
    99  		q.Order("created_at DESC, id DESC")
   100  		m := &Model{Value: model}
   101  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
   102  			return err
   103  		}
   104  		return m.afterFind(q.Connection)
   105  	})
   106  
   107  	if err != nil {
   108  		return err
   109  	}
   110  
   111  	if q.eager {
   112  		err = q.eagerAssociations(model)
   113  		q.disableEager()
   114  		return err
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  // All retrieves all of the records in the database that match the query.
   121  //
   122  //	c.All(&[]User{})
   123  func (c *Connection) All(models interface{}) error {
   124  	return Q(c).All(models)
   125  }
   126  
   127  // All retrieves all of the records in the database that match the query.
   128  //
   129  //	q.Where("name = ?", "mark").All(&[]User{})
   130  func (q *Query) All(models interface{}) error {
   131  	err := q.Connection.timeFunc("All", func() error {
   132  		m := &Model{Value: models}
   133  		err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q)
   134  		if err != nil {
   135  			return err
   136  		}
   137  		err = q.paginateModel(models)
   138  		if err != nil {
   139  			return err
   140  		}
   141  		return m.afterFind(q.Connection)
   142  	})
   143  
   144  	if err != nil {
   145  		return errors.Wrap(err, "unable to fetch records")
   146  	}
   147  
   148  	if q.eager {
   149  		err = q.eagerAssociations(models)
   150  		q.disableEager()
   151  		return err
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func (q *Query) paginateModel(models interface{}) error {
   158  	if q.Paginator == nil {
   159  		return nil
   160  	}
   161  
   162  	ct, err := q.Count(models)
   163  	if err != nil {
   164  		return err
   165  	}
   166  
   167  	q.Paginator.TotalEntriesSize = ct
   168  	st := reflect.ValueOf(models).Elem()
   169  	q.Paginator.CurrentEntriesSize = st.Len()
   170  	q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage
   171  	if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 {
   172  		q.Paginator.TotalPages = q.Paginator.TotalPages + 1
   173  	}
   174  	return nil
   175  }
   176  
   177  // Load loads all association or the fields specified in params for
   178  // an already loaded model.
   179  //
   180  // tx.First(&u)
   181  // tx.Load(&u)
   182  func (c *Connection) Load(model interface{}, fields ...string) error {
   183  	q := Q(c)
   184  	q.eagerFields = fields
   185  	err := q.eagerAssociations(model)
   186  	q.disableEager()
   187  	return err
   188  }
   189  
   190  func (q *Query) eagerAssociations(model interface{}) error {
   191  	var err error
   192  
   193  	// eagerAssociations for a slice or array model passed as a param.
   194  	v := reflect.ValueOf(model)
   195  	kind := reflect.Indirect(v).Kind()
   196  	if kind == reflect.Slice || kind == reflect.Array {
   197  		v = v.Elem()
   198  		for i := 0; i < v.Len(); i++ {
   199  			e := v.Index(i)
   200  			if e.Type().Kind() == reflect.Ptr {
   201  				// Already a pointer
   202  				err = q.eagerAssociations(e.Interface())
   203  			} else {
   204  				err = q.eagerAssociations(e.Addr().Interface())
   205  			}
   206  			if err != nil {
   207  				return err
   208  			}
   209  		}
   210  		return nil
   211  	}
   212  
   213  	// eagerAssociations for a single element
   214  	assos, err := associations.ForStruct(model, q.eagerFields...)
   215  	if err != nil {
   216  		return errors.Wrap(err, "could not retrieve associations")
   217  	}
   218  
   219  	// disable eager mode for current connection.
   220  	q.eager = false
   221  	q.Connection.eager = false
   222  
   223  	for _, association := range assos {
   224  		if association.Skipped() {
   225  			continue
   226  		}
   227  
   228  		query := Q(q.Connection)
   229  
   230  		whereCondition, args := association.Constraint()
   231  		query = query.Where(whereCondition, args...)
   232  
   233  		// validates if association is Sortable
   234  		sortable := (*associations.AssociationSortable)(nil)
   235  		t := reflect.TypeOf(association)
   236  		if t.Implements(reflect.TypeOf(sortable).Elem()) {
   237  			m := reflect.ValueOf(association).MethodByName("OrderBy")
   238  			out := m.Call([]reflect.Value{})
   239  			orderClause := out[0].String()
   240  			if orderClause != "" {
   241  				query = query.Order(orderClause)
   242  			}
   243  		}
   244  
   245  		sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()})
   246  		query = query.RawQuery(sqlSentence, args...)
   247  
   248  		if association.Kind() == reflect.Slice || association.Kind() == reflect.Array {
   249  			err = query.All(association.Interface())
   250  		}
   251  
   252  		if association.Kind() == reflect.Struct {
   253  			err = query.First(association.Interface())
   254  		}
   255  
   256  		if err != nil && errors.Cause(err) != sql.ErrNoRows {
   257  			return err
   258  		}
   259  
   260  		// load all inner associations.
   261  		innerAssociations := association.InnerAssociations()
   262  		for _, inner := range innerAssociations {
   263  			v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name)
   264  			innerQuery := Q(query.Connection)
   265  			innerQuery.eagerFields = []string{inner.Fields}
   266  			err = innerQuery.eagerAssociations(v.Addr().Interface())
   267  			if err != nil {
   268  				return err
   269  			}
   270  		}
   271  	}
   272  	return nil
   273  }
   274  
   275  // Exists returns true/false if a record exists in the database that matches
   276  // the query.
   277  //
   278  // 	q.Where("name = ?", "mark").Exists(&User{})
   279  func (q *Query) Exists(model interface{}) (bool, error) {
   280  	tmpQuery := Q(q.Connection)
   281  	q.Clone(tmpQuery) //avoid meddling with original query
   282  
   283  	var res bool
   284  
   285  	err := tmpQuery.Connection.timeFunc("Exists", func() error {
   286  		tmpQuery.Paginator = nil
   287  		tmpQuery.orderClauses = clauses{}
   288  		tmpQuery.limitResults = 0
   289  		query, args := tmpQuery.ToSQL(&Model{Value: model})
   290  
   291  		// when query contains custom selected fields / executed using RawQuery,
   292  		// sql may already contains limit and offset
   293  		if rLimitOffset.MatchString(query) {
   294  			foundLimit := rLimitOffset.FindString(query)
   295  			query = query[0 : len(query)-len(foundLimit)]
   296  		} else if rLimit.MatchString(query) {
   297  			foundLimit := rLimit.FindString(query)
   298  			query = query[0 : len(query)-len(foundLimit)]
   299  		}
   300  
   301  		existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query)
   302  		log(logging.SQL, existsQuery, args...)
   303  		return q.Connection.Store.Get(&res, existsQuery, args...)
   304  	})
   305  	return res, err
   306  }
   307  
   308  // Count the number of records in the database.
   309  //
   310  //	c.Count(&User{})
   311  func (c *Connection) Count(model interface{}) (int, error) {
   312  	return Q(c).Count(model)
   313  }
   314  
   315  // Count the number of records in the database.
   316  //
   317  //	q.Where("name = ?", "mark").Count(&User{})
   318  func (q Query) Count(model interface{}) (int, error) {
   319  	return q.CountByField(model, "*")
   320  }
   321  
   322  // CountByField counts the number of records in the database, for a given field.
   323  //
   324  //	q.Where("sex = ?", "f").Count(&User{}, "name")
   325  func (q Query) CountByField(model interface{}, field string) (int, error) {
   326  	tmpQuery := Q(q.Connection)
   327  	q.Clone(tmpQuery) //avoid meddling with original query
   328  
   329  	res := &rowCount{}
   330  
   331  	err := tmpQuery.Connection.timeFunc("CountByField", func() error {
   332  		tmpQuery.Paginator = nil
   333  		tmpQuery.orderClauses = clauses{}
   334  		tmpQuery.limitResults = 0
   335  		query, args := tmpQuery.ToSQL(&Model{Value: model})
   336  		//when query contains custom selected fields / executed using RawQuery,
   337  		//	sql may already contains limit and offset
   338  
   339  		if rLimitOffset.MatchString(query) {
   340  			foundLimit := rLimitOffset.FindString(query)
   341  			query = query[0 : len(query)-len(foundLimit)]
   342  		} else if rLimit.MatchString(query) {
   343  			foundLimit := rLimit.FindString(query)
   344  			query = query[0 : len(query)-len(foundLimit)]
   345  		}
   346  
   347  		countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query)
   348  		log(logging.SQL, countQuery, args...)
   349  		return q.Connection.Store.Get(res, countQuery, args...)
   350  	})
   351  	return res.Count, err
   352  }
   353  
   354  type rowCount struct {
   355  	Count int `db:"row_count"`
   356  }
   357  
   358  // Select allows to query only fields passed as parameter.
   359  // c.Select("field1", "field2").All(&model)
   360  // => SELECT field1, field2 FROM models
   361  func (c *Connection) Select(fields ...string) *Query {
   362  	return c.Q().Select(fields...)
   363  }
   364  
   365  // Select allows to query only fields passed as parameter.
   366  // c.Select("field1", "field2").All(&model)
   367  // => SELECT field1, field2 FROM models
   368  func (q *Query) Select(fields ...string) *Query {
   369  	for _, f := range fields {
   370  		if strings.TrimSpace(f) != "" {
   371  			q.addColumns = append(q.addColumns, f)
   372  		}
   373  	}
   374  	return q
   375  }