github.com/wfusion/gofusion@v1.1.14/db/callbacks/auto_incr.go (about)

     1  package callbacks
     2  
     3  import (
     4  	"reflect"
     5  	"strings"
     6  
     7  	"gorm.io/gorm"
     8  	"gorm.io/gorm/clause"
     9  	"gorm.io/gorm/utils"
    10  
    11  	comUtl "github.com/wfusion/gofusion/common/utils"
    12  
    13  	. "gorm.io/driver/mysql"
    14  )
    15  
    16  func CreateAutoIncr(db *gorm.DB, gormDialector gorm.Dialector, autoIncrIncr int64) {
    17  	withReturning := false
    18  	dialector := gormDialector.(*Dialector)
    19  	if !dialector.Config.SkipInitializeWithVersion && strings.Contains(dialector.ServerVersion, "MariaDB") {
    20  		withReturning = checkVersion(dialector.ServerVersion, "10.5")
    21  	}
    22  
    23  	lastInsertIDReversed := false
    24  	if !dialector.Config.DisableWithReturning && withReturning {
    25  		lastInsertIDReversed = true
    26  	}
    27  
    28  	comUtl.MustSuccess(
    29  		db.Callback().Create().Replace("gorm:create", func(db *gorm.DB) {
    30  			if db.Error != nil {
    31  				return
    32  			}
    33  
    34  			BuildCreateSQL(db)
    35  
    36  			isDryRun := !db.DryRun && db.Error == nil
    37  			if !isDryRun {
    38  				return
    39  			}
    40  
    41  			ok, mode := hasReturning(db, utils.Contains(db.Callback().Create().Clauses, "RETURNING"))
    42  			if ok {
    43  				if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
    44  					if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
    45  						mode |= gorm.ScanOnConflictDoNothing
    46  					}
    47  				}
    48  
    49  				rows, err := db.Statement.ConnPool.QueryContext(
    50  					db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
    51  				)
    52  				if db.AddError(err) == nil {
    53  					defer func() { _ = db.AddError(rows.Close()) }()
    54  					gorm.Scan(rows, db, mode)
    55  				}
    56  
    57  				return
    58  			}
    59  
    60  			result, err := db.Statement.ConnPool.ExecContext(
    61  				db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
    62  			)
    63  			if err != nil {
    64  				_ = db.AddError(err)
    65  				return
    66  			}
    67  
    68  			db.RowsAffected, _ = result.RowsAffected()
    69  			if db.RowsAffected != 0 && db.Statement.Schema != nil &&
    70  				db.Statement.Schema.PrioritizedPrimaryField != nil &&
    71  				db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
    72  				insertID, err := result.LastInsertId()
    73  				insertOk := err == nil && insertID > 0
    74  				if !insertOk {
    75  					_ = db.AddError(err)
    76  					return
    77  				}
    78  
    79  				switch db.Statement.ReflectValue.Kind() {
    80  				case reflect.Slice, reflect.Array:
    81  					if lastInsertIDReversed {
    82  						for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
    83  							rv := db.Statement.ReflectValue.Index(i)
    84  							if reflect.Indirect(rv).Kind() != reflect.Struct {
    85  								break
    86  							}
    87  
    88  							_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
    89  							if isZero {
    90  								_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
    91  								//insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
    92  								insertID -= autoIncrIncr
    93  							}
    94  						}
    95  					} else {
    96  						for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
    97  							rv := db.Statement.ReflectValue.Index(i)
    98  							if reflect.Indirect(rv).Kind() != reflect.Struct {
    99  								break
   100  							}
   101  
   102  							if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
   103  								_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
   104  								//insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
   105  								insertID += autoIncrIncr
   106  							}
   107  						}
   108  					}
   109  				case reflect.Struct:
   110  					_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
   111  					if isZero {
   112  						_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
   113  					}
   114  				}
   115  			}
   116  		}),
   117  	)
   118  }