github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/execsql.go (about)

     1  package sqx
     2  
     3  // refer https://yougg.github.io/2017/08/24/用go语言写一个简单的mysql客户端/
     4  import (
     5  	"context"
     6  	"database/sql"
     7  	"time"
     8  
     9  	"github.com/bingoohuang/gg/pkg/ss"
    10  )
    11  
    12  // Result defines the result structure of sql execution.
    13  type Result struct {
    14  	Error        error
    15  	CostTime     time.Duration
    16  	Headers      []string
    17  	Rows         [][]string
    18  	RowsAffected int64
    19  	LastInsertID int64
    20  	IsQuerySQL   bool
    21  	FirstKey     string
    22  }
    23  
    24  func (r Result) Return(start time.Time, err error) Result {
    25  	r.CostTime = time.Since(start)
    26  	r.Error = err
    27  	return r
    28  }
    29  
    30  // SQLExec wraps Exec method.
    31  type SQLExec interface {
    32  	Exec(query string, args ...interface{}) (sql.Result, error)
    33  	Query(query string, args ...interface{}) (*sql.Rows, error)
    34  }
    35  
    36  type SQLExecContext interface {
    37  	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
    38  	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    39  }
    40  
    41  type SqxDB interface {
    42  	SQLExec
    43  	SQLExecContext
    44  }
    45  
    46  type ExecOption struct {
    47  	MaxRows     int
    48  	NullReplace string
    49  	BlobReplace string
    50  }
    51  
    52  func (o ExecOption) reachMaxRows(row int) bool {
    53  	return o.MaxRows > 0 && row >= o.MaxRows
    54  }
    55  
    56  // Exec executes a SQL.
    57  func Exec(db SQLExec, query string, option ExecOption) Result {
    58  	firstKey, isQuerySQL := IsQuerySQL(query)
    59  	if isQuerySQL {
    60  		return processQuery(db, query, firstKey, option)
    61  	}
    62  
    63  	return execNonQuery(db, query, firstKey)
    64  }
    65  
    66  func processQuery(db SQLExec, query string, firstKey string, option ExecOption) (r Result) {
    67  	start := time.Now()
    68  	r.FirstKey = firstKey
    69  	r.IsQuerySQL = true
    70  
    71  	rows, err := db.Query(query)
    72  	if err != nil || rows != nil && rows.Err() != nil {
    73  		if err == nil {
    74  			err = rows.Err()
    75  		}
    76  
    77  		return r.Return(start, err)
    78  	}
    79  
    80  	defer rows.Close()
    81  
    82  	columns, err := rows.Columns()
    83  	if err != nil {
    84  		return r.Return(start, err)
    85  	}
    86  
    87  	r.Headers = columns
    88  
    89  	columnSize := len(columns)
    90  	columnTypes, _ := rows.ColumnTypes()
    91  	data := make([][]string, 0)
    92  
    93  	var columnLobs []bool
    94  	if option.BlobReplace != "" {
    95  		columnLobs = make([]bool, columnSize)
    96  		for i := 0; i < len(columnTypes); i++ {
    97  			columnLobs[i] = ss.ContainsFold(columnTypes[i].DatabaseTypeName(), "LOB")
    98  		}
    99  	}
   100  
   101  	for row := 0; rows.Next() && !option.reachMaxRows(row); row++ {
   102  		holders := make([]sql.NullString, columnSize)
   103  		pointers := make([]interface{}, columnSize)
   104  
   105  		for i := 0; i < columnSize; i++ {
   106  			pointers[i] = &holders[i]
   107  		}
   108  
   109  		if err := rows.Scan(pointers...); err != nil {
   110  			return r.Return(start, err)
   111  		}
   112  
   113  		values := make([]string, columnSize)
   114  
   115  		for i, v := range holders {
   116  			values[i] = ss.If(v.Valid, v.String, option.NullReplace)
   117  			if option.BlobReplace != "" && v.Valid && columnLobs[i] {
   118  				values[i] = "(" + columnTypes[i].DatabaseTypeName() + ")"
   119  			}
   120  		}
   121  
   122  		data = append(data, values)
   123  	}
   124  
   125  	r.Rows = data
   126  
   127  	return r.Return(start, nil)
   128  }
   129  
   130  func execNonQuery(db SQLExec, query string, firstKey string) Result {
   131  	start := time.Now()
   132  	r, err := db.Exec(query)
   133  
   134  	var affected int64
   135  	if r != nil {
   136  		affected, _ = r.RowsAffected()
   137  	}
   138  
   139  	var lastInsertID int64
   140  	if r != nil && firstKey == "INSERT" {
   141  		lastInsertID, _ = r.LastInsertId()
   142  	}
   143  
   144  	return Result{
   145  		Error:        err,
   146  		CostTime:     time.Since(start),
   147  		RowsAffected: affected,
   148  		IsQuerySQL:   false,
   149  		LastInsertID: lastInsertID,
   150  		FirstKey:     firstKey,
   151  	}
   152  }