github.com/systematiccaos/gorm@v1.22.6/gorm.go (about)

     1  package gorm
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"sort"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/systematiccaos/gorm/clause"
    12  	"github.com/systematiccaos/gorm/logger"
    13  	"github.com/systematiccaos/gorm/schema"
    14  )
    15  
    16  // for Config.cacheStore store PreparedStmtDB key
    17  const preparedStmtDBKey = "preparedStmt"
    18  
    19  // Config GORM config
    20  type Config struct {
    21  	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
    22  	// You can disable it by setting `SkipDefaultTransaction` to true
    23  	SkipDefaultTransaction bool
    24  	// NamingStrategy tables, columns naming strategy
    25  	NamingStrategy schema.Namer
    26  	// FullSaveAssociations full save associations
    27  	FullSaveAssociations bool
    28  	// Logger
    29  	Logger logger.Interface
    30  	// NowFunc the function to be used when creating a new timestamp
    31  	NowFunc func() time.Time
    32  	// DryRun generate sql without execute
    33  	DryRun bool
    34  	// PrepareStmt executes the given query in cached statement
    35  	PrepareStmt bool
    36  	// DisableAutomaticPing
    37  	DisableAutomaticPing bool
    38  	// DisableForeignKeyConstraintWhenMigrating
    39  	DisableForeignKeyConstraintWhenMigrating bool
    40  	// DisableNestedTransaction disable nested transaction
    41  	DisableNestedTransaction bool
    42  	// AllowGlobalUpdate allow global update
    43  	AllowGlobalUpdate bool
    44  	// QueryFields executes the SQL query with all fields of the table
    45  	QueryFields bool
    46  	// CreateBatchSize default create batch size
    47  	CreateBatchSize int
    48  
    49  	// ClauseBuilders clause builder
    50  	ClauseBuilders map[string]clause.ClauseBuilder
    51  	// ConnPool db conn pool
    52  	ConnPool ConnPool
    53  	// Dialector database dialector
    54  	Dialector
    55  	// Plugins registered plugins
    56  	Plugins map[string]Plugin
    57  
    58  	callbacks  *callbacks
    59  	cacheStore *sync.Map
    60  }
    61  
    62  func (c *Config) Apply(config *Config) error {
    63  	if config != c {
    64  		*config = *c
    65  	}
    66  	return nil
    67  }
    68  
    69  func (c *Config) AfterInitialize(db *DB) error {
    70  	if db != nil {
    71  		for _, plugin := range c.Plugins {
    72  			if err := plugin.Initialize(db); err != nil {
    73  				return err
    74  			}
    75  		}
    76  	}
    77  	return nil
    78  }
    79  
    80  type Option interface {
    81  	Apply(*Config) error
    82  	AfterInitialize(*DB) error
    83  }
    84  
    85  // DB GORM DB definition
    86  type DB struct {
    87  	*Config
    88  	Error        error
    89  	RowsAffected int64
    90  	Statement    *Statement
    91  	clone        int
    92  }
    93  
    94  // Session session config when create session with Session() method
    95  type Session struct {
    96  	DryRun                   bool
    97  	PrepareStmt              bool
    98  	NewDB                    bool
    99  	SkipHooks                bool
   100  	SkipDefaultTransaction   bool
   101  	DisableNestedTransaction bool
   102  	AllowGlobalUpdate        bool
   103  	FullSaveAssociations     bool
   104  	QueryFields              bool
   105  	Context                  context.Context
   106  	Logger                   logger.Interface
   107  	NowFunc                  func() time.Time
   108  	CreateBatchSize          int
   109  }
   110  
   111  // Open initialize db session based on dialector
   112  func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
   113  	config := &Config{}
   114  
   115  	sort.Slice(opts, func(i, j int) bool {
   116  		_, isConfig := opts[i].(*Config)
   117  		_, isConfig2 := opts[j].(*Config)
   118  		return isConfig && !isConfig2
   119  	})
   120  
   121  	for _, opt := range opts {
   122  		if opt != nil {
   123  			if err := opt.Apply(config); err != nil {
   124  				return nil, err
   125  			}
   126  			defer func(opt Option) {
   127  				if errr := opt.AfterInitialize(db); errr != nil {
   128  					err = errr
   129  				}
   130  			}(opt)
   131  		}
   132  	}
   133  
   134  	if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
   135  		if err = d.Apply(config); err != nil {
   136  			return
   137  		}
   138  	}
   139  
   140  	if config.NamingStrategy == nil {
   141  		config.NamingStrategy = schema.NamingStrategy{}
   142  	}
   143  
   144  	if config.Logger == nil {
   145  		config.Logger = logger.Default
   146  	}
   147  
   148  	if config.NowFunc == nil {
   149  		config.NowFunc = func() time.Time { return time.Now().Local() }
   150  	}
   151  
   152  	if dialector != nil {
   153  		config.Dialector = dialector
   154  	}
   155  
   156  	if config.Plugins == nil {
   157  		config.Plugins = map[string]Plugin{}
   158  	}
   159  
   160  	if config.cacheStore == nil {
   161  		config.cacheStore = &sync.Map{}
   162  	}
   163  
   164  	db = &DB{Config: config, clone: 1}
   165  
   166  	db.callbacks = initializeCallbacks(db)
   167  
   168  	if config.ClauseBuilders == nil {
   169  		config.ClauseBuilders = map[string]clause.ClauseBuilder{}
   170  	}
   171  
   172  	if config.Dialector != nil {
   173  		err = config.Dialector.Initialize(db)
   174  	}
   175  
   176  	preparedStmt := &PreparedStmtDB{
   177  		ConnPool:    db.ConnPool,
   178  		Stmts:       map[string]Stmt{},
   179  		Mux:         &sync.RWMutex{},
   180  		PreparedSQL: make([]string, 0, 100),
   181  	}
   182  	db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
   183  
   184  	if config.PrepareStmt {
   185  		db.ConnPool = preparedStmt
   186  	}
   187  
   188  	db.Statement = &Statement{
   189  		DB:       db,
   190  		ConnPool: db.ConnPool,
   191  		Context:  context.Background(),
   192  		Clauses:  map[string]clause.Clause{},
   193  	}
   194  
   195  	if err == nil && !config.DisableAutomaticPing {
   196  		if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
   197  			err = pinger.Ping()
   198  		}
   199  	}
   200  
   201  	if err != nil {
   202  		config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
   203  	}
   204  
   205  	return
   206  }
   207  
   208  // Session create new db session
   209  func (db *DB) Session(config *Session) *DB {
   210  	var (
   211  		txConfig = *db.Config
   212  		tx       = &DB{
   213  			Config:    &txConfig,
   214  			Statement: db.Statement,
   215  			Error:     db.Error,
   216  			clone:     1,
   217  		}
   218  	)
   219  	if config.CreateBatchSize > 0 {
   220  		tx.Config.CreateBatchSize = config.CreateBatchSize
   221  	}
   222  
   223  	if config.SkipDefaultTransaction {
   224  		tx.Config.SkipDefaultTransaction = true
   225  	}
   226  
   227  	if config.AllowGlobalUpdate {
   228  		txConfig.AllowGlobalUpdate = true
   229  	}
   230  
   231  	if config.FullSaveAssociations {
   232  		txConfig.FullSaveAssociations = true
   233  	}
   234  
   235  	if config.Context != nil || config.PrepareStmt || config.SkipHooks {
   236  		tx.Statement = tx.Statement.clone()
   237  		tx.Statement.DB = tx
   238  	}
   239  
   240  	if config.Context != nil {
   241  		tx.Statement.Context = config.Context
   242  	}
   243  
   244  	if config.PrepareStmt {
   245  		if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
   246  			preparedStmt := v.(*PreparedStmtDB)
   247  			tx.Statement.ConnPool = &PreparedStmtDB{
   248  				ConnPool: db.Config.ConnPool,
   249  				Mux:      preparedStmt.Mux,
   250  				Stmts:    preparedStmt.Stmts,
   251  			}
   252  			txConfig.ConnPool = tx.Statement.ConnPool
   253  			txConfig.PrepareStmt = true
   254  		}
   255  	}
   256  
   257  	if config.SkipHooks {
   258  		tx.Statement.SkipHooks = true
   259  	}
   260  
   261  	if config.DisableNestedTransaction {
   262  		txConfig.DisableNestedTransaction = true
   263  	}
   264  
   265  	if !config.NewDB {
   266  		tx.clone = 2
   267  	}
   268  
   269  	if config.DryRun {
   270  		tx.Config.DryRun = true
   271  	}
   272  
   273  	if config.QueryFields {
   274  		tx.Config.QueryFields = true
   275  	}
   276  
   277  	if config.Logger != nil {
   278  		tx.Config.Logger = config.Logger
   279  	}
   280  
   281  	if config.NowFunc != nil {
   282  		tx.Config.NowFunc = config.NowFunc
   283  	}
   284  
   285  	return tx
   286  }
   287  
   288  // WithContext change current instance db's context to ctx
   289  func (db *DB) WithContext(ctx context.Context) *DB {
   290  	return db.Session(&Session{Context: ctx})
   291  }
   292  
   293  // Debug start debug mode
   294  func (db *DB) Debug() (tx *DB) {
   295  	return db.Session(&Session{
   296  		Logger: db.Logger.LogMode(logger.Info),
   297  	})
   298  }
   299  
   300  // Set store value with key into current db instance's context
   301  func (db *DB) Set(key string, value interface{}) *DB {
   302  	tx := db.getInstance()
   303  	tx.Statement.Settings.Store(key, value)
   304  	return tx
   305  }
   306  
   307  // Get get value with key from current db instance's context
   308  func (db *DB) Get(key string) (interface{}, bool) {
   309  	return db.Statement.Settings.Load(key)
   310  }
   311  
   312  // InstanceSet store value with key into current db instance's context
   313  func (db *DB) InstanceSet(key string, value interface{}) *DB {
   314  	tx := db.getInstance()
   315  	tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
   316  	return tx
   317  }
   318  
   319  // InstanceGet get value with key from current db instance's context
   320  func (db *DB) InstanceGet(key string) (interface{}, bool) {
   321  	return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
   322  }
   323  
   324  // Callback returns callback manager
   325  func (db *DB) Callback() *callbacks {
   326  	return db.callbacks
   327  }
   328  
   329  // AddError add error to db
   330  func (db *DB) AddError(err error) error {
   331  	if db.Error == nil {
   332  		db.Error = err
   333  	} else if err != nil {
   334  		db.Error = fmt.Errorf("%v; %w", db.Error, err)
   335  	}
   336  	return db.Error
   337  }
   338  
   339  // DB returns `*sql.DB`
   340  func (db *DB) DB() (*sql.DB, error) {
   341  	connPool := db.ConnPool
   342  
   343  	if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
   344  		return dbConnector.GetDBConn()
   345  	}
   346  
   347  	if sqldb, ok := connPool.(*sql.DB); ok {
   348  		return sqldb, nil
   349  	}
   350  
   351  	return nil, ErrInvalidDB
   352  }
   353  
   354  func (db *DB) getInstance() *DB {
   355  	if db.clone > 0 {
   356  		tx := &DB{Config: db.Config, Error: db.Error}
   357  
   358  		if db.clone == 1 {
   359  			// clone with new statement
   360  			tx.Statement = &Statement{
   361  				DB:       tx,
   362  				ConnPool: db.Statement.ConnPool,
   363  				Context:  db.Statement.Context,
   364  				Clauses:  map[string]clause.Clause{},
   365  				Vars:     make([]interface{}, 0, 8),
   366  			}
   367  		} else {
   368  			// with clone statement
   369  			tx.Statement = db.Statement.clone()
   370  			tx.Statement.DB = tx
   371  		}
   372  
   373  		return tx
   374  	}
   375  
   376  	return db
   377  }
   378  
   379  func Expr(expr string, args ...interface{}) clause.Expr {
   380  	return clause.Expr{SQL: expr, Vars: args}
   381  }
   382  
   383  func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
   384  	var (
   385  		tx                      = db.getInstance()
   386  		stmt                    = tx.Statement
   387  		modelSchema, joinSchema *schema.Schema
   388  	)
   389  
   390  	err := stmt.Parse(model)
   391  	if err != nil {
   392  		return err
   393  	}
   394  	modelSchema = stmt.Schema
   395  
   396  	err = stmt.Parse(joinTable)
   397  	if err != nil {
   398  		return err
   399  	}
   400  	joinSchema = stmt.Schema
   401  
   402  	relation, ok := modelSchema.Relationships.Relations[field]
   403  	isRelation := ok && relation.JoinTable != nil
   404  	if !isRelation {
   405  		return fmt.Errorf("failed to found relation: %s", field)
   406  	}
   407  
   408  	for _, ref := range relation.References {
   409  		f := joinSchema.LookUpField(ref.ForeignKey.DBName)
   410  		if f == nil {
   411  			return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
   412  		}
   413  
   414  		f.DataType = ref.ForeignKey.DataType
   415  		f.GORMDataType = ref.ForeignKey.GORMDataType
   416  		if f.Size == 0 {
   417  			f.Size = ref.ForeignKey.Size
   418  		}
   419  		ref.ForeignKey = f
   420  	}
   421  
   422  	for name, rel := range relation.JoinTable.Relationships.Relations {
   423  		if _, ok := joinSchema.Relationships.Relations[name]; !ok {
   424  			rel.Schema = joinSchema
   425  			joinSchema.Relationships.Relations[name] = rel
   426  		}
   427  	}
   428  	relation.JoinTable = joinSchema
   429  
   430  	return nil
   431  }
   432  
   433  func (db *DB) Use(plugin Plugin) error {
   434  	name := plugin.Name()
   435  	if _, ok := db.Plugins[name]; ok {
   436  		return ErrRegistered
   437  	}
   438  	if err := plugin.Initialize(db); err != nil {
   439  		return err
   440  	}
   441  	db.Plugins[name] = plugin
   442  	return nil
   443  }
   444  
   445  // ToSQL for generate SQL string.
   446  //
   447  // db.ToSQL(func(tx *gorm.DB) *gorm.DB {
   448  // 		return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
   449  // 			.Limit(10).Offset(5)
   450  //			.Order("name ASC")
   451  //			.First(&User{})
   452  // })
   453  func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
   454  	tx := queryFn(db.Session(&Session{DryRun: true}))
   455  	stmt := tx.Statement
   456  
   457  	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
   458  }