github.com/wangyougui/gf/v2@v2.6.5/database/gdb/gdb_core_underlying.go (about)

     1  // Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the MIT License.
     4  // If a copy of the MIT was not distributed with this file,
     5  // You can obtain one at https://github.com/wangyougui/gf.
     6  //
     7  
     8  package gdb
     9  
    10  import (
    11  	"context"
    12  	"database/sql"
    13  	"fmt"
    14  	"reflect"
    15  
    16  	"go.opentelemetry.io/otel"
    17  	"go.opentelemetry.io/otel/trace"
    18  
    19  	"github.com/wangyougui/gf/v2/util/gconv"
    20  
    21  	"github.com/wangyougui/gf/v2"
    22  	"github.com/wangyougui/gf/v2/container/gvar"
    23  	"github.com/wangyougui/gf/v2/errors/gcode"
    24  	"github.com/wangyougui/gf/v2/errors/gerror"
    25  	"github.com/wangyougui/gf/v2/internal/intlog"
    26  	"github.com/wangyougui/gf/v2/os/gtime"
    27  	"github.com/wangyougui/gf/v2/util/guid"
    28  )
    29  
    30  // Query commits one query SQL to underlying driver and returns the execution result.
    31  // It is most commonly used for data querying.
    32  func (c *Core) Query(ctx context.Context, sql string, args ...interface{}) (result Result, err error) {
    33  	return c.db.DoQuery(ctx, nil, sql, args...)
    34  }
    35  
    36  // DoQuery commits the sql string and its arguments to underlying driver
    37  // through given link object and returns the execution result.
    38  func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (result Result, err error) {
    39  	// Transaction checks.
    40  	if link == nil {
    41  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
    42  			// Firstly, check and retrieve transaction link from context.
    43  			link = &txLink{tx.GetSqlTX()}
    44  		} else if link, err = c.SlaveLink(); err != nil {
    45  			// Or else it creates one from master node.
    46  			return nil, err
    47  		}
    48  	} else if !link.IsTransaction() {
    49  		// If current link is not transaction link, it checks and retrieves transaction from context.
    50  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
    51  			link = &txLink{tx.GetSqlTX()}
    52  		}
    53  	}
    54  
    55  	if c.db.GetConfig().QueryTimeout > 0 {
    56  		ctx, _ = context.WithTimeout(ctx, c.db.GetConfig().QueryTimeout)
    57  	}
    58  
    59  	// Sql filtering.
    60  	sql, args = c.FormatSqlBeforeExecuting(sql, args)
    61  	sql, args, err = c.db.DoFilter(ctx, link, sql, args)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	// SQL format and retrieve.
    66  	if v := ctx.Value(ctxKeyCatchSQL); v != nil {
    67  		var (
    68  			manager      = v.(*CatchSQLManager)
    69  			formattedSql = FormatSqlWithArgs(sql, args)
    70  		)
    71  		manager.SQLArray.Append(formattedSql)
    72  		if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
    73  			return nil, nil
    74  		}
    75  	}
    76  	// Link execution.
    77  	var out DoCommitOutput
    78  	out, err = c.db.DoCommit(ctx, DoCommitInput{
    79  		Link:          link,
    80  		Sql:           sql,
    81  		Args:          args,
    82  		Stmt:          nil,
    83  		Type:          SqlTypeQueryContext,
    84  		IsTransaction: link.IsTransaction(),
    85  	})
    86  	return out.Records, err
    87  }
    88  
    89  // Exec commits one query SQL to underlying driver and returns the execution result.
    90  // It is most commonly used for data inserting and updating.
    91  func (c *Core) Exec(ctx context.Context, sql string, args ...interface{}) (result sql.Result, err error) {
    92  	return c.db.DoExec(ctx, nil, sql, args...)
    93  }
    94  
    95  // DoExec commits the sql string and its arguments to underlying driver
    96  // through given link object and returns the execution result.
    97  func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) {
    98  	// Transaction checks.
    99  	if link == nil {
   100  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
   101  			// Firstly, check and retrieve transaction link from context.
   102  			link = &txLink{tx.GetSqlTX()}
   103  		} else if link, err = c.MasterLink(); err != nil {
   104  			// Or else it creates one from master node.
   105  			return nil, err
   106  		}
   107  	} else if !link.IsTransaction() {
   108  		// If current link is not transaction link, it checks and retrieves transaction from context.
   109  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
   110  			link = &txLink{tx.GetSqlTX()}
   111  		}
   112  	}
   113  
   114  	if c.db.GetConfig().ExecTimeout > 0 {
   115  		var cancelFunc context.CancelFunc
   116  		ctx, cancelFunc = context.WithTimeout(ctx, c.db.GetConfig().ExecTimeout)
   117  		defer cancelFunc()
   118  	}
   119  
   120  	// SQL filtering.
   121  	sql, args = c.FormatSqlBeforeExecuting(sql, args)
   122  	sql, args, err = c.db.DoFilter(ctx, link, sql, args)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	// SQL format and retrieve.
   127  	if v := ctx.Value(ctxKeyCatchSQL); v != nil {
   128  		var (
   129  			manager      = v.(*CatchSQLManager)
   130  			formattedSql = FormatSqlWithArgs(sql, args)
   131  		)
   132  		manager.SQLArray.Append(formattedSql)
   133  		if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
   134  			return new(SqlResult), nil
   135  		}
   136  	}
   137  	// Link execution.
   138  	var out DoCommitOutput
   139  	out, err = c.db.DoCommit(ctx, DoCommitInput{
   140  		Link:          link,
   141  		Sql:           sql,
   142  		Args:          args,
   143  		Stmt:          nil,
   144  		Type:          SqlTypeExecContext,
   145  		IsTransaction: link.IsTransaction(),
   146  	})
   147  	return out.Result, err
   148  }
   149  
   150  // DoFilter is a hook function, which filters the sql and its arguments before it's committed to underlying driver.
   151  // The parameter `link` specifies the current database connection operation object. You can modify the sql
   152  // string `sql` and its arguments `args` as you wish before they're committed to driver.
   153  func (c *Core) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
   154  	return sql, args, nil
   155  }
   156  
   157  // DoCommit commits current sql and arguments to underlying sql driver.
   158  func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) {
   159  	// Inject internal data into ctx, especially for transaction creating.
   160  	ctx = c.InjectInternalCtxData(ctx)
   161  
   162  	var (
   163  		sqlTx                *sql.Tx
   164  		sqlStmt              *sql.Stmt
   165  		sqlRows              *sql.Rows
   166  		sqlResult            sql.Result
   167  		stmtSqlRows          *sql.Rows
   168  		stmtSqlRow           *sql.Row
   169  		rowsAffected         int64
   170  		cancelFuncForTimeout context.CancelFunc
   171  		formattedSql         = FormatSqlWithArgs(in.Sql, in.Args)
   172  		timestampMilli1      = gtime.TimestampMilli()
   173  	)
   174  
   175  	// Trace span start.
   176  	tr := otel.GetTracerProvider().Tracer(traceInstrumentName, trace.WithInstrumentationVersion(gf.VERSION))
   177  	ctx, span := tr.Start(ctx, in.Type, trace.WithSpanKind(trace.SpanKindInternal))
   178  	defer span.End()
   179  
   180  	// Execution cased by type.
   181  	switch in.Type {
   182  	case SqlTypeBegin:
   183  		if sqlTx, err = in.Db.Begin(); err == nil {
   184  			out.Tx = &TXCore{
   185  				db:            c.db,
   186  				tx:            sqlTx,
   187  				ctx:           context.WithValue(ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1)),
   188  				master:        in.Db,
   189  				transactionId: guid.S(),
   190  			}
   191  			ctx = out.Tx.GetCtx()
   192  		}
   193  		out.RawResult = sqlTx
   194  
   195  	case SqlTypeTXCommit:
   196  		err = in.Tx.Commit()
   197  
   198  	case SqlTypeTXRollback:
   199  		err = in.Tx.Rollback()
   200  
   201  	case SqlTypeExecContext:
   202  		if c.db.GetDryRun() {
   203  			sqlResult = new(SqlResult)
   204  		} else {
   205  			sqlResult, err = in.Link.ExecContext(ctx, in.Sql, in.Args...)
   206  		}
   207  		out.RawResult = sqlResult
   208  
   209  	case SqlTypeQueryContext:
   210  		sqlRows, err = in.Link.QueryContext(ctx, in.Sql, in.Args...)
   211  		out.RawResult = sqlRows
   212  
   213  	case SqlTypePrepareContext:
   214  		sqlStmt, err = in.Link.PrepareContext(ctx, in.Sql)
   215  		out.RawResult = sqlStmt
   216  
   217  	case SqlTypeStmtExecContext:
   218  		ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec)
   219  		defer cancelFuncForTimeout()
   220  		if c.db.GetDryRun() {
   221  			sqlResult = new(SqlResult)
   222  		} else {
   223  			sqlResult, err = in.Stmt.ExecContext(ctx, in.Args...)
   224  		}
   225  		out.RawResult = sqlResult
   226  
   227  	case SqlTypeStmtQueryContext:
   228  		ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
   229  		defer cancelFuncForTimeout()
   230  		stmtSqlRows, err = in.Stmt.QueryContext(ctx, in.Args...)
   231  		out.RawResult = stmtSqlRows
   232  
   233  	case SqlTypeStmtQueryRowContext:
   234  		ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
   235  		defer cancelFuncForTimeout()
   236  		stmtSqlRow = in.Stmt.QueryRowContext(ctx, in.Args...)
   237  		out.RawResult = stmtSqlRow
   238  
   239  	default:
   240  		panic(gerror.NewCodef(gcode.CodeInvalidParameter, `invalid SqlType "%s"`, in.Type))
   241  	}
   242  	// Result handling.
   243  	switch {
   244  	case sqlResult != nil && !c.GetIgnoreResultFromCtx(ctx):
   245  		rowsAffected, err = sqlResult.RowsAffected()
   246  		out.Result = sqlResult
   247  
   248  	case sqlRows != nil:
   249  		out.Records, err = c.RowsToResult(ctx, sqlRows)
   250  		rowsAffected = int64(len(out.Records))
   251  
   252  	case sqlStmt != nil:
   253  		out.Stmt = &Stmt{
   254  			Stmt: sqlStmt,
   255  			core: c,
   256  			link: in.Link,
   257  			sql:  in.Sql,
   258  		}
   259  	}
   260  	var (
   261  		timestampMilli2 = gtime.TimestampMilli()
   262  		sqlObj          = &Sql{
   263  			Sql:           in.Sql,
   264  			Type:          in.Type,
   265  			Args:          in.Args,
   266  			Format:        formattedSql,
   267  			Error:         err,
   268  			Start:         timestampMilli1,
   269  			End:           timestampMilli2,
   270  			Group:         c.db.GetGroup(),
   271  			Schema:        c.db.GetSchema(),
   272  			RowsAffected:  rowsAffected,
   273  			IsTransaction: in.IsTransaction,
   274  		}
   275  	)
   276  
   277  	// Tracing.
   278  	c.traceSpanEnd(ctx, span, sqlObj)
   279  
   280  	// Logging.
   281  	if c.db.GetDebug() {
   282  		c.writeSqlToLogger(ctx, sqlObj)
   283  	}
   284  	if err != nil && err != sql.ErrNoRows {
   285  		err = gerror.WrapCode(
   286  			gcode.CodeDbOperationError,
   287  			err,
   288  			FormatSqlWithArgs(in.Sql, in.Args),
   289  		)
   290  	}
   291  	return out, err
   292  }
   293  
   294  // Prepare creates a prepared statement for later queries or executions.
   295  // Multiple queries or executions may be run concurrently from the
   296  // returned statement.
   297  // The caller must call the statement's Close method
   298  // when the statement is no longer needed.
   299  //
   300  // The parameter `execOnMaster` specifies whether executing the sql on master node,
   301  // or else it executes the sql on slave node if master-slave configured.
   302  func (c *Core) Prepare(ctx context.Context, sql string, execOnMaster ...bool) (*Stmt, error) {
   303  	var (
   304  		err  error
   305  		link Link
   306  	)
   307  	if len(execOnMaster) > 0 && execOnMaster[0] {
   308  		if link, err = c.MasterLink(); err != nil {
   309  			return nil, err
   310  		}
   311  	} else {
   312  		if link, err = c.SlaveLink(); err != nil {
   313  			return nil, err
   314  		}
   315  	}
   316  	return c.db.DoPrepare(ctx, link, sql)
   317  }
   318  
   319  // DoPrepare calls prepare function on given link object and returns the statement object.
   320  func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt, err error) {
   321  	// Transaction checks.
   322  	if link == nil {
   323  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
   324  			// Firstly, check and retrieve transaction link from context.
   325  			link = &txLink{tx.GetSqlTX()}
   326  		} else {
   327  			// Or else it creates one from master node.
   328  			var err error
   329  			if link, err = c.MasterLink(); err != nil {
   330  				return nil, err
   331  			}
   332  		}
   333  	} else if !link.IsTransaction() {
   334  		// If current link is not transaction link, it checks and retrieves transaction from context.
   335  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
   336  			link = &txLink{tx.GetSqlTX()}
   337  		}
   338  	}
   339  
   340  	if c.db.GetConfig().PrepareTimeout > 0 {
   341  		// DO NOT USE cancel function in prepare statement.
   342  		ctx, _ = context.WithTimeout(ctx, c.db.GetConfig().PrepareTimeout)
   343  	}
   344  
   345  	// Link execution.
   346  	var out DoCommitOutput
   347  	out, err = c.db.DoCommit(ctx, DoCommitInput{
   348  		Link:          link,
   349  		Sql:           sql,
   350  		Type:          SqlTypePrepareContext,
   351  		IsTransaction: link.IsTransaction(),
   352  	})
   353  	return out.Stmt, err
   354  }
   355  
   356  // FormatUpsert formats and returns SQL clause part for upsert statement.
   357  // In default implements, this function performs upsert statement for MySQL like:
   358  // `INSERT INTO ... ON DUPLICATE KEY UPDATE x=VALUES(z),m=VALUES(y)...`
   359  func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) {
   360  	var onDuplicateStr string
   361  	if option.OnDuplicateStr != "" {
   362  		onDuplicateStr = option.OnDuplicateStr
   363  	} else if len(option.OnDuplicateMap) > 0 {
   364  		for k, v := range option.OnDuplicateMap {
   365  			if len(onDuplicateStr) > 0 {
   366  				onDuplicateStr += ","
   367  			}
   368  			switch v.(type) {
   369  			case Raw, *Raw:
   370  				onDuplicateStr += fmt.Sprintf(
   371  					"%s=%s",
   372  					c.QuoteWord(k),
   373  					v,
   374  				)
   375  			default:
   376  				onDuplicateStr += fmt.Sprintf(
   377  					"%s=VALUES(%s)",
   378  					c.QuoteWord(k),
   379  					c.QuoteWord(gconv.String(v)),
   380  				)
   381  			}
   382  		}
   383  	} else {
   384  		for _, column := range columns {
   385  			// If it's SAVE operation, do not automatically update the creating time.
   386  			if c.IsSoftCreatedFieldName(column) {
   387  				continue
   388  			}
   389  			if len(onDuplicateStr) > 0 {
   390  				onDuplicateStr += ","
   391  			}
   392  			onDuplicateStr += fmt.Sprintf(
   393  				"%s=VALUES(%s)",
   394  				c.QuoteWord(column),
   395  				c.QuoteWord(column),
   396  			)
   397  		}
   398  	}
   399  
   400  	return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil
   401  }
   402  
   403  // RowsToResult converts underlying data record type sql.Rows to Result type.
   404  func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) {
   405  	if rows == nil {
   406  		return nil, nil
   407  	}
   408  	defer func() {
   409  		if err := rows.Close(); err != nil {
   410  			intlog.Errorf(ctx, `%+v`, err)
   411  		}
   412  	}()
   413  	if !rows.Next() {
   414  		return nil, nil
   415  	}
   416  	// Column names and types.
   417  	columnTypes, err := rows.ColumnTypes()
   418  	if err != nil {
   419  		return nil, err
   420  	}
   421  
   422  	if len(columnTypes) > 0 {
   423  		if internalData := c.GetInternalCtxDataFromCtx(ctx); internalData != nil {
   424  			internalData.FirstResultColumn = columnTypes[0].Name()
   425  		}
   426  	}
   427  	var (
   428  		values   = make([]interface{}, len(columnTypes))
   429  		result   = make(Result, 0)
   430  		scanArgs = make([]interface{}, len(values))
   431  	)
   432  	for i := range values {
   433  		scanArgs[i] = &values[i]
   434  	}
   435  	for {
   436  		if err = rows.Scan(scanArgs...); err != nil {
   437  			return result, err
   438  		}
   439  		record := Record{}
   440  		for i, value := range values {
   441  			if value == nil {
   442  				// DO NOT use `gvar.New(nil)` here as it creates an initialized object
   443  				// which will cause struct converting issue.
   444  				record[columnTypes[i].Name()] = nil
   445  			} else {
   446  				var convertedValue interface{}
   447  				if convertedValue, err = c.columnValueToLocalValue(ctx, value, columnTypes[i]); err != nil {
   448  					return nil, err
   449  				}
   450  				record[columnTypes[i].Name()] = gvar.New(convertedValue)
   451  			}
   452  		}
   453  		result = append(result, record)
   454  		if !rows.Next() {
   455  			break
   456  		}
   457  	}
   458  	return result, nil
   459  }
   460  
   461  func (c *Core) columnValueToLocalValue(ctx context.Context, value interface{}, columnType *sql.ColumnType) (interface{}, error) {
   462  	var scanType = columnType.ScanType()
   463  	if scanType != nil {
   464  		// Common basic builtin types.
   465  		switch scanType.Kind() {
   466  		case
   467  			reflect.Bool,
   468  			reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   469  			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   470  			reflect.Float32, reflect.Float64:
   471  			return gconv.Convert(
   472  				gconv.String(value),
   473  				columnType.ScanType().String(),
   474  			), nil
   475  		}
   476  	}
   477  	// Other complex types, especially custom types.
   478  	return c.db.ConvertValueForLocal(ctx, columnType.DatabaseTypeName(), value)
   479  }