github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/sqlconn.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  
     6  	"github.com/shuguocloud/go-zero/core/breaker"
     7  	"github.com/shuguocloud/go-zero/core/logx"
     8  )
     9  
    10  // ErrNotFound is an alias of sql.ErrNoRows
    11  var ErrNotFound = sql.ErrNoRows
    12  
    13  type (
    14  	// Session stands for raw connections or transaction sessions
    15  	Session interface {
    16  		Exec(query string, args ...interface{}) (sql.Result, error)
    17  		Prepare(query string) (StmtSession, error)
    18  		QueryRow(v interface{}, query string, args ...interface{}) error
    19  		QueryRowPartial(v interface{}, query string, args ...interface{}) error
    20  		QueryRows(v interface{}, query string, args ...interface{}) error
    21  		QueryRowsPartial(v interface{}, query string, args ...interface{}) error
    22  	}
    23  
    24  	// SqlConn only stands for raw connections, so Transact method can be called.
    25  	SqlConn interface {
    26  		Session
    27  		// RawDB is for other ORM to operate with, use it with caution.
    28  		// Notice: don't close it.
    29  		RawDB() (*sql.DB, error)
    30  		Transact(func(session Session) error) error
    31  	}
    32  
    33  	// SqlOption defines the method to customize a sql connection.
    34  	SqlOption func(*commonSqlConn)
    35  
    36  	// StmtSession interface represents a session that can be used to execute statements.
    37  	StmtSession interface {
    38  		Close() error
    39  		Exec(args ...interface{}) (sql.Result, error)
    40  		QueryRow(v interface{}, args ...interface{}) error
    41  		QueryRowPartial(v interface{}, args ...interface{}) error
    42  		QueryRows(v interface{}, args ...interface{}) error
    43  		QueryRowsPartial(v interface{}, args ...interface{}) error
    44  	}
    45  
    46  	// thread-safe
    47  	// Because CORBA doesn't support PREPARE, so we need to combine the
    48  	// query arguments into one string and do underlying query without arguments
    49  	commonSqlConn struct {
    50  		connProv connProvider
    51  		onError  func(error)
    52  		beginTx  beginnable
    53  		brk      breaker.Breaker
    54  		accept   func(error) bool
    55  	}
    56  
    57  	connProvider func() (*sql.DB, error)
    58  
    59  	sessionConn interface {
    60  		Exec(query string, args ...interface{}) (sql.Result, error)
    61  		Query(query string, args ...interface{}) (*sql.Rows, error)
    62  	}
    63  
    64  	statement struct {
    65  		query string
    66  		stmt  *sql.Stmt
    67  	}
    68  
    69  	stmtConn interface {
    70  		Exec(args ...interface{}) (sql.Result, error)
    71  		Query(args ...interface{}) (*sql.Rows, error)
    72  	}
    73  )
    74  
    75  // NewSqlConn returns a SqlConn with given driver name and datasource.
    76  func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
    77  	conn := &commonSqlConn{
    78  		connProv: func() (*sql.DB, error) {
    79  			return getSqlConn(driverName, datasource)
    80  		},
    81  		onError: func(err error) {
    82  			logInstanceError(datasource, err)
    83  		},
    84  		beginTx: begin,
    85  		brk:     breaker.NewBreaker(),
    86  	}
    87  	for _, opt := range opts {
    88  		opt(conn)
    89  	}
    90  
    91  	return conn
    92  }
    93  
    94  // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
    95  // Use it with caution, it's provided for other ORM to interact with.
    96  func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
    97  	conn := &commonSqlConn{
    98  		connProv: func() (*sql.DB, error) {
    99  			return db, nil
   100  		},
   101  		onError: func(err error) {
   102  			logx.Errorf("Error on getting sql instance: %v", err)
   103  		},
   104  		beginTx: begin,
   105  		brk:     breaker.NewBreaker(),
   106  	}
   107  	for _, opt := range opts {
   108  		opt(conn)
   109  	}
   110  
   111  	return conn
   112  }
   113  
   114  func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
   115  	err = db.brk.DoWithAcceptable(func() error {
   116  		var conn *sql.DB
   117  		conn, err = db.connProv()
   118  		if err != nil {
   119  			db.onError(err)
   120  			return err
   121  		}
   122  
   123  		result, err = exec(conn, q, args...)
   124  		return err
   125  	}, db.acceptable)
   126  
   127  	return
   128  }
   129  
   130  func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
   131  	err = db.brk.DoWithAcceptable(func() error {
   132  		var conn *sql.DB
   133  		conn, err = db.connProv()
   134  		if err != nil {
   135  			db.onError(err)
   136  			return err
   137  		}
   138  
   139  		st, err := conn.Prepare(query)
   140  		if err != nil {
   141  			return err
   142  		}
   143  
   144  		stmt = statement{
   145  			query: query,
   146  			stmt:  st,
   147  		}
   148  		return nil
   149  	}, db.acceptable)
   150  
   151  	return
   152  }
   153  
   154  func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
   155  	return db.queryRows(func(rows *sql.Rows) error {
   156  		return unmarshalRow(v, rows, true)
   157  	}, q, args...)
   158  }
   159  
   160  func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
   161  	return db.queryRows(func(rows *sql.Rows) error {
   162  		return unmarshalRow(v, rows, false)
   163  	}, q, args...)
   164  }
   165  
   166  func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
   167  	return db.queryRows(func(rows *sql.Rows) error {
   168  		return unmarshalRows(v, rows, true)
   169  	}, q, args...)
   170  }
   171  
   172  func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
   173  	return db.queryRows(func(rows *sql.Rows) error {
   174  		return unmarshalRows(v, rows, false)
   175  	}, q, args...)
   176  }
   177  
   178  func (db *commonSqlConn) RawDB() (*sql.DB, error) {
   179  	return db.connProv()
   180  }
   181  
   182  func (db *commonSqlConn) Transact(fn func(Session) error) error {
   183  	return db.brk.DoWithAcceptable(func() error {
   184  		return transact(db, db.beginTx, fn)
   185  	}, db.acceptable)
   186  }
   187  
   188  func (db *commonSqlConn) acceptable(err error) bool {
   189  	ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
   190  	if db.accept == nil {
   191  		return ok
   192  	}
   193  
   194  	return ok || db.accept(err)
   195  }
   196  
   197  func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
   198  	var qerr error
   199  	return db.brk.DoWithAcceptable(func() error {
   200  		conn, err := db.connProv()
   201  		if err != nil {
   202  			db.onError(err)
   203  			return err
   204  		}
   205  
   206  		return query(conn, func(rows *sql.Rows) error {
   207  			qerr = scanner(rows)
   208  			return qerr
   209  		}, q, args...)
   210  	}, func(err error) bool {
   211  		return qerr == err || db.acceptable(err)
   212  	})
   213  }
   214  
   215  func (s statement) Close() error {
   216  	return s.stmt.Close()
   217  }
   218  
   219  func (s statement) Exec(args ...interface{}) (sql.Result, error) {
   220  	return execStmt(s.stmt, s.query, args...)
   221  }
   222  
   223  func (s statement) QueryRow(v interface{}, args ...interface{}) error {
   224  	return queryStmt(s.stmt, func(rows *sql.Rows) error {
   225  		return unmarshalRow(v, rows, true)
   226  	}, s.query, args...)
   227  }
   228  
   229  func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
   230  	return queryStmt(s.stmt, func(rows *sql.Rows) error {
   231  		return unmarshalRow(v, rows, false)
   232  	}, s.query, args...)
   233  }
   234  
   235  func (s statement) QueryRows(v interface{}, args ...interface{}) error {
   236  	return queryStmt(s.stmt, func(rows *sql.Rows) error {
   237  		return unmarshalRows(v, rows, true)
   238  	}, s.query, args...)
   239  }
   240  
   241  func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
   242  	return queryStmt(s.stmt, func(rows *sql.Rows) error {
   243  		return unmarshalRows(v, rows, false)
   244  	}, s.query, args...)
   245  }