github.com/duskeagle/pop@v4.10.1-0.20190417200916-92f2b794aab5+incompatible/finders.go (about)

     1  package pop
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"reflect"
     8  	"regexp"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/gobuffalo/pop/associations"
    13  	"github.com/gobuffalo/pop/logging"
    14  	"github.com/gofrs/uuid"
    15  	"github.com/pkg/errors"
    16   	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
    17  )
    18  
    19  var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$")
    20  var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$")
    21  
    22  // Find the first record of the model in the database with a particular id.
    23  //
    24  //	c.Find(&User{}, 1)
    25  func (c *Connection) Find(ctx context.Context, model interface{}, id interface{}) error {
    26  	return Q(c).Find(ctx, model, id)
    27  }
    28  
    29  // Find the first record of the model in the database with a particular id.
    30  //
    31  //	q.Find(&User{}, 1)
    32  func (q *Query) Find(ctx context.Context, model interface{}, id interface{}) error {
    33  	span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/Find")
    34  	defer span.Finish()
    35  
    36  	m := &Model{Value: model}
    37  	idq := m.whereID()
    38  	switch t := id.(type) {
    39  	case uuid.UUID:
    40  		return q.Where(idq, t.String()).First(ctx, model)
    41  	case string:
    42  		l := len(t)
    43  		if l > 0 {
    44  			// Handle leading '0':
    45  			// if the string have a leading '0' and is not "0", prevent parsing to int
    46  			if t[0] != '0' || l == 1 {
    47  				var err error
    48  				id, err = strconv.Atoi(t)
    49  				if err != nil {
    50  					return q.Where(idq, t).First(ctx, model)
    51  				}
    52  			}
    53  		}
    54  	}
    55  
    56  	return q.Where(idq, id).First(ctx, model)
    57  }
    58  
    59  // First record of the model in the database that matches the query.
    60  //
    61  //	c.First(&User{})
    62  func (c *Connection) First(ctx context.Context, model interface{}) error {
    63  	return Q(c).First(ctx, model)
    64  }
    65  
    66  // First record of the model in the database that matches the query.
    67  //
    68  //	q.Where("name = ?", "mark").First(&User{})
    69  func (q *Query) First(ctx context.Context, model interface{}) error {
    70  	span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/First")
    71  	defer span.Finish()
    72  
    73  	err := q.Connection.timeFunc("First", func() error {
    74  		q.Limit(1)
    75  		m := &Model{Value: model}
    76  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
    77  			return err
    78  		}
    79  		return m.afterFind(ctx, q.Connection)
    80  	})
    81  
    82  	if err != nil {
    83  		return err
    84  	}
    85  
    86  	if q.eager {
    87  		err = q.eagerAssociations(ctx, model)
    88  		q.disableEager()
    89  		return err
    90  	}
    91  	return nil
    92  }
    93  
    94  // Last record of the model in the database that matches the query.
    95  //
    96  //	c.Last(&User{})
    97  func (c *Connection) Last(ctx context.Context, model interface{}) error {
    98  	return Q(c).Last(ctx, model)
    99  }
   100  
   101  // Last record of the model in the database that matches the query.
   102  //
   103  //	q.Where("name = ?", "mark").Last(&User{})
   104  func (q *Query) Last(ctx context.Context, model interface{}) error {
   105  	span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/Last")
   106  	defer span.Finish()
   107  
   108  	err := q.Connection.timeFunc("Last", func() error {
   109  		q.Limit(1)
   110  		q.Order("created_at DESC, id DESC")
   111  		m := &Model{Value: model}
   112  		if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil {
   113  			return err
   114  		}
   115  		return m.afterFind(ctx, q.Connection)
   116  	})
   117  
   118  	if err != nil {
   119  		return err
   120  	}
   121  
   122  	if q.eager {
   123  		err = q.eagerAssociations(ctx, model)
   124  		q.disableEager()
   125  		return err
   126  	}
   127  
   128  	return nil
   129  }
   130  
   131  // All retrieves all of the records in the database that match the query.
   132  //
   133  //	c.All(&[]User{})
   134  func (c *Connection) All(ctx context.Context, models interface{}) error {
   135  	return Q(c).All(ctx, models)
   136  }
   137  
   138  // All retrieves all of the records in the database that match the query.
   139  //
   140  //	q.Where("name = ?", "mark").All(&[]User{})
   141  func (q *Query) All(ctx context.Context, models interface{}) error {
   142  	span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/All")
   143  	defer span.Finish()
   144  	span.SetTag("models", reflect.TypeOf(models).String())
   145  	err := q.Connection.timeFunc("All", func() error {
   146  		m := &Model{Value: models}
   147  		err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q)
   148  		if err != nil {
   149  			return err
   150  		}
   151  		err = q.paginateModel(ctx, models)
   152  		if err != nil {
   153  			return err
   154  		}
   155  		return m.afterFind(ctx, q.Connection)
   156  	})
   157  
   158  	if err != nil {
   159  		return errors.Wrap(err, "unable to fetch records")
   160  	}
   161  
   162  	if q.eager {
   163  		err = q.eagerAssociations(ctx, models)
   164  		q.disableEager()
   165  		return err
   166  	}
   167  
   168  	return nil
   169  }
   170  
   171  func (q *Query) paginateModel(ctx context.Context, models interface{}) error {
   172  	span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/paginateModel")
   173  	defer span.Finish()
   174  	span.SetTag("models", reflect.TypeOf(models).String())
   175  
   176  	if q.Paginator == nil {
   177  		return nil
   178  	}
   179  
   180  	ct, err := q.Count(models)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	q.Paginator.TotalEntriesSize = ct
   186  	st := reflect.ValueOf(models).Elem()
   187  	q.Paginator.CurrentEntriesSize = st.Len()
   188  	q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage
   189  	if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 {
   190  		q.Paginator.TotalPages = q.Paginator.TotalPages + 1
   191  	}
   192  	return nil
   193  }
   194  
   195  // Load loads all association or the fields specified in params for
   196  // an already loaded model.
   197  //
   198  // tx.First(&u)
   199  // tx.Load(&u)
   200  func (c *Connection) Load(ctx context.Context, model interface{}, fields ...string) error {
   201  	span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/Load")
   202  	defer span.Finish()
   203  	q := Q(c)
   204  	q.eagerFields = fields
   205  	err := q.eagerAssociations(ctx, model)
   206  	q.disableEager()
   207  	return err
   208  }
   209  
   210  func (q *Query) eagerAssociations(ctx context.Context, model interface{}) error {
   211  	span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/eagerAssociations")
   212  	defer span.Finish()
   213  	span.SetTag("model", reflect.TypeOf(model).String())
   214  
   215  	var err error
   216  
   217  	// eagerAssociations for a slice or array model passed as a param.
   218  	v := reflect.ValueOf(model)
   219  	if reflect.Indirect(v).Kind() == reflect.Slice ||
   220  		reflect.Indirect(v).Kind() == reflect.Array {
   221  		v = v.Elem()
   222  		for i := 0; i < v.Len(); i++ {
   223  			err = q.eagerAssociations(ctx, v.Index(i).Addr().Interface())
   224  			if err != nil {
   225  				return err
   226  			}
   227  		}
   228  		return err
   229  	}
   230  
   231  	assos, err := associations.ForStruct(model, q.eagerFields...)
   232  	if err != nil {
   233  		return err
   234  	}
   235  
   236  	// disable eager mode for current connection.
   237  	q.eager = false
   238  	q.Connection.eager = false
   239  
   240  	for _, association := range assos {
   241  		if association.Skipped() {
   242  			continue
   243  		}
   244  
   245  		query := Q(q.Connection)
   246  
   247  		whereCondition, args := association.Constraint()
   248  		query = query.Where(whereCondition, args...)
   249  
   250  		// validates if association is Sortable
   251  		sortable := (*associations.AssociationSortable)(nil)
   252  		t := reflect.TypeOf(association)
   253  		if t.Implements(reflect.TypeOf(sortable).Elem()) {
   254  			m := reflect.ValueOf(association).MethodByName("OrderBy")
   255  			out := m.Call([]reflect.Value{})
   256  			orderClause := out[0].String()
   257  			if orderClause != "" {
   258  				query = query.Order(orderClause)
   259  			}
   260  		}
   261  
   262  		sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()})
   263  		query = query.RawQuery(sqlSentence, args...)
   264  
   265  		if association.Kind() == reflect.Slice || association.Kind() == reflect.Array {
   266  			err = query.All(ctx, association.Interface())
   267  		}
   268  
   269  		if association.Kind() == reflect.Struct {
   270  			err = query.First(ctx, association.Interface())
   271  		}
   272  
   273  		if err != nil && errors.Cause(err) != sql.ErrNoRows {
   274  			return err
   275  		}
   276  
   277  		// load all inner associations.
   278  		innerAssociations := association.InnerAssociations()
   279  		for _, inner := range innerAssociations {
   280  			v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name)
   281  			innerQuery := Q(query.Connection)
   282  			innerQuery.eagerFields = []string{inner.Fields}
   283  			err = innerQuery.eagerAssociations(ctx, v.Addr().Interface())
   284  			if err != nil {
   285  				return err
   286  			}
   287  		}
   288  	}
   289  	return nil
   290  }
   291  
   292  // Exists returns true/false if a record exists in the database that matches
   293  // the query.
   294  //
   295  // 	q.Where("name = ?", "mark").Exists(&User{})
   296  func (q *Query) Exists(model interface{}) (bool, error) {
   297  	tmpQuery := Q(q.Connection)
   298  	q.Clone(tmpQuery) //avoid meddling with original query
   299  
   300  	var res bool
   301  
   302  	err := tmpQuery.Connection.timeFunc("Exists", func() error {
   303  		tmpQuery.Paginator = nil
   304  		tmpQuery.orderClauses = clauses{}
   305  		tmpQuery.limitResults = 0
   306  		query, args := tmpQuery.ToSQL(&Model{Value: model})
   307  
   308  		// when query contains custom selected fields / executed using RawQuery,
   309  		// sql may already contains limit and offset
   310  		if rLimitOffset.MatchString(query) {
   311  			foundLimit := rLimitOffset.FindString(query)
   312  			query = query[0 : len(query)-len(foundLimit)]
   313  		} else if rLimit.MatchString(query) {
   314  			foundLimit := rLimit.FindString(query)
   315  			query = query[0 : len(query)-len(foundLimit)]
   316  		}
   317  
   318  		existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query)
   319  		log(logging.SQL, existsQuery, args...)
   320  		return q.Connection.Store.Get(&res, existsQuery, args...)
   321  	})
   322  	return res, err
   323  }
   324  
   325  // Count the number of records in the database.
   326  //
   327  //	c.Count(&User{})
   328  func (c *Connection) Count(model interface{}) (int, error) {
   329  	return Q(c).Count(model)
   330  }
   331  
   332  // Count the number of records in the database.
   333  //
   334  //	q.Where("name = ?", "mark").Count(&User{})
   335  func (q Query) Count(model interface{}) (int, error) {
   336  	return q.CountByField(model, "*")
   337  }
   338  
   339  // CountByField counts the number of records in the database, for a given field.
   340  //
   341  //	q.Where("sex = ?", "f").Count(&User{}, "name")
   342  func (q Query) CountByField(model interface{}, field string) (int, error) {
   343  	tmpQuery := Q(q.Connection)
   344  	q.Clone(tmpQuery) //avoid meddling with original query
   345  
   346  	res := &rowCount{}
   347  
   348  	err := tmpQuery.Connection.timeFunc("CountByField", func() error {
   349  		tmpQuery.Paginator = nil
   350  		tmpQuery.orderClauses = clauses{}
   351  		tmpQuery.limitResults = 0
   352  		query, args := tmpQuery.ToSQL(&Model{Value: model})
   353  		//when query contains custom selected fields / executed using RawQuery,
   354  		//	sql may already contains limit and offset
   355  
   356  		if rLimitOffset.MatchString(query) {
   357  			foundLimit := rLimitOffset.FindString(query)
   358  			query = query[0 : len(query)-len(foundLimit)]
   359  		} else if rLimit.MatchString(query) {
   360  			foundLimit := rLimit.FindString(query)
   361  			query = query[0 : len(query)-len(foundLimit)]
   362  		}
   363  
   364  		countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query)
   365  		log(logging.SQL, countQuery, args...)
   366  		return q.Connection.Store.Get(res, countQuery, args...)
   367  	})
   368  	return res.Count, err
   369  }
   370  
   371  type rowCount struct {
   372  	Count int `db:"row_count"`
   373  }
   374  
   375  // Select allows to query only fields passed as parameter.
   376  // c.Select("field1", "field2").All(&model)
   377  // => SELECT field1, field2 FROM models
   378  func (c *Connection) Select(fields ...string) *Query {
   379  	return c.Q().Select(fields...)
   380  }
   381  
   382  // Select allows to query only fields passed as parameter.
   383  // c.Select("field1", "field2").All(&model)
   384  // => SELECT field1, field2 FROM models
   385  func (q *Query) Select(fields ...string) *Query {
   386  	for _, f := range fields {
   387  		if strings.TrimSpace(f) != "" {
   388  			q.addColumns = append(q.addColumns, f)
   389  		}
   390  	}
   391  	return q
   392  }