github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/finders.go (about) 1 package pop 2 3 import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 "reflect" 8 "regexp" 9 "strconv" 10 "strings" 11 12 "github.com/gobuffalo/pop/v6/associations" 13 "github.com/gobuffalo/pop/v6/logging" 14 "github.com/gofrs/uuid" 15 ) 16 17 var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$") 18 var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$") 19 20 // Find the first record of the model in the database with a particular id. 21 // 22 // c.Find(&User{}, 1) 23 func (c *Connection) Find(model interface{}, id interface{}) error { 24 return Q(c).Find(model, id) 25 } 26 27 // Find the first record of the model in the database with a particular id. 28 // 29 // q.Find(&User{}, 1) 30 func (q *Query) Find(model interface{}, id interface{}) error { 31 m := NewModel(model, q.Connection.Context()) 32 idq := m.WhereID() 33 switch t := id.(type) { 34 case uuid.UUID: 35 return q.Where(idq, t.String()).First(model) 36 default: 37 // Pick argument type based on column type. This is required for keeping backwards compatibility with: 38 // 39 // https://github.com/gobuffalo/buffalo/blob/master/genny/resource/templates/use_model/actions/resource-name.go.tmpl#L76 40 pkt, err := m.PrimaryKeyType() 41 if err != nil { 42 return err 43 } 44 45 switch pkt { 46 case "int32", "int64", "uint32", "uint64", "int8", "uint8", "int16", "uint16", "int": 47 if tid, ok := id.(string); ok { 48 if intID, err := strconv.Atoi(tid); err == nil { 49 return q.Where(idq, intID).First(model) 50 } 51 } 52 } 53 } 54 55 return q.Where(idq, id).First(model) 56 } 57 58 // First record of the model in the database that matches the query. 59 // 60 // c.First(&User{}) 61 func (c *Connection) First(model interface{}) error { 62 return Q(c).First(model) 63 } 64 65 // First record of the model in the database that matches the query. 66 // 67 // q.Where("name = ?", "mark").First(&User{}) 68 func (q *Query) First(model interface{}) error { 69 var m *Model 70 err := q.Connection.timeFunc("First", func() error { 71 q.Limit(1) 72 m = NewModel(model, q.Connection.Context()) 73 if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil { 74 return err 75 } 76 return m.afterFind(q.Connection, false) 77 }) 78 79 if err != nil { 80 return err 81 } 82 83 if q.eager { 84 err := q.eagerAssociations(model) 85 q.disableEager() 86 if err != nil { 87 return err 88 } 89 return m.afterFind(q.Connection, true) 90 } 91 92 return nil 93 } 94 95 // Last record of the model in the database that matches the query. 96 // 97 // c.Last(&User{}) 98 func (c *Connection) Last(model interface{}) error { 99 return Q(c).Last(model) 100 } 101 102 // Last record of the model in the database that matches the query. 103 // 104 // q.Where("name = ?", "mark").Last(&User{}) 105 func (q *Query) Last(model interface{}) error { 106 var m *Model 107 err := q.Connection.timeFunc("Last", func() error { 108 q.Limit(1) 109 q.Order("created_at DESC, id DESC") 110 m = NewModel(model, q.Connection.Context()) 111 if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil { 112 return err 113 } 114 return m.afterFind(q.Connection, false) 115 }) 116 117 if err != nil { 118 return err 119 } 120 121 if q.eager { 122 err = q.eagerAssociations(model) 123 q.disableEager() 124 if err != nil { 125 return err 126 } 127 return m.afterFind(q.Connection, true) 128 } 129 130 return nil 131 } 132 133 // All retrieves all of the records in the database that match the query. 134 // 135 // c.All(&[]User{}) 136 func (c *Connection) All(models interface{}) error { 137 return Q(c).All(models) 138 } 139 140 // All retrieves all of the records in the database that match the query. 141 // 142 // q.Where("name = ?", "mark").All(&[]User{}) 143 func (q *Query) All(models interface{}) error { 144 var m *Model 145 err := q.Connection.timeFunc("All", func() error { 146 m = NewModel(models, q.Connection.Context()) 147 err := q.Connection.Dialect.SelectMany(q.Connection, m, *q) 148 if err != nil { 149 return err 150 } 151 152 err = q.paginateModel(models) 153 if err != nil { 154 return err 155 } 156 157 return m.afterFind(q.Connection, false) 158 }) 159 160 if err != nil { 161 return fmt.Errorf("unable to fetch records: %w", err) 162 } 163 164 if q.eager { 165 err = q.eagerAssociations(models) 166 q.disableEager() 167 if err != nil { 168 return err 169 } 170 return m.afterFind(q.Connection, true) 171 } 172 173 return nil 174 } 175 176 func (q *Query) paginateModel(models interface{}) error { 177 if q.Paginator == nil { 178 return nil 179 } 180 181 ct, err := q.Count(models) 182 if err != nil { 183 return err 184 } 185 186 q.Paginator.TotalEntriesSize = ct 187 st := reflect.ValueOf(models).Elem() 188 q.Paginator.CurrentEntriesSize = st.Len() 189 q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage 190 if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 { 191 q.Paginator.TotalPages = q.Paginator.TotalPages + 1 192 } 193 return nil 194 } 195 196 // Load loads all association or the fields specified in params for 197 // an already loaded model. 198 // 199 // tx.First(&u) 200 // tx.Load(&u) 201 func (c *Connection) Load(model interface{}, fields ...string) error { 202 q := Q(c) 203 q.eagerFields = fields 204 err := q.eagerAssociations(model) 205 q.disableEager() 206 return err 207 } 208 209 func (q *Query) eagerAssociations(model interface{}) error { 210 if q.eagerMode == eagerModeNil { 211 q.eagerMode = loadingAssociationsStrategy 212 } 213 if q.eagerMode == EagerPreload { 214 return preload(q.Connection, model, q.eagerFields...) 215 } 216 217 return q.eagerDefaultAssociations(model) 218 } 219 220 func (q *Query) eagerDefaultAssociations(model interface{}) error { 221 var err error 222 223 // eagerAssociations for a slice or array model passed as a param. 224 v := reflect.ValueOf(model) 225 kind := reflect.Indirect(v).Kind() 226 if kind == reflect.Slice || kind == reflect.Array { 227 v = v.Elem() 228 for i := 0; i < v.Len(); i++ { 229 e := v.Index(i) 230 if e.Type().Kind() == reflect.Ptr { 231 // Already a pointer 232 err = q.eagerAssociations(e.Interface()) 233 } else { 234 err = q.eagerAssociations(e.Addr().Interface()) 235 } 236 if err != nil { 237 return err 238 } 239 } 240 return nil 241 } 242 243 // eagerAssociations for a single element 244 assos, err := associations.ForStruct(model, q.eagerFields...) 245 if err != nil { 246 return fmt.Errorf("could not retrieve associations: %w", err) 247 } 248 249 // disable eager mode for current connection. 250 q.eager = false 251 q.Connection.eager = false 252 253 for _, association := range assos { 254 if association.Skipped() { 255 continue 256 } 257 258 query := Q(q.Connection) 259 260 whereCondition, args := association.Constraint() 261 query = query.Where(whereCondition, args...) 262 263 // validates if association is Sortable 264 sortable := (*associations.AssociationSortable)(nil) 265 t := reflect.TypeOf(association) 266 if t.Implements(reflect.TypeOf(sortable).Elem()) { 267 m := reflect.ValueOf(association).MethodByName("OrderBy") 268 out := m.Call([]reflect.Value{}) 269 orderClause := out[0].String() 270 if orderClause != "" { 271 query = query.Order(orderClause) 272 } 273 } 274 275 sqlSentence, args := query.ToSQL(NewModel(association.Interface(), query.Connection.Context())) 276 query = query.RawQuery(sqlSentence, args...) 277 278 if association.Kind() == reflect.Slice || association.Kind() == reflect.Array { 279 err = query.All(association.Interface()) 280 } 281 282 if association.Kind() == reflect.Struct { 283 err = query.First(association.Interface()) 284 } 285 286 if err != nil && !errors.Is(err, sql.ErrNoRows) { 287 return err 288 } 289 290 if err == sql.ErrNoRows { 291 continue 292 } 293 294 // load all inner associations. 295 innerAssociations := association.InnerAssociations() 296 for _, inner := range innerAssociations { 297 v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name) 298 innerQuery := Q(query.Connection) 299 innerQuery.eagerFields = inner.Fields 300 301 switch v.Kind() { 302 case reflect.Ptr: 303 err = innerQuery.eagerAssociations(v.Interface()) 304 default: 305 err = innerQuery.eagerAssociations(v.Addr().Interface()) 306 } 307 308 if err != nil { 309 return err 310 } 311 } 312 } 313 return nil 314 } 315 316 // Exists returns true/false if a record exists in the database that matches 317 // the query. 318 // 319 // q.Where("name = ?", "mark").Exists(&User{}) 320 func (q *Query) Exists(model interface{}) (bool, error) { 321 tmpQuery := Q(q.Connection) 322 q.Clone(tmpQuery) // avoid meddling with original query 323 324 var res bool 325 326 err := tmpQuery.Connection.timeFunc("Exists", func() error { 327 tmpQuery.Paginator = nil 328 tmpQuery.orderClauses = clauses{} 329 tmpQuery.limitResults = 0 330 query, args := tmpQuery.ToSQL(NewModel(model, tmpQuery.Connection.Context())) 331 332 // when query contains custom selected fields / executed using RawQuery, 333 // sql may already contains limit and offset 334 if rLimitOffset.MatchString(query) { 335 foundLimit := rLimitOffset.FindString(query) 336 query = query[0 : len(query)-len(foundLimit)] 337 } else if rLimit.MatchString(query) { 338 foundLimit := rLimit.FindString(query) 339 query = query[0 : len(query)-len(foundLimit)] 340 } 341 342 existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query) 343 txlog(logging.SQL, q.Connection, existsQuery, args...) 344 return q.Connection.Store.Get(&res, existsQuery, args...) 345 }) 346 return res, err 347 } 348 349 // Count the number of records in the database. 350 // 351 // c.Count(&User{}) 352 func (c *Connection) Count(model interface{}) (int, error) { 353 return Q(c).Count(model) 354 } 355 356 // Count the number of records in the database. 357 // 358 // q.Where("name = ?", "mark").Count(&User{}) 359 func (q Query) Count(model interface{}) (int, error) { 360 return q.CountByField(model, "*") 361 } 362 363 // CountByField counts the number of records in the database, for a given field. 364 // 365 // q.Where("sex = ?", "f").Count(&User{}, "name") 366 func (q Query) CountByField(model interface{}, field string) (int, error) { 367 tmpQuery := Q(q.Connection) 368 q.Clone(tmpQuery) // avoid meddling with original query 369 370 res := &rowCount{} 371 372 err := tmpQuery.Connection.timeFunc("CountByField", func() error { 373 tmpQuery.Paginator = nil 374 tmpQuery.orderClauses = clauses{} 375 tmpQuery.limitResults = 0 376 query, args := tmpQuery.ToSQL(NewModel(model, q.Connection.Context())) 377 // when query contains custom selected fields / executed using RawQuery, 378 // sql may already contains limit and offset 379 380 if rLimitOffset.MatchString(query) { 381 foundLimit := rLimitOffset.FindString(query) 382 query = query[0 : len(query)-len(foundLimit)] 383 } else if rLimit.MatchString(query) { 384 foundLimit := rLimit.FindString(query) 385 query = query[0 : len(query)-len(foundLimit)] 386 } 387 388 countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query) 389 txlog(logging.SQL, q.Connection, countQuery, args...) 390 return q.Connection.Store.Get(res, countQuery, args...) 391 }) 392 return res.Count, err 393 } 394 395 type rowCount struct { 396 Count int `db:"row_count"` 397 } 398 399 // Select allows to query only fields passed as parameter. 400 // c.Select("field1", "field2").All(&model) 401 // => SELECT field1, field2 FROM models 402 func (c *Connection) Select(fields ...string) *Query { 403 return c.Q().Select(fields...) 404 } 405 406 // Select allows to query only fields passed as parameter. 407 // c.Select("field1", "field2").All(&model) 408 // => SELECT field1, field2 FROM models 409 func (q *Query) Select(fields ...string) *Query { 410 for _, f := range fields { 411 if strings.TrimSpace(f) != "" { 412 q.addColumns = append(q.addColumns, f) 413 } 414 } 415 return q 416 }