github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/bulkinserter.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/lingyao2333/mo-zero/core/executors"
    10  	"github.com/lingyao2333/mo-zero/core/logx"
    11  	"github.com/lingyao2333/mo-zero/core/stringx"
    12  )
    13  
    14  const (
    15  	flushInterval = time.Second
    16  	maxBulkRows   = 1000
    17  	valuesKeyword = "values"
    18  )
    19  
    20  var emptyBulkStmt bulkStmt
    21  
    22  type (
    23  	// ResultHandler defines the method of result handlers.
    24  	ResultHandler func(sql.Result, error)
    25  
    26  	// A BulkInserter is used to batch insert records.
    27  	// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
    28  	// Oracle is not supported yet, because of the sql is formated with symbol `:`.
    29  	BulkInserter struct {
    30  		executor *executors.PeriodicalExecutor
    31  		inserter *dbInserter
    32  		stmt     bulkStmt
    33  	}
    34  
    35  	bulkStmt struct {
    36  		prefix      string
    37  		valueFormat string
    38  		suffix      string
    39  	}
    40  )
    41  
    42  // NewBulkInserter returns a BulkInserter.
    43  func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
    44  	bkStmt, err := parseInsertStmt(stmt)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	inserter := &dbInserter{
    50  		sqlConn: sqlConn,
    51  		stmt:    bkStmt,
    52  	}
    53  
    54  	return &BulkInserter{
    55  		executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
    56  		inserter: inserter,
    57  		stmt:     bkStmt,
    58  	}, nil
    59  }
    60  
    61  // Flush flushes all the pending records.
    62  func (bi *BulkInserter) Flush() {
    63  	bi.executor.Flush()
    64  }
    65  
    66  // Insert inserts given args.
    67  func (bi *BulkInserter) Insert(args ...interface{}) error {
    68  	value, err := format(bi.stmt.valueFormat, args...)
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	bi.executor.Add(value)
    74  
    75  	return nil
    76  }
    77  
    78  // SetResultHandler sets the given handler.
    79  func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
    80  	bi.executor.Sync(func() {
    81  		bi.inserter.resultHandler = handler
    82  	})
    83  }
    84  
    85  // UpdateOrDelete runs update or delete queries, which flushes pending records first.
    86  func (bi *BulkInserter) UpdateOrDelete(fn func()) {
    87  	bi.executor.Flush()
    88  	fn()
    89  }
    90  
    91  // UpdateStmt updates the insert statement.
    92  func (bi *BulkInserter) UpdateStmt(stmt string) error {
    93  	bkStmt, err := parseInsertStmt(stmt)
    94  	if err != nil {
    95  		return err
    96  	}
    97  
    98  	bi.executor.Flush()
    99  	bi.executor.Sync(func() {
   100  		bi.inserter.stmt = bkStmt
   101  	})
   102  
   103  	return nil
   104  }
   105  
   106  type dbInserter struct {
   107  	sqlConn       SqlConn
   108  	stmt          bulkStmt
   109  	values        []string
   110  	resultHandler ResultHandler
   111  }
   112  
   113  func (in *dbInserter) AddTask(task interface{}) bool {
   114  	in.values = append(in.values, task.(string))
   115  	return len(in.values) >= maxBulkRows
   116  }
   117  
   118  func (in *dbInserter) Execute(bulk interface{}) {
   119  	values := bulk.([]string)
   120  	if len(values) == 0 {
   121  		return
   122  	}
   123  
   124  	stmtWithoutValues := in.stmt.prefix
   125  	valuesStr := strings.Join(values, ", ")
   126  	stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
   127  	if len(in.stmt.suffix) > 0 {
   128  		stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
   129  	}
   130  	result, err := in.sqlConn.Exec(stmt)
   131  	if in.resultHandler != nil {
   132  		in.resultHandler(result, err)
   133  	} else if err != nil {
   134  		logx.Errorf("sql: %s, error: %s", stmt, err)
   135  	}
   136  }
   137  
   138  func (in *dbInserter) RemoveAll() interface{} {
   139  	values := in.values
   140  	in.values = nil
   141  	return values
   142  }
   143  
   144  func parseInsertStmt(stmt string) (bulkStmt, error) {
   145  	lower := strings.ToLower(stmt)
   146  	pos := strings.Index(lower, valuesKeyword)
   147  	if pos <= 0 {
   148  		return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
   149  	}
   150  
   151  	var columns int
   152  	right := strings.LastIndexByte(lower[:pos], ')')
   153  	if right > 0 {
   154  		left := strings.LastIndexByte(lower[:right], '(')
   155  		if left > 0 {
   156  			values := lower[left+1 : right]
   157  			values = stringx.Filter(values, func(r rune) bool {
   158  				return r == ' ' || r == '\t' || r == '\r' || r == '\n'
   159  			})
   160  			fields := strings.FieldsFunc(values, func(r rune) bool {
   161  				return r == ','
   162  			})
   163  			columns = len(fields)
   164  		}
   165  	}
   166  
   167  	var variables int
   168  	var valueFormat string
   169  	var suffix string
   170  	left := strings.IndexByte(lower[pos:], '(')
   171  	if left > 0 {
   172  		right = strings.IndexByte(lower[pos+left:], ')')
   173  		if right > 0 {
   174  			values := lower[pos+left : pos+left+right]
   175  			for _, x := range values {
   176  				if x == '?' {
   177  					variables++
   178  				}
   179  			}
   180  			valueFormat = stmt[pos+left : pos+left+right+1]
   181  			suffix = strings.TrimSpace(stmt[pos+left+right+1:])
   182  		}
   183  	}
   184  
   185  	if variables == 0 {
   186  		return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
   187  	}
   188  	if columns > 0 && columns != variables {
   189  		return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
   190  	}
   191  
   192  	return bulkStmt{
   193  		prefix:      stmt[:pos+len(valuesKeyword)],
   194  		valueFormat: valueFormat,
   195  		suffix:      suffix,
   196  	}, nil
   197  }