github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/election/storage_sql.go (about) 1 // Copyright 2022 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package election 15 16 import ( 17 "context" 18 "database/sql" 19 "encoding/json" 20 "fmt" 21 22 "github.com/pingcap/tiflow/engine/pkg/meta/mock" 23 "github.com/pingcap/tiflow/pkg/errors" 24 ) 25 26 var _ Storage = &SQLStorage{} 27 28 const ( 29 sqlRecordID = 1 30 sqlCreateTable = "CREATE TABLE IF NOT EXISTS %s (id int NOT NULL, version bigint NOT NULL, record text NOT NULL, PRIMARY KEY (id))" 31 sqlQueryRecord = "SELECT version, record FROM %s WHERE id = ?" 32 sqlInsertRecord = "INSERT INTO %s (id, version, record) VALUES (?, ?, ?)" 33 sqlUpdateRecord = "UPDATE %s SET version = ?, record = ? WHERE id = ? AND version = ?" 34 ) 35 36 // SQLStorage is a storage implementation based on SQL database. 37 type SQLStorage struct { 38 db *sql.DB 39 tableName string 40 } 41 42 // NewSQLStorage creates a new SQLStorage. 43 func NewSQLStorage(db *sql.DB, tableName string) (*SQLStorage, error) { 44 if _, err := db.Exec(fmt.Sprintf(sqlCreateTable, tableName)); err != nil { 45 return nil, errors.Trace(err) 46 } 47 return &SQLStorage{ 48 db: db, 49 tableName: tableName, 50 }, nil 51 } 52 53 // NewInMemorySQLStorage creates a new SQLStorage in memory based on SQLite. 54 func NewInMemorySQLStorage(dbName string, tableName string) (*SQLStorage, error) { 55 dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName) 56 57 db, err := sql.Open(mock.ThreadeSafeSqliteDriverName, dsn) 58 if err != nil { 59 return nil, err 60 } 61 return NewSQLStorage(db, tableName) 62 } 63 64 // Get implements Storage.Get. 65 func (s *SQLStorage) Get(ctx context.Context) (*Record, error) { 66 var ( 67 version int64 68 recordBytes []byte 69 ) 70 err := s.db.QueryRowContext(ctx, 71 fmt.Sprintf(sqlQueryRecord, s.tableName), sqlRecordID). 72 Scan(&version, &recordBytes) 73 if err != nil { 74 if err == sql.ErrNoRows { 75 err = nil 76 } 77 return &Record{}, errors.Trace(err) 78 } 79 var record Record 80 if err := json.Unmarshal(recordBytes, &record); err != nil { 81 return nil, errors.Trace(err) 82 } 83 record.Version = version 84 return &record, nil 85 } 86 87 // Update implements Storage.Update. 88 func (s *SQLStorage) Update(ctx context.Context, record *Record, _ bool) error { 89 recordBytes, err := json.Marshal(&record) 90 if err != nil { 91 return errors.Trace(err) 92 } 93 94 if record.Version == 0 { 95 _, err := s.db.ExecContext(ctx, 96 fmt.Sprintf(sqlInsertRecord, s.tableName), sqlRecordID, record.Version+1, recordBytes) 97 return errors.Trace(err) 98 } 99 100 result, err := s.db.ExecContext(ctx, 101 fmt.Sprintf(sqlUpdateRecord, s.tableName), record.Version+1, recordBytes, sqlRecordID, record.Version) 102 if err != nil { 103 return errors.Trace(err) 104 } 105 rowsAffected, err := result.RowsAffected() 106 if err != nil { 107 return errors.Trace(err) 108 } 109 if rowsAffected != 1 { 110 return errors.ErrElectionRecordConflict.GenWithStackByArgs() 111 } 112 return nil 113 }