github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/conn_executor_savepoints_test.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sql_test
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  	"strings"
    17  	"testing"
    18  
    19  	"github.com/cockroachdb/cockroach/pkg/base"
    20  	"github.com/cockroachdb/cockroach/pkg/sql"
    21  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    22  	"github.com/cockroachdb/cockroach/pkg/util"
    23  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    24  	"github.com/cockroachdb/datadriven"
    25  )
    26  
    27  func TestSavepoints(t *testing.T) {
    28  	defer leaktest.AfterTest(t)()
    29  
    30  	ctx := context.Background()
    31  	datadriven.Walk(t, "testdata/savepoints", func(t *testing.T, path string) {
    32  
    33  		params := base.TestServerArgs{}
    34  		s, sqlConn, _ := serverutils.StartServer(t, params)
    35  		defer s.Stopper().Stop(ctx)
    36  
    37  		if _, err := sqlConn.Exec("CREATE TABLE progress(n INT, marker BOOL)"); err != nil {
    38  			t.Fatal(err)
    39  		}
    40  
    41  		datadriven.RunTest(t, path, func(t *testing.T, td *datadriven.TestData) string {
    42  			switch td.Cmd {
    43  			case "sql":
    44  				// Implicitly abort any previously-ongoing txn.
    45  				_, _ = sqlConn.Exec("ABORT")
    46  				// Prepare for the next test.
    47  				if _, err := sqlConn.Exec("DELETE FROM progress"); err != nil {
    48  					td.Fatalf(t, "cleaning up: %v", err)
    49  				}
    50  
    51  				// Prepare a buffer to accumulate the results.
    52  				var buf strings.Builder
    53  
    54  				// We're going to execute the input line-by-line.
    55  				stmts := strings.Split(td.Input, "\n")
    56  
    57  				// progressBar is going to show the cancellation of writes
    58  				// during rollbacks.
    59  				progressBar := make([]byte, len(stmts))
    60  				erase := func(status string) {
    61  					char := byte('.')
    62  					if !isOpenTxn(status) {
    63  						char = 'X'
    64  					}
    65  					for i := range progressBar {
    66  						progressBar[i] = char
    67  					}
    68  				}
    69  
    70  				// stepNum is the index of the current statement
    71  				// in the input.
    72  				var stepNum int
    73  
    74  				// updateProgress loads the current set of writes
    75  				// into the progress bar.
    76  				updateProgress := func() {
    77  					rows, err := sqlConn.Query("SELECT n FROM progress")
    78  					if err != nil {
    79  						t.Logf("%d: reading progress: %v", stepNum, err)
    80  						// It's OK if we can't read this.
    81  						return
    82  					}
    83  					defer rows.Close()
    84  					for rows.Next() {
    85  						var n int
    86  						if err := rows.Scan(&n); err != nil {
    87  							td.Fatalf(t, "%d: unexpected error while reading progress: %v", stepNum, err)
    88  						}
    89  						if n < 1 || n > len(progressBar) {
    90  							td.Fatalf(t, "%d: unexpected stepnum in progress table: %d", stepNum, n)
    91  						}
    92  						progressBar[n-1] = '#'
    93  					}
    94  				}
    95  
    96  				// getTxnStatus retrieves the current txn state.
    97  				// This is guaranteed to always succeed because SHOW TRANSACTION STATUS
    98  				// is an observer statement.
    99  				getTxnStatus := func() string {
   100  					row := sqlConn.QueryRow("SHOW TRANSACTION STATUS")
   101  					var status string
   102  					if err := row.Scan(&status); err != nil {
   103  						td.Fatalf(t, "%d: unable to retrieve txn status: %v", stepNum, err)
   104  					}
   105  					return status
   106  				}
   107  				// showSavepointStatus is like getTxnStatus but retrieves the
   108  				// savepoint stack.
   109  				showSavepointStatus := func() {
   110  					rows, err := sqlConn.Query("SHOW SAVEPOINT STATUS")
   111  					if err != nil {
   112  						td.Fatalf(t, "%d: unable to retrieve savepoint status: %v", stepNum, err)
   113  					}
   114  					defer rows.Close()
   115  
   116  					comma := ""
   117  					hasSavepoints := false
   118  					for rows.Next() {
   119  						var name string
   120  						var isRestart bool
   121  						if err := rows.Scan(&name, &isRestart); err != nil {
   122  							td.Fatalf(t, "%d: unexpected error while reading savepoints: %v", stepNum, err)
   123  						}
   124  						if isRestart {
   125  							name += "(r)"
   126  						}
   127  						buf.WriteString(comma)
   128  						buf.WriteString(name)
   129  						hasSavepoints = true
   130  						comma = ">"
   131  					}
   132  					if !hasSavepoints {
   133  						buf.WriteString("(none)")
   134  					}
   135  				}
   136  				// report shows the progress of execution so far after
   137  				// each statement executed.
   138  				report := func(beforeStatus, afterStatus string) {
   139  					erase(afterStatus)
   140  					if isOpenTxn(afterStatus) {
   141  						updateProgress()
   142  					}
   143  					fmt.Fprintf(&buf, "-- %-11s -> %-11s %s ", beforeStatus, afterStatus, string(progressBar))
   144  					buf.WriteByte(' ')
   145  					showSavepointStatus()
   146  					buf.WriteByte('\n')
   147  				}
   148  
   149  				// The actual execution of the statements starts here.
   150  
   151  				beforeStatus := getTxnStatus()
   152  				for i, stmt := range stmts {
   153  					stepNum = i + 1
   154  					// Before each statement, mark the progress so far with
   155  					// a KV write.
   156  					if isOpenTxn(beforeStatus) {
   157  						_, err := sqlConn.Exec("INSERT INTO progress(n, marker) VALUES ($1, true)", stepNum)
   158  						if err != nil {
   159  							td.Fatalf(t, "%d: before-stmt: %v", stepNum, err)
   160  						}
   161  					}
   162  
   163  					// Run the statement and report errors/results.
   164  					fmt.Fprintf(&buf, "%d: %s -- ", stepNum, stmt)
   165  					execRes, err := sqlConn.Exec(stmt)
   166  					if err != nil {
   167  						fmt.Fprintf(&buf, "%v\n", err)
   168  					} else {
   169  						nRows, err := execRes.RowsAffected()
   170  						if err != nil {
   171  							fmt.Fprintf(&buf, "error retrieving rows: %v\n", err)
   172  						} else {
   173  							fmt.Fprintf(&buf, "%d row%s\n", nRows, util.Pluralize(nRows))
   174  						}
   175  					}
   176  
   177  					// Report progress on the next line
   178  					afterStatus := getTxnStatus()
   179  					report(beforeStatus, afterStatus)
   180  					beforeStatus = afterStatus
   181  				}
   182  
   183  				return buf.String()
   184  
   185  			default:
   186  				td.Fatalf(t, "unknown directive: %s", td.Cmd)
   187  			}
   188  			return ""
   189  		})
   190  	})
   191  }
   192  
   193  func isOpenTxn(status string) bool {
   194  	return status == sql.OpenStateStr || status == sql.NoTxnStateStr
   195  }