github.com/systematiccaos/gorm@v1.22.6/gorm.go (about) 1 package gorm 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "sort" 8 "sync" 9 "time" 10 11 "github.com/systematiccaos/gorm/clause" 12 "github.com/systematiccaos/gorm/logger" 13 "github.com/systematiccaos/gorm/schema" 14 ) 15 16 // for Config.cacheStore store PreparedStmtDB key 17 const preparedStmtDBKey = "preparedStmt" 18 19 // Config GORM config 20 type Config struct { 21 // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity 22 // You can disable it by setting `SkipDefaultTransaction` to true 23 SkipDefaultTransaction bool 24 // NamingStrategy tables, columns naming strategy 25 NamingStrategy schema.Namer 26 // FullSaveAssociations full save associations 27 FullSaveAssociations bool 28 // Logger 29 Logger logger.Interface 30 // NowFunc the function to be used when creating a new timestamp 31 NowFunc func() time.Time 32 // DryRun generate sql without execute 33 DryRun bool 34 // PrepareStmt executes the given query in cached statement 35 PrepareStmt bool 36 // DisableAutomaticPing 37 DisableAutomaticPing bool 38 // DisableForeignKeyConstraintWhenMigrating 39 DisableForeignKeyConstraintWhenMigrating bool 40 // DisableNestedTransaction disable nested transaction 41 DisableNestedTransaction bool 42 // AllowGlobalUpdate allow global update 43 AllowGlobalUpdate bool 44 // QueryFields executes the SQL query with all fields of the table 45 QueryFields bool 46 // CreateBatchSize default create batch size 47 CreateBatchSize int 48 49 // ClauseBuilders clause builder 50 ClauseBuilders map[string]clause.ClauseBuilder 51 // ConnPool db conn pool 52 ConnPool ConnPool 53 // Dialector database dialector 54 Dialector 55 // Plugins registered plugins 56 Plugins map[string]Plugin 57 58 callbacks *callbacks 59 cacheStore *sync.Map 60 } 61 62 func (c *Config) Apply(config *Config) error { 63 if config != c { 64 *config = *c 65 } 66 return nil 67 } 68 69 func (c *Config) AfterInitialize(db *DB) error { 70 if db != nil { 71 for _, plugin := range c.Plugins { 72 if err := plugin.Initialize(db); err != nil { 73 return err 74 } 75 } 76 } 77 return nil 78 } 79 80 type Option interface { 81 Apply(*Config) error 82 AfterInitialize(*DB) error 83 } 84 85 // DB GORM DB definition 86 type DB struct { 87 *Config 88 Error error 89 RowsAffected int64 90 Statement *Statement 91 clone int 92 } 93 94 // Session session config when create session with Session() method 95 type Session struct { 96 DryRun bool 97 PrepareStmt bool 98 NewDB bool 99 SkipHooks bool 100 SkipDefaultTransaction bool 101 DisableNestedTransaction bool 102 AllowGlobalUpdate bool 103 FullSaveAssociations bool 104 QueryFields bool 105 Context context.Context 106 Logger logger.Interface 107 NowFunc func() time.Time 108 CreateBatchSize int 109 } 110 111 // Open initialize db session based on dialector 112 func Open(dialector Dialector, opts ...Option) (db *DB, err error) { 113 config := &Config{} 114 115 sort.Slice(opts, func(i, j int) bool { 116 _, isConfig := opts[i].(*Config) 117 _, isConfig2 := opts[j].(*Config) 118 return isConfig && !isConfig2 119 }) 120 121 for _, opt := range opts { 122 if opt != nil { 123 if err := opt.Apply(config); err != nil { 124 return nil, err 125 } 126 defer func(opt Option) { 127 if errr := opt.AfterInitialize(db); errr != nil { 128 err = errr 129 } 130 }(opt) 131 } 132 } 133 134 if d, ok := dialector.(interface{ Apply(*Config) error }); ok { 135 if err = d.Apply(config); err != nil { 136 return 137 } 138 } 139 140 if config.NamingStrategy == nil { 141 config.NamingStrategy = schema.NamingStrategy{} 142 } 143 144 if config.Logger == nil { 145 config.Logger = logger.Default 146 } 147 148 if config.NowFunc == nil { 149 config.NowFunc = func() time.Time { return time.Now().Local() } 150 } 151 152 if dialector != nil { 153 config.Dialector = dialector 154 } 155 156 if config.Plugins == nil { 157 config.Plugins = map[string]Plugin{} 158 } 159 160 if config.cacheStore == nil { 161 config.cacheStore = &sync.Map{} 162 } 163 164 db = &DB{Config: config, clone: 1} 165 166 db.callbacks = initializeCallbacks(db) 167 168 if config.ClauseBuilders == nil { 169 config.ClauseBuilders = map[string]clause.ClauseBuilder{} 170 } 171 172 if config.Dialector != nil { 173 err = config.Dialector.Initialize(db) 174 } 175 176 preparedStmt := &PreparedStmtDB{ 177 ConnPool: db.ConnPool, 178 Stmts: map[string]Stmt{}, 179 Mux: &sync.RWMutex{}, 180 PreparedSQL: make([]string, 0, 100), 181 } 182 db.cacheStore.Store(preparedStmtDBKey, preparedStmt) 183 184 if config.PrepareStmt { 185 db.ConnPool = preparedStmt 186 } 187 188 db.Statement = &Statement{ 189 DB: db, 190 ConnPool: db.ConnPool, 191 Context: context.Background(), 192 Clauses: map[string]clause.Clause{}, 193 } 194 195 if err == nil && !config.DisableAutomaticPing { 196 if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { 197 err = pinger.Ping() 198 } 199 } 200 201 if err != nil { 202 config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) 203 } 204 205 return 206 } 207 208 // Session create new db session 209 func (db *DB) Session(config *Session) *DB { 210 var ( 211 txConfig = *db.Config 212 tx = &DB{ 213 Config: &txConfig, 214 Statement: db.Statement, 215 Error: db.Error, 216 clone: 1, 217 } 218 ) 219 if config.CreateBatchSize > 0 { 220 tx.Config.CreateBatchSize = config.CreateBatchSize 221 } 222 223 if config.SkipDefaultTransaction { 224 tx.Config.SkipDefaultTransaction = true 225 } 226 227 if config.AllowGlobalUpdate { 228 txConfig.AllowGlobalUpdate = true 229 } 230 231 if config.FullSaveAssociations { 232 txConfig.FullSaveAssociations = true 233 } 234 235 if config.Context != nil || config.PrepareStmt || config.SkipHooks { 236 tx.Statement = tx.Statement.clone() 237 tx.Statement.DB = tx 238 } 239 240 if config.Context != nil { 241 tx.Statement.Context = config.Context 242 } 243 244 if config.PrepareStmt { 245 if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { 246 preparedStmt := v.(*PreparedStmtDB) 247 tx.Statement.ConnPool = &PreparedStmtDB{ 248 ConnPool: db.Config.ConnPool, 249 Mux: preparedStmt.Mux, 250 Stmts: preparedStmt.Stmts, 251 } 252 txConfig.ConnPool = tx.Statement.ConnPool 253 txConfig.PrepareStmt = true 254 } 255 } 256 257 if config.SkipHooks { 258 tx.Statement.SkipHooks = true 259 } 260 261 if config.DisableNestedTransaction { 262 txConfig.DisableNestedTransaction = true 263 } 264 265 if !config.NewDB { 266 tx.clone = 2 267 } 268 269 if config.DryRun { 270 tx.Config.DryRun = true 271 } 272 273 if config.QueryFields { 274 tx.Config.QueryFields = true 275 } 276 277 if config.Logger != nil { 278 tx.Config.Logger = config.Logger 279 } 280 281 if config.NowFunc != nil { 282 tx.Config.NowFunc = config.NowFunc 283 } 284 285 return tx 286 } 287 288 // WithContext change current instance db's context to ctx 289 func (db *DB) WithContext(ctx context.Context) *DB { 290 return db.Session(&Session{Context: ctx}) 291 } 292 293 // Debug start debug mode 294 func (db *DB) Debug() (tx *DB) { 295 return db.Session(&Session{ 296 Logger: db.Logger.LogMode(logger.Info), 297 }) 298 } 299 300 // Set store value with key into current db instance's context 301 func (db *DB) Set(key string, value interface{}) *DB { 302 tx := db.getInstance() 303 tx.Statement.Settings.Store(key, value) 304 return tx 305 } 306 307 // Get get value with key from current db instance's context 308 func (db *DB) Get(key string) (interface{}, bool) { 309 return db.Statement.Settings.Load(key) 310 } 311 312 // InstanceSet store value with key into current db instance's context 313 func (db *DB) InstanceSet(key string, value interface{}) *DB { 314 tx := db.getInstance() 315 tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) 316 return tx 317 } 318 319 // InstanceGet get value with key from current db instance's context 320 func (db *DB) InstanceGet(key string) (interface{}, bool) { 321 return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) 322 } 323 324 // Callback returns callback manager 325 func (db *DB) Callback() *callbacks { 326 return db.callbacks 327 } 328 329 // AddError add error to db 330 func (db *DB) AddError(err error) error { 331 if db.Error == nil { 332 db.Error = err 333 } else if err != nil { 334 db.Error = fmt.Errorf("%v; %w", db.Error, err) 335 } 336 return db.Error 337 } 338 339 // DB returns `*sql.DB` 340 func (db *DB) DB() (*sql.DB, error) { 341 connPool := db.ConnPool 342 343 if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { 344 return dbConnector.GetDBConn() 345 } 346 347 if sqldb, ok := connPool.(*sql.DB); ok { 348 return sqldb, nil 349 } 350 351 return nil, ErrInvalidDB 352 } 353 354 func (db *DB) getInstance() *DB { 355 if db.clone > 0 { 356 tx := &DB{Config: db.Config, Error: db.Error} 357 358 if db.clone == 1 { 359 // clone with new statement 360 tx.Statement = &Statement{ 361 DB: tx, 362 ConnPool: db.Statement.ConnPool, 363 Context: db.Statement.Context, 364 Clauses: map[string]clause.Clause{}, 365 Vars: make([]interface{}, 0, 8), 366 } 367 } else { 368 // with clone statement 369 tx.Statement = db.Statement.clone() 370 tx.Statement.DB = tx 371 } 372 373 return tx 374 } 375 376 return db 377 } 378 379 func Expr(expr string, args ...interface{}) clause.Expr { 380 return clause.Expr{SQL: expr, Vars: args} 381 } 382 383 func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { 384 var ( 385 tx = db.getInstance() 386 stmt = tx.Statement 387 modelSchema, joinSchema *schema.Schema 388 ) 389 390 err := stmt.Parse(model) 391 if err != nil { 392 return err 393 } 394 modelSchema = stmt.Schema 395 396 err = stmt.Parse(joinTable) 397 if err != nil { 398 return err 399 } 400 joinSchema = stmt.Schema 401 402 relation, ok := modelSchema.Relationships.Relations[field] 403 isRelation := ok && relation.JoinTable != nil 404 if !isRelation { 405 return fmt.Errorf("failed to found relation: %s", field) 406 } 407 408 for _, ref := range relation.References { 409 f := joinSchema.LookUpField(ref.ForeignKey.DBName) 410 if f == nil { 411 return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) 412 } 413 414 f.DataType = ref.ForeignKey.DataType 415 f.GORMDataType = ref.ForeignKey.GORMDataType 416 if f.Size == 0 { 417 f.Size = ref.ForeignKey.Size 418 } 419 ref.ForeignKey = f 420 } 421 422 for name, rel := range relation.JoinTable.Relationships.Relations { 423 if _, ok := joinSchema.Relationships.Relations[name]; !ok { 424 rel.Schema = joinSchema 425 joinSchema.Relationships.Relations[name] = rel 426 } 427 } 428 relation.JoinTable = joinSchema 429 430 return nil 431 } 432 433 func (db *DB) Use(plugin Plugin) error { 434 name := plugin.Name() 435 if _, ok := db.Plugins[name]; ok { 436 return ErrRegistered 437 } 438 if err := plugin.Initialize(db); err != nil { 439 return err 440 } 441 db.Plugins[name] = plugin 442 return nil 443 } 444 445 // ToSQL for generate SQL string. 446 // 447 // db.ToSQL(func(tx *gorm.DB) *gorm.DB { 448 // return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) 449 // .Limit(10).Offset(5) 450 // .Order("name ASC") 451 // .First(&User{}) 452 // }) 453 func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { 454 tx := queryFn(db.Session(&Session{DryRun: true})) 455 stmt := tx.Statement 456 457 return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) 458 }