github.com/naemono/pop@v4.13.1+incompatible/finders.go (about) 1 package pop 2 3 import ( 4 "database/sql" 5 "fmt" 6 "reflect" 7 "regexp" 8 "strconv" 9 "strings" 10 11 "github.com/gobuffalo/pop/associations" 12 "github.com/gobuffalo/pop/logging" 13 "github.com/gofrs/uuid" 14 "github.com/pkg/errors" 15 ) 16 17 var rLimitOffset = regexp.MustCompile("(?i)(limit [0-9]+ offset [0-9]+)$") 18 var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$") 19 20 // Find the first record of the model in the database with a particular id. 21 // 22 // c.Find(&User{}, 1) 23 func (c *Connection) Find(model interface{}, id interface{}) error { 24 return Q(c).Find(model, id) 25 } 26 27 // Find the first record of the model in the database with a particular id. 28 // 29 // q.Find(&User{}, 1) 30 func (q *Query) Find(model interface{}, id interface{}) error { 31 m := &Model{Value: model} 32 idq := m.whereID() 33 switch t := id.(type) { 34 case uuid.UUID: 35 return q.Where(idq, t.String()).First(model) 36 case string: 37 l := len(t) 38 if l > 0 { 39 // Handle leading '0': 40 // if the string have a leading '0' and is not "0", prevent parsing to int 41 if t[0] != '0' || l == 1 { 42 var err error 43 id, err = strconv.Atoi(t) 44 if err != nil { 45 return q.Where(idq, t).First(model) 46 } 47 } 48 } 49 } 50 51 return q.Where(idq, id).First(model) 52 } 53 54 // First record of the model in the database that matches the query. 55 // 56 // c.First(&User{}) 57 func (c *Connection) First(model interface{}) error { 58 return Q(c).First(model) 59 } 60 61 // First record of the model in the database that matches the query. 62 // 63 // q.Where("name = ?", "mark").First(&User{}) 64 func (q *Query) First(model interface{}) error { 65 err := q.Connection.timeFunc("First", func() error { 66 q.Limit(1) 67 m := &Model{Value: model} 68 if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { 69 return err 70 } 71 return m.afterFind(q.Connection) 72 }) 73 74 if err != nil { 75 return err 76 } 77 78 if q.eager { 79 err = q.eagerAssociations(model) 80 q.disableEager() 81 return err 82 } 83 return nil 84 } 85 86 // Last record of the model in the database that matches the query. 87 // 88 // c.Last(&User{}) 89 func (c *Connection) Last(model interface{}) error { 90 return Q(c).Last(model) 91 } 92 93 // Last record of the model in the database that matches the query. 94 // 95 // q.Where("name = ?", "mark").Last(&User{}) 96 func (q *Query) Last(model interface{}) error { 97 err := q.Connection.timeFunc("Last", func() error { 98 q.Limit(1) 99 q.Order("created_at DESC, id DESC") 100 m := &Model{Value: model} 101 if err := q.Connection.Dialect.SelectOne(q.Connection.Store, m, *q); err != nil { 102 return err 103 } 104 return m.afterFind(q.Connection) 105 }) 106 107 if err != nil { 108 return err 109 } 110 111 if q.eager { 112 err = q.eagerAssociations(model) 113 q.disableEager() 114 return err 115 } 116 117 return nil 118 } 119 120 // All retrieves all of the records in the database that match the query. 121 // 122 // c.All(&[]User{}) 123 func (c *Connection) All(models interface{}) error { 124 return Q(c).All(models) 125 } 126 127 // All retrieves all of the records in the database that match the query. 128 // 129 // q.Where("name = ?", "mark").All(&[]User{}) 130 func (q *Query) All(models interface{}) error { 131 err := q.Connection.timeFunc("All", func() error { 132 m := &Model{Value: models} 133 err := q.Connection.Dialect.SelectMany(q.Connection.Store, m, *q) 134 if err != nil { 135 return err 136 } 137 err = q.paginateModel(models) 138 if err != nil { 139 return err 140 } 141 return m.afterFind(q.Connection) 142 }) 143 144 if err != nil { 145 return errors.Wrap(err, "unable to fetch records") 146 } 147 148 if q.eager { 149 err = q.eagerAssociations(models) 150 q.disableEager() 151 return err 152 } 153 154 return nil 155 } 156 157 func (q *Query) paginateModel(models interface{}) error { 158 if q.Paginator == nil { 159 return nil 160 } 161 162 ct, err := q.Count(models) 163 if err != nil { 164 return err 165 } 166 167 q.Paginator.TotalEntriesSize = ct 168 st := reflect.ValueOf(models).Elem() 169 q.Paginator.CurrentEntriesSize = st.Len() 170 q.Paginator.TotalPages = q.Paginator.TotalEntriesSize / q.Paginator.PerPage 171 if q.Paginator.TotalEntriesSize%q.Paginator.PerPage > 0 { 172 q.Paginator.TotalPages = q.Paginator.TotalPages + 1 173 } 174 return nil 175 } 176 177 // Load loads all association or the fields specified in params for 178 // an already loaded model. 179 // 180 // tx.First(&u) 181 // tx.Load(&u) 182 func (c *Connection) Load(model interface{}, fields ...string) error { 183 q := Q(c) 184 q.eagerFields = fields 185 err := q.eagerAssociations(model) 186 q.disableEager() 187 return err 188 } 189 190 func (q *Query) eagerAssociations(model interface{}) error { 191 var err error 192 193 // eagerAssociations for a slice or array model passed as a param. 194 v := reflect.ValueOf(model) 195 kind := reflect.Indirect(v).Kind() 196 if kind == reflect.Slice || kind == reflect.Array { 197 v = v.Elem() 198 for i := 0; i < v.Len(); i++ { 199 e := v.Index(i) 200 if e.Type().Kind() == reflect.Ptr { 201 // Already a pointer 202 err = q.eagerAssociations(e.Interface()) 203 } else { 204 err = q.eagerAssociations(e.Addr().Interface()) 205 } 206 if err != nil { 207 return err 208 } 209 } 210 return nil 211 } 212 213 // eagerAssociations for a single element 214 assos, err := associations.ForStruct(model, q.eagerFields...) 215 if err != nil { 216 return errors.Wrap(err, "could not retrieve associations") 217 } 218 219 // disable eager mode for current connection. 220 q.eager = false 221 q.Connection.eager = false 222 223 for _, association := range assos { 224 if association.Skipped() { 225 continue 226 } 227 228 query := Q(q.Connection) 229 230 whereCondition, args := association.Constraint() 231 query = query.Where(whereCondition, args...) 232 233 // validates if association is Sortable 234 sortable := (*associations.AssociationSortable)(nil) 235 t := reflect.TypeOf(association) 236 if t.Implements(reflect.TypeOf(sortable).Elem()) { 237 m := reflect.ValueOf(association).MethodByName("OrderBy") 238 out := m.Call([]reflect.Value{}) 239 orderClause := out[0].String() 240 if orderClause != "" { 241 query = query.Order(orderClause) 242 } 243 } 244 245 sqlSentence, args := query.ToSQL(&Model{Value: association.Interface()}) 246 query = query.RawQuery(sqlSentence, args...) 247 248 if association.Kind() == reflect.Slice || association.Kind() == reflect.Array { 249 err = query.All(association.Interface()) 250 } 251 252 if association.Kind() == reflect.Struct { 253 err = query.First(association.Interface()) 254 } 255 256 if err != nil && errors.Cause(err) != sql.ErrNoRows { 257 return err 258 } 259 260 // load all inner associations. 261 innerAssociations := association.InnerAssociations() 262 for _, inner := range innerAssociations { 263 v = reflect.Indirect(reflect.ValueOf(model)).FieldByName(inner.Name) 264 innerQuery := Q(query.Connection) 265 innerQuery.eagerFields = []string{inner.Fields} 266 err = innerQuery.eagerAssociations(v.Addr().Interface()) 267 if err != nil { 268 return err 269 } 270 } 271 } 272 return nil 273 } 274 275 // Exists returns true/false if a record exists in the database that matches 276 // the query. 277 // 278 // q.Where("name = ?", "mark").Exists(&User{}) 279 func (q *Query) Exists(model interface{}) (bool, error) { 280 tmpQuery := Q(q.Connection) 281 q.Clone(tmpQuery) //avoid meddling with original query 282 283 var res bool 284 285 err := tmpQuery.Connection.timeFunc("Exists", func() error { 286 tmpQuery.Paginator = nil 287 tmpQuery.orderClauses = clauses{} 288 tmpQuery.limitResults = 0 289 query, args := tmpQuery.ToSQL(&Model{Value: model}) 290 291 // when query contains custom selected fields / executed using RawQuery, 292 // sql may already contains limit and offset 293 if rLimitOffset.MatchString(query) { 294 foundLimit := rLimitOffset.FindString(query) 295 query = query[0 : len(query)-len(foundLimit)] 296 } else if rLimit.MatchString(query) { 297 foundLimit := rLimit.FindString(query) 298 query = query[0 : len(query)-len(foundLimit)] 299 } 300 301 existsQuery := fmt.Sprintf("SELECT EXISTS (%s)", query) 302 log(logging.SQL, existsQuery, args...) 303 return q.Connection.Store.Get(&res, existsQuery, args...) 304 }) 305 return res, err 306 } 307 308 // Count the number of records in the database. 309 // 310 // c.Count(&User{}) 311 func (c *Connection) Count(model interface{}) (int, error) { 312 return Q(c).Count(model) 313 } 314 315 // Count the number of records in the database. 316 // 317 // q.Where("name = ?", "mark").Count(&User{}) 318 func (q Query) Count(model interface{}) (int, error) { 319 return q.CountByField(model, "*") 320 } 321 322 // CountByField counts the number of records in the database, for a given field. 323 // 324 // q.Where("sex = ?", "f").Count(&User{}, "name") 325 func (q Query) CountByField(model interface{}, field string) (int, error) { 326 tmpQuery := Q(q.Connection) 327 q.Clone(tmpQuery) //avoid meddling with original query 328 329 res := &rowCount{} 330 331 err := tmpQuery.Connection.timeFunc("CountByField", func() error { 332 tmpQuery.Paginator = nil 333 tmpQuery.orderClauses = clauses{} 334 tmpQuery.limitResults = 0 335 query, args := tmpQuery.ToSQL(&Model{Value: model}) 336 //when query contains custom selected fields / executed using RawQuery, 337 // sql may already contains limit and offset 338 339 if rLimitOffset.MatchString(query) { 340 foundLimit := rLimitOffset.FindString(query) 341 query = query[0 : len(query)-len(foundLimit)] 342 } else if rLimit.MatchString(query) { 343 foundLimit := rLimit.FindString(query) 344 query = query[0 : len(query)-len(foundLimit)] 345 } 346 347 countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query) 348 log(logging.SQL, countQuery, args...) 349 return q.Connection.Store.Get(res, countQuery, args...) 350 }) 351 return res.Count, err 352 } 353 354 type rowCount struct { 355 Count int `db:"row_count"` 356 } 357 358 // Select allows to query only fields passed as parameter. 359 // c.Select("field1", "field2").All(&model) 360 // => SELECT field1, field2 FROM models 361 func (c *Connection) Select(fields ...string) *Query { 362 return c.Q().Select(fields...) 363 } 364 365 // Select allows to query only fields passed as parameter. 366 // c.Select("field1", "field2").All(&model) 367 // => SELECT field1, field2 FROM models 368 func (q *Query) Select(fields ...string) *Query { 369 for _, f := range fields { 370 if strings.TrimSpace(f) != "" { 371 q.addColumns = append(q.addColumns, f) 372 } 373 } 374 return q 375 }