github.com/systematiccaos/gorm@v1.22.6/prepare_stmt.go (about) 1 package gorm 2 3 import ( 4 "context" 5 "database/sql" 6 "sync" 7 ) 8 9 type Stmt struct { 10 *sql.Stmt 11 Transaction bool 12 } 13 14 type PreparedStmtDB struct { 15 Stmts map[string]Stmt 16 PreparedSQL []string 17 Mux *sync.RWMutex 18 ConnPool 19 } 20 21 func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { 22 if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { 23 return dbConnector.GetDBConn() 24 } 25 26 if sqldb, ok := db.ConnPool.(*sql.DB); ok { 27 return sqldb, nil 28 } 29 30 return nil, ErrInvalidDB 31 } 32 33 func (db *PreparedStmtDB) Close() { 34 db.Mux.Lock() 35 defer db.Mux.Unlock() 36 37 for _, query := range db.PreparedSQL { 38 if stmt, ok := db.Stmts[query]; ok { 39 delete(db.Stmts, query) 40 go stmt.Close() 41 } 42 } 43 } 44 45 func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { 46 db.Mux.RLock() 47 if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { 48 db.Mux.RUnlock() 49 return stmt, nil 50 } 51 db.Mux.RUnlock() 52 53 db.Mux.Lock() 54 defer db.Mux.Unlock() 55 56 // double check 57 if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { 58 return stmt, nil 59 } else if ok { 60 go stmt.Close() 61 } 62 63 stmt, err := conn.PrepareContext(ctx, query) 64 if err == nil { 65 db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} 66 db.PreparedSQL = append(db.PreparedSQL, query) 67 } 68 69 return db.Stmts[query], err 70 } 71 72 func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { 73 if beginner, ok := db.ConnPool.(TxBeginner); ok { 74 tx, err := beginner.BeginTx(ctx, opt) 75 return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err 76 } 77 return nil, ErrInvalidTransaction 78 } 79 80 func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { 81 stmt, err := db.prepare(ctx, db.ConnPool, false, query) 82 if err == nil { 83 result, err = stmt.ExecContext(ctx, args...) 84 if err != nil { 85 db.Mux.Lock() 86 defer db.Mux.Unlock() 87 go stmt.Close() 88 delete(db.Stmts, query) 89 } 90 } 91 return result, err 92 } 93 94 func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { 95 stmt, err := db.prepare(ctx, db.ConnPool, false, query) 96 if err == nil { 97 rows, err = stmt.QueryContext(ctx, args...) 98 if err != nil { 99 db.Mux.Lock() 100 defer db.Mux.Unlock() 101 102 go stmt.Close() 103 delete(db.Stmts, query) 104 } 105 } 106 return rows, err 107 } 108 109 func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 110 stmt, err := db.prepare(ctx, db.ConnPool, false, query) 111 if err == nil { 112 return stmt.QueryRowContext(ctx, args...) 113 } 114 return &sql.Row{} 115 } 116 117 type PreparedStmtTX struct { 118 *sql.Tx 119 PreparedStmtDB *PreparedStmtDB 120 } 121 122 func (tx *PreparedStmtTX) Commit() error { 123 if tx.Tx != nil { 124 return tx.Tx.Commit() 125 } 126 return ErrInvalidTransaction 127 } 128 129 func (tx *PreparedStmtTX) Rollback() error { 130 if tx.Tx != nil { 131 return tx.Tx.Rollback() 132 } 133 return ErrInvalidTransaction 134 } 135 136 func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { 137 stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) 138 if err == nil { 139 result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) 140 if err != nil { 141 tx.PreparedStmtDB.Mux.Lock() 142 defer tx.PreparedStmtDB.Mux.Unlock() 143 144 go stmt.Close() 145 delete(tx.PreparedStmtDB.Stmts, query) 146 } 147 } 148 return result, err 149 } 150 151 func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { 152 stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) 153 if err == nil { 154 rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) 155 if err != nil { 156 tx.PreparedStmtDB.Mux.Lock() 157 defer tx.PreparedStmtDB.Mux.Unlock() 158 159 go stmt.Close() 160 delete(tx.PreparedStmtDB.Stmts, query) 161 } 162 } 163 return rows, err 164 } 165 166 func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { 167 stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) 168 if err == nil { 169 return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) 170 } 171 return &sql.Row{} 172 }