github.com/gobuffalo/pop/v6@v6.1.2-0.20230426125638-01ebd5b92a24/executors.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"time"
     7  
     8  	"github.com/gobuffalo/pop/v6/associations"
     9  	"github.com/gobuffalo/pop/v6/columns"
    10  	"github.com/gobuffalo/pop/v6/logging"
    11  	"github.com/gobuffalo/validate/v3"
    12  	"github.com/gofrs/uuid"
    13  )
    14  
    15  // Reload fetch fresh data for a given model, using its ID.
    16  func (c *Connection) Reload(model interface{}) error {
    17  	sm := NewModel(model, c.Context())
    18  	return sm.iterate(func(m *Model) error {
    19  		return c.Find(m.Value, m.ID())
    20  	})
    21  }
    22  
    23  // TODO: consider merging the following two methods.
    24  
    25  // Exec runs the given query.
    26  func (q *Query) Exec() error {
    27  	return q.Connection.timeFunc("Exec", func() error {
    28  		sql, args := q.ToSQL(nil)
    29  		if sql == "" {
    30  			return fmt.Errorf("empty query")
    31  		}
    32  
    33  		txlog(logging.SQL, q.Connection, sql, args...)
    34  		_, err := q.Connection.Store.Exec(sql, args...)
    35  		return err
    36  	})
    37  }
    38  
    39  // ExecWithCount runs the given query, and returns the amount of
    40  // affected rows.
    41  func (q *Query) ExecWithCount() (int, error) {
    42  	count := int64(0)
    43  	return int(count), q.Connection.timeFunc("Exec", func() error {
    44  		sql, args := q.ToSQL(nil)
    45  		if sql == "" {
    46  			return fmt.Errorf("empty query")
    47  		}
    48  
    49  		txlog(logging.SQL, q.Connection, sql, args...)
    50  		result, err := q.Connection.Store.Exec(sql, args...)
    51  		if err != nil {
    52  			return err
    53  		}
    54  
    55  		count, err = result.RowsAffected()
    56  		return err
    57  	})
    58  }
    59  
    60  // ValidateAndSave applies validation rules on the given entry, then save it
    61  // if the validation succeed, excluding the given columns.
    62  //
    63  // If model is a slice, each item of the slice is validated then saved in the database.
    64  func (c *Connection) ValidateAndSave(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
    65  	sm := NewModel(model, c.Context())
    66  	if err := sm.beforeValidate(c); err != nil {
    67  		return nil, err
    68  	}
    69  	verrs, err := sm.validateSave(c)
    70  	if err != nil {
    71  		return verrs, err
    72  	}
    73  	if verrs.HasAny() {
    74  		return verrs, nil
    75  	}
    76  	return verrs, c.Save(model, excludeColumns...)
    77  }
    78  
    79  var emptyUUID = uuid.Nil.String()
    80  
    81  // IsZeroOfUnderlyingType will check if the value of anything is the equal to the Zero value of that type.
    82  func IsZeroOfUnderlyingType(x interface{}) bool {
    83  	return reflect.DeepEqual(x, reflect.Zero(reflect.TypeOf(x)).Interface())
    84  }
    85  
    86  // Save wraps the Create and Update methods. It executes a Create if no ID is provided with the entry;
    87  // or issues an Update otherwise.
    88  //
    89  // If model is a slice, each item of the slice is saved in the database.
    90  func (c *Connection) Save(model interface{}, excludeColumns ...string) error {
    91  	sm := NewModel(model, c.Context())
    92  	return sm.iterate(func(m *Model) error {
    93  		id, err := m.fieldByName("ID")
    94  		if err != nil {
    95  			return err
    96  		}
    97  		if IsZeroOfUnderlyingType(id.Interface()) {
    98  			return c.Create(m.Value, excludeColumns...)
    99  		}
   100  		return c.Update(m.Value, excludeColumns...)
   101  	})
   102  }
   103  
   104  // ValidateAndCreate applies validation rules on the given entry, then creates it
   105  // if the validation succeed, excluding the given columns.
   106  //
   107  // If model is a slice, each item of the slice is validated then created in the database.
   108  func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
   109  	sm := NewModel(model, c.Context())
   110  
   111  	isEager := c.eager
   112  	hasEagerFields := c.eagerFields
   113  
   114  	if err := sm.beforeValidate(c); err != nil {
   115  		return nil, err
   116  	}
   117  	verrs, err := sm.validateCreate(c)
   118  	if err != nil {
   119  		return verrs, err
   120  	}
   121  	if verrs.HasAny() {
   122  		return verrs, nil
   123  	}
   124  
   125  	if c.eager {
   126  		asos, err := associations.ForStruct(model, c.eagerFields...)
   127  		if err != nil {
   128  			return verrs, fmt.Errorf("could not retrieve associations: %w", err)
   129  		}
   130  
   131  		if len(asos) == 0 {
   132  			log(logging.Debug, "no associations found for given struct, disable eager mode")
   133  			c.disableEager()
   134  			return verrs, c.Create(model, excludeColumns...)
   135  		}
   136  
   137  		before := asos.AssociationsBeforeCreatable()
   138  		for index := range before {
   139  			i := before[index].BeforeInterface()
   140  			if i == nil {
   141  				continue
   142  			}
   143  
   144  			sm := NewModel(i, c.Context())
   145  			verrs, err := sm.validateAndOnlyCreate(c)
   146  			if err != nil || verrs.HasAny() {
   147  				return verrs, err
   148  			}
   149  		}
   150  
   151  		after := asos.AssociationsAfterCreatable()
   152  		for index := range after {
   153  			i := after[index].AfterInterface()
   154  			if i == nil {
   155  				continue
   156  			}
   157  
   158  			sm := NewModel(i, c.Context())
   159  			verrs, err := sm.validateAndOnlyCreate(c)
   160  			if err != nil || verrs.HasAny() {
   161  				return verrs, err
   162  			}
   163  		}
   164  
   165  		sm := NewModel(model, c.Context())
   166  		verrs, err = sm.validateCreate(c)
   167  		if err != nil || verrs.HasAny() {
   168  			return verrs, err
   169  		}
   170  	}
   171  
   172  	c.eager = isEager
   173  	c.eagerFields = hasEagerFields
   174  	return verrs, c.Create(model, excludeColumns...)
   175  }
   176  
   177  // Create add a new given entry to the database, excluding the given columns.
   178  // It updates `created_at` and `updated_at` columns automatically.
   179  //
   180  // If model is a slice, each item of the slice is created in the database.
   181  //
   182  // Create support two modes:
   183  // * Flat (default): Associate existing nested objects only. NO creation or update of nested objects.
   184  // * Eager: Associate existing nested objects and create non-existent objects. NO change to existing objects.
   185  func (c *Connection) Create(model interface{}, excludeColumns ...string) error {
   186  	var isEager = c.eager
   187  
   188  	c.disableEager()
   189  
   190  	sm := NewModel(model, c.Context())
   191  	return sm.iterate(func(m *Model) error {
   192  		return c.timeFunc("Create", func() error {
   193  			var localIsEager = isEager
   194  			asos, err := associations.ForStruct(m.Value, c.eagerFields...)
   195  			if err != nil {
   196  				return fmt.Errorf("could not retrieve associations: %w", err)
   197  			}
   198  
   199  			if localIsEager && len(asos) == 0 {
   200  				// No association, fallback to non-eager mode.
   201  				localIsEager = false
   202  			}
   203  
   204  			if err = m.beforeSave(c); err != nil {
   205  				return err
   206  			}
   207  
   208  			if err = m.beforeCreate(c); err != nil {
   209  				return err
   210  			}
   211  
   212  			processAssoc := len(asos) > 0
   213  
   214  			if processAssoc {
   215  				before := asos.AssociationsBeforeCreatable()
   216  				for index := range before {
   217  					i := before[index].BeforeInterface()
   218  					if i == nil {
   219  						continue
   220  					}
   221  
   222  					if localIsEager {
   223  						sm := NewModel(i, c.Context())
   224  						err = sm.iterate(func(m *Model) error {
   225  							id, err := m.fieldByName("ID")
   226  							if err != nil {
   227  								return err
   228  							}
   229  							if IsZeroOfUnderlyingType(id.Interface()) {
   230  								return c.Create(m.Value)
   231  							}
   232  							return nil
   233  						})
   234  
   235  						if err != nil {
   236  							return err
   237  						}
   238  					}
   239  
   240  					err = before[index].BeforeSetup()
   241  					if err != nil {
   242  						return err
   243  					}
   244  				}
   245  			}
   246  
   247  			tn := m.TableName()
   248  			cols := m.Columns()
   249  
   250  			if tn == sm.TableName() {
   251  				cols.Remove(excludeColumns...)
   252  			}
   253  
   254  			now := nowFunc().Truncate(time.Microsecond)
   255  			m.setUpdatedAt(now)
   256  			m.setCreatedAt(now)
   257  
   258  			if err = c.Dialect.Create(c, m, cols); err != nil {
   259  				return err
   260  			}
   261  
   262  			if processAssoc {
   263  				after := asos.AssociationsAfterCreatable()
   264  				for index := range after {
   265  					if localIsEager {
   266  						err = after[index].AfterSetup()
   267  						if err != nil {
   268  							return err
   269  						}
   270  
   271  						i := after[index].AfterInterface()
   272  						if i == nil {
   273  							continue
   274  						}
   275  
   276  						sm := NewModel(i, c.Context())
   277  						err = sm.iterate(func(m *Model) error {
   278  							fbn, err := m.fieldByName("ID")
   279  							if err != nil {
   280  								return err
   281  							}
   282  							id := fbn.Interface()
   283  							if IsZeroOfUnderlyingType(id) {
   284  								return c.Create(m.Value)
   285  							}
   286  
   287  							exists, errE := Q(c).Where(m.WhereID(), id).Exists(i)
   288  							if errE != nil || !exists {
   289  								return c.Create(m.Value)
   290  							}
   291  							return nil
   292  						})
   293  
   294  						if err != nil {
   295  							return err
   296  						}
   297  					}
   298  					stm := after[index].AfterProcess()
   299  					if c.TX != nil && !stm.Empty() {
   300  						err := c.RawQuery(c.Dialect.TranslateSQL(stm.Statement), stm.Args...).Exec()
   301  						if err != nil {
   302  							return err
   303  						}
   304  					}
   305  				}
   306  
   307  				stms := asos.AssociationsCreatableStatement()
   308  				for index := range stms {
   309  					statements := stms[index].Statements()
   310  					for _, stm := range statements {
   311  						err := c.RawQuery(c.Dialect.TranslateSQL(stm.Statement), stm.Args...).Exec()
   312  						if err != nil {
   313  							return err
   314  						}
   315  					}
   316  				}
   317  			}
   318  
   319  			if err = m.afterCreate(c); err != nil {
   320  				return err
   321  			}
   322  
   323  			return m.afterSave(c)
   324  		})
   325  	})
   326  }
   327  
   328  // ValidateAndUpdate applies validation rules on the given entry, then update it
   329  // if the validation succeed, excluding the given columns.
   330  //
   331  // If model is a slice, each item of the slice is validated then updated in the database.
   332  func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...string) (*validate.Errors, error) {
   333  	sm := NewModel(model, c.Context())
   334  	if err := sm.beforeValidate(c); err != nil {
   335  		return nil, err
   336  	}
   337  	verrs, err := sm.validateUpdate(c)
   338  	if err != nil {
   339  		return verrs, err
   340  	}
   341  	if verrs.HasAny() {
   342  		return verrs, nil
   343  	}
   344  	return verrs, c.Update(model, excludeColumns...)
   345  }
   346  
   347  // Update writes changes from an entry to the database, excluding the given columns.
   348  // It updates the `updated_at` column automatically.
   349  //
   350  // If model is a slice, each item of the slice is updated in the database.
   351  func (c *Connection) Update(model interface{}, excludeColumns ...string) error {
   352  	sm := NewModel(model, c.Context())
   353  	return sm.iterate(func(m *Model) error {
   354  		return c.timeFunc("Update", func() error {
   355  			var err error
   356  
   357  			if err = m.beforeSave(c); err != nil {
   358  				return err
   359  			}
   360  			if err = m.beforeUpdate(c); err != nil {
   361  				return err
   362  			}
   363  
   364  			tn := m.TableName()
   365  			cols := columns.ForStructWithAlias(model, tn, m.As, m.IDField())
   366  			cols.Remove(m.IDField(), "created_at")
   367  
   368  			if tn == sm.TableName() {
   369  				cols.Remove(excludeColumns...)
   370  			}
   371  
   372  			now := nowFunc().Truncate(time.Microsecond)
   373  			m.setUpdatedAt(now)
   374  
   375  			if err = c.Dialect.Update(c, m, cols); err != nil {
   376  				return err
   377  			}
   378  			if err = m.afterUpdate(c); err != nil {
   379  				return err
   380  			}
   381  
   382  			return m.afterSave(c)
   383  		})
   384  	})
   385  }
   386  
   387  // UpdateQuery updates all rows matched by the query. The new values are read
   388  // from the first argument, which must be a struct. The column names to be
   389  // updated must be listed explicitly in subsequent arguments. The ID and
   390  // CreatedAt columns are never updated. The UpdatedAt column is updated
   391  // automatically.
   392  //
   393  // UpdateQuery does not execute (before|after)(Create|Update|Save) callbacks.
   394  //
   395  // Calling UpdateQuery with no columnNames will result in only the UpdatedAt
   396  // column being updated.
   397  func (q *Query) UpdateQuery(model interface{}, columnNames ...string) (int64, error) {
   398  	sm := NewModel(model, q.Connection.Context())
   399  	modelKind := reflect.TypeOf(reflect.Indirect(reflect.ValueOf(model))).Kind()
   400  	if modelKind != reflect.Struct {
   401  		return 0, fmt.Errorf("model must be a struct; got %s", modelKind)
   402  	}
   403  
   404  	cols := columns.NewColumnsWithAlias(sm.TableName(), sm.As, sm.IDField())
   405  	cols.Add(columnNames...)
   406  	if _, err := sm.fieldByName("UpdatedAt"); err == nil {
   407  		cols.Add("updated_at")
   408  	}
   409  	cols.Remove(sm.IDField(), "created_at")
   410  
   411  	now := nowFunc().Truncate(time.Microsecond)
   412  	sm.setUpdatedAt(now)
   413  	return q.Connection.Dialect.UpdateQuery(q.Connection, sm, cols, *q)
   414  }
   415  
   416  // UpdateColumns writes changes from an entry to the database, including only the given columns
   417  // or all columns if no column names are provided.
   418  // It updates the `updated_at` column automatically if you include `updated_at` in columnNames.
   419  //
   420  // If model is a slice, each item of the slice is updated in the database.
   421  func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) error {
   422  	sm := NewModel(model, c.Context())
   423  	return sm.iterate(func(m *Model) error {
   424  		return c.timeFunc("Update", func() error {
   425  			var err error
   426  
   427  			if err = m.beforeSave(c); err != nil {
   428  				return err
   429  			}
   430  			if err = m.beforeUpdate(c); err != nil {
   431  				return err
   432  			}
   433  
   434  			tn := m.TableName()
   435  
   436  			cols := columns.Columns{}
   437  			if len(columnNames) > 0 && tn == sm.TableName() {
   438  				cols = columns.NewColumnsWithAlias(tn, m.As, sm.IDField())
   439  				cols.Add(columnNames...)
   440  
   441  			} else {
   442  				cols = columns.ForStructWithAlias(model, tn, m.As, m.IDField())
   443  			}
   444  			cols.Remove("id", "created_at")
   445  
   446  			now := nowFunc().Truncate(time.Microsecond)
   447  			m.setUpdatedAt(now)
   448  
   449  			if err = c.Dialect.Update(c, m, cols); err != nil {
   450  				return err
   451  			}
   452  			if err = m.afterUpdate(c); err != nil {
   453  				return err
   454  			}
   455  
   456  			return m.afterSave(c)
   457  		})
   458  	})
   459  }
   460  
   461  // Destroy deletes a given entry from the database.
   462  //
   463  // If model is a slice, each item of the slice is deleted from the database.
   464  func (c *Connection) Destroy(model interface{}) error {
   465  	sm := NewModel(model, c.Context())
   466  	return sm.iterate(func(m *Model) error {
   467  		return c.timeFunc("Destroy", func() error {
   468  			var err error
   469  
   470  			if err = m.beforeDestroy(c); err != nil {
   471  				return err
   472  			}
   473  			if err = c.Dialect.Destroy(c, m); err != nil {
   474  				return err
   475  			}
   476  
   477  			return m.afterDestroy(c)
   478  		})
   479  	})
   480  }
   481  
   482  func (q *Query) Delete(model interface{}) error {
   483  	q.Operation = Delete
   484  
   485  	return q.Connection.timeFunc("Delete", func() error {
   486  		m := NewModel(model, q.Connection.Context())
   487  		err := q.Connection.Dialect.Delete(q.Connection, m, *q)
   488  		if err != nil {
   489  			return err
   490  		}
   491  		return m.afterDestroy(q.Connection)
   492  	})
   493  }