github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/election/storage_sql.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package election
    15  
    16  import (
    17  	"context"
    18  	"database/sql"
    19  	"encoding/json"
    20  	"fmt"
    21  
    22  	"github.com/pingcap/tiflow/engine/pkg/meta/mock"
    23  	"github.com/pingcap/tiflow/pkg/errors"
    24  )
    25  
    26  var _ Storage = &SQLStorage{}
    27  
    28  const (
    29  	sqlRecordID     = 1
    30  	sqlCreateTable  = "CREATE TABLE IF NOT EXISTS %s (id int NOT NULL, version bigint NOT NULL, record text NOT NULL, PRIMARY KEY (id))"
    31  	sqlQueryRecord  = "SELECT version, record FROM %s WHERE id = ?"
    32  	sqlInsertRecord = "INSERT INTO %s (id, version, record) VALUES (?, ?, ?)"
    33  	sqlUpdateRecord = "UPDATE %s SET version = ?, record = ? WHERE id = ? AND version = ?"
    34  )
    35  
    36  // SQLStorage is a storage implementation based on SQL database.
    37  type SQLStorage struct {
    38  	db        *sql.DB
    39  	tableName string
    40  }
    41  
    42  // NewSQLStorage creates a new SQLStorage.
    43  func NewSQLStorage(db *sql.DB, tableName string) (*SQLStorage, error) {
    44  	if _, err := db.Exec(fmt.Sprintf(sqlCreateTable, tableName)); err != nil {
    45  		return nil, errors.Trace(err)
    46  	}
    47  	return &SQLStorage{
    48  		db:        db,
    49  		tableName: tableName,
    50  	}, nil
    51  }
    52  
    53  // NewInMemorySQLStorage creates a new SQLStorage in memory based on SQLite.
    54  func NewInMemorySQLStorage(dbName string, tableName string) (*SQLStorage, error) {
    55  	dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName)
    56  
    57  	db, err := sql.Open(mock.ThreadeSafeSqliteDriverName, dsn)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	return NewSQLStorage(db, tableName)
    62  }
    63  
    64  // Get implements Storage.Get.
    65  func (s *SQLStorage) Get(ctx context.Context) (*Record, error) {
    66  	var (
    67  		version     int64
    68  		recordBytes []byte
    69  	)
    70  	err := s.db.QueryRowContext(ctx,
    71  		fmt.Sprintf(sqlQueryRecord, s.tableName), sqlRecordID).
    72  		Scan(&version, &recordBytes)
    73  	if err != nil {
    74  		if err == sql.ErrNoRows {
    75  			err = nil
    76  		}
    77  		return &Record{}, errors.Trace(err)
    78  	}
    79  	var record Record
    80  	if err := json.Unmarshal(recordBytes, &record); err != nil {
    81  		return nil, errors.Trace(err)
    82  	}
    83  	record.Version = version
    84  	return &record, nil
    85  }
    86  
    87  // Update implements Storage.Update.
    88  func (s *SQLStorage) Update(ctx context.Context, record *Record, _ bool) error {
    89  	recordBytes, err := json.Marshal(&record)
    90  	if err != nil {
    91  		return errors.Trace(err)
    92  	}
    93  
    94  	if record.Version == 0 {
    95  		_, err := s.db.ExecContext(ctx,
    96  			fmt.Sprintf(sqlInsertRecord, s.tableName), sqlRecordID, record.Version+1, recordBytes)
    97  		return errors.Trace(err)
    98  	}
    99  
   100  	result, err := s.db.ExecContext(ctx,
   101  		fmt.Sprintf(sqlUpdateRecord, s.tableName), record.Version+1, recordBytes, sqlRecordID, record.Version)
   102  	if err != nil {
   103  		return errors.Trace(err)
   104  	}
   105  	rowsAffected, err := result.RowsAffected()
   106  	if err != nil {
   107  		return errors.Trace(err)
   108  	}
   109  	if rowsAffected != 1 {
   110  		return errors.ErrElectionRecordConflict.GenWithStackByArgs()
   111  	}
   112  	return nil
   113  }