github.com/jxskiss/gopkg/v2@v2.14.9-0.20240514120614-899f3e7952b4/utils/sqlutil/batch.go (about) 1 package sqlutil 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "reflect" 8 "strings" 9 "sync" 10 11 "github.com/jxskiss/gopkg/v2/unsafe/reflectx" 12 "github.com/jxskiss/gopkg/v2/utils/structtag" 13 "github.com/jxskiss/gopkg/v2/utils/strutil" 14 ) 15 16 // InsertOptions holds options to use with batch inserting operation. 17 type InsertOptions struct { 18 Context context.Context 19 TableName string 20 Quote string 21 OmitCols []string 22 23 Ignore bool 24 OnDuplicateKey string 25 OnConflict string 26 } 27 28 func (p *InsertOptions) apply(opts ...InsertOpt) *InsertOptions { 29 for _, f := range opts { 30 f(p) 31 } 32 return p 33 } 34 35 func (p *InsertOptions) quote(name string) string { 36 if p.Quote == "" { 37 return name 38 } 39 return p.Quote + name + p.Quote 40 } 41 42 // InsertOpt represents an inserting option to use with batch 43 // inserting operation. 44 type InsertOpt func(*InsertOptions) 45 46 // WithContext makes the query executed with `ExecContext` if available. 47 func WithContext(ctx context.Context) InsertOpt { 48 return func(opts *InsertOptions) { 49 opts.Context = ctx 50 } 51 } 52 53 // WithTable makes the generated query to use provided table name. 54 func WithTable(tableName string) InsertOpt { 55 return func(opts *InsertOptions) { 56 opts.TableName = tableName 57 } 58 } 59 60 // WithQuote quotes the table name and column names with the given string. 61 func WithQuote(quote string) InsertOpt { 62 return func(opts *InsertOptions) { 63 opts.Quote = quote 64 } 65 } 66 67 // OmitColumns exclude given columns from the generated query. 68 func OmitColumns(cols ...string) InsertOpt { 69 return func(opts *InsertOptions) { 70 opts.OmitCols = cols 71 } 72 } 73 74 // WithIgnore adds the mysql "IGNORE" adverb to the the generated query. 75 func WithIgnore() InsertOpt { 76 return func(opts *InsertOptions) { 77 opts.Ignore = true 78 } 79 } 80 81 // OnDuplicateKey appends the mysql "ON DUPLICATE KEY" clause to the generated query. 82 func OnDuplicateKey(clause string) InsertOpt { 83 return func(opts *InsertOptions) { 84 opts.OnDuplicateKey = clause 85 } 86 } 87 88 // OnConflict appends the postgresql "ON CONFLICT" clause to the generated query. 89 func OnConflict(clause string) InsertOpt { 90 return func(opts *InsertOptions) { 91 opts.OnConflict = clause 92 } 93 } 94 95 // Executor is the minimal interface for batch inserting requires. 96 // The interface is implemented by *sql.DB, *sql.Tx, *sqlx.DB, *sqlx.Tx. 97 type Executor interface { 98 Exec(query string, args ...any) (sql.Result, error) 99 } 100 101 // ContextExecutor is an optional interface to support context execution. 102 // If `BatchInsert` function is called with `WithContext` option, and the 103 // provided Executor implements this interface, then the method 104 // `ExecContext` will be called instead of the method `Exec`. 105 type ContextExecutor interface { 106 ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 107 } 108 109 // BatchInsert generates SQL and executes it on the provided Executor. 110 // The provided param `rows` must be a slice of struct or pointer to struct, 111 // and the slice must have at least one element, or it returns error. 112 func BatchInsert(conn Executor, rows any, opts ...InsertOpt) (result sql.Result, err error) { 113 defer func() { 114 if r := recover(); r != nil { 115 err = fmt.Errorf("%v", r) 116 } 117 }() 118 options := new(InsertOptions).apply(opts...) 119 query, args := makeBatchInsertSQL("BatchInsert", rows, options) 120 if options.Context != nil { 121 if ctxConn, ok := conn.(ContextExecutor); ok { 122 result, err = ctxConn.ExecContext(options.Context, query, args...) 123 } else { 124 result, err = conn.Exec(query, args...) 125 } 126 } else { 127 result, err = conn.Exec(query, args...) 128 } 129 return 130 } 131 132 // MakeBatchInsertSQL generates SQL and returns the arguments to execute on database connection. 133 // The provided param `rows` must be a slice of struct or pointer to struct, 134 // and the slice must have at least one element, or it panics. 135 // 136 // The returned query uses `?` as parameter placeholder, if you are using this function 137 // with database which don't use `?` as placeholder, you may check the `Rebind` function 138 // from package `github.com/jmoiron/sqlx` to replace placeholders. 139 func MakeBatchInsertSQL(rows any, opts ...InsertOpt) (sql string, args []any) { 140 options := new(InsertOptions).apply(opts...) 141 return makeBatchInsertSQL("MakeBatchInsertSQL", rows, options) 142 } 143 144 func makeBatchInsertSQL(where string, rows any, opts *InsertOptions) (sql string, args []any) { 145 assertSliceOfStructAndLength(where, rows) 146 147 typInfo := parseType(rows) 148 if len(opts.TableName) == 0 { 149 opts.TableName = typInfo.tableName 150 } 151 152 var buf strings.Builder 153 154 // mysql: insert ignore 155 if opts.Ignore { 156 buf.WriteString("INSERT IGNORE INTO ") 157 } else { 158 buf.WriteString("INSERT INTO ") 159 } 160 161 // table name 162 buf.WriteString(opts.quote(opts.TableName)) 163 164 // column names 165 var omitFieldIndex []int 166 buf.WriteString(" (") 167 for i, col := range typInfo.colNames { 168 if inStrings(opts.OmitCols, col) { 169 omitFieldIndex = append(omitFieldIndex, typInfo.fieldIndex[i]) 170 continue 171 } 172 buf.WriteString(opts.quote(col)) 173 if i < len(typInfo.colNames)-1 { 174 buf.WriteByte(',') 175 } 176 } 177 buf.WriteByte(')') 178 179 // value placeholders 180 placeholders := typInfo.placeholders 181 fieldIndex := typInfo.fieldIndex 182 if len(omitFieldIndex) > 0 { 183 fieldIndex = diffInts(fieldIndex, omitFieldIndex) 184 placeholders = makePlaceholders(len(fieldIndex)) 185 } 186 buf.WriteString(" VALUES ") 187 rowsVal := reflect.ValueOf(rows) 188 length := rowsVal.Len() 189 fieldNum := len(typInfo.fieldIndex) 190 args = make([]any, 0, length*fieldNum) 191 for i := 0; i < length; i++ { 192 if i > 0 { 193 buf.WriteByte(',') 194 } 195 buf.WriteString(placeholders) 196 elem := reflect.Indirect(rowsVal.Index(i)) 197 for _, j := range fieldIndex { 198 args = append(args, elem.Field(j).Interface()) 199 } 200 } 201 202 // mysql: on duplicate key clause 203 if len(opts.OnDuplicateKey) > 0 { 204 buf.WriteString(" ON DUPLICATE KEY ") 205 buf.WriteString(opts.OnDuplicateKey) 206 } 207 208 // postgresql: on conflict clause 209 if len(opts.OnConflict) > 0 { 210 buf.WriteString(" ON CONFLICT ") 211 buf.WriteString(opts.OnConflict) 212 } 213 214 sql = buf.String() 215 return sql, args 216 } 217 218 var typeCache sync.Map 219 220 type typeInfo struct { 221 tableName string 222 colNames []string 223 placeholders string 224 fieldIndex []int 225 } 226 227 func parseType(rows any) *typeInfo { 228 typ := reflect.TypeOf(rows) 229 cachedInfo, ok := typeCache.Load(typ) 230 if ok { 231 return cachedInfo.(*typeInfo) 232 } 233 234 elemTyp := indirectType(indirectType(typ).Elem()) 235 tableName := strutil.ToSnakeCase(elemTyp.Name()) 236 fieldNum := elemTyp.NumField() 237 colNames := make([]string, 0, fieldNum) 238 fieldIndex := make([]int, 0) 239 for i := 0; i < fieldNum; i++ { 240 field := elemTyp.Field(i) 241 col := "" 242 243 // ignore unexported fields 244 if len(field.PkgPath) != 0 { 245 continue 246 } 247 248 // be compatible with sqlx column name tag 249 dbTag := field.Tag.Get("db") 250 opts := structtag.ParseOptions(dbTag, ",", "") 251 if len(opts) > 0 { 252 if opts[0].String() == "-" { 253 continue 254 } 255 col = opts[0].String() 256 } 257 258 // be compatible with gorm column name tag 259 if col == "" { 260 gormTag := field.Tag.Get("gorm") 261 opts = structtag.ParseOptions(gormTag, ";", ":") 262 if len(opts) > 0 { 263 if opts[0].Key() == "-" { 264 continue 265 } 266 colopt, found := opts.Get("column") 267 if found && colopt.Value() != "" { 268 col = colopt.Value() 269 } 270 } 271 } 272 273 // default 274 if col == "" { 275 col = strutil.ToSnakeCase(field.Name) 276 } 277 278 colNames = append(colNames, col) 279 fieldIndex = append(fieldIndex, i) 280 } 281 282 placeholders := makePlaceholders(len(fieldIndex)) 283 info := &typeInfo{ 284 tableName: tableName, 285 colNames: colNames, 286 placeholders: placeholders, 287 fieldIndex: fieldIndex, 288 } 289 typeCache.Store(typ, info) 290 return info 291 } 292 293 func makePlaceholders(n int) string { 294 marks := strings.Repeat("?,", n) 295 marks = strings.TrimSuffix(marks, ",") 296 return "(" + marks + ")" 297 } 298 299 func indirectType(typ reflect.Type) reflect.Type { 300 if typ.Kind() != reflect.Ptr { 301 return typ 302 } 303 return typ.Elem() 304 } 305 306 func inStrings(slice []string, elem string) bool { 307 for _, x := range slice { 308 if x == elem { 309 return true 310 } 311 } 312 return false 313 } 314 315 func inInts(slice []int, elem int) bool { 316 for _, x := range slice { 317 if x == elem { 318 return true 319 } 320 } 321 return false 322 } 323 324 func diffInts(a, b []int) []int { 325 out := make([]int, 0, len(a)) 326 for _, x := range a { 327 if inInts(b, x) { 328 continue 329 } 330 out = append(out, x) 331 } 332 return out 333 } 334 335 func assertSliceOfStructAndLength(where string, rows any) { 336 sliceTyp := reflect.TypeOf(rows) 337 if sliceTyp == nil || sliceTyp.Kind() != reflect.Slice { 338 panic(where + ": param is nil or not a slice") 339 } 340 elemTyp := sliceTyp.Elem() 341 if indirectType(elemTyp).Kind() != reflect.Struct { 342 panic(where + ": slice element is not struct or pointer to struct") 343 } 344 if reflectx.SliceLen(rows) == 0 { 345 panic(where + ": slice length is zero") 346 } 347 }