github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/conn_executor_savepoints_test.go (about) 1 // Copyright 2020 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package sql_test 12 13 import ( 14 "context" 15 "fmt" 16 "strings" 17 "testing" 18 19 "github.com/cockroachdb/cockroach/pkg/base" 20 "github.com/cockroachdb/cockroach/pkg/sql" 21 "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" 22 "github.com/cockroachdb/cockroach/pkg/util" 23 "github.com/cockroachdb/cockroach/pkg/util/leaktest" 24 "github.com/cockroachdb/datadriven" 25 ) 26 27 func TestSavepoints(t *testing.T) { 28 defer leaktest.AfterTest(t)() 29 30 ctx := context.Background() 31 datadriven.Walk(t, "testdata/savepoints", func(t *testing.T, path string) { 32 33 params := base.TestServerArgs{} 34 s, sqlConn, _ := serverutils.StartServer(t, params) 35 defer s.Stopper().Stop(ctx) 36 37 if _, err := sqlConn.Exec("CREATE TABLE progress(n INT, marker BOOL)"); err != nil { 38 t.Fatal(err) 39 } 40 41 datadriven.RunTest(t, path, func(t *testing.T, td *datadriven.TestData) string { 42 switch td.Cmd { 43 case "sql": 44 // Implicitly abort any previously-ongoing txn. 45 _, _ = sqlConn.Exec("ABORT") 46 // Prepare for the next test. 47 if _, err := sqlConn.Exec("DELETE FROM progress"); err != nil { 48 td.Fatalf(t, "cleaning up: %v", err) 49 } 50 51 // Prepare a buffer to accumulate the results. 52 var buf strings.Builder 53 54 // We're going to execute the input line-by-line. 55 stmts := strings.Split(td.Input, "\n") 56 57 // progressBar is going to show the cancellation of writes 58 // during rollbacks. 59 progressBar := make([]byte, len(stmts)) 60 erase := func(status string) { 61 char := byte('.') 62 if !isOpenTxn(status) { 63 char = 'X' 64 } 65 for i := range progressBar { 66 progressBar[i] = char 67 } 68 } 69 70 // stepNum is the index of the current statement 71 // in the input. 72 var stepNum int 73 74 // updateProgress loads the current set of writes 75 // into the progress bar. 76 updateProgress := func() { 77 rows, err := sqlConn.Query("SELECT n FROM progress") 78 if err != nil { 79 t.Logf("%d: reading progress: %v", stepNum, err) 80 // It's OK if we can't read this. 81 return 82 } 83 defer rows.Close() 84 for rows.Next() { 85 var n int 86 if err := rows.Scan(&n); err != nil { 87 td.Fatalf(t, "%d: unexpected error while reading progress: %v", stepNum, err) 88 } 89 if n < 1 || n > len(progressBar) { 90 td.Fatalf(t, "%d: unexpected stepnum in progress table: %d", stepNum, n) 91 } 92 progressBar[n-1] = '#' 93 } 94 } 95 96 // getTxnStatus retrieves the current txn state. 97 // This is guaranteed to always succeed because SHOW TRANSACTION STATUS 98 // is an observer statement. 99 getTxnStatus := func() string { 100 row := sqlConn.QueryRow("SHOW TRANSACTION STATUS") 101 var status string 102 if err := row.Scan(&status); err != nil { 103 td.Fatalf(t, "%d: unable to retrieve txn status: %v", stepNum, err) 104 } 105 return status 106 } 107 // showSavepointStatus is like getTxnStatus but retrieves the 108 // savepoint stack. 109 showSavepointStatus := func() { 110 rows, err := sqlConn.Query("SHOW SAVEPOINT STATUS") 111 if err != nil { 112 td.Fatalf(t, "%d: unable to retrieve savepoint status: %v", stepNum, err) 113 } 114 defer rows.Close() 115 116 comma := "" 117 hasSavepoints := false 118 for rows.Next() { 119 var name string 120 var isRestart bool 121 if err := rows.Scan(&name, &isRestart); err != nil { 122 td.Fatalf(t, "%d: unexpected error while reading savepoints: %v", stepNum, err) 123 } 124 if isRestart { 125 name += "(r)" 126 } 127 buf.WriteString(comma) 128 buf.WriteString(name) 129 hasSavepoints = true 130 comma = ">" 131 } 132 if !hasSavepoints { 133 buf.WriteString("(none)") 134 } 135 } 136 // report shows the progress of execution so far after 137 // each statement executed. 138 report := func(beforeStatus, afterStatus string) { 139 erase(afterStatus) 140 if isOpenTxn(afterStatus) { 141 updateProgress() 142 } 143 fmt.Fprintf(&buf, "-- %-11s -> %-11s %s ", beforeStatus, afterStatus, string(progressBar)) 144 buf.WriteByte(' ') 145 showSavepointStatus() 146 buf.WriteByte('\n') 147 } 148 149 // The actual execution of the statements starts here. 150 151 beforeStatus := getTxnStatus() 152 for i, stmt := range stmts { 153 stepNum = i + 1 154 // Before each statement, mark the progress so far with 155 // a KV write. 156 if isOpenTxn(beforeStatus) { 157 _, err := sqlConn.Exec("INSERT INTO progress(n, marker) VALUES ($1, true)", stepNum) 158 if err != nil { 159 td.Fatalf(t, "%d: before-stmt: %v", stepNum, err) 160 } 161 } 162 163 // Run the statement and report errors/results. 164 fmt.Fprintf(&buf, "%d: %s -- ", stepNum, stmt) 165 execRes, err := sqlConn.Exec(stmt) 166 if err != nil { 167 fmt.Fprintf(&buf, "%v\n", err) 168 } else { 169 nRows, err := execRes.RowsAffected() 170 if err != nil { 171 fmt.Fprintf(&buf, "error retrieving rows: %v\n", err) 172 } else { 173 fmt.Fprintf(&buf, "%d row%s\n", nRows, util.Pluralize(nRows)) 174 } 175 } 176 177 // Report progress on the next line 178 afterStatus := getTxnStatus() 179 report(beforeStatus, afterStatus) 180 beforeStatus = afterStatus 181 } 182 183 return buf.String() 184 185 default: 186 td.Fatalf(t, "unknown directive: %s", td.Cmd) 187 } 188 return "" 189 }) 190 }) 191 } 192 193 func isOpenTxn(status string) bool { 194 return status == sql.OpenStateStr || status == sql.NoTxnStateStr 195 }