github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/sqlconn.go (about) 1 package sqlx 2 3 import ( 4 "database/sql" 5 6 "github.com/shuguocloud/go-zero/core/breaker" 7 "github.com/shuguocloud/go-zero/core/logx" 8 ) 9 10 // ErrNotFound is an alias of sql.ErrNoRows 11 var ErrNotFound = sql.ErrNoRows 12 13 type ( 14 // Session stands for raw connections or transaction sessions 15 Session interface { 16 Exec(query string, args ...interface{}) (sql.Result, error) 17 Prepare(query string) (StmtSession, error) 18 QueryRow(v interface{}, query string, args ...interface{}) error 19 QueryRowPartial(v interface{}, query string, args ...interface{}) error 20 QueryRows(v interface{}, query string, args ...interface{}) error 21 QueryRowsPartial(v interface{}, query string, args ...interface{}) error 22 } 23 24 // SqlConn only stands for raw connections, so Transact method can be called. 25 SqlConn interface { 26 Session 27 // RawDB is for other ORM to operate with, use it with caution. 28 // Notice: don't close it. 29 RawDB() (*sql.DB, error) 30 Transact(func(session Session) error) error 31 } 32 33 // SqlOption defines the method to customize a sql connection. 34 SqlOption func(*commonSqlConn) 35 36 // StmtSession interface represents a session that can be used to execute statements. 37 StmtSession interface { 38 Close() error 39 Exec(args ...interface{}) (sql.Result, error) 40 QueryRow(v interface{}, args ...interface{}) error 41 QueryRowPartial(v interface{}, args ...interface{}) error 42 QueryRows(v interface{}, args ...interface{}) error 43 QueryRowsPartial(v interface{}, args ...interface{}) error 44 } 45 46 // thread-safe 47 // Because CORBA doesn't support PREPARE, so we need to combine the 48 // query arguments into one string and do underlying query without arguments 49 commonSqlConn struct { 50 connProv connProvider 51 onError func(error) 52 beginTx beginnable 53 brk breaker.Breaker 54 accept func(error) bool 55 } 56 57 connProvider func() (*sql.DB, error) 58 59 sessionConn interface { 60 Exec(query string, args ...interface{}) (sql.Result, error) 61 Query(query string, args ...interface{}) (*sql.Rows, error) 62 } 63 64 statement struct { 65 query string 66 stmt *sql.Stmt 67 } 68 69 stmtConn interface { 70 Exec(args ...interface{}) (sql.Result, error) 71 Query(args ...interface{}) (*sql.Rows, error) 72 } 73 ) 74 75 // NewSqlConn returns a SqlConn with given driver name and datasource. 76 func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn { 77 conn := &commonSqlConn{ 78 connProv: func() (*sql.DB, error) { 79 return getSqlConn(driverName, datasource) 80 }, 81 onError: func(err error) { 82 logInstanceError(datasource, err) 83 }, 84 beginTx: begin, 85 brk: breaker.NewBreaker(), 86 } 87 for _, opt := range opts { 88 opt(conn) 89 } 90 91 return conn 92 } 93 94 // NewSqlConnFromDB returns a SqlConn with the given sql.DB. 95 // Use it with caution, it's provided for other ORM to interact with. 96 func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { 97 conn := &commonSqlConn{ 98 connProv: func() (*sql.DB, error) { 99 return db, nil 100 }, 101 onError: func(err error) { 102 logx.Errorf("Error on getting sql instance: %v", err) 103 }, 104 beginTx: begin, 105 brk: breaker.NewBreaker(), 106 } 107 for _, opt := range opts { 108 opt(conn) 109 } 110 111 return conn 112 } 113 114 func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) { 115 err = db.brk.DoWithAcceptable(func() error { 116 var conn *sql.DB 117 conn, err = db.connProv() 118 if err != nil { 119 db.onError(err) 120 return err 121 } 122 123 result, err = exec(conn, q, args...) 124 return err 125 }, db.acceptable) 126 127 return 128 } 129 130 func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { 131 err = db.brk.DoWithAcceptable(func() error { 132 var conn *sql.DB 133 conn, err = db.connProv() 134 if err != nil { 135 db.onError(err) 136 return err 137 } 138 139 st, err := conn.Prepare(query) 140 if err != nil { 141 return err 142 } 143 144 stmt = statement{ 145 query: query, 146 stmt: st, 147 } 148 return nil 149 }, db.acceptable) 150 151 return 152 } 153 154 func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error { 155 return db.queryRows(func(rows *sql.Rows) error { 156 return unmarshalRow(v, rows, true) 157 }, q, args...) 158 } 159 160 func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error { 161 return db.queryRows(func(rows *sql.Rows) error { 162 return unmarshalRow(v, rows, false) 163 }, q, args...) 164 } 165 166 func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error { 167 return db.queryRows(func(rows *sql.Rows) error { 168 return unmarshalRows(v, rows, true) 169 }, q, args...) 170 } 171 172 func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { 173 return db.queryRows(func(rows *sql.Rows) error { 174 return unmarshalRows(v, rows, false) 175 }, q, args...) 176 } 177 178 func (db *commonSqlConn) RawDB() (*sql.DB, error) { 179 return db.connProv() 180 } 181 182 func (db *commonSqlConn) Transact(fn func(Session) error) error { 183 return db.brk.DoWithAcceptable(func() error { 184 return transact(db, db.beginTx, fn) 185 }, db.acceptable) 186 } 187 188 func (db *commonSqlConn) acceptable(err error) bool { 189 ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone 190 if db.accept == nil { 191 return ok 192 } 193 194 return ok || db.accept(err) 195 } 196 197 func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error { 198 var qerr error 199 return db.brk.DoWithAcceptable(func() error { 200 conn, err := db.connProv() 201 if err != nil { 202 db.onError(err) 203 return err 204 } 205 206 return query(conn, func(rows *sql.Rows) error { 207 qerr = scanner(rows) 208 return qerr 209 }, q, args...) 210 }, func(err error) bool { 211 return qerr == err || db.acceptable(err) 212 }) 213 } 214 215 func (s statement) Close() error { 216 return s.stmt.Close() 217 } 218 219 func (s statement) Exec(args ...interface{}) (sql.Result, error) { 220 return execStmt(s.stmt, s.query, args...) 221 } 222 223 func (s statement) QueryRow(v interface{}, args ...interface{}) error { 224 return queryStmt(s.stmt, func(rows *sql.Rows) error { 225 return unmarshalRow(v, rows, true) 226 }, s.query, args...) 227 } 228 229 func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error { 230 return queryStmt(s.stmt, func(rows *sql.Rows) error { 231 return unmarshalRow(v, rows, false) 232 }, s.query, args...) 233 } 234 235 func (s statement) QueryRows(v interface{}, args ...interface{}) error { 236 return queryStmt(s.stmt, func(rows *sql.Rows) error { 237 return unmarshalRows(v, rows, true) 238 }, s.query, args...) 239 } 240 241 func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error { 242 return queryStmt(s.stmt, func(rows *sql.Rows) error { 243 return unmarshalRows(v, rows, false) 244 }, s.query, args...) 245 }