github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/props/func_dep_rand_test.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package props
    12  
    13  // This file is the home of TestFuncDepOpsRandom, a randomized FD tester.
    14  
    15  import (
    16  	"bytes"
    17  	"fmt"
    18  	"math/rand"
    19  	"os"
    20  	"strings"
    21  	"testing"
    22  	"text/tabwriter"
    23  
    24  	"github.com/cockroachdb/cockroach/pkg/sql/opt"
    25  	"github.com/cockroachdb/cockroach/pkg/util/randutil"
    26  	"github.com/cockroachdb/errors"
    27  )
    28  
    29  const debug = false
    30  
    31  // testVal is a value in a row in a test relation.
    32  // Value of 0 is special and is treated as NULL.
    33  type testVal uint8
    34  
    35  const null testVal = 0
    36  
    37  func (v testVal) String() string {
    38  	if v == null {
    39  		return "NULL"
    40  	}
    41  	return fmt.Sprintf("%d", v)
    42  }
    43  
    44  // A testRow in a test relation. Value of 0 is special and is treated as NULL.
    45  // The first value corresponds to ColumnID 1, and so on.
    46  // A testRow is immutable after initial construction.
    47  type testRow []testVal
    48  
    49  func (tr testRow) value(col opt.ColumnID) testVal {
    50  	return tr[col-1]
    51  }
    52  
    53  func (tr testRow) String() string {
    54  	var b strings.Builder
    55  	for i, v := range tr {
    56  		if i > 0 {
    57  			b.WriteByte(' ')
    58  		}
    59  		b.WriteString(v.String())
    60  	}
    61  	return b.String()
    62  }
    63  
    64  // rowKey encodes the values of at most 16 columns.
    65  type rowKey uint64
    66  
    67  // key generates a rowKey from the values on the given columns. Also returns
    68  // whether there were any null values.
    69  func (tr testRow) key(cols opt.ColSet) (_ rowKey, hasNulls bool) {
    70  	if cols.Len() > 16 {
    71  		panic(errors.AssertionFailedf("max 16 columns supported"))
    72  	}
    73  	var key rowKey
    74  	cols.ForEach(func(c opt.ColumnID) {
    75  		val := tr.value(c)
    76  		if val == null {
    77  			hasNulls = true
    78  		}
    79  		if val > 15 {
    80  			panic(errors.AssertionFailedf("testVal must be <= 15"))
    81  		}
    82  		key = (key << 4) | rowKey(val)
    83  	})
    84  	return key, hasNulls
    85  }
    86  
    87  // hasNulls returns true if the row has a null value on any of the given
    88  // columns.
    89  func (tr testRow) hasNulls(cols opt.ColSet) bool {
    90  	var res bool
    91  	cols.ForEach(func(c opt.ColumnID) {
    92  		res = res || tr.value(c) == null
    93  	})
    94  	return res
    95  }
    96  
    97  // equalOn returns true if the two rows are equal on the given columns.
    98  func (tr testRow) equalOn(other testRow, cols opt.ColSet) bool {
    99  	eq := true
   100  	cols.ForEach(func(c opt.ColumnID) {
   101  		if tr.value(c) != other.value(c) {
   102  			eq = false
   103  		}
   104  	})
   105  	return eq
   106  }
   107  
   108  type testRelation []testRow
   109  
   110  // String prints out the test relation in the following format:
   111  //
   112  //   1     2     3
   113  //   -------------
   114  //   NULL  1     2
   115  //   3     NULL  4
   116  //
   117  func (tr testRelation) String() string {
   118  	if len(tr) == 0 {
   119  		return "  <empty>\n"
   120  	}
   121  	var buf bytes.Buffer
   122  	tw := tabwriter.NewWriter(&buf, 2, 1, 2, ' ', 0)
   123  	for i := range tr[0] {
   124  		fmt.Fprintf(tw, "%d\t", i+1)
   125  	}
   126  	fmt.Fprint(tw, "\n")
   127  	for _, r := range tr {
   128  		for _, v := range r {
   129  			fmt.Fprintf(tw, "%s\t", v)
   130  		}
   131  		fmt.Fprint(tw, "\n")
   132  	}
   133  	_ = tw.Flush()
   134  
   135  	rows := strings.Split(buf.String(), "\n")
   136  	buf.Reset()
   137  	fmt.Fprintf(&buf, "  %s\n  ", strings.TrimRight(rows[0], " "))
   138  	for range rows[0] {
   139  		buf.WriteByte('-')
   140  	}
   141  	buf.WriteString("\n")
   142  	for _, r := range rows[1:] {
   143  		fmt.Fprintf(&buf, "  %s\n", strings.TrimRight(r, " "))
   144  	}
   145  	return buf.String()
   146  }
   147  
   148  // checkKey verifies that a certain key (strict or lax) is valid for the test
   149  // relation.
   150  func (tr testRelation) checkKey(key opt.ColSet, typ keyType) error {
   151  	m := make(map[rowKey]testRow)
   152  	for _, r := range tr {
   153  		k, hasNulls := r.key(key)
   154  		// If it is a lax key, we can ignore any rows that contain a NULL on the
   155  		// key columns.
   156  		if typ == laxKey && hasNulls {
   157  			continue
   158  		}
   159  		if existingRow, ok := m[k]; ok {
   160  			keyStr := ""
   161  			if typ == laxKey {
   162  				keyStr = "lax-"
   163  			}
   164  			return fmt.Errorf(
   165  				"%skey%s doesn't hold on rows:\n%s", keyStr, key, testRelation{r, existingRow},
   166  			)
   167  		}
   168  		m[k] = r
   169  	}
   170  	return nil
   171  }
   172  
   173  // checkFD verifies that a certain FD holds for the test relation.
   174  func (tr testRelation) checkFD(dep funcDep) error {
   175  	if dep.equiv {
   176  		// An equivalence FD is easy to check row-by-row.
   177  		for _, r := range tr {
   178  			c, _ := dep.from.Next(0)
   179  			val := r.value(c)
   180  			fail := false
   181  			dep.to.ForEach(func(col opt.ColumnID) {
   182  				if r.value(col) != val {
   183  					fail = true
   184  				}
   185  			})
   186  			if fail {
   187  				return fmt.Errorf("FD %s doesn't hold on row %s", &dep, r)
   188  			}
   189  		}
   190  		return nil
   191  	}
   192  
   193  	// We split the rows into groups (keyed on the `from` columns), picking the
   194  	// first row in each group as the "representative" of that group. All other
   195  	// rows in the group are checked against the representative row.
   196  	m := make(map[rowKey]testRow)
   197  	for _, r := range tr {
   198  		k, hasNulls := r.key(dep.from)
   199  		// If it is not a strict FD, we can ignore any rows that contain a NULL on
   200  		// the 'from' columns.
   201  		if !dep.strict && hasNulls {
   202  			continue
   203  		}
   204  
   205  		if first, ok := m[k]; ok {
   206  			if !first.equalOn(r, dep.to) {
   207  				return fmt.Errorf("FD %s doesn't hold on rows:\n%s", &dep, testRelation{first, r})
   208  			}
   209  		} else {
   210  			m[k] = r
   211  		}
   212  	}
   213  	return nil
   214  }
   215  
   216  // checkFDs verifies that the given FDs hold against the test relation.
   217  func (tr testRelation) checkFDs(fd *FuncDepSet) error {
   218  	// Check deps.
   219  	for _, dep := range fd.deps {
   220  		if err := tr.checkFD(dep); err != nil {
   221  			return err
   222  		}
   223  	}
   224  
   225  	// Check keys.
   226  	if fd.hasKey != noKey {
   227  		if err := tr.checkKey(fd.key, fd.hasKey); err != nil {
   228  			return err
   229  		}
   230  	}
   231  
   232  	return nil
   233  }
   234  
   235  // notNullCols returns the set columns that have no nulls in the test relation.
   236  func (tr testRelation) notNullCols(numCols int) opt.ColSet {
   237  	var res opt.ColSet
   238  	for c := opt.ColumnID(1); c <= opt.ColumnID(numCols); c++ {
   239  		res.Add(c)
   240  		for _, r := range tr {
   241  			if r.value(c) == null {
   242  				res.Remove(c)
   243  				break
   244  			}
   245  		}
   246  	}
   247  	return res
   248  }
   249  
   250  // joinTestRelations creates a possible result of joining two testRelations,
   251  // specifically:
   252  //  - an inner join if both leftOuter and rightOuter are false;
   253  //  - a left/right outer join if one of them is true;
   254  //  - a full outer join if both are true.
   255  func joinTestRelations(
   256  	numLeftCols int,
   257  	left testRelation,
   258  	numRightCols int,
   259  	right testRelation,
   260  	filters []testOp,
   261  	leftOuter bool,
   262  	rightOuter bool,
   263  ) testRelation {
   264  	var res testRelation
   265  
   266  	// Adds the given rows to the join result.
   267  	add := func(l, r testRow) {
   268  		newRow := make(testRow, numLeftCols+numRightCols)
   269  		if l != nil {
   270  			copy(newRow, l)
   271  		}
   272  		if r != nil {
   273  			copy(newRow[numLeftCols:], r)
   274  		}
   275  		res = append(res, newRow)
   276  	}
   277  
   278  	// Perform a cross join between the left and right relations.
   279  	for _, leftRow := range left {
   280  		for _, rightRow := range right {
   281  			add(leftRow, rightRow)
   282  		}
   283  	}
   284  
   285  	// Apply the filters to the result of the cross join.
   286  	for i := range filters {
   287  		res = filters[i].FilterRelation(res)
   288  	}
   289  
   290  	// Walk through the now-filtered cartesian product and keep track of all
   291  	// unique left and right rows.
   292  	leftCols := makeCols(numLeftCols)
   293  	rightCols := makeCols(numRightCols)
   294  	matchedLeftRows := map[rowKey]struct{}{}
   295  	matchedRightRows := map[rowKey]struct{}{}
   296  	for _, row := range res {
   297  		leftKey, _ := row.key(leftCols)
   298  		rightKey, _ := row.key(shiftSet(rightCols, numLeftCols))
   299  		matchedLeftRows[leftKey] = struct{}{}
   300  		matchedRightRows[rightKey] = struct{}{}
   301  	}
   302  
   303  	// If leftOuter is true, add back any left rows that were filtered out,
   304  	// null-extending the right side.
   305  	if leftOuter {
   306  		for _, row := range left {
   307  			key, _ := row.key(leftCols)
   308  			if _, ok := matchedLeftRows[key]; !ok {
   309  				add(row, nil)
   310  			}
   311  		}
   312  	}
   313  
   314  	// If rightOuter is true, add back any right rows that were filtered out,
   315  	// null-extending the left side.
   316  	if rightOuter {
   317  		for _, row := range right {
   318  			key, _ := row.key(rightCols)
   319  			if _, ok := matchedRightRows[key]; !ok {
   320  				add(nil, row)
   321  			}
   322  		}
   323  	}
   324  
   325  	if debug {
   326  		fmt.Printf("left:\n%s", left)
   327  		fmt.Printf("right:\n%s", right)
   328  		fmt.Printf("filters:\n%s", filters)
   329  		fmt.Printf("join(leftOuter=%t, rightOuter=%t):\n%s", leftOuter, rightOuter, res)
   330  	}
   331  	return res
   332  }
   333  
   334  // testOp is an interface implemented by test operations. Each test operation
   335  // makes a call to an FD API and filters out some of the rows in the current
   336  // test relation in order for that API call to be correct. The resulting FDs are
   337  // checked to hold on the updated test relations (and various FD APIs are
   338  // checked as well).
   339  type testOp interface {
   340  	fmt.Stringer
   341  
   342  	// FilterRelation returns a new subset testRelation that is consistent with
   343  	// the FD operation.
   344  	FilterRelation(tr testRelation) testRelation
   345  
   346  	// ApplyToFDs returns a new FuncDepSet after the operation.
   347  	ApplyToFDs(fd FuncDepSet) FuncDepSet
   348  }
   349  
   350  var _ testOp = &addKeyOp{}
   351  
   352  type testConfig struct {
   353  	// numCols is the number of columns in the relation.
   354  	// Can be at most 8 (see rowKey).
   355  	numCols int
   356  	// valRange is the range of values in the test relation.
   357  	valRange testVal
   358  }
   359  
   360  func (tc *testConfig) randCol() opt.ColumnID {
   361  	return opt.ColumnID(rand.Intn(tc.numCols) + 1)
   362  }
   363  
   364  func (tc *testConfig) randColSet(minLen, maxLen int) opt.ColSet {
   365  	if maxLen > tc.numCols {
   366  		panic(errors.AssertionFailedf("maxLen > numCols"))
   367  	}
   368  	length := rand.Intn(maxLen-minLen+1) + minLen
   369  	// Use Robert Floyd's algorithm to generate <length> distinct integers between
   370  	// 0 and numCols-1, just because it's so cool!
   371  	var res opt.ColSet
   372  	for j := tc.numCols - length; j < tc.numCols; j++ {
   373  		if t := rand.Intn(j + 1); !res.Contains(opt.ColumnID(t + 1)) {
   374  			res.Add(opt.ColumnID(t + 1))
   375  		} else {
   376  			res.Add(opt.ColumnID(j + 1))
   377  		}
   378  	}
   379  	return res
   380  }
   381  
   382  func (tc *testConfig) allCols() opt.ColSet {
   383  	return makeCols(tc.numCols)
   384  }
   385  
   386  // initTestRelation creates a testRelation with all possible combinations of
   387  // values in the set {0 (null), 1, 2, ... valRange}.
   388  func (tc *testConfig) initTestRelation() testRelation {
   389  	var tr testRelation
   390  
   391  	// genRows takes a row prefix and recursively generates all rows with that
   392  	// prefix, appending them to tr.
   393  	var genRows func(vals []testVal)
   394  	genRows = func(vals []testVal) {
   395  		n := len(vals)
   396  		if n == tc.numCols {
   397  			// Add each row twice to have duplicates.
   398  			tr = append(tr, vals, vals)
   399  			return
   400  		}
   401  		for i := null; i <= tc.valRange; i++ {
   402  			genRows(append(vals[:n:n], i))
   403  		}
   404  	}
   405  	genRows(nil)
   406  	return tr
   407  }
   408  
   409  func (tc *testConfig) checkAPIs(fd *FuncDepSet, tr testRelation) error {
   410  	if fd.HasMax1Row() && len(tr) > 1 {
   411  		return fmt.Errorf("HasMax1Row() incorrectly returns true")
   412  	}
   413  
   414  	for t := 0; t < 5; t++ {
   415  		cols := tc.randColSet(1, tc.numCols)
   416  
   417  		if fd.ColsAreLaxKey(cols) {
   418  			if err := tr.checkKey(cols, laxKey); err != nil {
   419  				return fmt.Errorf("ColsAreLaxKey%s incorrectly returns true", cols)
   420  			}
   421  		}
   422  
   423  		if fd.ColsAreStrictKey(cols) {
   424  			if err := tr.checkKey(cols, strictKey); err != nil {
   425  				return fmt.Errorf("ColsAreStrictKey%s incorrectly returns true", cols)
   426  			}
   427  		}
   428  
   429  		closure := fd.ComputeClosure(cols)
   430  		if err := tr.checkFD(funcDep{
   431  			from:   cols,
   432  			to:     closure,
   433  			strict: true,
   434  		}); err != nil {
   435  			return fmt.Errorf("ComputeClosure%s incorrectly returns %s: %s", cols, closure, err)
   436  		}
   437  
   438  		reduced := fd.ReduceCols(cols)
   439  		if err := tr.checkFD(funcDep{
   440  			from:   reduced,
   441  			to:     cols,
   442  			strict: true,
   443  		}); err != nil {
   444  			return fmt.Errorf("ReduceCols%s incorrectly returns %s: %s", cols, reduced, err)
   445  		}
   446  
   447  		var proj FuncDepSet
   448  		proj.CopyFrom(fd)
   449  		proj.ProjectCols(cols)
   450  		// The FDs after projection should still hold on the table.
   451  		if err := tr.checkFDs(&proj); err != nil {
   452  			return fmt.Errorf("ProjectCols%s incorrectly returns %s: %s", cols, proj.String(), err)
   453  		}
   454  	}
   455  
   456  	return nil
   457  }
   458  
   459  // testOpGenerator generates a testOp, given the number of columns in the
   460  // relation.
   461  type testOpGenerator = func(tc *testConfig) testOp
   462  
   463  // addKeyOp is a test operation corresponding to AddStrictKey / AddLaxKey.
   464  type addKeyOp struct {
   465  	allCols opt.ColSet
   466  	key     opt.ColSet
   467  	typ     keyType
   468  }
   469  
   470  func genAddKey(minKeyCols, maxKeyCols int) testOpGenerator {
   471  	return func(tc *testConfig) testOp {
   472  		cols := tc.randColSet(minKeyCols, maxKeyCols)
   473  		typ := strictKey
   474  		if !cols.Empty() && rand.Int()%2 == 0 {
   475  			typ = laxKey
   476  		}
   477  		return &addKeyOp{
   478  			allCols: tc.allCols(),
   479  			key:     cols,
   480  			typ:     typ,
   481  		}
   482  	}
   483  }
   484  
   485  func (o *addKeyOp) String() string {
   486  	if o.typ == strictKey {
   487  		return fmt.Sprintf("AddStrictKey%s", o.key)
   488  	}
   489  	return fmt.Sprintf("AddLaxKey%s", o.key)
   490  }
   491  
   492  func (o *addKeyOp) FilterRelation(tr testRelation) testRelation {
   493  	var out testRelation
   494  	// Process the rows in random order and remove any duplicate rows.
   495  	m := make(map[rowKey]bool)
   496  	perm := rand.Perm(len(tr))
   497  	for _, rowIdx := range perm {
   498  		r := tr[perm[rowIdx]]
   499  		key, hasNulls := r.key(o.key)
   500  		// If it is a lax key, we can leave all rows that contain a NULL on the
   501  		// key columns.
   502  		if o.typ == laxKey && hasNulls {
   503  			out = append(out, r)
   504  			continue
   505  		}
   506  		if !m[key] {
   507  			out = append(out, r)
   508  			m[key] = true
   509  		}
   510  	}
   511  	return out
   512  }
   513  
   514  func (o *addKeyOp) ApplyToFDs(fd FuncDepSet) FuncDepSet {
   515  	var out FuncDepSet
   516  	out.CopyFrom(&fd)
   517  	if o.typ == strictKey {
   518  		out.AddStrictKey(o.key, o.allCols)
   519  	} else {
   520  		out.AddLaxKey(o.key, o.allCols)
   521  	}
   522  	return out
   523  }
   524  
   525  // makeNotNullOp is a test operation corresponding to MakeNotNull.
   526  type makeNotNullOp struct {
   527  	cols opt.ColSet
   528  }
   529  
   530  func genMakeNotNull(minCols, maxCols int) testOpGenerator {
   531  	return func(tc *testConfig) testOp {
   532  		return &makeNotNullOp{
   533  			cols: tc.randColSet(minCols, maxCols),
   534  		}
   535  	}
   536  }
   537  
   538  func (o *makeNotNullOp) String() string {
   539  	return fmt.Sprintf("MakeNotNull%s", o.cols)
   540  }
   541  
   542  func (o *makeNotNullOp) FilterRelation(tr testRelation) testRelation {
   543  	var out testRelation
   544  	for _, r := range tr {
   545  		if !r.hasNulls(o.cols) {
   546  			out = append(out, r)
   547  		}
   548  	}
   549  	return out
   550  }
   551  
   552  func (o *makeNotNullOp) ApplyToFDs(fd FuncDepSet) FuncDepSet {
   553  	var out FuncDepSet
   554  	out.CopyFrom(&fd)
   555  	out.MakeNotNull(o.cols)
   556  	return out
   557  }
   558  
   559  // addConstOp is a test operation corresponding to AddConstants.
   560  type addConstOp struct {
   561  	cols opt.ColSet
   562  	vals []testVal
   563  }
   564  
   565  func genAddConst(minCols, maxCols int) testOpGenerator {
   566  	return func(tc *testConfig) testOp {
   567  		cols := tc.randColSet(minCols, maxCols)
   568  		vals := make([]testVal, cols.Len())
   569  		for i := range vals {
   570  			vals[i] = testVal(rand.Intn(int(tc.valRange + 1)))
   571  		}
   572  		return &addConstOp{
   573  			cols: cols,
   574  			vals: vals,
   575  		}
   576  	}
   577  }
   578  
   579  func (o *addConstOp) String() string {
   580  	return fmt.Sprintf("AddConstants%s values {%v}", o.cols, testRow(o.vals).String())
   581  }
   582  
   583  func (o *addConstOp) FilterRelation(tr testRelation) testRelation {
   584  	var out testRelation
   585  	for _, r := range tr {
   586  		idx := 0
   587  		ok := true
   588  		o.cols.ForEach(func(c opt.ColumnID) {
   589  			if r[c-1] != o.vals[idx] {
   590  				ok = false
   591  			}
   592  			idx++
   593  		})
   594  		if ok {
   595  			out = append(out, r)
   596  		}
   597  	}
   598  	return out
   599  }
   600  
   601  func (o *addConstOp) ApplyToFDs(fd FuncDepSet) FuncDepSet {
   602  	var out FuncDepSet
   603  	out.CopyFrom(&fd)
   604  	out.AddConstants(o.cols)
   605  	return out
   606  }
   607  
   608  // addEquivOp is a test operation corresponding to AddEquivalency.
   609  type addEquivOp struct {
   610  	a, b opt.ColumnID
   611  }
   612  
   613  func genAddEquiv() testOpGenerator {
   614  	return func(tc *testConfig) testOp {
   615  		return &addEquivOp{
   616  			a: tc.randCol(),
   617  			b: tc.randCol(),
   618  		}
   619  	}
   620  }
   621  
   622  func (o *addEquivOp) String() string {
   623  	return fmt.Sprintf("AddEquivalency(%d,%d)", o.a, o.b)
   624  }
   625  
   626  func (o *addEquivOp) FilterRelation(tr testRelation) testRelation {
   627  	// Filter out rows where the equivalency doesn't hold.
   628  	var out testRelation
   629  	for _, r := range tr {
   630  		if r.value(o.a) == r.value(o.b) {
   631  			out = append(out, r)
   632  		}
   633  	}
   634  	return out
   635  }
   636  
   637  func (o *addEquivOp) ApplyToFDs(fd FuncDepSet) FuncDepSet {
   638  	var out FuncDepSet
   639  	out.CopyFrom(&fd)
   640  	out.AddEquivalency(o.a, o.b)
   641  	return out
   642  }
   643  
   644  // addSynthOp is a test operation corresponding to AddSynthesizedCol.
   645  type addSynthOp struct {
   646  	from opt.ColSet
   647  	to   opt.ColumnID
   648  }
   649  
   650  func genAddSynth(minCols, maxCols int) testOpGenerator {
   651  	return func(tc *testConfig) testOp {
   652  		from := tc.randColSet(minCols, maxCols)
   653  		to := tc.randCol()
   654  		from.Remove(to)
   655  		return &addSynthOp{
   656  			from: from,
   657  			to:   to,
   658  		}
   659  	}
   660  }
   661  
   662  func (o *addSynthOp) String() string {
   663  	return fmt.Sprintf("AddSynthesizedCol(%s, %d)", o.from, o.to)
   664  }
   665  
   666  func (o *addSynthOp) FilterRelation(tr testRelation) testRelation {
   667  	// Filter out rows where the from->to FD doesn't hold. The code here parallels
   668  	// that in testRelation.checkKey.
   669  	//
   670  	// We split the rows into groups (keyed on the `from` columns), picking the
   671  	// first row in each group as the "representative" of that group. All other
   672  	// rows in the group are checked against the representative row.
   673  	var out testRelation
   674  	m := make(map[rowKey]testRow)
   675  	perm := rand.Perm(len(tr))
   676  	for _, rowIdx := range perm {
   677  		r := tr[rowIdx]
   678  		k, _ := r.key(o.from)
   679  		if first, ok := m[k]; ok {
   680  			if first.value(o.to) != r.value(o.to) {
   681  				// Filter out row.
   682  				continue
   683  			}
   684  		} else {
   685  			m[k] = r
   686  		}
   687  		out = append(out, r)
   688  	}
   689  	return out
   690  }
   691  
   692  func (o *addSynthOp) ApplyToFDs(fd FuncDepSet) FuncDepSet {
   693  	var out FuncDepSet
   694  	out.CopyFrom(&fd)
   695  	out.AddSynthesizedCol(o.from, o.to)
   696  	return out
   697  }
   698  
   699  // testState corresponds to a chain of applied test operations. The head of a
   700  // testStates chain has no parent and no op and just corresponds to the initial
   701  // (empty) FDs and test relation.
   702  type testState struct {
   703  	parent *testState
   704  	cfg    *testConfig
   705  
   706  	op  testOp
   707  	fds FuncDepSet
   708  	rel testRelation
   709  }
   710  
   711  func (ts *testState) format(b *strings.Builder) {
   712  	if ts.parent == nil {
   713  		fmt.Fprintf(b, "initial numCols=%d valRange=%d\n", ts.cfg.numCols, ts.cfg.valRange)
   714  	} else {
   715  		ts.parent.format(b)
   716  		fmt.Fprintf(b, " => %s\n", ts.op.String())
   717  		fmt.Fprintf(b, "    FDs: %s\n", ts.fds.String())
   718  	}
   719  }
   720  
   721  // String describes the chain of operations and corresponding FDs.
   722  // For example:
   723  //   initial numCols=3 valRange=3
   724  //    => MakeNotNull(2)
   725  //       FDs:
   726  //    => AddConstants(1,3) values {NULL,1}
   727  //       FDs: ()-->(1,3)
   728  //    => AddLaxKey(3)
   729  //       FDs: ()-->(1-3)
   730  //
   731  func (ts *testState) String() string {
   732  	var b strings.Builder
   733  	ts.format(&b)
   734  	return b.String()
   735  }
   736  
   737  func newTestState(cfg *testConfig) *testState {
   738  	state := &testState{cfg: cfg}
   739  	state.rel = cfg.initTestRelation()
   740  	return state
   741  }
   742  
   743  func (ts *testState) child(t *testing.T, op testOp) *testState {
   744  	child := &testState{
   745  		parent: ts,
   746  		cfg:    ts.cfg,
   747  		op:     op,
   748  		rel:    op.FilterRelation(ts.rel),
   749  	}
   750  
   751  	err := func() (err error) {
   752  		defer func() {
   753  			if r := recover(); r != nil {
   754  				err = errors.AssertionFailedf("%v", r)
   755  			}
   756  		}()
   757  		child.fds = op.ApplyToFDs(ts.fds)
   758  		child.fds.Verify()
   759  		if err = child.rel.checkFDs(&child.fds); err != nil {
   760  			return err
   761  		}
   762  		if err = ts.cfg.checkAPIs(&child.fds, child.rel); err != nil {
   763  			return err
   764  		}
   765  		return nil
   766  	}()
   767  	if err != nil {
   768  		t.Fatalf("details below\n%s\n%+v", child.String(), err)
   769  	}
   770  
   771  	return child
   772  }
   773  
   774  // TestFuncDepOpsRandom performs random FD operations and maintains a test
   775  // relation in parallel, making sure that the FDs always hold w.r.t the test
   776  // table. We start with a "full" table (all possible combinations of values in a
   777  // certain range) and each operation filters out rows in conformance with the
   778  // operation (e.g. if we add a key, we remove duplicate rows). We also test
   779  // various FuncDepSet APIs at each stage.
   780  //
   781  // To reuse work, instead of generating one chain of operations at a time, we
   782  // generate a tree of operations; each path from root to a leaf is a chain that
   783  // is getting tested.
   784  //
   785  func TestFuncDepOpsRandom(t *testing.T) {
   786  	type testParams struct {
   787  		testConfig
   788  
   789  		// maxDepth is the maximum length of a chain of test operations.
   790  		maxDepth int
   791  
   792  		// branching is the number of test operations, generated at each level.
   793  		// The number of total operations is branching ^ maxDepth.
   794  		branching int
   795  
   796  		ops []testOpGenerator
   797  	}
   798  
   799  	testConfigs := []testParams{
   800  		{
   801  			testConfig: testConfig{
   802  				numCols:  3,
   803  				valRange: 2,
   804  			},
   805  			maxDepth:  2,
   806  			branching: 3,
   807  			ops: []testOpGenerator{
   808  				genAddKey(0 /* minKeyCols */, 2 /* maxKeyCols */),
   809  				genMakeNotNull(1 /* minCols */, 3 /* maxCols */),
   810  				genAddConst(1 /* minCols */, 3 /* maxCols */),
   811  				genAddEquiv(),
   812  				genAddSynth(0 /* minCols */, 3 /* maxCols */),
   813  			},
   814  		},
   815  
   816  		{
   817  			testConfig: testConfig{
   818  				numCols:  5,
   819  				valRange: 2,
   820  			},
   821  			maxDepth:  5,
   822  			branching: 2,
   823  			ops: []testOpGenerator{
   824  				genAddKey(1 /* minKeyCols */, 5 /* maxKeyCols */),
   825  				genMakeNotNull(1 /* minCols */, 4 /* maxCols */),
   826  				genAddConst(1 /* minCols */, 4 /* maxCols */),
   827  				genAddEquiv(),
   828  				genAddSynth(0 /* minCols */, 3 /* maxCols */),
   829  			},
   830  		},
   831  	}
   832  
   833  	filterConfigs := []testParams{
   834  		{
   835  			testConfig: testConfig{
   836  				numCols:  6,
   837  				valRange: 2,
   838  			},
   839  			maxDepth:  3,
   840  			branching: 3,
   841  			ops: []testOpGenerator{
   842  				genMakeNotNull(1 /* minCols */, 4 /* maxCols */),
   843  				genAddConst(0 /* minCols */, 6 /* maxCols */),
   844  				genAddEquiv(),
   845  			},
   846  		},
   847  
   848  		{
   849  			testConfig: testConfig{
   850  				numCols:  8,
   851  				valRange: 2,
   852  			},
   853  			maxDepth:  2,
   854  			branching: 3,
   855  			ops: []testOpGenerator{
   856  				genMakeNotNull(0 /* minCols */, 8 /* maxCols */),
   857  				genAddConst(3 /* minCols */, 5 /* maxCols */),
   858  				genAddEquiv(),
   859  			},
   860  		},
   861  
   862  		{
   863  			testConfig: testConfig{
   864  				numCols:  10,
   865  				valRange: 2,
   866  			},
   867  			maxDepth:  3,
   868  			branching: 6,
   869  			ops: []testOpGenerator{
   870  				genMakeNotNull(2 /* minCols */, 7 /* maxCols */),
   871  				genAddConst(2 /* minCols */, 3 /* maxCols */),
   872  				genAddEquiv(),
   873  			},
   874  		},
   875  	}
   876  
   877  	// Allows a set of filters to be chosen based on the total number of columns
   878  	// in the input relations.
   879  	filterIndex := map[int]int{6: 0, 8: 1, 10: 2}
   880  
   881  	const repeats = 100
   882  
   883  	for _, cfg := range testConfigs {
   884  		for r := 0; r < repeats; r++ {
   885  			var run func(state *testState, depth int)
   886  			run = func(state *testState, depth int) {
   887  				for i := 0; i < cfg.branching; i++ {
   888  					opGenIdx := rand.Intn(len(cfg.ops))
   889  					op := cfg.ops[opGenIdx](&cfg.testConfig)
   890  
   891  					child := state.child(t, op)
   892  
   893  					if debug {
   894  						fmt.Printf("%d: %s %s\n", depth, op, child.fds.String())
   895  						for _, r := range child.rel {
   896  							fmt.Printf("  %s\n", r)
   897  						}
   898  					}
   899  
   900  					if depth < cfg.maxDepth {
   901  						run(child, depth+1)
   902  					}
   903  				}
   904  			}
   905  			run(newTestState(&cfg.testConfig), 1 /* depth */)
   906  		}
   907  	}
   908  
   909  	// Run tests for joins.
   910  	for r := 0; r < repeats; r++ {
   911  		// Generate left and right input op chains.
   912  		genRandChain := func() *testState {
   913  			cfg := testConfigs[rand.Intn(len(testConfigs))]
   914  			state := newTestState(&cfg.testConfig)
   915  
   916  			steps := 1 + rand.Intn(cfg.maxDepth)
   917  			for i := 0; i < steps; i++ {
   918  				opGenIdx := rand.Intn(len(cfg.ops))
   919  				op := cfg.ops[opGenIdx](&cfg.testConfig)
   920  				state = state.child(t, op)
   921  			}
   922  			return state
   923  		}
   924  		left := genRandChain()
   925  		nLeft := left.cfg.numCols
   926  		leftCols := left.cfg.allCols()
   927  		right := genRandChain()
   928  		nRight := right.cfg.numCols
   929  		rightCols := shiftSet(right.cfg.allCols(), nLeft)
   930  		rightFDs := shiftColumns(right.fds, nLeft)
   931  
   932  		// Generate filter ops.
   933  		var filters []testOp
   934  		filtersFDs := FuncDepSet{}
   935  		cfg := filterConfigs[filterIndex[nLeft+nRight]]
   936  		steps := 1 + rand.Intn(cfg.maxDepth)
   937  		for i := 0; i < steps; i++ {
   938  			opGenIdx := rand.Intn(len(cfg.ops))
   939  			op := cfg.ops[opGenIdx](&cfg.testConfig)
   940  			filters = append(filters, op)
   941  			filtersFDs = op.ApplyToFDs(filtersFDs)
   942  		}
   943  
   944  		// Test inner join.
   945  		join := joinTestRelations(
   946  			nLeft, left.rel, nRight, right.rel, filters, false /* leftOuter */, false, /* rightOuter */
   947  		)
   948  
   949  		var fd FuncDepSet
   950  		fd.CopyFrom(&left.fds)
   951  		fd.MakeProduct(&rightFDs)
   952  		fd.AddFrom(&filtersFDs)
   953  		if err := join.checkFDs(&fd); err != nil {
   954  			t.Fatalf(
   955  				"MakeProduct and AddFrom returned incorrect FDs\n"+
   956  					"left:    %s\n"+
   957  					"right:   %s\n"+
   958  					"filters: %s\n"+
   959  					"result:  %s\n"+
   960  					"error:   %+v\n\n"+
   961  					"left side: %s\n"+
   962  					"right side (cols shifted by %d): %s\n"+
   963  					"filter steps: %s",
   964  				&left.fds, &rightFDs, &filtersFDs, &fd, err, left, nLeft, right, filters,
   965  			)
   966  		}
   967  
   968  		leftNotNullCols := left.rel.notNullCols(nLeft)
   969  		rightNotNullCols := shiftSet(right.rel.notNullCols(nRight), nLeft)
   970  		notNullInputCols := leftNotNullCols.Union(rightNotNullCols)
   971  
   972  		// Test left join.
   973  		join = joinTestRelations(
   974  			nLeft, left.rel, nRight, right.rel, filters, true /* leftOuter */, false, /* rightOuter */
   975  		)
   976  
   977  		fd.CopyFrom(&left.fds)
   978  		fd.MakeProduct(&rightFDs)
   979  		fd.MakeLeftOuter(&left.fds, &filtersFDs, leftCols, rightCols, notNullInputCols)
   980  		if err := join.checkFDs(&fd); err != nil {
   981  			t.Fatalf(
   982  				"MakeLeftOuter(..., %s, %s, %s) returned incorrect FDs\n"+
   983  					"left:    %s\n"+
   984  					"right:   %s\n"+
   985  					"filters: %s\n"+
   986  					"result:  %s\n"+
   987  					"error:   %+v\n\n"+
   988  					"left side: %s\n"+
   989  					"right side (cols shifted by %d): %s\n"+
   990  					"filter steps: %s",
   991  				leftCols, rightCols, notNullInputCols,
   992  				&left.fds, &rightFDs, &filtersFDs, &fd, err, left, nLeft, right, filters,
   993  			)
   994  		}
   995  
   996  		// Test outer join.
   997  		join = joinTestRelations(
   998  			nLeft, left.rel, nRight, right.rel, nil, true /* leftOuter */, true, /* rightOuter */
   999  		)
  1000  
  1001  		fd.CopyFrom(&left.fds)
  1002  		fd.MakeProduct(&rightFDs)
  1003  		fd.MakeFullOuter(leftCols, rightCols, notNullInputCols)
  1004  		if err := join.checkFDs(&fd); err != nil {
  1005  			t.Fatalf(
  1006  				"MakeFullOuter(%s, %s, %s) returned incorrect FDs\n"+
  1007  					"left:   %s\n"+
  1008  					"right:  %s\n"+
  1009  					"result: %s\n"+
  1010  					"error:  %+v\n\n"+
  1011  					"left side: %s\n"+
  1012  					"right side (cols shifted by %d): %s",
  1013  				leftCols, rightCols, notNullInputCols,
  1014  				&left.fds, &rightFDs, &fd, err, left, nLeft, right,
  1015  			)
  1016  		}
  1017  	}
  1018  }
  1019  
  1020  func TestMain(m *testing.M) {
  1021  	randutil.SeedForTests()
  1022  	os.Exit(m.Run())
  1023  }
  1024  
  1025  func shiftSet(cols opt.ColSet, delta int) opt.ColSet {
  1026  	var res opt.ColSet
  1027  	cols.ForEach(func(col opt.ColumnID) {
  1028  		res.Add(col + opt.ColumnID(delta))
  1029  	})
  1030  	return res
  1031  }
  1032  
  1033  func shiftColumns(fd FuncDepSet, delta int) FuncDepSet {
  1034  	var res FuncDepSet
  1035  	res.CopyFrom(&fd)
  1036  	for i := range res.deps {
  1037  		d := &res.deps[i]
  1038  		d.from = shiftSet(d.from, delta)
  1039  		d.to = shiftSet(d.to, delta)
  1040  	}
  1041  	res.key = shiftSet(res.key, delta)
  1042  	return res
  1043  }
  1044  
  1045  func makeCols(numCols int) opt.ColSet {
  1046  	var allCols opt.ColSet
  1047  	for i := opt.ColumnID(1); i <= opt.ColumnID(numCols); i++ {
  1048  		allCols.Add(i)
  1049  	}
  1050  	return allCols
  1051  }