github.com/systematiccaos/gorm@v1.22.6/finisher_api.go (about)

     1  package gorm
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"strings"
     9  
    10  	"github.com/systematiccaos/gorm/clause"
    11  	"github.com/systematiccaos/gorm/logger"
    12  	"github.com/systematiccaos/gorm/schema"
    13  	"github.com/systematiccaos/gorm/utils"
    14  )
    15  
    16  // Create insert the value into database
    17  func (db *DB) Create(value interface{}) (tx *DB) {
    18  	if db.CreateBatchSize > 0 {
    19  		return db.CreateInBatches(value, db.CreateBatchSize)
    20  	}
    21  
    22  	tx = db.getInstance()
    23  	tx.Statement.Dest = value
    24  	return tx.callbacks.Create().Execute(tx)
    25  }
    26  
    27  // CreateInBatches insert the value in batches into database
    28  func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
    29  	reflectValue := reflect.Indirect(reflect.ValueOf(value))
    30  
    31  	switch reflectValue.Kind() {
    32  	case reflect.Slice, reflect.Array:
    33  		var rowsAffected int64
    34  		tx = db.getInstance()
    35  
    36  		callFc := func(tx *DB) error {
    37  			// the reflection length judgment of the optimized value
    38  			reflectLen := reflectValue.Len()
    39  			for i := 0; i < reflectLen; i += batchSize {
    40  				ends := i + batchSize
    41  				if ends > reflectLen {
    42  					ends = reflectLen
    43  				}
    44  
    45  				subtx := tx.getInstance()
    46  				subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
    47  				subtx.callbacks.Create().Execute(subtx)
    48  				if subtx.Error != nil {
    49  					return subtx.Error
    50  				}
    51  				rowsAffected += subtx.RowsAffected
    52  			}
    53  			return nil
    54  		}
    55  
    56  		if tx.SkipDefaultTransaction {
    57  			tx.AddError(callFc(tx.Session(&Session{})))
    58  		} else {
    59  			tx.AddError(tx.Transaction(callFc))
    60  		}
    61  
    62  		tx.RowsAffected = rowsAffected
    63  	default:
    64  		tx = db.getInstance()
    65  		tx.Statement.Dest = value
    66  		tx = tx.callbacks.Create().Execute(tx)
    67  	}
    68  	return
    69  }
    70  
    71  // Save update value in database, if the value doesn't have primary key, will insert it
    72  func (db *DB) Save(value interface{}) (tx *DB) {
    73  	tx = db.getInstance()
    74  	tx.Statement.Dest = value
    75  
    76  	reflectValue := reflect.Indirect(reflect.ValueOf(value))
    77  	switch reflectValue.Kind() {
    78  	case reflect.Slice, reflect.Array:
    79  		if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
    80  			tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
    81  		}
    82  		tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
    83  	case reflect.Struct:
    84  		if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
    85  			for _, pf := range tx.Statement.Schema.PrimaryFields {
    86  				if _, isZero := pf.ValueOf(reflectValue); isZero {
    87  					return tx.callbacks.Create().Execute(tx)
    88  				}
    89  			}
    90  		}
    91  
    92  		fallthrough
    93  	default:
    94  		selectedUpdate := len(tx.Statement.Selects) != 0
    95  		// when updating, use all fields including those zero-value fields
    96  		if !selectedUpdate {
    97  			tx.Statement.Selects = append(tx.Statement.Selects, "*")
    98  		}
    99  
   100  		tx = tx.callbacks.Update().Execute(tx)
   101  
   102  		if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
   103  			result := reflect.New(tx.Statement.Schema.ModelType).Interface()
   104  			if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) {
   105  				return tx.Create(value)
   106  			}
   107  		}
   108  	}
   109  
   110  	return
   111  }
   112  
   113  // First find first record that match given conditions, order by primary key
   114  func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
   115  	tx = db.Limit(1).Order(clause.OrderByColumn{
   116  		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
   117  	})
   118  	if len(conds) > 0 {
   119  		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
   120  			tx.Statement.AddClause(clause.Where{Exprs: exprs})
   121  		}
   122  	}
   123  	tx.Statement.RaiseErrorOnNotFound = true
   124  	tx.Statement.Dest = dest
   125  	return tx.callbacks.Query().Execute(tx)
   126  }
   127  
   128  // Take return a record that match given conditions, the order will depend on the database implementation
   129  func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
   130  	tx = db.Limit(1)
   131  	if len(conds) > 0 {
   132  		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
   133  			tx.Statement.AddClause(clause.Where{Exprs: exprs})
   134  		}
   135  	}
   136  	tx.Statement.RaiseErrorOnNotFound = true
   137  	tx.Statement.Dest = dest
   138  	return tx.callbacks.Query().Execute(tx)
   139  }
   140  
   141  // Last find last record that match given conditions, order by primary key
   142  func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
   143  	tx = db.Limit(1).Order(clause.OrderByColumn{
   144  		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
   145  		Desc:   true,
   146  	})
   147  	if len(conds) > 0 {
   148  		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
   149  			tx.Statement.AddClause(clause.Where{Exprs: exprs})
   150  		}
   151  	}
   152  	tx.Statement.RaiseErrorOnNotFound = true
   153  	tx.Statement.Dest = dest
   154  	return tx.callbacks.Query().Execute(tx)
   155  }
   156  
   157  // Find find records that match given conditions
   158  func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
   159  	tx = db.getInstance()
   160  	if len(conds) > 0 {
   161  		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
   162  			tx.Statement.AddClause(clause.Where{Exprs: exprs})
   163  		}
   164  	}
   165  	tx.Statement.Dest = dest
   166  	return tx.callbacks.Query().Execute(tx)
   167  }
   168  
   169  // FindInBatches find records in batches
   170  func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
   171  	var (
   172  		tx = db.Order(clause.OrderByColumn{
   173  			Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
   174  		}).Session(&Session{})
   175  		queryDB      = tx
   176  		rowsAffected int64
   177  		batch        int
   178  	)
   179  
   180  	for {
   181  		result := queryDB.Limit(batchSize).Find(dest)
   182  		rowsAffected += result.RowsAffected
   183  		batch++
   184  
   185  		if result.Error == nil && result.RowsAffected != 0 {
   186  			tx.AddError(fc(result, batch))
   187  		} else if result.Error != nil {
   188  			tx.AddError(result.Error)
   189  		}
   190  
   191  		if tx.Error != nil || int(result.RowsAffected) < batchSize {
   192  			break
   193  		}
   194  
   195  		// Optimize for-break
   196  		resultsValue := reflect.Indirect(reflect.ValueOf(dest))
   197  		if result.Statement.Schema.PrioritizedPrimaryField == nil {
   198  			tx.AddError(ErrPrimaryKeyRequired)
   199  			break
   200  		}
   201  
   202  		primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
   203  		queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
   204  	}
   205  
   206  	tx.RowsAffected = rowsAffected
   207  	return tx
   208  }
   209  
   210  func (tx *DB) assignInterfacesToValue(values ...interface{}) {
   211  	for _, value := range values {
   212  		switch v := value.(type) {
   213  		case []clause.Expression:
   214  			for _, expr := range v {
   215  				if eq, ok := expr.(clause.Eq); ok {
   216  					switch column := eq.Column.(type) {
   217  					case string:
   218  						if field := tx.Statement.Schema.LookUpField(column); field != nil {
   219  							tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
   220  						}
   221  					case clause.Column:
   222  						if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
   223  							tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
   224  						}
   225  					}
   226  				} else if andCond, ok := expr.(clause.AndConditions); ok {
   227  					tx.assignInterfacesToValue(andCond.Exprs)
   228  				}
   229  			}
   230  		case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
   231  			if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 {
   232  				tx.assignInterfacesToValue(exprs)
   233  			}
   234  		default:
   235  			if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
   236  				reflectValue := reflect.Indirect(reflect.ValueOf(value))
   237  				switch reflectValue.Kind() {
   238  				case reflect.Struct:
   239  					for _, f := range s.Fields {
   240  						if f.Readable {
   241  							if v, isZero := f.ValueOf(reflectValue); !isZero {
   242  								if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
   243  									tx.AddError(field.Set(tx.Statement.ReflectValue, v))
   244  								}
   245  							}
   246  						}
   247  					}
   248  				}
   249  			} else if len(values) > 0 {
   250  				if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
   251  					tx.assignInterfacesToValue(exprs)
   252  				}
   253  				return
   254  			}
   255  		}
   256  	}
   257  }
   258  
   259  func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
   260  	queryTx := db.Limit(1).Order(clause.OrderByColumn{
   261  		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
   262  	})
   263  
   264  	if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
   265  		if c, ok := tx.Statement.Clauses["WHERE"]; ok {
   266  			if where, ok := c.Expression.(clause.Where); ok {
   267  				tx.assignInterfacesToValue(where.Exprs)
   268  			}
   269  		}
   270  
   271  		// initialize with attrs, conds
   272  		if len(tx.Statement.attrs) > 0 {
   273  			tx.assignInterfacesToValue(tx.Statement.attrs...)
   274  		}
   275  	}
   276  
   277  	// initialize with attrs, conds
   278  	if len(tx.Statement.assigns) > 0 {
   279  		tx.assignInterfacesToValue(tx.Statement.assigns...)
   280  	}
   281  	return
   282  }
   283  
   284  func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
   285  	queryTx := db.Limit(1).Order(clause.OrderByColumn{
   286  		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
   287  	})
   288  	if tx = queryTx.Find(dest, conds...); tx.Error == nil {
   289  		if tx.RowsAffected == 0 {
   290  			if c, ok := tx.Statement.Clauses["WHERE"]; ok {
   291  				if where, ok := c.Expression.(clause.Where); ok {
   292  					tx.assignInterfacesToValue(where.Exprs)
   293  				}
   294  			}
   295  
   296  			// initialize with attrs, conds
   297  			if len(tx.Statement.attrs) > 0 {
   298  				tx.assignInterfacesToValue(tx.Statement.attrs...)
   299  			}
   300  
   301  			// initialize with attrs, conds
   302  			if len(tx.Statement.assigns) > 0 {
   303  				tx.assignInterfacesToValue(tx.Statement.assigns...)
   304  			}
   305  
   306  			return tx.Create(dest)
   307  		} else if len(db.Statement.assigns) > 0 {
   308  			exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
   309  			assigns := map[string]interface{}{}
   310  			for _, expr := range exprs {
   311  				if eq, ok := expr.(clause.Eq); ok {
   312  					switch column := eq.Column.(type) {
   313  					case string:
   314  						assigns[column] = eq.Value
   315  					case clause.Column:
   316  						assigns[column.Name] = eq.Value
   317  					default:
   318  					}
   319  				}
   320  			}
   321  
   322  			return tx.Model(dest).Updates(assigns)
   323  		}
   324  	}
   325  	return tx
   326  }
   327  
   328  // Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
   329  func (db *DB) Update(column string, value interface{}) (tx *DB) {
   330  	tx = db.getInstance()
   331  	tx.Statement.Dest = map[string]interface{}{column: value}
   332  	return tx.callbacks.Update().Execute(tx)
   333  }
   334  
   335  // Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
   336  func (db *DB) Updates(values interface{}) (tx *DB) {
   337  	tx = db.getInstance()
   338  	tx.Statement.Dest = values
   339  	return tx.callbacks.Update().Execute(tx)
   340  }
   341  
   342  func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
   343  	tx = db.getInstance()
   344  	tx.Statement.Dest = map[string]interface{}{column: value}
   345  	tx.Statement.SkipHooks = true
   346  	return tx.callbacks.Update().Execute(tx)
   347  }
   348  
   349  func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
   350  	tx = db.getInstance()
   351  	tx.Statement.Dest = values
   352  	tx.Statement.SkipHooks = true
   353  	return tx.callbacks.Update().Execute(tx)
   354  }
   355  
   356  // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
   357  func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
   358  	tx = db.getInstance()
   359  	if len(conds) > 0 {
   360  		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
   361  			tx.Statement.AddClause(clause.Where{Exprs: exprs})
   362  		}
   363  	}
   364  	tx.Statement.Dest = value
   365  	return tx.callbacks.Delete().Execute(tx)
   366  }
   367  
   368  func (db *DB) Count(count *int64) (tx *DB) {
   369  	tx = db.getInstance()
   370  	if tx.Statement.Model == nil {
   371  		tx.Statement.Model = tx.Statement.Dest
   372  		defer func() {
   373  			tx.Statement.Model = nil
   374  		}()
   375  	}
   376  
   377  	if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
   378  		defer func() {
   379  			tx.Statement.Clauses["SELECT"] = selectClause
   380  		}()
   381  	} else {
   382  		defer delete(tx.Statement.Clauses, "SELECT")
   383  	}
   384  
   385  	if len(tx.Statement.Selects) == 0 {
   386  		tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
   387  	} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
   388  		expr := clause.Expr{SQL: "count(*)"}
   389  
   390  		if len(tx.Statement.Selects) == 1 {
   391  			dbName := tx.Statement.Selects[0]
   392  			fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
   393  			if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
   394  				if tx.Statement.Parse(tx.Statement.Model) == nil {
   395  					if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
   396  						dbName = f.DBName
   397  					}
   398  				}
   399  
   400  				if tx.Statement.Distinct {
   401  					expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
   402  				} else if dbName != "*" {
   403  					expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
   404  				}
   405  			}
   406  		}
   407  
   408  		tx.Statement.AddClause(clause.Select{Expression: expr})
   409  	}
   410  
   411  	if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
   412  		if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
   413  			delete(tx.Statement.Clauses, "ORDER BY")
   414  			defer func() {
   415  				tx.Statement.Clauses["ORDER BY"] = orderByClause
   416  			}()
   417  		}
   418  	}
   419  
   420  	tx.Statement.Dest = count
   421  	tx = tx.callbacks.Query().Execute(tx)
   422  
   423  	if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
   424  		*count = tx.RowsAffected
   425  	}
   426  
   427  	return
   428  }
   429  
   430  func (db *DB) Row() *sql.Row {
   431  	tx := db.getInstance().Set("rows", false)
   432  	tx = tx.callbacks.Row().Execute(tx)
   433  	row, ok := tx.Statement.Dest.(*sql.Row)
   434  	if !ok && tx.DryRun {
   435  		db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
   436  	}
   437  	return row
   438  }
   439  
   440  func (db *DB) Rows() (*sql.Rows, error) {
   441  	tx := db.getInstance().Set("rows", true)
   442  	tx = tx.callbacks.Row().Execute(tx)
   443  	rows, ok := tx.Statement.Dest.(*sql.Rows)
   444  	if !ok && tx.DryRun && tx.Error == nil {
   445  		tx.Error = ErrDryRunModeUnsupported
   446  	}
   447  	return rows, tx.Error
   448  }
   449  
   450  // Scan scan value to a struct
   451  func (db *DB) Scan(dest interface{}) (tx *DB) {
   452  	config := *db.Config
   453  	currentLogger, newLogger := config.Logger, logger.Recorder.New()
   454  	config.Logger = newLogger
   455  
   456  	tx = db.getInstance()
   457  	tx.Config = &config
   458  
   459  	if rows, err := tx.Rows(); err == nil {
   460  		if rows.Next() {
   461  			tx.ScanRows(rows, dest)
   462  		} else {
   463  			tx.RowsAffected = 0
   464  		}
   465  		tx.AddError(rows.Close())
   466  	}
   467  
   468  	currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
   469  		return newLogger.SQL, tx.RowsAffected
   470  	}, tx.Error)
   471  	tx.Logger = currentLogger
   472  	return
   473  }
   474  
   475  // Pluck used to query single column from a model as a map
   476  //     var ages []int64
   477  //     db.Model(&users).Pluck("age", &ages)
   478  func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
   479  	tx = db.getInstance()
   480  	if tx.Statement.Model != nil {
   481  		if tx.Statement.Parse(tx.Statement.Model) == nil {
   482  			if f := tx.Statement.Schema.LookUpField(column); f != nil {
   483  				column = f.DBName
   484  			}
   485  		}
   486  	}
   487  
   488  	if len(tx.Statement.Selects) != 1 {
   489  		fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
   490  		tx.Statement.AddClauseIfNotExists(clause.Select{
   491  			Distinct: tx.Statement.Distinct,
   492  			Columns:  []clause.Column{{Name: column, Raw: len(fields) != 1}},
   493  		})
   494  	}
   495  	tx.Statement.Dest = dest
   496  	return tx.callbacks.Query().Execute(tx)
   497  }
   498  
   499  func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
   500  	tx := db.getInstance()
   501  	if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
   502  		tx.AddError(err)
   503  	}
   504  	tx.Statement.Dest = dest
   505  	tx.Statement.ReflectValue = reflect.ValueOf(dest)
   506  	for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
   507  		elem := tx.Statement.ReflectValue.Elem()
   508  		if !elem.IsValid() {
   509  			elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
   510  			tx.Statement.ReflectValue.Set(elem)
   511  		}
   512  		tx.Statement.ReflectValue = elem
   513  	}
   514  	Scan(rows, tx, ScanInitialized)
   515  	return tx.Error
   516  }
   517  
   518  // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
   519  func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
   520  	panicked := true
   521  
   522  	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
   523  		// nested transaction
   524  		if !db.DisableNestedTransaction {
   525  			err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
   526  			defer func() {
   527  				// Make sure to rollback when panic, Block error or Commit error
   528  				if panicked || err != nil {
   529  					db.RollbackTo(fmt.Sprintf("sp%p", fc))
   530  				}
   531  			}()
   532  		}
   533  
   534  		if err == nil {
   535  			err = fc(db.Session(&Session{}))
   536  		}
   537  	} else {
   538  		tx := db.Begin(opts...)
   539  
   540  		defer func() {
   541  			// Make sure to rollback when panic, Block error or Commit error
   542  			if panicked || err != nil {
   543  				tx.Rollback()
   544  			}
   545  		}()
   546  
   547  		if err = tx.Error; err == nil {
   548  			err = fc(tx)
   549  		}
   550  
   551  		if err == nil {
   552  			err = tx.Commit().Error
   553  		}
   554  	}
   555  
   556  	panicked = false
   557  	return
   558  }
   559  
   560  // Begin begins a transaction
   561  func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
   562  	var (
   563  		// clone statement
   564  		tx  = db.getInstance().Session(&Session{Context: db.Statement.Context})
   565  		opt *sql.TxOptions
   566  		err error
   567  	)
   568  
   569  	if len(opts) > 0 {
   570  		opt = opts[0]
   571  	}
   572  
   573  	if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
   574  		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
   575  	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
   576  		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
   577  	} else {
   578  		err = ErrInvalidTransaction
   579  	}
   580  
   581  	if err != nil {
   582  		tx.AddError(err)
   583  	}
   584  
   585  	return tx
   586  }
   587  
   588  // Commit commit a transaction
   589  func (db *DB) Commit() *DB {
   590  	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
   591  		db.AddError(committer.Commit())
   592  	} else {
   593  		db.AddError(ErrInvalidTransaction)
   594  	}
   595  	return db
   596  }
   597  
   598  // Rollback rollback a transaction
   599  func (db *DB) Rollback() *DB {
   600  	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
   601  		if !reflect.ValueOf(committer).IsNil() {
   602  			db.AddError(committer.Rollback())
   603  		}
   604  	} else {
   605  		db.AddError(ErrInvalidTransaction)
   606  	}
   607  	return db
   608  }
   609  
   610  func (db *DB) SavePoint(name string) *DB {
   611  	if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
   612  		db.AddError(savePointer.SavePoint(db, name))
   613  	} else {
   614  		db.AddError(ErrUnsupportedDriver)
   615  	}
   616  	return db
   617  }
   618  
   619  func (db *DB) RollbackTo(name string) *DB {
   620  	if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
   621  		db.AddError(savePointer.RollbackTo(db, name))
   622  	} else {
   623  		db.AddError(ErrUnsupportedDriver)
   624  	}
   625  	return db
   626  }
   627  
   628  // Exec execute raw sql
   629  func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
   630  	tx = db.getInstance()
   631  	tx.Statement.SQL = strings.Builder{}
   632  
   633  	if strings.Contains(sql, "@") {
   634  		clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
   635  	} else {
   636  		clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
   637  	}
   638  
   639  	return tx.callbacks.Raw().Execute(tx)
   640  }