github.com/matrixorigin/matrixone@v1.2.0/pkg/util/export/etl/db/db_holder.go (about) 1 // Copyright 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package db_holder 16 17 import ( 18 "bytes" 19 "context" 20 "database/sql" 21 "encoding/csv" 22 "errors" 23 "fmt" 24 "strings" 25 "sync" 26 "sync/atomic" 27 "time" 28 29 "github.com/matrixorigin/matrixone/pkg/common/moerr" 30 "github.com/matrixorigin/matrixone/pkg/common/mpool" 31 "github.com/matrixorigin/matrixone/pkg/util/export/table" 32 ) 33 34 var ( 35 errNotReady = moerr.NewInvalidStateNoCtx("SQL writer's DB conn not ready") 36 ) 37 38 // sqlWriterDBUser holds the db user for logger 39 var ( 40 sqlWriterDBUser atomic.Value 41 dbAddressFunc atomic.Value 42 43 db atomic.Value 44 dbRefreshTime time.Time 45 46 dbMux sync.Mutex 47 48 DBConnErrCount atomic.Uint32 49 ) 50 51 const MOLoggerUser = "mo_logger" 52 const MaxConnectionNumber = 1 53 54 const DBConnRetryThreshold = 8 55 56 const DBRefreshTime = time.Hour 57 58 type DBUser struct { 59 UserName string 60 Password string 61 } 62 63 func SetSQLWriterDBUser(userName string, password string) { 64 user := &DBUser{ 65 UserName: userName, 66 Password: password, 67 } 68 sqlWriterDBUser.Store(user) 69 } 70 func GetSQLWriterDBUser() (*DBUser, error) { 71 dbUser := sqlWriterDBUser.Load() 72 if dbUser == nil { 73 return nil, errNotReady 74 } else { 75 return sqlWriterDBUser.Load().(*DBUser), nil 76 77 } 78 } 79 80 func SetSQLWriterDBAddressFunc(f func(context.Context, bool) (string, error)) { 81 dbAddressFunc.Store(f) 82 } 83 84 func GetSQLWriterDBAddressFunc() func(context.Context, bool) (string, error) { 85 if f := dbAddressFunc.Load(); f == nil { 86 return nil 87 } else { 88 return f.(func(context.Context, bool) (string, error)) 89 } 90 } 91 92 func SetDBConn(conn *sql.DB) { 93 db.Store(conn) 94 dbRefreshTime = time.Now().Add(DBRefreshTime) 95 } 96 97 func CloseDBConn() { 98 dbVal := db.Load() 99 if dbVal == nil { 100 return 101 } 102 dbConn := dbVal.(*sql.DB) 103 if dbConn != nil { 104 dbConn.Close() 105 } 106 } 107 108 func GetOrInitDBConn(forceNewConn bool, randomCN bool) (*sql.DB, error) { 109 dbMux.Lock() 110 defer dbMux.Unlock() 111 initFunc := func() error { 112 CloseDBConn() 113 dbUser, _ := GetSQLWriterDBUser() 114 if dbUser == nil { 115 return errNotReady 116 } 117 118 // TODO: trigger with new selected-CN, converge all connections 119 addressFunc := GetSQLWriterDBAddressFunc() 120 if addressFunc == nil { 121 return errNotReady 122 } 123 dbAddress, err := addressFunc(context.Background(), randomCN) 124 if err != nil { 125 return err 126 } 127 dsn := 128 fmt.Sprintf("%s:%s@tcp(%s)/?readTimeout=10s&writeTimeout=15s&timeout=15s&maxAllowedPacket=0", 129 dbUser.UserName, 130 dbUser.Password, 131 dbAddress) 132 newDBConn, err := sql.Open("mysql", dsn) 133 if err != nil { 134 return err 135 } 136 if _, err := newDBConn.Exec("set session disable_txn_trace=1"); err != nil { 137 return errors.Join(err, newDBConn.Close()) 138 } 139 140 //45s suggest by xzxiong 141 newDBConn.SetConnMaxLifetime(45 * time.Second) 142 newDBConn.SetMaxOpenConns(MaxConnectionNumber) 143 newDBConn.SetMaxIdleConns(MaxConnectionNumber) 144 SetDBConn(newDBConn) 145 return nil 146 } 147 148 if forceNewConn || db.Load() == nil { 149 err := initFunc() 150 if err != nil { 151 return nil, err 152 } 153 } else if time.Now().After(dbRefreshTime) { 154 err := initFunc() 155 if err != nil { 156 return nil, err 157 } 158 } 159 160 dbConn := db.Load().(*sql.DB) 161 return dbConn, nil 162 } 163 164 func WriteRowRecords(records [][]string, tbl *table.Table, timeout time.Duration) (int, error) { 165 if len(records) == 0 { 166 return 0, nil 167 } 168 var err error 169 170 var dbConn *sql.DB 171 172 if DBConnErrCount.Load() > DBConnRetryThreshold { 173 dbConn, err = GetOrInitDBConn(true, true) 174 DBConnErrCount.Store(0) 175 } else { 176 dbConn, err = GetOrInitDBConn(false, false) 177 } 178 if err != nil { 179 return 0, err 180 } 181 182 ctx, cancel := context.WithTimeout(context.Background(), timeout) 183 defer cancel() 184 185 err = bulkInsert(ctx, dbConn, records, tbl) 186 if err != nil { 187 DBConnErrCount.Add(1) 188 return 0, err 189 } 190 191 return len(records), nil 192 } 193 194 const initedSize = 4 * mpool.MB 195 196 var bufPool = sync.Pool{New: func() any { 197 return bytes.NewBuffer(make([]byte, 0, initedSize)) 198 }} 199 200 func getBuffer() *bytes.Buffer { 201 return bufPool.Get().(*bytes.Buffer) 202 } 203 204 func putBuffer(buf *bytes.Buffer) { 205 if buf != nil { 206 buf.Reset() 207 bufPool.Put(buf) 208 } 209 } 210 211 type CSVWriter struct { 212 ctx context.Context 213 formatter *csv.Writer 214 buf *bytes.Buffer 215 } 216 217 func NewCSVWriter(ctx context.Context) *CSVWriter { 218 buf := getBuffer() 219 buf.Reset() 220 writer := csv.NewWriter(buf) 221 222 w := &CSVWriter{ 223 ctx: ctx, 224 buf: buf, 225 formatter: writer, 226 } 227 return w 228 } 229 230 func (w *CSVWriter) WriteStrings(record []string) error { 231 if err := w.formatter.Write(record); err != nil { 232 return err 233 } 234 return nil 235 } 236 237 func (w *CSVWriter) GetContent() string { 238 w.formatter.Flush() // Ensure all data is written to buffer 239 return w.buf.String() 240 } 241 242 func (w *CSVWriter) Release() { 243 if w.buf != nil { 244 w.buf.Reset() 245 w.buf = nil 246 w.formatter = nil 247 } 248 putBuffer(w.buf) 249 } 250 251 func bulkInsert(ctx context.Context, sqlDb *sql.DB, records [][]string, tbl *table.Table) error { 252 if len(records) == 0 { 253 return nil 254 } 255 256 csvWriter := NewCSVWriter(ctx) 257 defer csvWriter.Release() // Ensures that the buffer is returned to the pool 258 259 // Write each record of the chunk to the CSVWriter 260 for _, record := range records { 261 for i, col := range record { 262 record[i] = strings.ReplaceAll(strings.ReplaceAll(col, "\\", "\\\\"), "'", "''") 263 } 264 if err := csvWriter.WriteStrings(record); err != nil { 265 return err 266 } 267 } 268 269 csvData := csvWriter.GetContent() 270 271 loadSQL := fmt.Sprintf("LOAD DATA INLINE FORMAT='csv', DATA='%s' INTO TABLE %s.%s FIELDS TERMINATED BY ','", csvData, tbl.Database, tbl.Table) 272 273 // Use the transaction to execute the SQL command 274 275 _, execErr := sqlDb.Exec(loadSQL) 276 277 return execErr 278 279 } 280 281 type DBConnProvider func(forceNewConn bool, randomCN bool) (*sql.DB, error) 282 283 func IsRecordExisted(ctx context.Context, record []string, tbl *table.Table, getDBConn DBConnProvider) (bool, error) { 284 dbConn, err := getDBConn(false, false) 285 if err != nil { 286 return false, err 287 } 288 289 if tbl.Table == "statement_info" { 290 const stmtIDIndex = 0 // Replace with actual index for statement ID if different 291 const statusIndex = 15 // Replace with actual index for status 292 const requestAtIndex = 12 // Replace with actual index for request_at 293 if len(record) <= statusIndex { // Use the largest index you will access 294 return false, nil 295 } 296 return isStatementExisted(ctx, dbConn, record[stmtIDIndex], record[statusIndex], record[requestAtIndex]) 297 } 298 299 return false, nil 300 } 301 302 func isStatementExisted(ctx context.Context, db *sql.DB, stmtId string, status string, request_at string) (bool, error) { 303 var exists bool 304 query := "SELECT EXISTS(SELECT 1 FROM `system`.statement_info WHERE statement_id = ? AND status = ? AND request_at = ?)" 305 err := db.QueryRowContext(ctx, query, stmtId, status, request_at).Scan(&exists) 306 if err != nil { 307 return false, err 308 } 309 return exists, nil 310 } 311 312 var gLabels map[string]string = nil 313 314 func SetLabelSelector(labels map[string]string) { 315 if len(labels) == 0 { 316 return 317 } 318 gLabels = make(map[string]string, len(labels)+1) 319 gLabels["account"] = "sys" 320 for k, v := range labels { 321 gLabels[k] = v 322 } 323 } 324 325 // GetLabelSelector 326 // Tips: more details in route.RouteForSuperTenant function. It mainly depends on S1. 327 // Tips: gLabels better contain {"account":"sys"}. 328 // - Because clusterservice.Selector using clusterservice.globbing do regex-match in route.RouteForSuperTenant 329 // - If you use labels{"account":"sys", "role":"ob"}, the Selector can match those pods, which have labels{"account":"*", "role":"ob"} 330 func GetLabelSelector() map[string]string { 331 return gLabels 332 }