github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/bulkinserter.go (about) 1 package sqlx 2 3 import ( 4 "database/sql" 5 "fmt" 6 "strings" 7 "time" 8 9 "github.com/lingyao2333/mo-zero/core/executors" 10 "github.com/lingyao2333/mo-zero/core/logx" 11 "github.com/lingyao2333/mo-zero/core/stringx" 12 ) 13 14 const ( 15 flushInterval = time.Second 16 maxBulkRows = 1000 17 valuesKeyword = "values" 18 ) 19 20 var emptyBulkStmt bulkStmt 21 22 type ( 23 // ResultHandler defines the method of result handlers. 24 ResultHandler func(sql.Result, error) 25 26 // A BulkInserter is used to batch insert records. 27 // Postgresql is not supported yet, because of the sql is formated with symbol `$`. 28 // Oracle is not supported yet, because of the sql is formated with symbol `:`. 29 BulkInserter struct { 30 executor *executors.PeriodicalExecutor 31 inserter *dbInserter 32 stmt bulkStmt 33 } 34 35 bulkStmt struct { 36 prefix string 37 valueFormat string 38 suffix string 39 } 40 ) 41 42 // NewBulkInserter returns a BulkInserter. 43 func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) { 44 bkStmt, err := parseInsertStmt(stmt) 45 if err != nil { 46 return nil, err 47 } 48 49 inserter := &dbInserter{ 50 sqlConn: sqlConn, 51 stmt: bkStmt, 52 } 53 54 return &BulkInserter{ 55 executor: executors.NewPeriodicalExecutor(flushInterval, inserter), 56 inserter: inserter, 57 stmt: bkStmt, 58 }, nil 59 } 60 61 // Flush flushes all the pending records. 62 func (bi *BulkInserter) Flush() { 63 bi.executor.Flush() 64 } 65 66 // Insert inserts given args. 67 func (bi *BulkInserter) Insert(args ...interface{}) error { 68 value, err := format(bi.stmt.valueFormat, args...) 69 if err != nil { 70 return err 71 } 72 73 bi.executor.Add(value) 74 75 return nil 76 } 77 78 // SetResultHandler sets the given handler. 79 func (bi *BulkInserter) SetResultHandler(handler ResultHandler) { 80 bi.executor.Sync(func() { 81 bi.inserter.resultHandler = handler 82 }) 83 } 84 85 // UpdateOrDelete runs update or delete queries, which flushes pending records first. 86 func (bi *BulkInserter) UpdateOrDelete(fn func()) { 87 bi.executor.Flush() 88 fn() 89 } 90 91 // UpdateStmt updates the insert statement. 92 func (bi *BulkInserter) UpdateStmt(stmt string) error { 93 bkStmt, err := parseInsertStmt(stmt) 94 if err != nil { 95 return err 96 } 97 98 bi.executor.Flush() 99 bi.executor.Sync(func() { 100 bi.inserter.stmt = bkStmt 101 }) 102 103 return nil 104 } 105 106 type dbInserter struct { 107 sqlConn SqlConn 108 stmt bulkStmt 109 values []string 110 resultHandler ResultHandler 111 } 112 113 func (in *dbInserter) AddTask(task interface{}) bool { 114 in.values = append(in.values, task.(string)) 115 return len(in.values) >= maxBulkRows 116 } 117 118 func (in *dbInserter) Execute(bulk interface{}) { 119 values := bulk.([]string) 120 if len(values) == 0 { 121 return 122 } 123 124 stmtWithoutValues := in.stmt.prefix 125 valuesStr := strings.Join(values, ", ") 126 stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ") 127 if len(in.stmt.suffix) > 0 { 128 stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ") 129 } 130 result, err := in.sqlConn.Exec(stmt) 131 if in.resultHandler != nil { 132 in.resultHandler(result, err) 133 } else if err != nil { 134 logx.Errorf("sql: %s, error: %s", stmt, err) 135 } 136 } 137 138 func (in *dbInserter) RemoveAll() interface{} { 139 values := in.values 140 in.values = nil 141 return values 142 } 143 144 func parseInsertStmt(stmt string) (bulkStmt, error) { 145 lower := strings.ToLower(stmt) 146 pos := strings.Index(lower, valuesKeyword) 147 if pos <= 0 { 148 return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt) 149 } 150 151 var columns int 152 right := strings.LastIndexByte(lower[:pos], ')') 153 if right > 0 { 154 left := strings.LastIndexByte(lower[:right], '(') 155 if left > 0 { 156 values := lower[left+1 : right] 157 values = stringx.Filter(values, func(r rune) bool { 158 return r == ' ' || r == '\t' || r == '\r' || r == '\n' 159 }) 160 fields := strings.FieldsFunc(values, func(r rune) bool { 161 return r == ',' 162 }) 163 columns = len(fields) 164 } 165 } 166 167 var variables int 168 var valueFormat string 169 var suffix string 170 left := strings.IndexByte(lower[pos:], '(') 171 if left > 0 { 172 right = strings.IndexByte(lower[pos+left:], ')') 173 if right > 0 { 174 values := lower[pos+left : pos+left+right] 175 for _, x := range values { 176 if x == '?' { 177 variables++ 178 } 179 } 180 valueFormat = stmt[pos+left : pos+left+right+1] 181 suffix = strings.TrimSpace(stmt[pos+left+right+1:]) 182 } 183 } 184 185 if variables == 0 { 186 return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt) 187 } 188 if columns > 0 && columns != variables { 189 return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt) 190 } 191 192 return bulkStmt{ 193 prefix: stmt[:pos+len(valuesKeyword)], 194 valueFormat: valueFormat, 195 suffix: suffix, 196 }, nil 197 }