github.com/gogf/gf/v2@v2.7.4/database/gdb/gdb_core_underlying.go (about)

     1  // Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the MIT License.
     4  // If a copy of the MIT was not distributed with this file,
     5  // You can obtain one at https://github.com/gogf/gf.
     6  //
     7  
     8  package gdb
     9  
    10  import (
    11  	"context"
    12  	"database/sql"
    13  	"fmt"
    14  	"reflect"
    15  
    16  	"go.opentelemetry.io/otel"
    17  	"go.opentelemetry.io/otel/trace"
    18  
    19  	"github.com/gogf/gf/v2/util/gconv"
    20  
    21  	"github.com/gogf/gf/v2"
    22  	"github.com/gogf/gf/v2/container/gvar"
    23  	"github.com/gogf/gf/v2/errors/gcode"
    24  	"github.com/gogf/gf/v2/errors/gerror"
    25  	"github.com/gogf/gf/v2/internal/intlog"
    26  	"github.com/gogf/gf/v2/os/gtime"
    27  	"github.com/gogf/gf/v2/util/guid"
    28  )
    29  
    30  // Query commits one query SQL to underlying driver and returns the execution result.
    31  // It is most commonly used for data querying.
    32  func (c *Core) Query(ctx context.Context, sql string, args ...interface{}) (result Result, err error) {
    33  	return c.db.DoQuery(ctx, nil, sql, args...)
    34  }
    35  
    36  // DoQuery commits the sql string and its arguments to underlying driver
    37  // through given link object and returns the execution result.
    38  func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (result Result, err error) {
    39  	// Transaction checks.
    40  	if link == nil {
    41  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
    42  			// Firstly, check and retrieve transaction link from context.
    43  			link = &txLink{tx.GetSqlTX()}
    44  		} else if link, err = c.SlaveLink(); err != nil {
    45  			// Or else it creates one from master node.
    46  			return nil, err
    47  		}
    48  	} else if !link.IsTransaction() {
    49  		// If current link is not transaction link, it checks and retrieves transaction from context.
    50  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
    51  			link = &txLink{tx.GetSqlTX()}
    52  		}
    53  	}
    54  
    55  	if c.db.GetConfig().QueryTimeout > 0 {
    56  		ctx, _ = context.WithTimeout(ctx, c.db.GetConfig().QueryTimeout)
    57  	}
    58  
    59  	// Sql filtering.
    60  	sql, args = c.FormatSqlBeforeExecuting(sql, args)
    61  	sql, args, err = c.db.DoFilter(ctx, link, sql, args)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	// SQL format and retrieve.
    66  	if v := ctx.Value(ctxKeyCatchSQL); v != nil {
    67  		var (
    68  			manager      = v.(*CatchSQLManager)
    69  			formattedSql = FormatSqlWithArgs(sql, args)
    70  		)
    71  		manager.SQLArray.Append(formattedSql)
    72  		if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
    73  			return nil, nil
    74  		}
    75  	}
    76  	// Link execution.
    77  	var out DoCommitOutput
    78  	out, err = c.db.DoCommit(ctx, DoCommitInput{
    79  		Link:          link,
    80  		Sql:           sql,
    81  		Args:          args,
    82  		Stmt:          nil,
    83  		Type:          SqlTypeQueryContext,
    84  		IsTransaction: link.IsTransaction(),
    85  	})
    86  	return out.Records, err
    87  }
    88  
    89  // Exec commits one query SQL to underlying driver and returns the execution result.
    90  // It is most commonly used for data inserting and updating.
    91  func (c *Core) Exec(ctx context.Context, sql string, args ...interface{}) (result sql.Result, err error) {
    92  	return c.db.DoExec(ctx, nil, sql, args...)
    93  }
    94  
    95  // DoExec commits the sql string and its arguments to underlying driver
    96  // through given link object and returns the execution result.
    97  func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) {
    98  	// Transaction checks.
    99  	if link == nil {
   100  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
   101  			// Firstly, check and retrieve transaction link from context.
   102  			link = &txLink{tx.GetSqlTX()}
   103  		} else if link, err = c.MasterLink(); err != nil {
   104  			// Or else it creates one from master node.
   105  			return nil, err
   106  		}
   107  	} else if !link.IsTransaction() {
   108  		// If current link is not transaction link, it checks and retrieves transaction from context.
   109  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
   110  			link = &txLink{tx.GetSqlTX()}
   111  		}
   112  	}
   113  
   114  	if c.db.GetConfig().ExecTimeout > 0 {
   115  		var cancelFunc context.CancelFunc
   116  		ctx, cancelFunc = context.WithTimeout(ctx, c.db.GetConfig().ExecTimeout)
   117  		defer cancelFunc()
   118  	}
   119  
   120  	// SQL filtering.
   121  	sql, args = c.FormatSqlBeforeExecuting(sql, args)
   122  	sql, args, err = c.db.DoFilter(ctx, link, sql, args)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	// SQL format and retrieve.
   127  	if v := ctx.Value(ctxKeyCatchSQL); v != nil {
   128  		var (
   129  			manager      = v.(*CatchSQLManager)
   130  			formattedSql = FormatSqlWithArgs(sql, args)
   131  		)
   132  		manager.SQLArray.Append(formattedSql)
   133  		if !manager.DoCommit && ctx.Value(ctxKeyInternalProducedSQL) == nil {
   134  			return new(SqlResult), nil
   135  		}
   136  	}
   137  	// Link execution.
   138  	var out DoCommitOutput
   139  	out, err = c.db.DoCommit(ctx, DoCommitInput{
   140  		Link:          link,
   141  		Sql:           sql,
   142  		Args:          args,
   143  		Stmt:          nil,
   144  		Type:          SqlTypeExecContext,
   145  		IsTransaction: link.IsTransaction(),
   146  	})
   147  	return out.Result, err
   148  }
   149  
   150  // DoFilter is a hook function, which filters the sql and its arguments before it's committed to underlying driver.
   151  // The parameter `link` specifies the current database connection operation object. You can modify the sql
   152  // string `sql` and its arguments `args` as you wish before they're committed to driver.
   153  func (c *Core) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
   154  	return sql, args, nil
   155  }
   156  
   157  // DoCommit commits current sql and arguments to underlying sql driver.
   158  func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) {
   159  	var (
   160  		sqlTx                *sql.Tx
   161  		sqlStmt              *sql.Stmt
   162  		sqlRows              *sql.Rows
   163  		sqlResult            sql.Result
   164  		stmtSqlRows          *sql.Rows
   165  		stmtSqlRow           *sql.Row
   166  		rowsAffected         int64
   167  		cancelFuncForTimeout context.CancelFunc
   168  		formattedSql         = FormatSqlWithArgs(in.Sql, in.Args)
   169  		timestampMilli1      = gtime.TimestampMilli()
   170  	)
   171  
   172  	// Trace span start.
   173  	tr := otel.GetTracerProvider().Tracer(traceInstrumentName, trace.WithInstrumentationVersion(gf.VERSION))
   174  	ctx, span := tr.Start(ctx, string(in.Type), trace.WithSpanKind(trace.SpanKindInternal))
   175  	defer span.End()
   176  
   177  	// Execution cased by type.
   178  	switch in.Type {
   179  	case SqlTypeBegin:
   180  		if sqlTx, err = in.Db.Begin(); err == nil {
   181  			out.Tx = &TXCore{
   182  				db:            c.db,
   183  				tx:            sqlTx,
   184  				ctx:           context.WithValue(ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1)),
   185  				master:        in.Db,
   186  				transactionId: guid.S(),
   187  			}
   188  			ctx = out.Tx.GetCtx()
   189  		}
   190  		out.RawResult = sqlTx
   191  
   192  	case SqlTypeTXCommit:
   193  		err = in.Tx.Commit()
   194  
   195  	case SqlTypeTXRollback:
   196  		err = in.Tx.Rollback()
   197  
   198  	case SqlTypeExecContext:
   199  		if c.db.GetDryRun() {
   200  			sqlResult = new(SqlResult)
   201  		} else {
   202  			sqlResult, err = in.Link.ExecContext(ctx, in.Sql, in.Args...)
   203  		}
   204  		out.RawResult = sqlResult
   205  
   206  	case SqlTypeQueryContext:
   207  		sqlRows, err = in.Link.QueryContext(ctx, in.Sql, in.Args...)
   208  		out.RawResult = sqlRows
   209  
   210  	case SqlTypePrepareContext:
   211  		sqlStmt, err = in.Link.PrepareContext(ctx, in.Sql)
   212  		out.RawResult = sqlStmt
   213  
   214  	case SqlTypeStmtExecContext:
   215  		ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec)
   216  		defer cancelFuncForTimeout()
   217  		if c.db.GetDryRun() {
   218  			sqlResult = new(SqlResult)
   219  		} else {
   220  			sqlResult, err = in.Stmt.ExecContext(ctx, in.Args...)
   221  		}
   222  		out.RawResult = sqlResult
   223  
   224  	case SqlTypeStmtQueryContext:
   225  		ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
   226  		defer cancelFuncForTimeout()
   227  		stmtSqlRows, err = in.Stmt.QueryContext(ctx, in.Args...)
   228  		out.RawResult = stmtSqlRows
   229  
   230  	case SqlTypeStmtQueryRowContext:
   231  		ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery)
   232  		defer cancelFuncForTimeout()
   233  		stmtSqlRow = in.Stmt.QueryRowContext(ctx, in.Args...)
   234  		out.RawResult = stmtSqlRow
   235  
   236  	default:
   237  		panic(gerror.NewCodef(gcode.CodeInvalidParameter, `invalid SqlType "%s"`, in.Type))
   238  	}
   239  	// Result handling.
   240  	switch {
   241  	case sqlResult != nil && !c.GetIgnoreResultFromCtx(ctx):
   242  		rowsAffected, err = sqlResult.RowsAffected()
   243  		out.Result = sqlResult
   244  
   245  	case sqlRows != nil:
   246  		out.Records, err = c.RowsToResult(ctx, sqlRows)
   247  		rowsAffected = int64(len(out.Records))
   248  
   249  	case sqlStmt != nil:
   250  		out.Stmt = &Stmt{
   251  			Stmt: sqlStmt,
   252  			core: c,
   253  			link: in.Link,
   254  			sql:  in.Sql,
   255  		}
   256  	}
   257  	var (
   258  		timestampMilli2 = gtime.TimestampMilli()
   259  		sqlObj          = &Sql{
   260  			Sql:           in.Sql,
   261  			Type:          in.Type,
   262  			Args:          in.Args,
   263  			Format:        formattedSql,
   264  			Error:         err,
   265  			Start:         timestampMilli1,
   266  			End:           timestampMilli2,
   267  			Group:         c.db.GetGroup(),
   268  			Schema:        c.db.GetSchema(),
   269  			RowsAffected:  rowsAffected,
   270  			IsTransaction: in.IsTransaction,
   271  		}
   272  	)
   273  
   274  	// Tracing.
   275  	c.traceSpanEnd(ctx, span, sqlObj)
   276  
   277  	// Logging.
   278  	if c.db.GetDebug() {
   279  		c.writeSqlToLogger(ctx, sqlObj)
   280  	}
   281  	if err != nil && err != sql.ErrNoRows {
   282  		err = gerror.WrapCode(
   283  			gcode.CodeDbOperationError,
   284  			err,
   285  			FormatSqlWithArgs(in.Sql, in.Args),
   286  		)
   287  	}
   288  	return out, err
   289  }
   290  
   291  // Prepare creates a prepared statement for later queries or executions.
   292  // Multiple queries or executions may be run concurrently from the
   293  // returned statement.
   294  // The caller must call the statement's Close method
   295  // when the statement is no longer needed.
   296  //
   297  // The parameter `execOnMaster` specifies whether executing the sql on master node,
   298  // or else it executes the sql on slave node if master-slave configured.
   299  func (c *Core) Prepare(ctx context.Context, sql string, execOnMaster ...bool) (*Stmt, error) {
   300  	var (
   301  		err  error
   302  		link Link
   303  	)
   304  	if len(execOnMaster) > 0 && execOnMaster[0] {
   305  		if link, err = c.MasterLink(); err != nil {
   306  			return nil, err
   307  		}
   308  	} else {
   309  		if link, err = c.SlaveLink(); err != nil {
   310  			return nil, err
   311  		}
   312  	}
   313  	return c.db.DoPrepare(ctx, link, sql)
   314  }
   315  
   316  // DoPrepare calls prepare function on given link object and returns the statement object.
   317  func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt, err error) {
   318  	// Transaction checks.
   319  	if link == nil {
   320  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
   321  			// Firstly, check and retrieve transaction link from context.
   322  			link = &txLink{tx.GetSqlTX()}
   323  		} else {
   324  			// Or else it creates one from master node.
   325  			var err error
   326  			if link, err = c.MasterLink(); err != nil {
   327  				return nil, err
   328  			}
   329  		}
   330  	} else if !link.IsTransaction() {
   331  		// If current link is not transaction link, it checks and retrieves transaction from context.
   332  		if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil {
   333  			link = &txLink{tx.GetSqlTX()}
   334  		}
   335  	}
   336  
   337  	if c.db.GetConfig().PrepareTimeout > 0 {
   338  		// DO NOT USE cancel function in prepare statement.
   339  		ctx, _ = context.WithTimeout(ctx, c.db.GetConfig().PrepareTimeout)
   340  	}
   341  
   342  	// Link execution.
   343  	var out DoCommitOutput
   344  	out, err = c.db.DoCommit(ctx, DoCommitInput{
   345  		Link:          link,
   346  		Sql:           sql,
   347  		Type:          SqlTypePrepareContext,
   348  		IsTransaction: link.IsTransaction(),
   349  	})
   350  	return out.Stmt, err
   351  }
   352  
   353  // FormatUpsert formats and returns SQL clause part for upsert statement.
   354  // In default implements, this function performs upsert statement for MySQL like:
   355  // `INSERT INTO ... ON DUPLICATE KEY UPDATE x=VALUES(z),m=VALUES(y)...`
   356  func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) {
   357  	var onDuplicateStr string
   358  	if option.OnDuplicateStr != "" {
   359  		onDuplicateStr = option.OnDuplicateStr
   360  	} else if len(option.OnDuplicateMap) > 0 {
   361  		for k, v := range option.OnDuplicateMap {
   362  			if len(onDuplicateStr) > 0 {
   363  				onDuplicateStr += ","
   364  			}
   365  			switch v.(type) {
   366  			case Raw, *Raw:
   367  				onDuplicateStr += fmt.Sprintf(
   368  					"%s=%s",
   369  					c.QuoteWord(k),
   370  					v,
   371  				)
   372  			default:
   373  				onDuplicateStr += fmt.Sprintf(
   374  					"%s=VALUES(%s)",
   375  					c.QuoteWord(k),
   376  					c.QuoteWord(gconv.String(v)),
   377  				)
   378  			}
   379  		}
   380  	} else {
   381  		for _, column := range columns {
   382  			// If it's SAVE operation, do not automatically update the creating time.
   383  			if c.IsSoftCreatedFieldName(column) {
   384  				continue
   385  			}
   386  			if len(onDuplicateStr) > 0 {
   387  				onDuplicateStr += ","
   388  			}
   389  			onDuplicateStr += fmt.Sprintf(
   390  				"%s=VALUES(%s)",
   391  				c.QuoteWord(column),
   392  				c.QuoteWord(column),
   393  			)
   394  		}
   395  	}
   396  
   397  	return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil
   398  }
   399  
   400  // RowsToResult converts underlying data record type sql.Rows to Result type.
   401  func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) {
   402  	if rows == nil {
   403  		return nil, nil
   404  	}
   405  	defer func() {
   406  		if err := rows.Close(); err != nil {
   407  			intlog.Errorf(ctx, `%+v`, err)
   408  		}
   409  	}()
   410  	if !rows.Next() {
   411  		return nil, nil
   412  	}
   413  	// Column names and types.
   414  	columnTypes, err := rows.ColumnTypes()
   415  	if err != nil {
   416  		return nil, err
   417  	}
   418  
   419  	if len(columnTypes) > 0 {
   420  		if internalData := c.getInternalColumnFromCtx(ctx); internalData != nil {
   421  			internalData.FirstResultColumn = columnTypes[0].Name()
   422  		}
   423  	}
   424  	var (
   425  		values   = make([]interface{}, len(columnTypes))
   426  		result   = make(Result, 0)
   427  		scanArgs = make([]interface{}, len(values))
   428  	)
   429  	for i := range values {
   430  		scanArgs[i] = &values[i]
   431  	}
   432  	for {
   433  		if err = rows.Scan(scanArgs...); err != nil {
   434  			return result, err
   435  		}
   436  		record := Record{}
   437  		for i, value := range values {
   438  			if value == nil {
   439  				// DO NOT use `gvar.New(nil)` here as it creates an initialized object
   440  				// which will cause struct converting issue.
   441  				record[columnTypes[i].Name()] = nil
   442  			} else {
   443  				var convertedValue interface{}
   444  				if convertedValue, err = c.columnValueToLocalValue(ctx, value, columnTypes[i]); err != nil {
   445  					return nil, err
   446  				}
   447  				record[columnTypes[i].Name()] = gvar.New(convertedValue)
   448  			}
   449  		}
   450  		result = append(result, record)
   451  		if !rows.Next() {
   452  			break
   453  		}
   454  	}
   455  	return result, nil
   456  }
   457  
   458  // OrderRandomFunction returns the SQL function for random ordering.
   459  func (c *Core) OrderRandomFunction() string {
   460  	return "RAND()"
   461  }
   462  
   463  func (c *Core) columnValueToLocalValue(ctx context.Context, value interface{}, columnType *sql.ColumnType) (interface{}, error) {
   464  	var scanType = columnType.ScanType()
   465  	if scanType != nil {
   466  		// Common basic builtin types.
   467  		switch scanType.Kind() {
   468  		case
   469  			reflect.Bool,
   470  			reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   471  			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   472  			reflect.Float32, reflect.Float64:
   473  			return gconv.Convert(
   474  				gconv.String(value),
   475  				columnType.ScanType().String(),
   476  			), nil
   477  		}
   478  	}
   479  	// Other complex types, especially custom types.
   480  	return c.db.ConvertValueForLocal(ctx, columnType.DatabaseTypeName(), value)
   481  }