github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/bulkinserter.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/shuguocloud/go-zero/core/executors"
    10  	"github.com/shuguocloud/go-zero/core/logx"
    11  	"github.com/shuguocloud/go-zero/core/stringx"
    12  )
    13  
    14  const (
    15  	flushInterval = time.Second
    16  	maxBulkRows   = 1000
    17  	valuesKeyword = "values"
    18  )
    19  
    20  var emptyBulkStmt bulkStmt
    21  
    22  type (
    23  	// ResultHandler defines the method of result handlers.
    24  	ResultHandler func(sql.Result, error)
    25  
    26  	// A BulkInserter is used to batch insert records.
    27  	// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
    28  	BulkInserter struct {
    29  		executor *executors.PeriodicalExecutor
    30  		inserter *dbInserter
    31  		stmt     bulkStmt
    32  	}
    33  
    34  	bulkStmt struct {
    35  		prefix      string
    36  		valueFormat string
    37  		suffix      string
    38  	}
    39  )
    40  
    41  // NewBulkInserter returns a BulkInserter.
    42  func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
    43  	bkStmt, err := parseInsertStmt(stmt)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	inserter := &dbInserter{
    49  		sqlConn: sqlConn,
    50  		stmt:    bkStmt,
    51  	}
    52  
    53  	return &BulkInserter{
    54  		executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
    55  		inserter: inserter,
    56  		stmt:     bkStmt,
    57  	}, nil
    58  }
    59  
    60  // Flush flushes all the pending records.
    61  func (bi *BulkInserter) Flush() {
    62  	bi.executor.Flush()
    63  }
    64  
    65  // Insert inserts given args.
    66  func (bi *BulkInserter) Insert(args ...interface{}) error {
    67  	value, err := format(bi.stmt.valueFormat, args...)
    68  	if err != nil {
    69  		return err
    70  	}
    71  
    72  	bi.executor.Add(value)
    73  
    74  	return nil
    75  }
    76  
    77  // SetResultHandler sets the given handler.
    78  func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
    79  	bi.executor.Sync(func() {
    80  		bi.inserter.resultHandler = handler
    81  	})
    82  }
    83  
    84  // UpdateOrDelete runs update or delete queries, which flushes pending records first.
    85  func (bi *BulkInserter) UpdateOrDelete(fn func()) {
    86  	bi.executor.Flush()
    87  	fn()
    88  }
    89  
    90  // UpdateStmt updates the insert statement.
    91  func (bi *BulkInserter) UpdateStmt(stmt string) error {
    92  	bkStmt, err := parseInsertStmt(stmt)
    93  	if err != nil {
    94  		return err
    95  	}
    96  
    97  	bi.executor.Flush()
    98  	bi.executor.Sync(func() {
    99  		bi.inserter.stmt = bkStmt
   100  	})
   101  
   102  	return nil
   103  }
   104  
   105  type dbInserter struct {
   106  	sqlConn       SqlConn
   107  	stmt          bulkStmt
   108  	values        []string
   109  	resultHandler ResultHandler
   110  }
   111  
   112  func (in *dbInserter) AddTask(task interface{}) bool {
   113  	in.values = append(in.values, task.(string))
   114  	return len(in.values) >= maxBulkRows
   115  }
   116  
   117  func (in *dbInserter) Execute(bulk interface{}) {
   118  	values := bulk.([]string)
   119  	if len(values) == 0 {
   120  		return
   121  	}
   122  
   123  	stmtWithoutValues := in.stmt.prefix
   124  	valuesStr := strings.Join(values, ", ")
   125  	stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
   126  	if len(in.stmt.suffix) > 0 {
   127  		stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
   128  	}
   129  	result, err := in.sqlConn.Exec(stmt)
   130  	if in.resultHandler != nil {
   131  		in.resultHandler(result, err)
   132  	} else if err != nil {
   133  		logx.Errorf("sql: %s, error: %s", stmt, err)
   134  	}
   135  }
   136  
   137  func (in *dbInserter) RemoveAll() interface{} {
   138  	values := in.values
   139  	in.values = nil
   140  	return values
   141  }
   142  
   143  func parseInsertStmt(stmt string) (bulkStmt, error) {
   144  	lower := strings.ToLower(stmt)
   145  	pos := strings.Index(lower, valuesKeyword)
   146  	if pos <= 0 {
   147  		return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
   148  	}
   149  
   150  	var columns int
   151  	right := strings.LastIndexByte(lower[:pos], ')')
   152  	if right > 0 {
   153  		left := strings.LastIndexByte(lower[:right], '(')
   154  		if left > 0 {
   155  			values := lower[left+1 : right]
   156  			values = stringx.Filter(values, func(r rune) bool {
   157  				return r == ' ' || r == '\t' || r == '\r' || r == '\n'
   158  			})
   159  			fields := strings.FieldsFunc(values, func(r rune) bool {
   160  				return r == ','
   161  			})
   162  			columns = len(fields)
   163  		}
   164  	}
   165  
   166  	var variables int
   167  	var valueFormat string
   168  	var suffix string
   169  	left := strings.IndexByte(lower[pos:], '(')
   170  	if left > 0 {
   171  		right = strings.IndexByte(lower[pos+left:], ')')
   172  		if right > 0 {
   173  			values := lower[pos+left : pos+left+right]
   174  			for _, x := range values {
   175  				if x == '?' {
   176  					variables++
   177  				}
   178  			}
   179  			valueFormat = stmt[pos+left : pos+left+right+1]
   180  			suffix = strings.TrimSpace(stmt[pos+left+right+1:])
   181  		}
   182  	}
   183  
   184  	if variables == 0 {
   185  		return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
   186  	}
   187  	if columns > 0 && columns != variables {
   188  		return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
   189  	}
   190  
   191  	return bulkStmt{
   192  		prefix:      stmt[:pos+len(valuesKeyword)],
   193  		valueFormat: valueFormat,
   194  		suffix:      suffix,
   195  	}, nil
   196  }