github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/bulkinserter.go (about) 1 package sqlx 2 3 import ( 4 "database/sql" 5 "fmt" 6 "strings" 7 "time" 8 9 "github.com/shuguocloud/go-zero/core/executors" 10 "github.com/shuguocloud/go-zero/core/logx" 11 "github.com/shuguocloud/go-zero/core/stringx" 12 ) 13 14 const ( 15 flushInterval = time.Second 16 maxBulkRows = 1000 17 valuesKeyword = "values" 18 ) 19 20 var emptyBulkStmt bulkStmt 21 22 type ( 23 // ResultHandler defines the method of result handlers. 24 ResultHandler func(sql.Result, error) 25 26 // A BulkInserter is used to batch insert records. 27 // Postgresql is not supported yet, because of the sql is formated with symbol `$`. 28 BulkInserter struct { 29 executor *executors.PeriodicalExecutor 30 inserter *dbInserter 31 stmt bulkStmt 32 } 33 34 bulkStmt struct { 35 prefix string 36 valueFormat string 37 suffix string 38 } 39 ) 40 41 // NewBulkInserter returns a BulkInserter. 42 func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) { 43 bkStmt, err := parseInsertStmt(stmt) 44 if err != nil { 45 return nil, err 46 } 47 48 inserter := &dbInserter{ 49 sqlConn: sqlConn, 50 stmt: bkStmt, 51 } 52 53 return &BulkInserter{ 54 executor: executors.NewPeriodicalExecutor(flushInterval, inserter), 55 inserter: inserter, 56 stmt: bkStmt, 57 }, nil 58 } 59 60 // Flush flushes all the pending records. 61 func (bi *BulkInserter) Flush() { 62 bi.executor.Flush() 63 } 64 65 // Insert inserts given args. 66 func (bi *BulkInserter) Insert(args ...interface{}) error { 67 value, err := format(bi.stmt.valueFormat, args...) 68 if err != nil { 69 return err 70 } 71 72 bi.executor.Add(value) 73 74 return nil 75 } 76 77 // SetResultHandler sets the given handler. 78 func (bi *BulkInserter) SetResultHandler(handler ResultHandler) { 79 bi.executor.Sync(func() { 80 bi.inserter.resultHandler = handler 81 }) 82 } 83 84 // UpdateOrDelete runs update or delete queries, which flushes pending records first. 85 func (bi *BulkInserter) UpdateOrDelete(fn func()) { 86 bi.executor.Flush() 87 fn() 88 } 89 90 // UpdateStmt updates the insert statement. 91 func (bi *BulkInserter) UpdateStmt(stmt string) error { 92 bkStmt, err := parseInsertStmt(stmt) 93 if err != nil { 94 return err 95 } 96 97 bi.executor.Flush() 98 bi.executor.Sync(func() { 99 bi.inserter.stmt = bkStmt 100 }) 101 102 return nil 103 } 104 105 type dbInserter struct { 106 sqlConn SqlConn 107 stmt bulkStmt 108 values []string 109 resultHandler ResultHandler 110 } 111 112 func (in *dbInserter) AddTask(task interface{}) bool { 113 in.values = append(in.values, task.(string)) 114 return len(in.values) >= maxBulkRows 115 } 116 117 func (in *dbInserter) Execute(bulk interface{}) { 118 values := bulk.([]string) 119 if len(values) == 0 { 120 return 121 } 122 123 stmtWithoutValues := in.stmt.prefix 124 valuesStr := strings.Join(values, ", ") 125 stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ") 126 if len(in.stmt.suffix) > 0 { 127 stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ") 128 } 129 result, err := in.sqlConn.Exec(stmt) 130 if in.resultHandler != nil { 131 in.resultHandler(result, err) 132 } else if err != nil { 133 logx.Errorf("sql: %s, error: %s", stmt, err) 134 } 135 } 136 137 func (in *dbInserter) RemoveAll() interface{} { 138 values := in.values 139 in.values = nil 140 return values 141 } 142 143 func parseInsertStmt(stmt string) (bulkStmt, error) { 144 lower := strings.ToLower(stmt) 145 pos := strings.Index(lower, valuesKeyword) 146 if pos <= 0 { 147 return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt) 148 } 149 150 var columns int 151 right := strings.LastIndexByte(lower[:pos], ')') 152 if right > 0 { 153 left := strings.LastIndexByte(lower[:right], '(') 154 if left > 0 { 155 values := lower[left+1 : right] 156 values = stringx.Filter(values, func(r rune) bool { 157 return r == ' ' || r == '\t' || r == '\r' || r == '\n' 158 }) 159 fields := strings.FieldsFunc(values, func(r rune) bool { 160 return r == ',' 161 }) 162 columns = len(fields) 163 } 164 } 165 166 var variables int 167 var valueFormat string 168 var suffix string 169 left := strings.IndexByte(lower[pos:], '(') 170 if left > 0 { 171 right = strings.IndexByte(lower[pos+left:], ')') 172 if right > 0 { 173 values := lower[pos+left : pos+left+right] 174 for _, x := range values { 175 if x == '?' { 176 variables++ 177 } 178 } 179 valueFormat = stmt[pos+left : pos+left+right+1] 180 suffix = strings.TrimSpace(stmt[pos+left+right+1:]) 181 } 182 } 183 184 if variables == 0 { 185 return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt) 186 } 187 if columns > 0 && columns != variables { 188 return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt) 189 } 190 191 return bulkStmt{ 192 prefix: stmt[:pos+len(valuesKeyword)], 193 valueFormat: valueFormat, 194 suffix: suffix, 195 }, nil 196 }