github.com/systematiccaos/gorm@v1.22.6/prepare_stmt.go (about)

     1  package gorm
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"sync"
     7  )
     8  
     9  type Stmt struct {
    10  	*sql.Stmt
    11  	Transaction bool
    12  }
    13  
    14  type PreparedStmtDB struct {
    15  	Stmts       map[string]Stmt
    16  	PreparedSQL []string
    17  	Mux         *sync.RWMutex
    18  	ConnPool
    19  }
    20  
    21  func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
    22  	if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
    23  		return dbConnector.GetDBConn()
    24  	}
    25  
    26  	if sqldb, ok := db.ConnPool.(*sql.DB); ok {
    27  		return sqldb, nil
    28  	}
    29  
    30  	return nil, ErrInvalidDB
    31  }
    32  
    33  func (db *PreparedStmtDB) Close() {
    34  	db.Mux.Lock()
    35  	defer db.Mux.Unlock()
    36  
    37  	for _, query := range db.PreparedSQL {
    38  		if stmt, ok := db.Stmts[query]; ok {
    39  			delete(db.Stmts, query)
    40  			go stmt.Close()
    41  		}
    42  	}
    43  }
    44  
    45  func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
    46  	db.Mux.RLock()
    47  	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
    48  		db.Mux.RUnlock()
    49  		return stmt, nil
    50  	}
    51  	db.Mux.RUnlock()
    52  
    53  	db.Mux.Lock()
    54  	defer db.Mux.Unlock()
    55  
    56  	// double check
    57  	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
    58  		return stmt, nil
    59  	} else if ok {
    60  		go stmt.Close()
    61  	}
    62  
    63  	stmt, err := conn.PrepareContext(ctx, query)
    64  	if err == nil {
    65  		db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
    66  		db.PreparedSQL = append(db.PreparedSQL, query)
    67  	}
    68  
    69  	return db.Stmts[query], err
    70  }
    71  
    72  func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
    73  	if beginner, ok := db.ConnPool.(TxBeginner); ok {
    74  		tx, err := beginner.BeginTx(ctx, opt)
    75  		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
    76  	}
    77  	return nil, ErrInvalidTransaction
    78  }
    79  
    80  func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
    81  	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
    82  	if err == nil {
    83  		result, err = stmt.ExecContext(ctx, args...)
    84  		if err != nil {
    85  			db.Mux.Lock()
    86  			defer db.Mux.Unlock()
    87  			go stmt.Close()
    88  			delete(db.Stmts, query)
    89  		}
    90  	}
    91  	return result, err
    92  }
    93  
    94  func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
    95  	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
    96  	if err == nil {
    97  		rows, err = stmt.QueryContext(ctx, args...)
    98  		if err != nil {
    99  			db.Mux.Lock()
   100  			defer db.Mux.Unlock()
   101  
   102  			go stmt.Close()
   103  			delete(db.Stmts, query)
   104  		}
   105  	}
   106  	return rows, err
   107  }
   108  
   109  func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
   110  	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
   111  	if err == nil {
   112  		return stmt.QueryRowContext(ctx, args...)
   113  	}
   114  	return &sql.Row{}
   115  }
   116  
   117  type PreparedStmtTX struct {
   118  	*sql.Tx
   119  	PreparedStmtDB *PreparedStmtDB
   120  }
   121  
   122  func (tx *PreparedStmtTX) Commit() error {
   123  	if tx.Tx != nil {
   124  		return tx.Tx.Commit()
   125  	}
   126  	return ErrInvalidTransaction
   127  }
   128  
   129  func (tx *PreparedStmtTX) Rollback() error {
   130  	if tx.Tx != nil {
   131  		return tx.Tx.Rollback()
   132  	}
   133  	return ErrInvalidTransaction
   134  }
   135  
   136  func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
   137  	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
   138  	if err == nil {
   139  		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
   140  		if err != nil {
   141  			tx.PreparedStmtDB.Mux.Lock()
   142  			defer tx.PreparedStmtDB.Mux.Unlock()
   143  
   144  			go stmt.Close()
   145  			delete(tx.PreparedStmtDB.Stmts, query)
   146  		}
   147  	}
   148  	return result, err
   149  }
   150  
   151  func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
   152  	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
   153  	if err == nil {
   154  		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
   155  		if err != nil {
   156  			tx.PreparedStmtDB.Mux.Lock()
   157  			defer tx.PreparedStmtDB.Mux.Unlock()
   158  
   159  			go stmt.Close()
   160  			delete(tx.PreparedStmtDB.Stmts, query)
   161  		}
   162  	}
   163  	return rows, err
   164  }
   165  
   166  func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
   167  	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
   168  	if err == nil {
   169  		return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
   170  	}
   171  	return &sql.Row{}
   172  }