
     1  package sqlutil
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"reflect"
     8  	"strings"
     9  	"sync"
    11  	""
    12  	""
    13  	""
    14  )
    16  // InsertOptions holds options to use with batch inserting operation.
    17  type InsertOptions struct {
    18  	Context   context.Context
    19  	TableName string
    20  	Quote     string
    21  	OmitCols  []string
    23  	Ignore         bool
    24  	OnDuplicateKey string
    25  	OnConflict     string
    26  }
    28  func (p *InsertOptions) apply(opts ...InsertOpt) *InsertOptions {
    29  	for _, f := range opts {
    30  		f(p)
    31  	}
    32  	return p
    33  }
    35  func (p *InsertOptions) quote(name string) string {
    36  	if p.Quote == "" {
    37  		return name
    38  	}
    39  	return p.Quote + name + p.Quote
    40  }
    42  // InsertOpt represents an inserting option to use with batch
    43  // inserting operation.
    44  type InsertOpt func(*InsertOptions)
    46  // WithContext makes the query executed with `ExecContext` if available.
    47  func WithContext(ctx context.Context) InsertOpt {
    48  	return func(opts *InsertOptions) {
    49  		opts.Context = ctx
    50  	}
    51  }
    53  // WithTable makes the generated query to use provided table name.
    54  func WithTable(tableName string) InsertOpt {
    55  	return func(opts *InsertOptions) {
    56  		opts.TableName = tableName
    57  	}
    58  }
    60  // WithQuote quotes the table name and column names with the given string.
    61  func WithQuote(quote string) InsertOpt {
    62  	return func(opts *InsertOptions) {
    63  		opts.Quote = quote
    64  	}
    65  }
    67  // OmitColumns exclude given columns from the generated query.
    68  func OmitColumns(cols ...string) InsertOpt {
    69  	return func(opts *InsertOptions) {
    70  		opts.OmitCols = cols
    71  	}
    72  }
    74  // WithIgnore adds the mysql "IGNORE" adverb to the the generated query.
    75  func WithIgnore() InsertOpt {
    76  	return func(opts *InsertOptions) {
    77  		opts.Ignore = true
    78  	}
    79  }
    81  // OnDuplicateKey appends the mysql "ON DUPLICATE KEY" clause to the generated query.
    82  func OnDuplicateKey(clause string) InsertOpt {
    83  	return func(opts *InsertOptions) {
    84  		opts.OnDuplicateKey = clause
    85  	}
    86  }
    88  // OnConflict appends the postgresql "ON CONFLICT" clause to the generated query.
    89  func OnConflict(clause string) InsertOpt {
    90  	return func(opts *InsertOptions) {
    91  		opts.OnConflict = clause
    92  	}
    93  }
    95  // Executor is the minimal interface for batch inserting requires.
    96  // The interface is implemented by *sql.DB, *sql.Tx, *sqlx.DB, *sqlx.Tx.
    97  type Executor interface {
    98  	Exec(query string, args ...interface{}) (sql.Result, error)
    99  }
   101  // ContextExecutor is an optional interface to support context execution.
   102  // If `BatchInsert` function is called with `WithContext` option, and the
   103  // provided Executor implements this interface, then the method
   104  // `ExecContext` will be called instead of the method `Exec`.
   105  type ContextExecutor interface {
   106  	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
   107  }
   109  // BatchInsert generates SQL and executes it on the provided Executor.
   110  // The provided param `rows` must be a slice of struct or pointer to struct,
   111  // and the slice must have at least one element, or it returns error.
   112  func BatchInsert(conn Executor, rows interface{}, opts ...InsertOpt) (result sql.Result, err error) {
   113  	defer func() {
   114  		if r := recover(); r != nil {
   115  			err = fmt.Errorf("%v", r)
   116  		}
   117  	}()
   118  	options := new(InsertOptions).apply(opts...)
   119  	query, args := makeBatchInsertSQL("BatchInsert", rows, options)
   120  	if options.Context != nil {
   121  		if ctxConn, ok := conn.(ContextExecutor); ok {
   122  			result, err = ctxConn.ExecContext(options.Context, query, args...)
   123  		} else {
   124  			result, err = conn.Exec(query, args...)
   125  		}
   126  	} else {
   127  		result, err = conn.Exec(query, args...)
   128  	}
   129  	return
   130  }
   132  // MakeBatchInsertSQL generates SQL and returns the arguments to execute on database connection.
   133  // The provided param `rows` must be a slice of struct or pointer to struct,
   134  // and the slice must have at least one element, or it panics.
   135  //
   136  // The returned query uses `?` as parameter placeholder, if you are using this function
   137  // with database which don't use `?` as placeholder, you may check the `Rebind` function
   138  // from package `` to replace placeholders.
   139  func MakeBatchInsertSQL(rows interface{}, opts ...InsertOpt) (sql string, args []interface{}) {
   140  	options := new(InsertOptions).apply(opts...)
   141  	return makeBatchInsertSQL("MakeBatchInsertSQL", rows, options)
   142  }
   144  func makeBatchInsertSQL(where string, rows interface{}, opts *InsertOptions) (sql string, args []interface{}) {
   145  	assertSliceOfStructAndLength(where, rows)
   147  	typInfo := parseType(rows)
   148  	if len(opts.TableName) == 0 {
   149  		opts.TableName = typInfo.tableName
   150  	}
   152  	var buf strings.Builder
   154  	// mysql: insert ignore
   155  	if opts.Ignore {
   156  		buf.WriteString("INSERT IGNORE INTO ")
   157  	} else {
   158  		buf.WriteString("INSERT INTO ")
   159  	}
   161  	// table name
   162  	buf.WriteString(opts.quote(opts.TableName))
   164  	// column names
   165  	var omitFieldIndex []int
   166  	buf.WriteString(" (")
   167  	for i, col := range typInfo.colNames {
   168  		if inStrings(opts.OmitCols, col) {
   169  			omitFieldIndex = append(omitFieldIndex, typInfo.fieldIndex[i])
   170  			continue
   171  		}
   172  		buf.WriteString(opts.quote(col))
   173  		if i < len(typInfo.colNames)-1 {
   174  			buf.WriteByte(',')
   175  		}
   176  	}
   177  	buf.WriteByte(')')
   179  	// value placeholders
   180  	placeholders := typInfo.placeholders
   181  	fieldIndex := typInfo.fieldIndex
   182  	if len(omitFieldIndex) > 0 {
   183  		fieldIndex = diffInts(fieldIndex, omitFieldIndex)
   184  		placeholders = makePlaceholders(len(fieldIndex))
   185  	}
   186  	buf.WriteString(" VALUES ")
   187  	rowsVal := reflect.ValueOf(rows)
   188  	length := rowsVal.Len()
   189  	fieldNum := len(typInfo.fieldIndex)
   190  	args = make([]interface{}, 0, length*fieldNum)
   191  	for i := 0; i < length; i++ {
   192  		if i > 0 {
   193  			buf.WriteByte(',')
   194  		}
   195  		buf.WriteString(placeholders)
   196  		elem := reflect.Indirect(rowsVal.Index(i))
   197  		for _, j := range fieldIndex {
   198  			args = append(args, elem.Field(j).Interface())
   199  		}
   200  	}
   202  	// mysql: on duplicate key clause
   203  	if len(opts.OnDuplicateKey) > 0 {
   204  		buf.WriteString(" ON DUPLICATE KEY ")
   205  		buf.WriteString(opts.OnDuplicateKey)
   206  	}
   208  	// postgresql: on conflict clause
   209  	if len(opts.OnConflict) > 0 {
   210  		buf.WriteString(" ON CONFLICT ")
   211  		buf.WriteString(opts.OnConflict)
   212  	}
   214  	sql = buf.String()
   215  	return
   216  }
   218  var typeCache sync.Map
   220  type typeInfo struct {
   221  	tableName    string
   222  	colNames     []string
   223  	placeholders string
   224  	fieldIndex   []int
   225  }
   227  func parseType(rows interface{}) *typeInfo {
   228  	typ := reflect.TypeOf(rows)
   229  	cachedInfo, ok := typeCache.Load(typ)
   230  	if ok {
   231  		return cachedInfo.(*typeInfo)
   232  	}
   234  	elemTyp := indirectType(indirectType(typ).Elem())
   235  	tableName := strutil.ToSnakeCase(elemTyp.Name())
   236  	fieldNum := elemTyp.NumField()
   237  	colNames := make([]string, 0, fieldNum)
   238  	fieldIndex := make([]int, 0)
   239  	for i := 0; i < fieldNum; i++ {
   240  		field := elemTyp.Field(i)
   241  		col := ""
   243  		// ignore unexported fields
   244  		if len(field.PkgPath) != 0 {
   245  			continue
   246  		}
   248  		// be compatible with sqlx column name tag
   249  		dbTag := field.Tag.Get("db")
   250  		opts := structtag.ParseOptions(dbTag, ",", "")
   251  		if len(opts) > 0 {
   252  			if opts[0].String() == "-" {
   253  				continue
   254  			}
   255  			col = opts[0].String()
   256  		}
   258  		// be compatible with gorm column name tag
   259  		if col == "" {
   260  			gormTag := field.Tag.Get("gorm")
   261  			opts = structtag.ParseOptions(gormTag, ";", ":")
   262  			if len(opts) > 0 {
   263  				if opts[0].Key() == "-" {
   264  					continue
   265  				}
   266  				colopt, found := opts.Get("column")
   267  				if found && colopt.Value() != "" {
   268  					col = colopt.Value()
   269  				}
   270  			}
   271  		}
   273  		// default
   274  		if col == "" {
   275  			col = strutil.ToSnakeCase(field.Name)
   276  		}
   278  		colNames = append(colNames, col)
   279  		fieldIndex = append(fieldIndex, i)
   280  	}
   282  	placeholders := makePlaceholders(len(fieldIndex))
   283  	info := &typeInfo{
   284  		tableName:    tableName,
   285  		colNames:     colNames,
   286  		placeholders: placeholders,
   287  		fieldIndex:   fieldIndex,
   288  	}
   289  	typeCache.Store(typ, info)
   290  	return info
   291  }
   293  func makePlaceholders(n int) string {
   294  	marks := strings.Repeat("?,", n)
   295  	marks = strings.TrimSuffix(marks, ",")
   296  	return "(" + marks + ")"
   297  }
   299  func indirectType(typ reflect.Type) reflect.Type {
   300  	if typ.Kind() != reflect.Ptr {
   301  		return typ
   302  	}
   303  	return typ.Elem()
   304  }
   306  func inStrings(slice []string, elem string) bool {
   307  	for _, x := range slice {
   308  		if x == elem {
   309  			return true
   310  		}
   311  	}
   312  	return false
   313  }
   315  func inInts(slice []int, elem int) bool {
   316  	for _, x := range slice {
   317  		if x == elem {
   318  			return true
   319  		}
   320  	}
   321  	return false
   322  }
   324  func diffInts(a, b []int) []int {
   325  	out := make([]int, 0, len(a))
   326  	for _, x := range a {
   327  		if inInts(b, x) {
   328  			continue
   329  		}
   330  		out = append(out, x)
   331  	}
   332  	return out
   333  }
   335  func assertSliceOfStructAndLength(where string, rows interface{}) {
   336  	sliceTyp := reflect.TypeOf(rows)
   337  	if sliceTyp == nil || sliceTyp.Kind() != reflect.Slice {
   338  		panic(where + ": param is nil or not a slice")
   339  	}
   340  	elemTyp := sliceTyp.Elem()
   341  	if indirectType(elemTyp).Kind() != reflect.Struct {
   342  		panic(where + ": slice element is not struct or pointer to struct")
   343  	}
   344  	sh := internal.UnpackSlice(rows)
   345  	if sh.Len == 0 {
   346  		panic(where + ": slice length is zero")
   347  	}
   348  }