github.com/dkishere/pop/v6@v6.103.1/finders.go (about) 1 package pop 2 3 import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 "reflect" 8 "regexp" 9 "strconv" 10 "strings" 11 12 "github.com/dkishere/pop/v6/associations" 13 "github.com/dkishere/pop/v6/logging" 14 "github.com/gofrs/uuid" 15 ) 16 17 var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$") 18 var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$") 19 20 // Find the first record of the model in the database with a particular id. 21 // 22 // c.Find(&User{}, 1) 23 func (c *Connection) Find(model interface{}, id interface{}) error { 24 return Q(c).Find(model, id) 25 } 26 27 // Find the first record of the model in the database with a particular id. 28 // 29 // q.Find(&User{}, 1) 30 func (q *Query) Find(model interface{}, id interface{}) error { 31 m := NewModel(model, q.Connection.Context()) 32 idq := m.WhereID() 33 switch t := id.(type) { 34 case uuid.UUID: 35 return q.Where(idq, t.String()).First(model) 36 default: 37 // Pick argument type based on column type. This is required for keeping backwards compatibility with: 38 // 39 // https://github.com/gobuffalo/buffalo/blob/master/genny/resource/templates/use_model/actions/resource-name.go.tmpl#L76 40 pkt, err := m.PrimaryKeyType() 41 if err != nil { 42 return err 43 } 44 45 switch pkt { 46 case "int32", "int64", "uint32", "uint64", "int8", "uint8", "int16", "uint16", "int": 47 if tid, ok := id.(string); ok { 48 if intID, err := strconv.Atoi(tid); err == nil { 49 return q.Where(idq, intID).First(model) 50 } 51 } 52 } 53 } 54 55 return q.Where(idq, id).First(model) 56 } 57 58 // First record of the model in the database that matches the query. 59 // 60 // c.First(&User{}) 61 func (c *Connection) First(model interface{}) error { 62 return Q(c).First(model) 63 } 64 65 // First record of the model in the database that matches the query. 66 // 67 // q.Where("name = ?", "mark").First(&User{}) 68 func (q *Query) First(model interface{}) error { 69 err := q.Connection.timeFunc("First", func() error { 70 q.Limit(1) 71 m := NewModel(model, q.Connection.Context()) 72 if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { 73 return err 74 } 75 return m.afterFind(q.Connection) 76 }) 77 78 if err != nil { 79 return err 80 } 81 82 if q.eager { 83 err = q.eagerAssociations(model) 84 q.disableEager() 85 return err 86 } 87 return nil 88 } 89 90 // Last record of the model in the database that matches the query. 91 // 92 // c.Last(&User{}) 93 func (c *Connection) Last(model interface{}) error { 94 return Q(c).Last(model) 95 } 96 97 // Last record of the model in the database that matches the query. 98 // 99 // q.Where("name = ?", "mark").Last(&User{}) 100 func (q *Query) Last(model interface{}) error { 101 err := q.Connection.timeFunc("Last", func() error { 102 q.Limit(1) 103 q.Order("created_at DESC, id DESC") 104 m := NewModel(model, q.Connection.Context()) 105 if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { 106 return err 107 } 108 return m.afterFind(q.Connection) 109 }) 110 111 if err != nil { 112 return err 113 } 114 115 if q.eager { 116 err = q.eagerAssociations(model) 117 q.disableEager() 118 return err 119 } 120 121 return nil 122 } 123 124 // All retrieves all of the records in the database that match the query. 125 // 126 // c.All(&[]User{}) 127 func (c *Connection) All(models interface{}) error { 128 return Q(c).All(models) 129 } 130 131 // All retrieves all of the records in the database that match the query. 132 // 133 // q.Where("name = ?", "mark").All(&[]User{}) 134 func (q *Query) All(models interface{}) error { 135 err := q.Connection.timeFunc("All", func() error { 136 m := NewModel(models, q.Connection.Context()) 137 err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q) 138 if err != nil { 139 return err 140 } 141 err = q.paginateModel(models) 142 if err != nil { 143 return err 144 } 145 return m.afterFind(q.Connection) 146 }) 147 148 if err != nil { 149 return fmt.Errorf("unable to fetch records: %w", err) 150 } 151 152 if q.eager { 153 err = q.eagerAssociations(models) 154 q.disableEager() 155 return err 156 } 157 158 return nil 159 } 160 161 func (q *Query) paginateModel(models interface{}) error { 162 if q.Paginator == nil { 163 return nil 164 } 165 166 ct, err := q.Count(models) 167 if err != nil { 168 return err 169 } 170 171 q.Paginator.TotalEntriesSize = ct 172 st := reflect.ValueOf(models).Elem() 173 q.Paginator.CurrentEntriesSize = st.Len() 174 q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage 175 if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 { 176 q.Paginator.TotalPages = q.Paginator.TotalPages + 1 177 } 178 return nil 179 } 180 181 // Load loads all association or the fields specified in params for 182 // an already loaded model. 183 // 184 // tx.First(&u) 185 // tx.Load(&u) 186 func (c *Connection) Load(model interface{}, fields ...string) error { 187 q := Q(c) 188 q.eagerFields = fields 189 err := q.eagerAssociations(model) 190 q.disableEager() 191 return err 192 } 193 194 func (q *Query) eagerAssociations(model interface{}) error { 195 if q.eagerMode == eagerModeNil { 196 q.eagerMode = loadingAssociationsStrategy 197 } 198 if q.eagerMode == EagerPreload { 199 return preload(q.Connection, model, q.eagerFields...) 200 } 201 202 return q.eagerDefaultAssociations(model) 203 } 204 205 func (q *Query) eagerDefaultAssociations(model interface{}) error { 206 var err error 207 208 // eagerAssociations for a slice or array model passed as a param. 209 v := reflect.ValueOf(model) 210 kind := reflect.Indirect(v).Kind() 211 if kind == reflect.Slice || kind == reflect.Array { 212 v = v.Elem() 213 for i := 0; i < v.Len(); i++ { 214 e := v.Index(i) 215 if e.Type().Kind() == reflect.Ptr { 216 // Already a pointer 217 err = q.eagerAssociations(e.Interface()) 218 } else { 219 err = q.eagerAssociations(e.Addr().Interface()) 220 } 221 if err != nil { 222 return err 223 } 224 } 225 return nil 226 } 227 228 // eagerAssociations for a single element 229 assos, err := associations.ForStruct(model, q.eagerFields...) 230 if err != nil { 231 return fmt.Errorf("could not retrieve associations: %w", err) 232 } 233 234 // disable eager mode for current connection. 235 q.eager = false 236 q.Connection.eager = false 237 238 for _, association := range assos { 239 if association.Skipped() { 240 continue 241 } 242 243 query := Q(q.Connection) 244 245 whereCondition, args := association.Constraint() 246 query = query.Where(whereCondition, args...) 247 248 // validates if association is Sortable 249 sortable := (*associations.AssociationSortable)(nil) 250 t := reflect.TypeOf(association) 251 if t.Implements(reflect.TypeOf(sortable).Elem()) { 252 m := reflect.ValueOf(association).MethodByName("OrderBy") 253 out := m.Call([]reflect.Value{}) 254 orderClause := out[0].String() 255 if orderClause != "" { 256 query = query.Order(orderClause) 257 } 258 } 259 260 sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context())) 261 query = query.RawQuery(sqlSentence, args...) 262 263 if association.Kind() == reflect.Slice || association.Kind() == reflect.Array { 264 err = query.All(association.Interface()) 265 } 266 267 if association.Kind() == reflect.Struct { 268 err = query.First(association.Interface()) 269 } 270 271 if err != nil && !errors.Is(err, sql.ErrNoRows) { 272 return err 273 } 274 275 if err == sql.ErrNoRows { 276 continue 277 } 278 279 // load all inner associations. 280 innerAssociations := association.InnerAssociations() 281 for _, inner := range innerAssociations { 282 v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name) 283 innerQuery := Q(query.Connection) 284 innerQuery.eagerFields = inner.Fields 285 err = innerQuery.eagerAssociations(v.Addr().Interface()) 286 if err != nil { 287 return err 288 } 289 } 290 } 291 return nil 292 } 293 294 // Exists returns true/false if a record exists in the database that matches 295 // the query. 296 // 297 // q.Where("name = ?", "mark").Exists(&User{}) 298 func (q *Query) Exists(model interface{}) (bool, error) { 299 tmpQuery := Q(q.Connection) 300 q.Clone(tmpQuery) // avoid meddling with original query 301 302 var res bool 303 304 err := tmpQuery.Connection.timeFunc("Exists", func() error { 305 tmpQuery.Paginator = nil 306 tmpQuery.orderClauses = clauses{} 307 tmpQuery.limitResults = 0 308 query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context())) 309 310 // when query contains custom selected fields / executed using RawQuery, 311 // sql may already contains limit and offset 312 if rLimitOffset.MatchString(query) { 313 foundLimit := rLimitOffset.FindString(query) 314 query = query[0 : len(query)-len(foundLimit)] 315 } else if rLimit.MatchString(query) { 316 foundLimit := rLimit.FindString(query) 317 query = query[0 : len(query)-len(foundLimit)] 318 } 319 320 existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query) 321 log(logging.SQL, existsQuery, args...) 322 return q.Connection.Store.Get(&res, existsQuery, args...) 323 }) 324 return res, err 325 } 326 327 // Count the number of records in the database. 328 // 329 // c.Count(&User{}) 330 func (c *Connection) Count(model interface{}) (int, error) { 331 return Q(c).Count(model) 332 } 333 334 // Count the number of records in the database. 335 // 336 // q.Where("name = ?", "mark").Count(&User{}) 337 func (q Query) Count(model interface{}) (int, error) { 338 return q.CountByField(model, "*") 339 } 340 341 // CountByField counts the number of records in the database, for a given field. 342 // 343 // q.Where("sex = ?", "f").Count(&User{}, "name") 344 func (q Query) CountByField(model interface{}, field string) (int, error) { 345 tmpQuery := Q(q.Connection) 346 q.Clone(tmpQuery) // avoid meddling with original query 347 348 res := &rowCount{} 349 350 err := tmpQuery.Connection.timeFunc("CountByField", func() error { 351 tmpQuery.Paginator = nil 352 tmpQuery.orderClauses = clauses{} 353 tmpQuery.limitResults = 0 354 query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context())) 355 // when query contains custom selected fields / executed using RawQuery, 356 // sql may already contains limit and offset 357 358 if rLimitOffset.MatchString(query) { 359 foundLimit := rLimitOffset.FindString(query) 360 query = query[0 : len(query)-len(foundLimit)] 361 } else if rLimit.MatchString(query) { 362 foundLimit := rLimit.FindString(query) 363 query = query[0 : len(query)-len(foundLimit)] 364 } 365 366 countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query) 367 log(logging.SQL, countQuery, args...) 368 return q.Connection.Store.Get(res, countQuery, args...) 369 }) 370 return res.Count, err 371 } 372 373 type rowCount struct { 374 Count int `db:"row_count"` 375 } 376 377 // Select allows to query only fields passed as parameter. 378 // c.Select("field1", "field2").All(&model) 379 // => SELECT field1, field2 FROM models 380 func (c *Connection) Select(fields ...string) *Query { 381 return c.Q().Select(fields...) 382 } 383 384 // Select allows to query only fields passed as parameter. 385 // c.Select("field1", "field2").All(&model) 386 // => SELECT field1, field2 FROM models 387 func (q *Query) Select(fields ...string) *Query { 388 for _, f := range fields { 389 if strings.TrimSpace(f) != "" { 390 q.addColumns = append(q.addColumns, f) 391 } 392 } 393 return q 394 }