github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/run_control_test.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sql_test
    12  
    13  import (
    14  	"context"
    15  	gosql "database/sql"
    16  	gosqldriver "database/sql/driver"
    17  	"fmt"
    18  	"math/rand"
    19  	"strings"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/cockroachdb/cockroach/pkg/base"
    24  	"github.com/cockroachdb/cockroach/pkg/sql"
    25  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    26  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    27  	"github.com/cockroachdb/cockroach/pkg/testutils"
    28  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    29  	"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
    30  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    31  	"github.com/cockroachdb/cockroach/pkg/util/retry"
    32  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    33  	"github.com/cockroachdb/errors"
    34  	"github.com/lib/pq"
    35  )
    36  
    37  func TestCancelSelectQuery(t *testing.T) {
    38  	defer leaktest.AfterTest(t)()
    39  
    40  	const queryToCancel = "SELECT * FROM generate_series(1,20000000)"
    41  
    42  	var conn1 *gosql.DB
    43  	var conn2 *gosql.DB
    44  
    45  	tc := serverutils.StartTestCluster(t, 2, /* numNodes */
    46  		base.TestClusterArgs{
    47  			ReplicationMode: base.ReplicationManual,
    48  		})
    49  	defer tc.Stopper().Stop(context.Background())
    50  
    51  	conn1 = tc.ServerConn(0)
    52  	conn2 = tc.ServerConn(1)
    53  
    54  	sem := make(chan struct{})
    55  	errChan := make(chan error)
    56  
    57  	go func() {
    58  		sem <- struct{}{}
    59  		rows, err := conn2.Query(queryToCancel)
    60  		if err != nil {
    61  			errChan <- err
    62  			return
    63  		}
    64  		for rows.Next() {
    65  		}
    66  		if err = rows.Err(); err != nil {
    67  			errChan <- err
    68  		}
    69  	}()
    70  
    71  	<-sem
    72  	time.Sleep(time.Second * 2)
    73  
    74  	const cancelQuery = "CANCEL QUERIES SELECT query_id FROM [SHOW CLUSTER QUERIES] WHERE node_id = 2"
    75  
    76  	if _, err := conn1.Exec(cancelQuery); err != nil {
    77  		t.Fatal(err)
    78  	}
    79  
    80  	select {
    81  	case err := <-errChan:
    82  		if !isClientsideQueryCanceledErr(err) {
    83  			t.Fatal(err)
    84  		}
    85  	case <-time.After(time.Second * 5):
    86  		t.Fatal("no error received from query supposed to be canceled")
    87  	}
    88  
    89  }
    90  
    91  // TestCancelDistSQLQuery runs a distsql query and cancels it randomly at
    92  // various points of execution.
    93  func TestCancelDistSQLQuery(t *testing.T) {
    94  	defer leaktest.AfterTest(t)()
    95  	const queryToCancel = "SELECT * FROM nums ORDER BY num"
    96  	cancelQuery := fmt.Sprintf("CANCEL QUERIES SELECT query_id FROM [SHOW CLUSTER QUERIES] WHERE query = '%s'", queryToCancel)
    97  
    98  	// conn1 is used for the query above. conn2 is solely for the CANCEL statement.
    99  	var conn1 *gosql.DB
   100  	var conn2 *gosql.DB
   101  
   102  	var queryLatency *time.Duration
   103  	sem := make(chan struct{}, 1)
   104  	rng := rand.New(rand.NewSource(timeutil.Now().UnixNano()))
   105  	tc := serverutils.StartTestCluster(t, 2, /* numNodes */
   106  		base.TestClusterArgs{
   107  			ReplicationMode: base.ReplicationManual,
   108  			ServerArgs: base.TestServerArgs{
   109  				UseDatabase: "test",
   110  				Knobs: base.TestingKnobs{
   111  					SQLExecutor: &sql.ExecutorTestingKnobs{
   112  						BeforeExecute: func(_ context.Context, stmt string) {
   113  							if strings.HasPrefix(stmt, queryToCancel) {
   114  								// Wait for the race to start.
   115  								<-sem
   116  							} else if strings.HasPrefix(stmt, cancelQuery) {
   117  								// Signal to start the race.
   118  								sleepTime := time.Duration(rng.Int63n(int64(*queryLatency)))
   119  								sem <- struct{}{}
   120  								time.Sleep(sleepTime)
   121  							}
   122  						},
   123  					},
   124  				},
   125  			},
   126  		})
   127  	defer tc.Stopper().Stop(context.Background())
   128  
   129  	conn1 = tc.ServerConn(0)
   130  	conn2 = tc.ServerConn(1)
   131  
   132  	sqlutils.CreateTable(t, conn1, "nums", "num INT", 0, nil)
   133  	if _, err := conn1.Exec("INSERT INTO nums SELECT generate_series(1,100)"); err != nil {
   134  		t.Fatal(err)
   135  	}
   136  
   137  	if _, err := conn1.Exec("ALTER TABLE nums SPLIT AT VALUES (50)"); err != nil {
   138  		t.Fatal(err)
   139  	}
   140  
   141  	// Make the second node the leaseholder for the first range to distribute the
   142  	// query. This may have to retry if the second store's descriptor has not yet
   143  	// propagated to the first store's StorePool.
   144  	testutils.SucceedsSoon(t, func() error {
   145  		_, err := conn1.Exec(fmt.Sprintf(
   146  			"ALTER TABLE nums EXPERIMENTAL_RELOCATE VALUES (ARRAY[%d], 1)",
   147  			tc.Server(1).GetFirstStoreID()))
   148  		return err
   149  	})
   150  
   151  	// Run queryToCancel to be able to get an estimate of how long it should
   152  	// take. The goroutine in charge of cancellation will sleep a random
   153  	// amount of time within this bound. Signal sem so that it can run
   154  	// unhindered.
   155  	sem <- struct{}{}
   156  	start := timeutil.Now()
   157  	if _, err := conn1.Exec(queryToCancel); err != nil {
   158  		t.Fatal(err)
   159  	}
   160  	execTime := timeutil.Since(start)
   161  	queryLatency = &execTime
   162  
   163  	errChan := make(chan error)
   164  	go func() {
   165  		_, err := conn1.Exec(queryToCancel)
   166  		errChan <- err
   167  	}()
   168  	_, err := conn2.Exec(cancelQuery)
   169  	if err != nil && !testutils.IsError(err, "query ID") {
   170  		t.Fatal(err)
   171  	}
   172  
   173  	err = <-errChan
   174  	if err == nil {
   175  		// A successful cancellation does not imply that the query was canceled.
   176  		return
   177  	}
   178  	if !isClientsideQueryCanceledErr(err) {
   179  		t.Fatalf("expected error with specific error code, got: %s", err)
   180  	}
   181  }
   182  
   183  func testCancelSession(t *testing.T, hasActiveSession bool) {
   184  	ctx := context.Background()
   185  
   186  	numNodes := 2
   187  	tc := serverutils.StartTestCluster(t, numNodes,
   188  		base.TestClusterArgs{
   189  			ReplicationMode: base.ReplicationManual,
   190  		})
   191  	defer tc.Stopper().Stop(ctx)
   192  
   193  	// Since we're testing session cancellation, use single connections instead of
   194  	// connection pools.
   195  	var err error
   196  	conn1, err := tc.ServerConn(0).Conn(ctx)
   197  	if err != nil {
   198  		t.Fatal(err)
   199  	}
   200  	conn2, err := tc.ServerConn(1).Conn(ctx)
   201  	if err != nil {
   202  		t.Fatal(err)
   203  	}
   204  
   205  	// Wait for node 2 to know about both sessions.
   206  	if err := retry.ForDuration(10*time.Second, func() error {
   207  		rows, err := conn2.QueryContext(ctx, "SELECT * FROM [SHOW CLUSTER SESSIONS] WHERE application_name NOT LIKE '$%'")
   208  		if err != nil {
   209  			return err
   210  		}
   211  
   212  		m, err := sqlutils.RowsToStrMatrix(rows)
   213  		if err != nil {
   214  			return err
   215  		}
   216  
   217  		if numRows := len(m); numRows != numNodes {
   218  			return fmt.Errorf("expected %d sessions but found %d\n%s",
   219  				numNodes, numRows, sqlutils.MatrixToStr(m))
   220  		}
   221  
   222  		return nil
   223  	}); err != nil {
   224  		t.Fatal(err)
   225  	}
   226  
   227  	// Get node 1's session ID now, so that we don't need to serialize the session
   228  	// later and race with the active query's type-checking and name resolution.
   229  	rows, err := conn1.QueryContext(
   230  		ctx, "SELECT session_id FROM [SHOW LOCAL SESSIONS]",
   231  	)
   232  	if err != nil {
   233  		t.Fatal(err)
   234  	}
   235  
   236  	var id string
   237  	if !rows.Next() {
   238  		t.Fatal("no sessions on node 1")
   239  	}
   240  	if err := rows.Scan(&id); err != nil {
   241  		t.Fatal(err)
   242  	}
   243  	if err := rows.Close(); err != nil {
   244  		t.Fatal(err)
   245  	}
   246  
   247  	// Now that we've obtained the session ID, query planning won't race with
   248  	// session serialization, so we can kick it off now.
   249  	errChan := make(chan error, 1)
   250  	if hasActiveSession {
   251  		go func() {
   252  			var err error
   253  			_, err = conn1.ExecContext(ctx, "SELECT pg_sleep(1000000)")
   254  			errChan <- err
   255  		}()
   256  	}
   257  
   258  	// Cancel the session on node 1.
   259  	if _, err = conn2.ExecContext(ctx, fmt.Sprintf("CANCEL SESSION '%s'", id)); err != nil {
   260  		t.Fatal(err)
   261  	}
   262  
   263  	if hasActiveSession {
   264  		// Verify that the query was canceled because the session closed.
   265  		err = <-errChan
   266  	} else {
   267  		// Verify that the connection is closed.
   268  		_, err = conn1.ExecContext(ctx, "SELECT 1")
   269  	}
   270  
   271  	if !errors.Is(err, gosqldriver.ErrBadConn) {
   272  		t.Fatalf("session not canceled; actual error: %s", err)
   273  	}
   274  }
   275  
   276  func TestCancelMultipleSessions(t *testing.T) {
   277  	defer leaktest.AfterTest(t)()
   278  	ctx := context.Background()
   279  
   280  	tc := serverutils.StartTestCluster(t, 2, /* numNodes */
   281  		base.TestClusterArgs{
   282  			ReplicationMode: base.ReplicationManual,
   283  		})
   284  	defer tc.Stopper().Stop(ctx)
   285  
   286  	// Open two connections on node 1.
   287  	var conns [2]*gosql.Conn
   288  	for i := 0; i < 2; i++ {
   289  		var err error
   290  		if conns[i], err = tc.ServerConn(0).Conn(ctx); err != nil {
   291  			t.Fatal(err)
   292  		}
   293  		if _, err := conns[i].ExecContext(ctx, "SET application_name = 'killme'"); err != nil {
   294  			t.Fatal(err)
   295  		}
   296  	}
   297  	// Open a control connection on node 2.
   298  	ctlconn, err := tc.ServerConn(1).Conn(ctx)
   299  	if err != nil {
   300  		t.Fatal(err)
   301  	}
   302  
   303  	// Cancel the sessions on node 1.
   304  	if _, err = ctlconn.ExecContext(ctx,
   305  		`CANCEL SESSIONS SELECT session_id FROM [SHOW CLUSTER SESSIONS] WHERE application_name = 'killme'`,
   306  	); err != nil {
   307  		t.Fatal(err)
   308  	}
   309  
   310  	// Verify that the connections on node 1 are closed.
   311  	for i := 0; i < 2; i++ {
   312  		_, err := conns[i].ExecContext(ctx, "SELECT 1")
   313  		if !errors.Is(err, gosqldriver.ErrBadConn) {
   314  			t.Fatalf("session %d not canceled; actual error: %s", i, err)
   315  		}
   316  	}
   317  }
   318  
   319  func TestIdleCancelSession(t *testing.T) {
   320  	defer leaktest.AfterTest(t)()
   321  	testCancelSession(t, false /* hasActiveSession */)
   322  }
   323  
   324  func TestActiveCancelSession(t *testing.T) {
   325  	defer leaktest.AfterTest(t)()
   326  	testCancelSession(t, true /* hasActiveSession */)
   327  }
   328  
   329  func TestCancelIfExists(t *testing.T) {
   330  	defer leaktest.AfterTest(t)()
   331  
   332  	tc := serverutils.StartTestCluster(t, 1, /* numNodes */
   333  		base.TestClusterArgs{
   334  			ReplicationMode: base.ReplicationManual,
   335  		})
   336  	defer tc.Stopper().Stop(context.Background())
   337  
   338  	conn := tc.ServerConn(0)
   339  
   340  	var err error
   341  
   342  	// Try to cancel a query that doesn't exist.
   343  	_, err = conn.Exec("CANCEL QUERY IF EXISTS '00000000000000000000000000000001'")
   344  	if err != nil {
   345  		t.Fatal(err)
   346  	}
   347  
   348  	// Try to cancel a session that doesn't exist.
   349  	_, err = conn.Exec("CANCEL SESSION IF EXISTS '00000000000000000000000000000001'")
   350  	if err != nil {
   351  		t.Fatal(err)
   352  	}
   353  }
   354  
   355  func isClientsideQueryCanceledErr(err error) bool {
   356  	if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) {
   357  		return pqErr.Code == pgcode.QueryCanceled
   358  	}
   359  	return pgerror.GetPGCode(err) == pgcode.QueryCanceled
   360  }