github.com/matrixorigin/matrixone@v1.2.0/pkg/util/export/etl/db/db_holder.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package db_holder
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"database/sql"
    21  	"encoding/csv"
    22  	"errors"
    23  	"fmt"
    24  	"strings"
    25  	"sync"
    26  	"sync/atomic"
    27  	"time"
    28  
    29  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    30  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    31  	"github.com/matrixorigin/matrixone/pkg/util/export/table"
    32  )
    33  
    34  var (
    35  	errNotReady = moerr.NewInvalidStateNoCtx("SQL writer's DB conn not ready")
    36  )
    37  
    38  // sqlWriterDBUser holds the db user for logger
    39  var (
    40  	sqlWriterDBUser atomic.Value
    41  	dbAddressFunc   atomic.Value
    42  
    43  	db            atomic.Value
    44  	dbRefreshTime time.Time
    45  
    46  	dbMux sync.Mutex
    47  
    48  	DBConnErrCount atomic.Uint32
    49  )
    50  
    51  const MOLoggerUser = "mo_logger"
    52  const MaxConnectionNumber = 1
    53  
    54  const DBConnRetryThreshold = 8
    55  
    56  const DBRefreshTime = time.Hour
    57  
    58  type DBUser struct {
    59  	UserName string
    60  	Password string
    61  }
    62  
    63  func SetSQLWriterDBUser(userName string, password string) {
    64  	user := &DBUser{
    65  		UserName: userName,
    66  		Password: password,
    67  	}
    68  	sqlWriterDBUser.Store(user)
    69  }
    70  func GetSQLWriterDBUser() (*DBUser, error) {
    71  	dbUser := sqlWriterDBUser.Load()
    72  	if dbUser == nil {
    73  		return nil, errNotReady
    74  	} else {
    75  		return sqlWriterDBUser.Load().(*DBUser), nil
    76  
    77  	}
    78  }
    79  
    80  func SetSQLWriterDBAddressFunc(f func(context.Context, bool) (string, error)) {
    81  	dbAddressFunc.Store(f)
    82  }
    83  
    84  func GetSQLWriterDBAddressFunc() func(context.Context, bool) (string, error) {
    85  	if f := dbAddressFunc.Load(); f == nil {
    86  		return nil
    87  	} else {
    88  		return f.(func(context.Context, bool) (string, error))
    89  	}
    90  }
    91  
    92  func SetDBConn(conn *sql.DB) {
    93  	db.Store(conn)
    94  	dbRefreshTime = time.Now().Add(DBRefreshTime)
    95  }
    96  
    97  func CloseDBConn() {
    98  	dbVal := db.Load()
    99  	if dbVal == nil {
   100  		return
   101  	}
   102  	dbConn := dbVal.(*sql.DB)
   103  	if dbConn != nil {
   104  		dbConn.Close()
   105  	}
   106  }
   107  
   108  func GetOrInitDBConn(forceNewConn bool, randomCN bool) (*sql.DB, error) {
   109  	dbMux.Lock()
   110  	defer dbMux.Unlock()
   111  	initFunc := func() error {
   112  		CloseDBConn()
   113  		dbUser, _ := GetSQLWriterDBUser()
   114  		if dbUser == nil {
   115  			return errNotReady
   116  		}
   117  
   118  		// TODO: trigger with new selected-CN, converge all connections
   119  		addressFunc := GetSQLWriterDBAddressFunc()
   120  		if addressFunc == nil {
   121  			return errNotReady
   122  		}
   123  		dbAddress, err := addressFunc(context.Background(), randomCN)
   124  		if err != nil {
   125  			return err
   126  		}
   127  		dsn :=
   128  			fmt.Sprintf("%s:%s@tcp(%s)/?readTimeout=10s&writeTimeout=15s&timeout=15s&maxAllowedPacket=0",
   129  				dbUser.UserName,
   130  				dbUser.Password,
   131  				dbAddress)
   132  		newDBConn, err := sql.Open("mysql", dsn)
   133  		if err != nil {
   134  			return err
   135  		}
   136  		if _, err := newDBConn.Exec("set session disable_txn_trace=1"); err != nil {
   137  			return errors.Join(err, newDBConn.Close())
   138  		}
   139  
   140  		//45s suggest by xzxiong
   141  		newDBConn.SetConnMaxLifetime(45 * time.Second)
   142  		newDBConn.SetMaxOpenConns(MaxConnectionNumber)
   143  		newDBConn.SetMaxIdleConns(MaxConnectionNumber)
   144  		SetDBConn(newDBConn)
   145  		return nil
   146  	}
   147  
   148  	if forceNewConn || db.Load() == nil {
   149  		err := initFunc()
   150  		if err != nil {
   151  			return nil, err
   152  		}
   153  	} else if time.Now().After(dbRefreshTime) {
   154  		err := initFunc()
   155  		if err != nil {
   156  			return nil, err
   157  		}
   158  	}
   159  
   160  	dbConn := db.Load().(*sql.DB)
   161  	return dbConn, nil
   162  }
   163  
   164  func WriteRowRecords(records [][]string, tbl *table.Table, timeout time.Duration) (int, error) {
   165  	if len(records) == 0 {
   166  		return 0, nil
   167  	}
   168  	var err error
   169  
   170  	var dbConn *sql.DB
   171  
   172  	if DBConnErrCount.Load() > DBConnRetryThreshold {
   173  		dbConn, err = GetOrInitDBConn(true, true)
   174  		DBConnErrCount.Store(0)
   175  	} else {
   176  		dbConn, err = GetOrInitDBConn(false, false)
   177  	}
   178  	if err != nil {
   179  		return 0, err
   180  	}
   181  
   182  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   183  	defer cancel()
   184  
   185  	err = bulkInsert(ctx, dbConn, records, tbl)
   186  	if err != nil {
   187  		DBConnErrCount.Add(1)
   188  		return 0, err
   189  	}
   190  
   191  	return len(records), nil
   192  }
   193  
   194  const initedSize = 4 * mpool.MB
   195  
   196  var bufPool = sync.Pool{New: func() any {
   197  	return bytes.NewBuffer(make([]byte, 0, initedSize))
   198  }}
   199  
   200  func getBuffer() *bytes.Buffer {
   201  	return bufPool.Get().(*bytes.Buffer)
   202  }
   203  
   204  func putBuffer(buf *bytes.Buffer) {
   205  	if buf != nil {
   206  		buf.Reset()
   207  		bufPool.Put(buf)
   208  	}
   209  }
   210  
   211  type CSVWriter struct {
   212  	ctx       context.Context
   213  	formatter *csv.Writer
   214  	buf       *bytes.Buffer
   215  }
   216  
   217  func NewCSVWriter(ctx context.Context) *CSVWriter {
   218  	buf := getBuffer()
   219  	buf.Reset()
   220  	writer := csv.NewWriter(buf)
   221  
   222  	w := &CSVWriter{
   223  		ctx:       ctx,
   224  		buf:       buf,
   225  		formatter: writer,
   226  	}
   227  	return w
   228  }
   229  
   230  func (w *CSVWriter) WriteStrings(record []string) error {
   231  	if err := w.formatter.Write(record); err != nil {
   232  		return err
   233  	}
   234  	return nil
   235  }
   236  
   237  func (w *CSVWriter) GetContent() string {
   238  	w.formatter.Flush() // Ensure all data is written to buffer
   239  	return w.buf.String()
   240  }
   241  
   242  func (w *CSVWriter) Release() {
   243  	if w.buf != nil {
   244  		w.buf.Reset()
   245  		w.buf = nil
   246  		w.formatter = nil
   247  	}
   248  	putBuffer(w.buf)
   249  }
   250  
   251  func bulkInsert(ctx context.Context, sqlDb *sql.DB, records [][]string, tbl *table.Table) error {
   252  	if len(records) == 0 {
   253  		return nil
   254  	}
   255  
   256  	csvWriter := NewCSVWriter(ctx)
   257  	defer csvWriter.Release() // Ensures that the buffer is returned to the pool
   258  
   259  	// Write each record of the chunk to the CSVWriter
   260  	for _, record := range records {
   261  		for i, col := range record {
   262  			record[i] = strings.ReplaceAll(strings.ReplaceAll(col, "\\", "\\\\"), "'", "''")
   263  		}
   264  		if err := csvWriter.WriteStrings(record); err != nil {
   265  			return err
   266  		}
   267  	}
   268  
   269  	csvData := csvWriter.GetContent()
   270  
   271  	loadSQL := fmt.Sprintf("LOAD DATA INLINE FORMAT='csv', DATA='%s' INTO TABLE %s.%s FIELDS TERMINATED BY ','", csvData, tbl.Database, tbl.Table)
   272  
   273  	// Use the transaction to execute the SQL command
   274  
   275  	_, execErr := sqlDb.Exec(loadSQL)
   276  
   277  	return execErr
   278  
   279  }
   280  
   281  type DBConnProvider func(forceNewConn bool, randomCN bool) (*sql.DB, error)
   282  
   283  func IsRecordExisted(ctx context.Context, record []string, tbl *table.Table, getDBConn DBConnProvider) (bool, error) {
   284  	dbConn, err := getDBConn(false, false)
   285  	if err != nil {
   286  		return false, err
   287  	}
   288  
   289  	if tbl.Table == "statement_info" {
   290  		const stmtIDIndex = 0           // Replace with actual index for statement ID if different
   291  		const statusIndex = 15          // Replace with actual index for status
   292  		const requestAtIndex = 12       // Replace with actual index for request_at
   293  		if len(record) <= statusIndex { // Use the largest index you will access
   294  			return false, nil
   295  		}
   296  		return isStatementExisted(ctx, dbConn, record[stmtIDIndex], record[statusIndex], record[requestAtIndex])
   297  	}
   298  
   299  	return false, nil
   300  }
   301  
   302  func isStatementExisted(ctx context.Context, db *sql.DB, stmtId string, status string, request_at string) (bool, error) {
   303  	var exists bool
   304  	query := "SELECT EXISTS(SELECT 1 FROM `system`.statement_info WHERE statement_id = ? AND status = ? AND request_at = ?)"
   305  	err := db.QueryRowContext(ctx, query, stmtId, status, request_at).Scan(&exists)
   306  	if err != nil {
   307  		return false, err
   308  	}
   309  	return exists, nil
   310  }
   311  
   312  var gLabels map[string]string = nil
   313  
   314  func SetLabelSelector(labels map[string]string) {
   315  	if len(labels) == 0 {
   316  		return
   317  	}
   318  	gLabels = make(map[string]string, len(labels)+1)
   319  	gLabels["account"] = "sys"
   320  	for k, v := range labels {
   321  		gLabels[k] = v
   322  	}
   323  }
   324  
   325  // GetLabelSelector
   326  // Tips: more details in route.RouteForSuperTenant function. It mainly depends on S1.
   327  // Tips: gLabels better contain {"account":"sys"}.
   328  // - Because clusterservice.Selector using clusterservice.globbing do regex-match in route.RouteForSuperTenant
   329  // - If you use labels{"account":"sys", "role":"ob"}, the Selector can match those pods, which have labels{"account":"*", "role":"ob"}
   330  func GetLabelSelector() map[string]string {
   331  	return gLabels
   332  }