github.com/matrixorigin/matrixone@v1.2.0/pkg/tests/txn/sql_client.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package txn
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"fmt"
    22  	"sync"
    23  
    24  	_ "github.com/go-sql-driver/mysql"
    25  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    26  	"github.com/matrixorigin/matrixone/pkg/tests/service"
    27  	"github.com/matrixorigin/matrixone/pkg/txn/client"
    28  )
    29  
    30  var (
    31  	createDB  = `create database if not exists kv_test`
    32  	useDB     = `use kv_test;`
    33  	createSql = `create table if not exists txn_test_kv (kv_key varchar(20) primary key, kv_value varchar(10))`
    34  )
    35  
    36  // sqlClient use sql client to connect to CN node and use a table to simulate rr test KV operations
    37  type sqlClient struct {
    38  	cn service.CNService
    39  }
    40  
    41  func newSQLClient(env service.Cluster) (Client, error) {
    42  	cn, err := env.GetCNServiceIndexed(0)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	db, err := sql.Open("mysql", fmt.Sprintf("dump:111@tcp(%s)/", cn.SQLAddress()))
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	_, err = db.Exec(createDB)
    53  	if err != nil {
    54  		return nil, errors.Join(err, db.Close())
    55  	}
    56  
    57  	_, err = db.Exec(useDB)
    58  	if err != nil {
    59  		return nil, errors.Join(err, db.Close())
    60  	}
    61  
    62  	_, err = db.Exec(createSql)
    63  	if err != nil {
    64  		return nil, errors.Join(err, db.Close())
    65  	}
    66  
    67  	return &sqlClient{
    68  		cn: cn,
    69  	}, errors.Join(err, db.Close())
    70  }
    71  
    72  func (c *sqlClient) NewTxn(...client.TxnOption) (Txn, error) {
    73  	return newSQLTxn(c.cn)
    74  }
    75  
    76  type sqlTxn struct {
    77  	db  *sql.DB
    78  	txn *sql.Tx
    79  
    80  	mu struct {
    81  		sync.Mutex
    82  		closed bool
    83  	}
    84  }
    85  
    86  func newSQLTxn(cn service.CNService) (Txn, error) {
    87  	db, err := sql.Open("mysql", fmt.Sprintf("dump:111@tcp(%s)/kv_test", cn.SQLAddress()))
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	txn, err := db.Begin()
    93  	if err != nil {
    94  		return nil, errors.Join(err, db.Close())
    95  	}
    96  	return &sqlTxn{
    97  		db:  db,
    98  		txn: txn,
    99  	}, nil
   100  }
   101  
   102  func (kop *sqlTxn) Commit() error {
   103  	kop.mu.Lock()
   104  	defer kop.mu.Unlock()
   105  	if kop.mu.closed {
   106  		return moerr.NewTxnClosed(context.Background(), nil)
   107  	}
   108  
   109  	kop.mu.closed = true
   110  	return errors.Join(kop.txn.Commit(), kop.db.Close())
   111  }
   112  
   113  func (kop *sqlTxn) Rollback() error {
   114  	kop.mu.Lock()
   115  	defer kop.mu.Unlock()
   116  	if kop.mu.closed {
   117  		return nil
   118  	}
   119  
   120  	kop.mu.closed = true
   121  	return errors.Join(kop.txn.Rollback(), kop.db.Close())
   122  }
   123  
   124  func (kop *sqlTxn) Read(key string) (_ string, err error) {
   125  	rows, err := kop.txn.Query(fmt.Sprintf("select kv_value from txn_test_kv where kv_key = '%s'", key))
   126  	if err != nil {
   127  		return "", err
   128  	}
   129  	defer func() {
   130  		err = errors.Join(err, rows.Close(), rows.Err())
   131  	}()
   132  	if !rows.Next() {
   133  		return "", nil
   134  	}
   135  	v := ""
   136  	if err := rows.Scan(&v); err != nil {
   137  		return "", err
   138  	}
   139  	return v, nil
   140  }
   141  
   142  func (kop *sqlTxn) Write(key, value string) error {
   143  	v, err := kop.Read(key)
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	if v == "" {
   149  		return kop.insert(key, value)
   150  	}
   151  	return kop.update(key, value)
   152  }
   153  
   154  func (kop *sqlTxn) ExecSQL(sql string) (sql.Result, error) {
   155  	return kop.txn.Exec(sql)
   156  }
   157  
   158  func (kop *sqlTxn) ExecSQLQuery(sql string) (*sql.Rows, error) {
   159  	return kop.txn.Query(sql)
   160  }
   161  
   162  func (kop *sqlTxn) insert(key, value string) error {
   163  	res, err := kop.txn.Exec(fmt.Sprintf("insert into txn_test_kv(kv_key, kv_value) values('%s', '%s')", key, value))
   164  	if err != nil {
   165  		return err
   166  	}
   167  	n, err := res.RowsAffected()
   168  	if err != nil {
   169  		panic(err)
   170  	}
   171  	if n != 1 {
   172  		panic(n)
   173  	}
   174  	return err
   175  }
   176  
   177  func (kop *sqlTxn) update(key, value string) error {
   178  	_, err := kop.txn.Exec(fmt.Sprintf("update txn_test_kv set kv_value = '%s' where kv_key = '%s'", value, key))
   179  	return err
   180  }