github.com/rjgonzale/pop/v5@v5.1.3-dev/finders.go (about) 1 package pop 2 3 import ( 4 "database/sql" 5 "fmt" 6 "reflect" 7 "regexp" 8 "strconv" 9 "strings" 10 11 "github.com/gobuffalo/pop/v5/associations" 12 "github.com/gobuffalo/pop/v5/logging" 13 "github.com/gofrs/uuid" 14 "github.com/pkg/errors" 15 ) 16 17 var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$") 18 var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$") 19 20 // Find the first record of the model in the database with a particular id. 21 // 22 // c.Find(&User{}, 1) 23 func (c *Connection) Find(model interface{}, id interface{}) error { 24 return Q(c).Find(model, id) 25 } 26 27 // Find the first record of the model in the database with a particular id. 28 // 29 // q.Find(&User{}, 1) 30 func (q *Query) Find(model interface{}, id interface{}) error { 31 m := &Model{Value: model} 32 idq := m.whereID() 33 switch t := id.(type) { 34 case uuid.UUID: 35 return q.Where(idq, t.String()).First(model) 36 case string: 37 l := len(t) 38 if l > 0 { 39 // Handle leading '0': 40 // if the string have a leading '0' and is not "0", prevent parsing to int 41 if t[0] != '0' || l == 1 { 42 var err error 43 id, err = strconv.Atoi(t) 44 if err != nil { 45 return q.Where(idq, t).First(model) 46 } 47 } 48 } 49 } 50 51 return q.Where(idq, id).First(model) 52 } 53 54 // First record of the model in the database that matches the query. 55 // 56 // c.First(&User{}) 57 func (c *Connection) First(model interface{}) error { 58 return Q(c).First(model) 59 } 60 61 // First record of the model in the database that matches the query. 62 // 63 // q.Where("name = ?", "mark").First(&User{}) 64 func (q *Query) First(model interface{}) error { 65 err := q.Connection.timeFunc("First", func() error { 66 q.Limit(1) 67 m := &Model{Value: model} 68 if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { 69 return err 70 } 71 return m.afterFind(q.Connection) 72 }) 73 74 if err != nil { 75 return err 76 } 77 78 if q.eager { 79 err = q.eagerAssociations(model) 80 q.disableEager() 81 return err 82 } 83 return nil 84 } 85 86 // Last record of the model in the database that matches the query. 87 // 88 // c.Last(&User{}) 89 func (c *Connection) Last(model interface{}) error { 90 return Q(c).Last(model) 91 } 92 93 // Last record of the model in the database that matches the query. 94 // 95 // q.Where("name = ?", "mark").Last(&User{}) 96 func (q *Query) Last(model interface{}) error { 97 err := q.Connection.timeFunc("Last", func() error { 98 q.Limit(1) 99 q.Order("created_at DESC, id DESC") 100 m := &Model{Value: model} 101 if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { 102 return err 103 } 104 return m.afterFind(q.Connection) 105 }) 106 107 if err != nil { 108 return err 109 } 110 111 if q.eager { 112 err = q.eagerAssociations(model) 113 q.disableEager() 114 return err 115 } 116 117 return nil 118 } 119 120 // All retrieves all of the records in the database that match the query. 121 // 122 // c.All(&[]User{}) 123 func (c *Connection) All(models interface{}) error { 124 return Q(c).All(models) 125 } 126 127 // All retrieves all of the records in the database that match the query. 128 // 129 // q.Where("name = ?", "mark").All(&[]User{}) 130 func (q *Query) All(models interface{}) error { 131 err := q.Connection.timeFunc("All", func() error { 132 m := &Model{Value: models} 133 err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q) 134 if err != nil { 135 return err 136 } 137 err = q.paginateModel(models) 138 if err != nil { 139 return err 140 } 141 return m.afterFind(q.Connection) 142 }) 143 144 if err != nil { 145 return errors.Wrap(err, "unable to fetch records") 146 } 147 148 if q.eager { 149 err = q.eagerAssociations(models) 150 q.disableEager() 151 return err 152 } 153 154 return nil 155 } 156 157 func (q *Query) paginateModel(models interface{}) error { 158 if q.Paginator == nil { 159 return nil 160 } 161 162 ct, err := q.Count(models) 163 if err != nil { 164 return err 165 } 166 167 q.Paginator.TotalEntriesSize = ct 168 st := reflect.ValueOf(models).Elem() 169 q.Paginator.CurrentEntriesSize = st.Len() 170 q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage 171 if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 { 172 q.Paginator.TotalPages = q.Paginator.TotalPages + 1 173 } 174 return nil 175 } 176 177 // Load loads all association or the fields specified in params for 178 // an already loaded model. 179 // 180 // tx.First(&u) 181 // tx.Load(&u) 182 func (c *Connection) Load(model interface{}, fields ...string) error { 183 q := Q(c) 184 q.eagerFields = fields 185 err := q.eagerAssociations(model) 186 q.disableEager() 187 return err 188 } 189 190 func (q *Query) eagerAssociations(model interface{}) error { 191 if q.eagerMode == eagerModeNil { 192 q.eagerMode = loadingAssociationsStrategy 193 } 194 if q.eagerMode == EagerPreload { 195 return preload(q.Connection, model, q.eagerFields...) 196 } 197 198 return q.eagerDefaultAssociations(model) 199 } 200 201 func (q *Query) eagerDefaultAssociations(model interface{}) error { 202 var err error 203 204 // eagerAssociations for a slice or array model passed as a param. 205 v := reflect.ValueOf(model) 206 kind := reflect.Indirect(v).Kind() 207 if kind == reflect.Slice || kind == reflect.Array { 208 v = v.Elem() 209 for i := 0; i < v.Len(); i++ { 210 e := v.Index(i) 211 if e.Type().Kind() == reflect.Ptr { 212 // Already a pointer 213 err = q.eagerAssociations(e.Interface()) 214 } else { 215 err = q.eagerAssociations(e.Addr().Interface()) 216 } 217 if err != nil { 218 return err 219 } 220 } 221 return nil 222 } 223 224 // eagerAssociations for a single element 225 assos, err := associations.ForStruct(model, q.eagerFields...) 226 if err != nil { 227 return errors.Wrap(err, "could not retrieve associations") 228 } 229 230 // disable eager mode for current connection. 231 q.eager = false 232 q.Connection.eager = false 233 234 for _, association := range assos { 235 if association.Skipped() { 236 continue 237 } 238 239 query := Q(q.Connection) 240 241 whereCondition, args := association.Constraint() 242 query = query.Where(whereCondition, args...) 243 244 // validates if association is Sortable 245 sortable := (*associations.AssociationSortable)(nil) 246 t := reflect.TypeOf(association) 247 if t.Implements(reflect.TypeOf(sortable).Elem()) { 248 m := reflect.ValueOf(association).MethodByName("OrderBy") 249 out := m.Call([]reflect.Value{}) 250 orderClause := out[0].String() 251 if orderClause != "" { 252 query = query.Order(orderClause) 253 } 254 } 255 256 sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()}) 257 query = query.RawQuery(sqlSentence, args...) 258 259 if association.Kind() == reflect.Slice || association.Kind() == reflect.Array { 260 err = query.All(association.Interface()) 261 } 262 263 if association.Kind() == reflect.Struct { 264 err = query.First(association.Interface()) 265 } 266 267 if err != nil && errors.Cause(err) != sql.ErrNoRows { 268 return err 269 } 270 271 // load all inner associations. 272 innerAssociations := association.InnerAssociations() 273 for _, inner := range innerAssociations { 274 v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name) 275 innerQuery := Q(query.Connection) 276 innerQuery.eagerFields = []string{inner.Fields} 277 err = innerQuery.eagerAssociations(v.Addr().Interface()) 278 if err != nil { 279 return err 280 } 281 } 282 } 283 return nil 284 } 285 286 // Exists returns true/false if a record exists in the database that matches 287 // the query. 288 // 289 // q.Where("name = ?", "mark").Exists(&User{}) 290 func (q *Query) Exists(model interface{}) (bool, error) { 291 tmpQuery := Q(q.Connection) 292 q.Clone(tmpQuery) //avoid meddling with original query 293 294 var res bool 295 296 err := tmpQuery.Connection.timeFunc("Exists", func() error { 297 tmpQuery.Paginator = nil 298 tmpQuery.orderClauses = clauses{} 299 tmpQuery.limitResults = 0 300 query, args := tmpQuery.ToSQL(&Model{Value: model}) 301 302 // when query contains custom selected fields / executed using RawQuery, 303 // sql may already contains limit and offset 304 if rLimitOffset.MatchString(query) { 305 foundLimit := rLimitOffset.FindString(query) 306 query = query[0 : len(query)-len(foundLimit)] 307 } else if rLimit.MatchString(query) { 308 foundLimit := rLimit.FindString(query) 309 query = query[0 : len(query)-len(foundLimit)] 310 } 311 312 existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query) 313 log(logging.SQL, existsQuery, args...) 314 return q.Connection.Store.Get(&res, existsQuery, args...) 315 }) 316 return res, err 317 } 318 319 // Count the number of records in the database. 320 // 321 // c.Count(&User{}) 322 func (c *Connection) Count(model interface{}) (int, error) { 323 return Q(c).Count(model) 324 } 325 326 // Count the number of records in the database. 327 // 328 // q.Where("name = ?", "mark").Count(&User{}) 329 func (q Query) Count(model interface{}) (int, error) { 330 return q.CountByField(model, "*") 331 } 332 333 // CountByField counts the number of records in the database, for a given field. 334 // 335 // q.Where("sex = ?", "f").Count(&User{}, "name") 336 func (q Query) CountByField(model interface{}, field string) (int, error) { 337 tmpQuery := Q(q.Connection) 338 q.Clone(tmpQuery) //avoid meddling with original query 339 340 res := &rowCount{} 341 342 err := tmpQuery.Connection.timeFunc("CountByField", func() error { 343 tmpQuery.Paginator = nil 344 tmpQuery.orderClauses = clauses{} 345 tmpQuery.limitResults = 0 346 query, args := tmpQuery.ToSQL(&Model{Value: model}) 347 //when query contains custom selected fields / executed using RawQuery, 348 // sql may already contains limit and offset 349 350 if rLimitOffset.MatchString(query) { 351 foundLimit := rLimitOffset.FindString(query) 352 query = query[0 : len(query)-len(foundLimit)] 353 } else if rLimit.MatchString(query) { 354 foundLimit := rLimit.FindString(query) 355 query = query[0 : len(query)-len(foundLimit)] 356 } 357 358 countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query) 359 log(logging.SQL, countQuery, args...) 360 return q.Connection.Store.Get(res, countQuery, args...) 361 }) 362 return res.Count, err 363 } 364 365 type rowCount struct { 366 Count int `db:"row_count"` 367 } 368 369 // Select allows to query only fields passed as parameter. 370 // c.Select("field1", "field2").All(&model) 371 // => SELECT field1, field2 FROM models 372 func (c *Connection) Select(fields ...string) *Query { 373 return c.Q().Select(fields...) 374 } 375 376 // Select allows to query only fields passed as parameter. 377 // c.Select("field1", "field2").All(&model) 378 // => SELECT field1, field2 FROM models 379 func (q *Query) Select(fields ...string) *Query { 380 for _, f := range fields { 381 if strings.TrimSpace(f) != "" { 382 q.addColumns = append(q.addColumns, f) 383 } 384 } 385 return q 386 }