github.com/duskeagle/pop@v4.10.1-0.20190417200916-92f2b794aab5+incompatible/finders.go (about) 1 package pop 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "reflect" 8 "regexp" 9 "strconv" 10 "strings" 11 12 "github.com/gobuffalo/pop/associations" 13 "github.com/gobuffalo/pop/logging" 14 "github.com/gofrs/uuid" 15 "github.com/pkg/errors" 16 "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" 17 ) 18 19 var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$") 20 var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$") 21 22 // Find the first record of the model in the database with a particular id. 23 // 24 // c.Find(&User{}, 1) 25 func (c *Connection) Find(ctx context.Context, model interface{}, id interface{}) error { 26 return Q(c).Find(ctx, model, id) 27 } 28 29 // Find the first record of the model in the database with a particular id. 30 // 31 // q.Find(&User{}, 1) 32 func (q *Query) Find(ctx context.Context, model interface{}, id interface{}) error { 33 span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/Find") 34 defer span.Finish() 35 36 m := &Model{Value: model} 37 idq := m.whereID() 38 switch t := id.(type) { 39 case uuid.UUID: 40 return q.Where(idq, t.String()).First(ctx, model) 41 case string: 42 l := len(t) 43 if l > 0 { 44 // Handle leading '0': 45 // if the string have a leading '0' and is not "0", prevent parsing to int 46 if t[0] != '0' || l == 1 { 47 var err error 48 id, err = strconv.Atoi(t) 49 if err != nil { 50 return q.Where(idq, t).First(ctx, model) 51 } 52 } 53 } 54 } 55 56 return q.Where(idq, id).First(ctx, model) 57 } 58 59 // First record of the model in the database that matches the query. 60 // 61 // c.First(&User{}) 62 func (c *Connection) First(ctx context.Context, model interface{}) error { 63 return Q(c).First(ctx, model) 64 } 65 66 // First record of the model in the database that matches the query. 67 // 68 // q.Where("name = ?", "mark").First(&User{}) 69 func (q *Query) First(ctx context.Context, model interface{}) error { 70 span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/First") 71 defer span.Finish() 72 73 err := q.Connection.timeFunc("First", func() error { 74 q.Limit(1) 75 m := &Model{Value: model} 76 if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { 77 return err 78 } 79 return m.afterFind(ctx, q.Connection) 80 }) 81 82 if err != nil { 83 return err 84 } 85 86 if q.eager { 87 err = q.eagerAssociations(ctx, model) 88 q.disableEager() 89 return err 90 } 91 return nil 92 } 93 94 // Last record of the model in the database that matches the query. 95 // 96 // c.Last(&User{}) 97 func (c *Connection) Last(ctx context.Context, model interface{}) error { 98 return Q(c).Last(ctx, model) 99 } 100 101 // Last record of the model in the database that matches the query. 102 // 103 // q.Where("name = ?", "mark").Last(&User{}) 104 func (q *Query) Last(ctx context.Context, model interface{}) error { 105 span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/Last") 106 defer span.Finish() 107 108 err := q.Connection.timeFunc("Last", func() error { 109 q.Limit(1) 110 q.Order("created_at DESC, id DESC") 111 m := &Model{Value: model} 112 if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { 113 return err 114 } 115 return m.afterFind(ctx, q.Connection) 116 }) 117 118 if err != nil { 119 return err 120 } 121 122 if q.eager { 123 err = q.eagerAssociations(ctx, model) 124 q.disableEager() 125 return err 126 } 127 128 return nil 129 } 130 131 // All retrieves all of the records in the database that match the query. 132 // 133 // c.All(&[]User{}) 134 func (c *Connection) All(ctx context.Context, models interface{}) error { 135 return Q(c).All(ctx, models) 136 } 137 138 // All retrieves all of the records in the database that match the query. 139 // 140 // q.Where("name = ?", "mark").All(&[]User{}) 141 func (q *Query) All(ctx context.Context, models interface{}) error { 142 span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/All") 143 defer span.Finish() 144 span.SetTag("models", reflect.TypeOf(models).String()) 145 err := q.Connection.timeFunc("All", func() error { 146 m := &Model{Value: models} 147 err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q) 148 if err != nil { 149 return err 150 } 151 err = q.paginateModel(ctx, models) 152 if err != nil { 153 return err 154 } 155 return m.afterFind(ctx, q.Connection) 156 }) 157 158 if err != nil { 159 return errors.Wrap(err, "unable to fetch records") 160 } 161 162 if q.eager { 163 err = q.eagerAssociations(ctx, models) 164 q.disableEager() 165 return err 166 } 167 168 return nil 169 } 170 171 func (q *Query) paginateModel(ctx context.Context, models interface{}) error { 172 span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/paginateModel") 173 defer span.Finish() 174 span.SetTag("models", reflect.TypeOf(models).String()) 175 176 if q.Paginator == nil { 177 return nil 178 } 179 180 ct, err := q.Count(models) 181 if err != nil { 182 return err 183 } 184 185 q.Paginator.TotalEntriesSize = ct 186 st := reflect.ValueOf(models).Elem() 187 q.Paginator.CurrentEntriesSize = st.Len() 188 q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage 189 if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 { 190 q.Paginator.TotalPages = q.Paginator.TotalPages + 1 191 } 192 return nil 193 } 194 195 // Load loads all association or the fields specified in params for 196 // an already loaded model. 197 // 198 // tx.First(&u) 199 // tx.Load(&u) 200 func (c *Connection) Load(ctx context.Context, model interface{}, fields ...string) error { 201 span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/Load") 202 defer span.Finish() 203 q := Q(c) 204 q.eagerFields = fields 205 err := q.eagerAssociations(ctx, model) 206 q.disableEager() 207 return err 208 } 209 210 func (q *Query) eagerAssociations(ctx context.Context, model interface{}) error { 211 span, ctx := tracer.StartSpanFromContext(ctx, "pop/finders/eagerAssociations") 212 defer span.Finish() 213 span.SetTag("model", reflect.TypeOf(model).String()) 214 215 var err error 216 217 // eagerAssociations for a slice or array model passed as a param. 218 v := reflect.ValueOf(model) 219 if reflect.Indirect(v).Kind() == reflect.Slice || 220 reflect.Indirect(v).Kind() == reflect.Array { 221 v = v.Elem() 222 for i := 0; i < v.Len(); i++ { 223 err = q.eagerAssociations(ctx, v.Index(i).Addr().Interface()) 224 if err != nil { 225 return err 226 } 227 } 228 return err 229 } 230 231 assos, err := associations.ForStruct(model, q.eagerFields...) 232 if err != nil { 233 return err 234 } 235 236 // disable eager mode for current connection. 237 q.eager = false 238 q.Connection.eager = false 239 240 for _, association := range assos { 241 if association.Skipped() { 242 continue 243 } 244 245 query := Q(q.Connection) 246 247 whereCondition, args := association.Constraint() 248 query = query.Where(whereCondition, args...) 249 250 // validates if association is Sortable 251 sortable := (*associations.AssociationSortable)(nil) 252 t := reflect.TypeOf(association) 253 if t.Implements(reflect.TypeOf(sortable).Elem()) { 254 m := reflect.ValueOf(association).MethodByName("OrderBy") 255 out := m.Call([]reflect.Value{}) 256 orderClause := out[0].String() 257 if orderClause != "" { 258 query = query.Order(orderClause) 259 } 260 } 261 262 sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()}) 263 query = query.RawQuery(sqlSentence, args...) 264 265 if association.Kind() == reflect.Slice || association.Kind() == reflect.Array { 266 err = query.All(ctx, association.Interface()) 267 } 268 269 if association.Kind() == reflect.Struct { 270 err = query.First(ctx, association.Interface()) 271 } 272 273 if err != nil && errors.Cause(err) != sql.ErrNoRows { 274 return err 275 } 276 277 // load all inner associations. 278 innerAssociations := association.InnerAssociations() 279 for _, inner := range innerAssociations { 280 v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name) 281 innerQuery := Q(query.Connection) 282 innerQuery.eagerFields = []string{inner.Fields} 283 err = innerQuery.eagerAssociations(ctx, v.Addr().Interface()) 284 if err != nil { 285 return err 286 } 287 } 288 } 289 return nil 290 } 291 292 // Exists returns true/false if a record exists in the database that matches 293 // the query. 294 // 295 // q.Where("name = ?", "mark").Exists(&User{}) 296 func (q *Query) Exists(model interface{}) (bool, error) { 297 tmpQuery := Q(q.Connection) 298 q.Clone(tmpQuery) //avoid meddling with original query 299 300 var res bool 301 302 err := tmpQuery.Connection.timeFunc("Exists", func() error { 303 tmpQuery.Paginator = nil 304 tmpQuery.orderClauses = clauses{} 305 tmpQuery.limitResults = 0 306 query, args := tmpQuery.ToSQL(&Model{Value: model}) 307 308 // when query contains custom selected fields / executed using RawQuery, 309 // sql may already contains limit and offset 310 if rLimitOffset.MatchString(query) { 311 foundLimit := rLimitOffset.FindString(query) 312 query = query[0 : len(query)-len(foundLimit)] 313 } else if rLimit.MatchString(query) { 314 foundLimit := rLimit.FindString(query) 315 query = query[0 : len(query)-len(foundLimit)] 316 } 317 318 existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query) 319 log(logging.SQL, existsQuery, args...) 320 return q.Connection.Store.Get(&res, existsQuery, args...) 321 }) 322 return res, err 323 } 324 325 // Count the number of records in the database. 326 // 327 // c.Count(&User{}) 328 func (c *Connection) Count(model interface{}) (int, error) { 329 return Q(c).Count(model) 330 } 331 332 // Count the number of records in the database. 333 // 334 // q.Where("name = ?", "mark").Count(&User{}) 335 func (q Query) Count(model interface{}) (int, error) { 336 return q.CountByField(model, "*") 337 } 338 339 // CountByField counts the number of records in the database, for a given field. 340 // 341 // q.Where("sex = ?", "f").Count(&User{}, "name") 342 func (q Query) CountByField(model interface{}, field string) (int, error) { 343 tmpQuery := Q(q.Connection) 344 q.Clone(tmpQuery) //avoid meddling with original query 345 346 res := &rowCount{} 347 348 err := tmpQuery.Connection.timeFunc("CountByField", func() error { 349 tmpQuery.Paginator = nil 350 tmpQuery.orderClauses = clauses{} 351 tmpQuery.limitResults = 0 352 query, args := tmpQuery.ToSQL(&Model{Value: model}) 353 //when query contains custom selected fields / executed using RawQuery, 354 // sql may already contains limit and offset 355 356 if rLimitOffset.MatchString(query) { 357 foundLimit := rLimitOffset.FindString(query) 358 query = query[0 : len(query)-len(foundLimit)] 359 } else if rLimit.MatchString(query) { 360 foundLimit := rLimit.FindString(query) 361 query = query[0 : len(query)-len(foundLimit)] 362 } 363 364 countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query) 365 log(logging.SQL, countQuery, args...) 366 return q.Connection.Store.Get(res, countQuery, args...) 367 }) 368 return res.Count, err 369 } 370 371 type rowCount struct { 372 Count int `db:"row_count"` 373 } 374 375 // Select allows to query only fields passed as parameter. 376 // c.Select("field1", "field2").All(&model) 377 // => SELECT field1, field2 FROM models 378 func (c *Connection) Select(fields ...string) *Query { 379 return c.Q().Select(fields...) 380 } 381 382 // Select allows to query only fields passed as parameter. 383 // c.Select("field1", "field2").All(&model) 384 // => SELECT field1, field2 FROM models 385 func (q *Query) Select(fields ...string) *Query { 386 for _, f := range fields { 387 if strings.TrimSpace(f) != "" { 388 q.addColumns = append(q.addColumns, f) 389 } 390 } 391 return q 392 }