github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/logictest/parallel_test.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  //
    11  // The parallel_test adds an orchestration layer on top of the logic_test code
    12  // with the capability of running multiple test data files in parallel.
    13  //
    14  // Each test lives in a separate subdir under testdata/paralleltest. Each test
    15  // dir contains a "test.yaml" file along with a set of files in logic test
    16  // format. The test.yaml file corresponds to the parTestSpec structure below.
    17  
    18  package logictest
    19  
    20  import (
    21  	"context"
    22  	gosql "database/sql"
    23  	"flag"
    24  	"fmt"
    25  	"io/ioutil"
    26  	"net/url"
    27  	"path/filepath"
    28  	"strings"
    29  	"testing"
    30  
    31  	"github.com/cockroachdb/cockroach/pkg/base"
    32  	"github.com/cockroachdb/cockroach/pkg/config/zonepb"
    33  	"github.com/cockroachdb/cockroach/pkg/keys"
    34  	"github.com/cockroachdb/cockroach/pkg/security"
    35  	"github.com/cockroachdb/cockroach/pkg/sql"
    36  	"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
    37  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    38  	"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
    39  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    40  	"github.com/cockroachdb/cockroach/pkg/util/log"
    41  	"github.com/cockroachdb/cockroach/pkg/util/protoutil"
    42  	"github.com/cockroachdb/cockroach/pkg/util/randutil"
    43  	"github.com/cockroachdb/cockroach/pkg/util/stop"
    44  	"github.com/gogo/protobuf/proto"
    45  	yaml "gopkg.in/yaml.v2"
    46  )
    47  
    48  var (
    49  	paralleltestdata = flag.String("partestdata", "testdata/parallel_test/[^.]*", "test data glob")
    50  )
    51  
    52  type parallelTest struct {
    53  	*testing.T
    54  	ctx     context.Context
    55  	cluster serverutils.TestClusterInterface
    56  	clients [][]*gosql.DB
    57  }
    58  
    59  func (t *parallelTest) close() {
    60  	t.clients = nil
    61  	if t.cluster != nil {
    62  		t.cluster.Stopper().Stop(context.Background())
    63  	}
    64  }
    65  
    66  func (t *parallelTest) processTestFile(path string, nodeIdx int, db *gosql.DB, ch chan bool) {
    67  	if ch != nil {
    68  		defer func() { ch <- true }()
    69  	}
    70  
    71  	// Set up a dummy logicTest structure to use that code.
    72  	rng, _ := randutil.NewPseudoRand()
    73  	l := &logicTest{
    74  		rootT:   t.T,
    75  		cluster: t.cluster,
    76  		nodeIdx: nodeIdx,
    77  		db:      db,
    78  		user:    security.RootUser,
    79  		verbose: testing.Verbose() || log.V(1),
    80  		rng:     rng,
    81  	}
    82  	if err := l.processTestFile(path, testClusterConfig{}); err != nil {
    83  		log.Errorf(context.Background(), "error processing %s: %s", path, err)
    84  		t.Error(err)
    85  	}
    86  }
    87  
    88  func (t *parallelTest) getClient(nodeIdx, clientIdx int) *gosql.DB {
    89  	for len(t.clients[nodeIdx]) <= clientIdx {
    90  		// Add a client.
    91  		pgURL, cleanupFunc := sqlutils.PGUrl(t.T,
    92  			t.cluster.Server(nodeIdx).ServingSQLAddr(),
    93  			"TestParallel",
    94  			url.User(security.RootUser))
    95  		db, err := gosql.Open("postgres", pgURL.String())
    96  		if err != nil {
    97  			t.Fatal(err)
    98  		}
    99  		sqlutils.MakeSQLRunner(db).Exec(t, "SET DATABASE = test")
   100  		t.cluster.Stopper().AddCloser(
   101  			stop.CloserFn(func() {
   102  				_ = db.Close()
   103  				cleanupFunc()
   104  			}))
   105  		t.clients[nodeIdx] = append(t.clients[nodeIdx], db)
   106  	}
   107  	return t.clients[nodeIdx][clientIdx]
   108  }
   109  
   110  type parTestRunEntry struct {
   111  	Node int    `yaml:"node"`
   112  	File string `yaml:"file"`
   113  }
   114  
   115  type parTestSpec struct {
   116  	SkipReason string `yaml:"skip_reason"`
   117  
   118  	// ClusterSize is the number of nodes in the cluster. If 0, single node.
   119  	ClusterSize int `yaml:"cluster_size"`
   120  
   121  	RangeSplitSize int `yaml:"range_split_size"`
   122  
   123  	// Run contains a set of "run lists". The files in a runlist are run in
   124  	// parallel and they complete before the next run list start.
   125  	Run [][]parTestRunEntry `yaml:"run"`
   126  }
   127  
   128  func (t *parallelTest) run(dir string) {
   129  	// Process the spec file.
   130  	mainFile := filepath.Join(dir, "test.yaml")
   131  	yamlData, err := ioutil.ReadFile(mainFile)
   132  	if err != nil {
   133  		t.Fatalf("%s: %s", mainFile, err)
   134  	}
   135  	var spec parTestSpec
   136  	if err := yaml.UnmarshalStrict(yamlData, &spec); err != nil {
   137  		t.Fatalf("%s: %s", mainFile, err)
   138  	}
   139  
   140  	if spec.SkipReason != "" {
   141  		t.Skip(spec.SkipReason)
   142  	}
   143  
   144  	log.Infof(t.ctx, "Running test %s", dir)
   145  	if testing.Verbose() || log.V(1) {
   146  		log.Infof(t.ctx, "spec: %+v", spec)
   147  	}
   148  
   149  	t.setup(&spec)
   150  	defer t.close()
   151  
   152  	for runListIdx, runList := range spec.Run {
   153  		if testing.Verbose() || log.V(1) {
   154  			var descr []string
   155  			for _, re := range runList {
   156  				descr = append(descr, fmt.Sprintf("%d:%s", re.Node, re.File))
   157  			}
   158  			log.Infof(t.ctx, "%s: run list %d: %s", mainFile, runListIdx,
   159  				strings.Join(descr, ", "))
   160  		}
   161  		// Store the number of clients used so far (per node).
   162  		numClients := make([]int, spec.ClusterSize)
   163  		ch := make(chan bool)
   164  		for _, re := range runList {
   165  			client := t.getClient(re.Node, numClients[re.Node])
   166  			numClients[re.Node]++
   167  			go t.processTestFile(filepath.Join(dir, re.File), re.Node, client, ch)
   168  		}
   169  		// Wait for all clients to complete.
   170  		for range runList {
   171  			<-ch
   172  		}
   173  	}
   174  }
   175  
   176  func (t *parallelTest) setup(spec *parTestSpec) {
   177  	if spec.ClusterSize == 0 {
   178  		spec.ClusterSize = 1
   179  	}
   180  
   181  	if testing.Verbose() || log.V(1) {
   182  		log.Infof(t.ctx, "Cluster Size: %d", spec.ClusterSize)
   183  	}
   184  
   185  	t.cluster = serverutils.StartTestCluster(t, spec.ClusterSize, base.TestClusterArgs{})
   186  
   187  	for i := 0; i < t.cluster.NumServers(); i++ {
   188  		server := t.cluster.Server(i)
   189  		mode := sessiondata.DistSQLOff
   190  		st := server.ClusterSettings()
   191  		st.Manual.Store(true)
   192  		sql.DistSQLClusterExecMode.Override(&st.SV, int64(mode))
   193  	}
   194  
   195  	t.clients = make([][]*gosql.DB, spec.ClusterSize)
   196  	for i := range t.clients {
   197  		t.clients[i] = append(t.clients[i], t.cluster.ServerConn(i))
   198  	}
   199  	r0 := sqlutils.MakeSQLRunner(t.clients[0][0])
   200  
   201  	if spec.RangeSplitSize != 0 {
   202  		if testing.Verbose() || log.V(1) {
   203  			log.Infof(t.ctx, "Setting range split size: %d", spec.RangeSplitSize)
   204  		}
   205  		zoneCfg := zonepb.DefaultZoneConfig()
   206  		zoneCfg.RangeMaxBytes = proto.Int64(int64(spec.RangeSplitSize))
   207  		zoneCfg.RangeMinBytes = proto.Int64(*zoneCfg.RangeMaxBytes / 2)
   208  		buf, err := protoutil.Marshal(&zoneCfg)
   209  		if err != nil {
   210  			t.Fatal(err)
   211  		}
   212  		objID := keys.RootNamespaceID
   213  		r0.Exec(t, `UPDATE system.zones SET config = $2 WHERE id = $1`, objID, buf)
   214  	}
   215  
   216  	if testing.Verbose() || log.V(1) {
   217  		log.Infof(t.ctx, "Creating database")
   218  	}
   219  
   220  	r0.Exec(t, "CREATE DATABASE test")
   221  	for i := range t.clients {
   222  		sqlutils.MakeSQLRunner(t.clients[i][0]).Exec(t, "SET DATABASE = test")
   223  	}
   224  
   225  	if testing.Verbose() || log.V(1) {
   226  		log.Infof(t.ctx, "Test setup done")
   227  	}
   228  }
   229  
   230  func TestParallel(t *testing.T) {
   231  	defer leaktest.AfterTest(t)()
   232  
   233  	glob := *paralleltestdata
   234  	paths, err := filepath.Glob(glob)
   235  	if err != nil {
   236  		t.Fatal(err)
   237  	}
   238  	if len(paths) == 0 {
   239  		t.Fatalf("No testfiles found (glob: %s)", glob)
   240  	}
   241  	total := 0
   242  	failed := 0
   243  	for _, path := range paths {
   244  		t.Run(filepath.Base(path), func(t *testing.T) {
   245  			pt := parallelTest{T: t, ctx: context.Background()}
   246  			pt.run(path)
   247  			total++
   248  			if t.Failed() {
   249  				failed++
   250  			}
   251  		})
   252  	}
   253  	if failed == 0 {
   254  		log.Infof(context.Background(), "%d parallel tests passed", total)
   255  	} else {
   256  		log.Infof(context.Background(), "%d out of %d parallel tests failed", failed, total)
   257  	}
   258  }