github.com/pingcap/chaos@v0.0.0-20190710112158-c86faf4b3719/db/tidb/bank.go (about)

     1  package tidb
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"encoding/json"
     7  	"fmt"
     8  	"log"
     9  	"math/rand"
    10  	"sort"
    11  	"time"
    12  
    13  	"github.com/anishathalye/porcupine"
    14  	pchecker "github.com/pingcap/chaos/pkg/check/porcupine"
    15  	"github.com/pingcap/chaos/pkg/core"
    16  	"github.com/pingcap/chaos/pkg/history"
    17  
    18  	// use mysql
    19  	_ "github.com/go-sql-driver/mysql"
    20  )
    21  
    22  const (
    23  	accountNum  = 5
    24  	initBalance = int64(1000)
    25  )
    26  
    27  type bankClient struct {
    28  	db         *sql.DB
    29  	r          *rand.Rand
    30  	accountNum int
    31  }
    32  
    33  func (c *bankClient) SetUp(ctx context.Context, nodes []string, node string) error {
    34  	c.r = rand.New(rand.NewSource(time.Now().UnixNano()))
    35  	db, err := sql.Open("mysql", fmt.Sprintf("root@tcp(%s:4000)/test", node))
    36  	if err != nil {
    37  		return err
    38  	}
    39  	c.db = db
    40  
    41  	db.SetMaxIdleConns(1 + c.accountNum)
    42  
    43  	// Do SetUp in the first node
    44  	if node != nodes[0] {
    45  		return nil
    46  	}
    47  
    48  	log.Printf("begin to create table accounts on node %s", node)
    49  	sql := `create table if not exists accounts
    50  			(id     int not null primary key,
    51  			balance bigint not null)`
    52  
    53  	if _, err = db.ExecContext(ctx, sql); err != nil {
    54  		return err
    55  	}
    56  
    57  	for i := 0; i < c.accountNum; i++ {
    58  		if _, err = db.ExecContext(ctx, "insert into accounts values (?, ?)", i, initBalance); err != nil {
    59  			return err
    60  		}
    61  	}
    62  
    63  	return nil
    64  }
    65  
    66  func (c *bankClient) TearDown(ctx context.Context, nodes []string, node string) error {
    67  	return c.db.Close()
    68  }
    69  
    70  func (c *bankClient) invokeRead(ctx context.Context, r bankRequest) bankResponse {
    71  	txn, err := c.db.Begin()
    72  
    73  	if err != nil {
    74  		return bankResponse{Unknown: true}
    75  	}
    76  	defer txn.Rollback()
    77  
    78  	var tso uint64
    79  	if err = txn.QueryRow("select @@tidb_current_ts").Scan(&tso); err != nil {
    80  		return bankResponse{Unknown: true}
    81  	}
    82  
    83  	rows, err := txn.QueryContext(ctx, "select balance from accounts")
    84  	if err != nil {
    85  		return bankResponse{Unknown: true}
    86  	}
    87  	defer rows.Close()
    88  
    89  	balances := make([]int64, 0, c.accountNum)
    90  	for rows.Next() {
    91  		var v int64
    92  		if err = rows.Scan(&v); err != nil {
    93  			return bankResponse{Unknown: true}
    94  		}
    95  		balances = append(balances, v)
    96  	}
    97  
    98  	return bankResponse{Balances: balances, Tso: tso}
    99  }
   100  
   101  func (c *bankClient) Invoke(ctx context.Context, node string, r interface{}) interface{} {
   102  	arg := r.(bankRequest)
   103  	if arg.Op == 0 {
   104  		return c.invokeRead(ctx, arg)
   105  	}
   106  
   107  	txn, err := c.db.Begin()
   108  
   109  	if err != nil {
   110  		return bankResponse{Ok: false}
   111  	}
   112  	defer txn.Rollback()
   113  
   114  	var (
   115  		fromBalance int64
   116  		toBalance   int64
   117  		tso         uint64
   118  	)
   119  
   120  	if err = txn.QueryRow("select @@tidb_current_ts").Scan(&tso); err != nil {
   121  		return bankResponse{Ok: false}
   122  	}
   123  
   124  	if err = txn.QueryRowContext(ctx, "select balance from accounts where id = ? for update", arg.From).Scan(&fromBalance); err != nil {
   125  		return bankResponse{Ok: false}
   126  	}
   127  
   128  	if err = txn.QueryRowContext(ctx, "select balance from accounts where id = ? for update", arg.To).Scan(&toBalance); err != nil {
   129  		return bankResponse{Ok: false}
   130  	}
   131  
   132  	if fromBalance < arg.Amount {
   133  		return bankResponse{Ok: false}
   134  	}
   135  
   136  	if _, err = txn.ExecContext(ctx, "update accounts set balance = balance - ? where id = ?", arg.Amount, arg.From); err != nil {
   137  		return bankResponse{Ok: false}
   138  	}
   139  
   140  	if _, err = txn.ExecContext(ctx, "update accounts set balance = balance + ? where id = ?", arg.Amount, arg.To); err != nil {
   141  		return bankResponse{Ok: false}
   142  	}
   143  
   144  	if err = txn.Commit(); err != nil {
   145  		return bankResponse{Unknown: true, Tso: tso, FromBalance: fromBalance, ToBalance: toBalance}
   146  	}
   147  
   148  	return bankResponse{Ok: true, Tso: tso, FromBalance: fromBalance, ToBalance: toBalance}
   149  }
   150  
   151  func (c *bankClient) NextRequest() interface{} {
   152  	r := bankRequest{
   153  		Op: c.r.Int() % 2,
   154  	}
   155  	if r.Op == 0 {
   156  		return r
   157  	}
   158  
   159  	r.From = c.r.Intn(c.accountNum)
   160  
   161  	r.To = c.r.Intn(c.accountNum)
   162  	if r.From == r.To {
   163  		r.To = (r.To + 1) % c.accountNum
   164  	}
   165  
   166  	r.Amount = 5
   167  	return r
   168  }
   169  
   170  func (c *bankClient) DumpState(ctx context.Context) (interface{}, error) {
   171  	txn, err := c.db.Begin()
   172  
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	defer txn.Rollback()
   177  
   178  	rows, err := txn.QueryContext(ctx, "select balance from accounts")
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	defer rows.Close()
   183  
   184  	balances := make([]int64, 0, c.accountNum)
   185  	for rows.Next() {
   186  		var v int64
   187  		if err = rows.Scan(&v); err != nil {
   188  			return nil, err
   189  		}
   190  		balances = append(balances, v)
   191  	}
   192  	return balances, nil
   193  }
   194  
   195  // BankClientCreator creates a bank test client for tidb.
   196  type BankClientCreator struct {
   197  }
   198  
   199  // Create creates a client.
   200  func (BankClientCreator) Create(node string) core.Client {
   201  	return &bankClient{
   202  		accountNum: accountNum,
   203  	}
   204  }
   205  
   206  type bankRequest struct {
   207  	// 0: read
   208  	// 1: transfer
   209  	Op     int
   210  	From   int
   211  	To     int
   212  	Amount int64
   213  }
   214  
   215  type bankResponse struct {
   216  	// Transaction start timestamp
   217  	Tso uint64
   218  	// read result
   219  	Balances []int64
   220  	// transfer ok or not
   221  	Ok bool
   222  	// FromBalance is the previous from balance before transafer
   223  	FromBalance int64
   224  	// ToBalance is the previous to balance before transafer
   225  	ToBalance int64
   226  	// read/transfer unknown
   227  	Unknown bool
   228  }
   229  
   230  var _ core.UnknownResponse = (*bankResponse)(nil)
   231  
   232  // IsUnknown implements UnknownResponse interface
   233  func (r bankResponse) IsUnknown() bool {
   234  	return r.Unknown
   235  }
   236  
   237  func balancesEqual(a, b []int64) bool {
   238  	if len(a) != len(b) {
   239  		return false
   240  	}
   241  
   242  	for i := 0; i < len(a); i++ {
   243  		if a[i] != b[i] {
   244  			return false
   245  		}
   246  	}
   247  
   248  	return true
   249  }
   250  
   251  type bank struct {
   252  	accountNum    int
   253  	perparedState *[]int64
   254  }
   255  
   256  func (b *bank) Prepare(state interface{}) {
   257  	s := state.([]int64)
   258  	b.perparedState = &s
   259  }
   260  
   261  func (b *bank) Init() interface{} {
   262  	if b.perparedState != nil {
   263  		return *b.perparedState
   264  	}
   265  
   266  	// Or make a brand new state.
   267  	v := make([]int64, b.accountNum)
   268  	for i := 0; i < b.accountNum; i++ {
   269  		v[i] = initBalance
   270  	}
   271  	return v
   272  }
   273  
   274  func (*bank) Step(state interface{}, input interface{}, output interface{}) (bool, interface{}) {
   275  	st := state.([]int64)
   276  	inp := input.(bankRequest)
   277  	out := output.(bankResponse)
   278  
   279  	if inp.Op == 0 {
   280  		// read
   281  		ok := out.Unknown || balancesEqual(st, out.Balances)
   282  		return ok, state
   283  	}
   284  
   285  	// for transfer
   286  	if !out.Ok && !out.Unknown {
   287  		return true, state
   288  	}
   289  
   290  	newSt := append([]int64{}, st...)
   291  	newSt[inp.From] -= inp.Amount
   292  	newSt[inp.To] += inp.Amount
   293  	return out.Ok || out.Unknown, newSt
   294  }
   295  
   296  func (*bank) Equal(state1, state2 interface{}) bool {
   297  	st1 := state1.([]int64)
   298  	st2 := state2.([]int64)
   299  	return balancesEqual(st1, st2)
   300  }
   301  
   302  func (*bank) Name() string {
   303  	return "tidb_bank"
   304  }
   305  
   306  // BankModel is the model of bank in TiDB
   307  func BankModel() core.Model {
   308  	return &bank{
   309  		accountNum: accountNum,
   310  	}
   311  }
   312  
   313  type bankParser struct{}
   314  
   315  // OnRequest impls history.RecordParser.OnRequest
   316  func (p bankParser) OnRequest(data json.RawMessage) (interface{}, error) {
   317  	r := bankRequest{}
   318  	err := json.Unmarshal(data, &r)
   319  	return r, err
   320  }
   321  
   322  // OnResponse impls history.RecordParser.OnRequest
   323  func (p bankParser) OnResponse(data json.RawMessage) (interface{}, error) {
   324  	r := bankResponse{}
   325  	err := json.Unmarshal(data, &r)
   326  	if r.Unknown {
   327  		return nil, err
   328  	}
   329  	return r, err
   330  }
   331  
   332  // OnNoopResponse impls history.RecordParser.OnRequest
   333  func (p bankParser) OnNoopResponse() interface{} {
   334  	return bankResponse{Unknown: true}
   335  }
   336  
   337  func (p bankParser) OnState(data json.RawMessage) (interface{}, error) {
   338  	var state []int64
   339  	err := json.Unmarshal(data, &state)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  	return state, nil
   344  }
   345  
   346  // BankParser parses a history of bank operations.
   347  func BankParser() history.RecordParser {
   348  	return bankParser{}
   349  }
   350  
   351  type tsoEvent struct {
   352  	Tso uint64
   353  	Op  int
   354  	// For transfer
   355  	From        int
   356  	To          int
   357  	FromBalance int64
   358  	ToBalance   int64
   359  	Amount      int64
   360  	// For read
   361  	Balances []int64
   362  
   363  	Unknown bool
   364  }
   365  
   366  func (e *tsoEvent) String() string {
   367  	if e.Op == 0 {
   368  		return fmt.Sprintf("%d, read %v, unknown %v", e.Tso, e.Balances, e.Unknown)
   369  	}
   370  
   371  	return fmt.Sprintf("%d, transafer %d %d(%d) -> %d(%d), unknown %v", e.Tso, e.Amount, e.From, e.FromBalance, e.To, e.ToBalance, e.Unknown)
   372  }
   373  
   374  // GetBalances gets the two balances of account before and after the transfer.
   375  func (e *tsoEvent) GetBalances(index int) (int64, int64) {
   376  	if index == e.From {
   377  		return e.FromBalance, e.FromBalance - e.Amount
   378  	}
   379  
   380  	return e.ToBalance, e.ToBalance + e.Amount
   381  }
   382  
   383  type tsoEvents []*tsoEvent
   384  
   385  func (s tsoEvents) Len() int           { return len(s) }
   386  func (s tsoEvents) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
   387  func (s tsoEvents) Less(i, j int) bool { return s[i].Tso < s[j].Tso }
   388  
   389  // TODO: remove porcupine dependence.
   390  func generateTsoEvents(events []porcupine.Event) tsoEvents {
   391  	tEvents := make(tsoEvents, 0, len(events))
   392  
   393  	mapEvents := make(map[uint]porcupine.Event, len(events))
   394  	for _, event := range events {
   395  		if event.Kind == porcupine.CallEvent {
   396  			mapEvents[event.Id] = event
   397  			continue
   398  		}
   399  
   400  		// Handle Return Event
   401  		// Find the corresponding Call Event
   402  		callEvent, ok := mapEvents[event.Id]
   403  		if !ok {
   404  			continue
   405  		}
   406  		delete(mapEvents, event.Id)
   407  
   408  		request := callEvent.Value.(bankRequest)
   409  		response := event.Value.(bankResponse)
   410  
   411  		if response.Tso == 0 {
   412  			// We don't care operation which has no TSO.
   413  			continue
   414  		}
   415  
   416  		tEvent := tsoEvent{
   417  			Tso:     response.Tso,
   418  			Op:      request.Op,
   419  			Unknown: response.Unknown,
   420  		}
   421  		if request.Op == 0 {
   422  			tEvent.Balances = response.Balances
   423  		} else {
   424  			tEvent.From = request.From
   425  			tEvent.To = request.To
   426  			tEvent.Amount = request.Amount
   427  			tEvent.FromBalance = response.FromBalance
   428  			tEvent.ToBalance = response.ToBalance
   429  		}
   430  
   431  		tEvents = append(tEvents, &tEvent)
   432  	}
   433  	sort.Sort(tEvents)
   434  	return tEvents
   435  }
   436  
   437  // mergeTransferEvents checks whether e can be merged into the events.
   438  // We may meet following cases for one account:
   439  // Assume last event starts at T1, the checking event starts at T2.
   440  // 1:
   441  // 	T1: [1000] -> [900], Unknown
   442  //	T2: [900] -> [800], Unknown?
   443  // Here T2 reads 900, so we can ensure T1 is successful no matter T1 is unknown or not.
   444  // We can set T1 to OK. After T1 is set to OK, we must check T1 to its previous events.
   445  // 2:
   446  //	T1: [1000] -> [900], OK
   447  //	T2: [1000] -> [800], Unknown
   448  // Here T1 is successful, but T2 is unknown, it is fine now.
   449  // 3:
   450  // 	T1: [1000] -> [900], Ok
   451  //	T2: [1000] -> [800], Ok
   452  // Invalid, because we use SSI here, even T2 can read 1000, it can't change it because
   453  // it must conflict with T1.
   454  // 4:
   455  // 	T1: [1000] -> [900], Unknown?
   456  //	T2: [800] -> [700], Unknown?
   457  // Invalid, T2 reads a stale value.
   458  func mergeTransferEvents(index int, events tsoEvents, e *tsoEvent) (tsoEvents, error) {
   459  	curBalance, _ := e.GetBalances(index)
   460  
   461  	if !checkBalance(index, events, curBalance) {
   462  		return nil, fmt.Errorf("%d %v invalid event %s", index, events, e)
   463  	}
   464  
   465  	events = append(events, e)
   466  
   467  	// Get the last successful event e2
   468  	lastIdx, err := checkTransferEvents(index, events)
   469  	if err != nil {
   470  		return nil, err
   471  	}
   472  
   473  	// clear all events before the successful event
   474  	return events[lastIdx:], nil
   475  }
   476  
   477  // For all the successful transfer events, we must form a transfer chain like
   478  // T1 [1000] -> [900]
   479  // T2 [900] -> [800]
   480  // T3 [800] -> [700]
   481  // The function will return the last successful event index, if no found, return 0
   482  func checkTransferEvents(index int, events tsoEvents) (int, error) {
   483  	var (
   484  		lastEvent *tsoEvent
   485  		lastIndex int
   486  	)
   487  	for i, e := range events {
   488  		if e.Unknown {
   489  			continue
   490  		}
   491  
   492  		if lastEvent != nil {
   493  			_, next := lastEvent.GetBalances(index)
   494  			cur, _ := e.GetBalances(index)
   495  			if next != cur {
   496  				return 0, fmt.Errorf("invalid events from %s to %s", lastEvent, e)
   497  			}
   498  		}
   499  
   500  		lastIndex = i
   501  		lastEvent = e
   502  	}
   503  
   504  	return lastIndex, nil
   505  }
   506  
   507  func checkBalance(index int, events tsoEvents, curBalance int64) bool {
   508  	if len(events) == 0 {
   509  		return curBalance == initBalance
   510  	}
   511  
   512  	for i := len(events) - 1; i >= 0; i-- {
   513  		lastEvent := events[i]
   514  		cur, next := lastEvent.GetBalances(index)
   515  		if next == curBalance {
   516  			// We read the next balance of the last event, which means the last transfer is
   517  			// successful
   518  			lastEvent.Unknown = false
   519  			return true
   520  		}
   521  
   522  		if cur == curBalance {
   523  			// Oh, we read the same balance with the last event
   524  			return true
   525  		}
   526  	}
   527  
   528  	return false
   529  }
   530  
   531  // verifyReadEvent verifies the read event.
   532  func verifyReadEvent(possibleEvents []tsoEvents, e *tsoEvent) bool {
   533  	if e.Unknown {
   534  		return true
   535  	}
   536  
   537  	sum := int64(0)
   538  	for i, balance := range e.Balances {
   539  		sum += balance
   540  
   541  		if !checkBalance(i, possibleEvents[i], balance) {
   542  			log.Printf("invalid event %s, balance mismatch", e)
   543  			return false
   544  		}
   545  	}
   546  
   547  	if sum != int64(len(e.Balances))*initBalance {
   548  		log.Printf("invalid event %s, sum corruption", e)
   549  		return false
   550  	}
   551  
   552  	return true
   553  }
   554  
   555  func verifyTsoEvents(events tsoEvents) bool {
   556  	possibleEvents := make([]tsoEvents, accountNum)
   557  
   558  	var err error
   559  	for _, event := range events {
   560  		if event.Op == 0 {
   561  			if !verifyReadEvent(possibleEvents, event) {
   562  				return false
   563  			}
   564  		}
   565  
   566  		if event.Op == 1 {
   567  			from := event.From
   568  			possibleEvents[from], err = mergeTransferEvents(from, possibleEvents[from], event)
   569  			if err != nil {
   570  				log.Print(err.Error())
   571  				return false
   572  			}
   573  
   574  			to := event.To
   575  			possibleEvents[to], err = mergeTransferEvents(to, possibleEvents[to], event)
   576  			if err != nil {
   577  				log.Print(err.Error())
   578  				return false
   579  			}
   580  		}
   581  	}
   582  
   583  	return true
   584  }
   585  
   586  // bankTsoChecker uses a direct way because we know every timestamp of the transaction.
   587  // So we can order all transactions with timetamp and replay them.
   588  type bankTsoChecker struct {
   589  }
   590  
   591  // Check checks the bank history.
   592  func (bankTsoChecker) Check(_ core.Model, ops []core.Operation) (bool, error) {
   593  	events, err := pchecker.ConvertOperationsToEvents(ops)
   594  	if err != nil {
   595  		return false, err
   596  	}
   597  	tEvents := generateTsoEvents(events)
   598  	return verifyTsoEvents(tEvents), nil
   599  }
   600  
   601  // Name returns the name of the verifier.
   602  func (bankTsoChecker) Name() string {
   603  	return "tidb_bank_tso_checker"
   604  }
   605  
   606  // BankTsoChecker checks the bank history with the help of tso.
   607  func BankTsoChecker() core.Checker {
   608  	return bankTsoChecker{}
   609  }