github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/internal/sqlsmith/schema.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sqlsmith
    12  
    13  import (
    14  	gosql "database/sql"
    15  	"fmt"
    16  	"strings"
    17  
    18  	// Import builtins so they are reflected in tree.FunDefs.
    19  	_ "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins"
    20  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    22  	"github.com/lib/pq/oid"
    23  )
    24  
    25  // tableRef represents a table and its columns.
    26  type tableRef struct {
    27  	TableName *tree.TableName
    28  	Columns   []*tree.ColumnTableDef
    29  }
    30  
    31  type aliasedTableRef struct {
    32  	*tableRef
    33  	indexFlags *tree.IndexFlags
    34  }
    35  
    36  type tableRefs []*tableRef
    37  
    38  // ReloadSchemas loads tables from the database.
    39  func (s *Smither) ReloadSchemas() error {
    40  	if s.db == nil {
    41  		return nil
    42  	}
    43  	s.lock.Lock()
    44  	defer s.lock.Unlock()
    45  	var err error
    46  	s.tables, err = extractTables(s.db)
    47  	if err != nil {
    48  		return err
    49  	}
    50  	s.indexes, err = extractIndexes(s.db, s.tables)
    51  	s.columns = make(map[tree.TableName]map[tree.Name]*tree.ColumnTableDef)
    52  	for _, ref := range s.tables {
    53  		s.columns[*ref.TableName] = make(map[tree.Name]*tree.ColumnTableDef)
    54  		for _, col := range ref.Columns {
    55  			s.columns[*ref.TableName][col.Name] = col
    56  		}
    57  	}
    58  	return err
    59  }
    60  
    61  func (s *Smither) getRandTable() (*aliasedTableRef, bool) {
    62  	s.lock.RLock()
    63  	defer s.lock.RUnlock()
    64  	if len(s.tables) == 0 {
    65  		return nil, false
    66  	}
    67  	table := s.tables[s.rnd.Intn(len(s.tables))]
    68  	indexes := s.indexes[*table.TableName]
    69  	var indexFlags tree.IndexFlags
    70  	if s.coin() {
    71  		indexNames := make([]tree.Name, 0, len(indexes))
    72  		for _, index := range indexes {
    73  			if !index.Inverted {
    74  				indexNames = append(indexNames, index.Name)
    75  			}
    76  		}
    77  		if len(indexNames) > 0 {
    78  			indexFlags.Index = tree.UnrestrictedName(indexNames[s.rnd.Intn(len(indexNames))])
    79  		}
    80  	}
    81  	aliased := &aliasedTableRef{
    82  		tableRef:   table,
    83  		indexFlags: &indexFlags,
    84  	}
    85  	return aliased, true
    86  }
    87  
    88  func (s *Smither) getRandTableIndex(
    89  	table, alias tree.TableName,
    90  ) (*tree.TableIndexName, *tree.CreateIndex, colRefs, bool) {
    91  	s.lock.RLock()
    92  	indexes := s.indexes[table]
    93  	s.lock.RUnlock()
    94  	if len(indexes) == 0 {
    95  		return nil, nil, nil, false
    96  	}
    97  	names := make([]tree.Name, 0, len(indexes))
    98  	for n := range indexes {
    99  		names = append(names, n)
   100  	}
   101  	idx := indexes[names[s.rnd.Intn(len(names))]]
   102  	var refs colRefs
   103  	s.lock.RLock()
   104  	defer s.lock.RUnlock()
   105  	for _, col := range idx.Columns {
   106  		refs = append(refs, &colRef{
   107  			typ:  tree.MustBeStaticallyKnownType(s.columns[table][col.Column].Type),
   108  			item: tree.NewColumnItem(&alias, col.Column),
   109  		})
   110  	}
   111  	return &tree.TableIndexName{
   112  		Table: alias,
   113  		Index: tree.UnrestrictedName(idx.Name),
   114  	}, idx, refs, true
   115  }
   116  
   117  func (s *Smither) getRandIndex() (*tree.TableIndexName, *tree.CreateIndex, colRefs, bool) {
   118  	tableRef, ok := s.getRandTable()
   119  	if !ok {
   120  		return nil, nil, nil, false
   121  	}
   122  	name := *tableRef.TableName
   123  	return s.getRandTableIndex(name, name)
   124  }
   125  
   126  func extractTables(db *gosql.DB) ([]*tableRef, error) {
   127  	rows, err := db.Query(`
   128  SELECT
   129  	table_catalog,
   130  	table_schema,
   131  	table_name,
   132  	column_name,
   133  	crdb_sql_type,
   134  	generation_expression != '' AS computed,
   135  	is_nullable = 'YES' AS nullable,
   136  	is_hidden = 'YES' AS hidden
   137  FROM
   138  	information_schema.columns
   139  WHERE
   140  	table_schema = 'public'
   141  ORDER BY
   142  	table_catalog, table_schema, table_name
   143  	`)
   144  	// TODO(justin): have a flag that includes system tables?
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	defer rows.Close()
   149  
   150  	// This is a little gross: we want to operate on each segment of the results
   151  	// that corresponds to a single table. We could maybe json_agg the results
   152  	// or something for a cleaner processing step?
   153  
   154  	firstTime := true
   155  	var lastCatalog, lastSchema, lastName tree.Name
   156  	var tables []*tableRef
   157  	var currentCols []*tree.ColumnTableDef
   158  	emit := func() error {
   159  		if lastSchema != "public" {
   160  			return nil
   161  		}
   162  		if len(currentCols) == 0 {
   163  			return fmt.Errorf("zero columns for %s.%s", lastCatalog, lastName)
   164  		}
   165  		tables = append(tables, &tableRef{
   166  			TableName: tree.NewTableName(lastCatalog, lastName),
   167  			Columns:   currentCols,
   168  		})
   169  		return nil
   170  	}
   171  	for rows.Next() {
   172  		var catalog, schema, name, col tree.Name
   173  		var typ string
   174  		var computed, nullable, hidden bool
   175  		if err := rows.Scan(&catalog, &schema, &name, &col, &typ, &computed, &nullable, &hidden); err != nil {
   176  			return nil, err
   177  		}
   178  		if hidden {
   179  			continue
   180  		}
   181  
   182  		if firstTime {
   183  			lastCatalog = catalog
   184  			lastSchema = schema
   185  			lastName = name
   186  		}
   187  		firstTime = false
   188  
   189  		if lastCatalog != catalog || lastSchema != schema || lastName != name {
   190  			if err := emit(); err != nil {
   191  				return nil, err
   192  			}
   193  			currentCols = nil
   194  		}
   195  
   196  		coltyp := typeFromName(typ)
   197  		column := tree.ColumnTableDef{
   198  			Name: col,
   199  			Type: coltyp,
   200  		}
   201  		if nullable {
   202  			column.Nullable.Nullability = tree.Null
   203  		}
   204  		if computed {
   205  			column.Computed.Computed = true
   206  		}
   207  		currentCols = append(currentCols, &column)
   208  		lastCatalog = catalog
   209  		lastSchema = schema
   210  		lastName = name
   211  	}
   212  	if !firstTime {
   213  		if err := emit(); err != nil {
   214  			return nil, err
   215  		}
   216  	}
   217  	return tables, rows.Err()
   218  }
   219  
   220  func extractIndexes(
   221  	db *gosql.DB, tables tableRefs,
   222  ) (map[tree.TableName]map[tree.Name]*tree.CreateIndex, error) {
   223  	ret := map[tree.TableName]map[tree.Name]*tree.CreateIndex{}
   224  
   225  	for _, t := range tables {
   226  		indexes := map[tree.Name]*tree.CreateIndex{}
   227  		// Ignore rowid indexes since those columns aren't known to
   228  		// sqlsmith.
   229  		rows, err := db.Query(fmt.Sprintf(`
   230  			SELECT
   231  			    index_name, column_name, storing, direction = 'ASC'
   232  			FROM
   233  			    [SHOW INDEXES FROM %s]
   234  			WHERE
   235  			    column_name != 'rowid'
   236  			`, t.TableName))
   237  		if err != nil {
   238  			return nil, err
   239  		}
   240  		for rows.Next() {
   241  			var idx, col tree.Name
   242  			var storing, ascending bool
   243  			if err := rows.Scan(&idx, &col, &storing, &ascending); err != nil {
   244  				rows.Close()
   245  				return nil, err
   246  			}
   247  			if _, ok := indexes[idx]; !ok {
   248  				indexes[idx] = &tree.CreateIndex{
   249  					Name:  idx,
   250  					Table: *t.TableName,
   251  				}
   252  			}
   253  			create := indexes[idx]
   254  			if storing {
   255  				create.Storing = append(create.Storing, col)
   256  			} else {
   257  				dir := tree.Ascending
   258  				if !ascending {
   259  					dir = tree.Descending
   260  				}
   261  				create.Columns = append(create.Columns, tree.IndexElem{
   262  					Column:    col,
   263  					Direction: dir,
   264  				})
   265  			}
   266  			row := db.QueryRow(fmt.Sprintf(`
   267  			SELECT
   268  			    is_inverted
   269  			FROM
   270  			    crdb_internal.table_indexes
   271  			WHERE
   272  			    descriptor_name = '%s' AND index_name = '%s'
   273  `, t.TableName.Table(), idx))
   274  			var isInverted bool
   275  			if err = row.Scan(&isInverted); err != nil {
   276  				// We got an error which likely indicates that 'is_inverted' column is
   277  				// not present in crdb_internal.table_indexes vtable (probably because
   278  				// we're running 19.2 version). We will use a heuristic to determine
   279  				// whether the index is inverted.
   280  				isInverted = strings.Contains(strings.ToLower(idx.String()), "jsonb")
   281  			}
   282  			indexes[idx].Inverted = isInverted
   283  		}
   284  		rows.Close()
   285  		if err := rows.Err(); err != nil {
   286  			return nil, err
   287  		}
   288  		ret[*t.TableName] = indexes
   289  	}
   290  	return ret, nil
   291  }
   292  
   293  type operator struct {
   294  	*tree.BinOp
   295  	Operator tree.BinaryOperator
   296  }
   297  
   298  var operators = func() map[oid.Oid][]operator {
   299  	m := map[oid.Oid][]operator{}
   300  	for BinaryOperator, overload := range tree.BinOps {
   301  		for _, ov := range overload {
   302  			bo := ov.(*tree.BinOp)
   303  			m[bo.ReturnType.Oid()] = append(m[bo.ReturnType.Oid()], operator{
   304  				BinOp:    bo,
   305  				Operator: BinaryOperator,
   306  			})
   307  		}
   308  	}
   309  	return m
   310  }()
   311  
   312  type function struct {
   313  	def      *tree.FunctionDefinition
   314  	overload *tree.Overload
   315  }
   316  
   317  var functions = func() map[tree.FunctionClass]map[oid.Oid][]function {
   318  	m := map[tree.FunctionClass]map[oid.Oid][]function{}
   319  	for _, def := range tree.FunDefs {
   320  		switch def.Name {
   321  		case "pg_sleep":
   322  			continue
   323  		}
   324  		if strings.Contains(def.Name, "crdb_internal.force_") {
   325  			continue
   326  		}
   327  		if _, ok := m[def.Class]; !ok {
   328  			m[def.Class] = map[oid.Oid][]function{}
   329  		}
   330  		// Ignore pg compat functions since many are unimplemented.
   331  		if def.Category == "Compatibility" {
   332  			continue
   333  		}
   334  		if def.Private {
   335  			continue
   336  		}
   337  		for _, ov := range def.Definition {
   338  			ov := ov.(*tree.Overload)
   339  			// Ignore documented unusable functions.
   340  			if strings.Contains(ov.Info, "Not usable") {
   341  				continue
   342  			}
   343  			typ := ov.FixedReturnType()
   344  			found := false
   345  			for _, scalarTyp := range types.Scalar {
   346  				if typ.Family() == scalarTyp.Family() {
   347  					found = true
   348  				}
   349  			}
   350  			if !found {
   351  				continue
   352  			}
   353  			m[def.Class][typ.Oid()] = append(m[def.Class][typ.Oid()], function{
   354  				def:      def,
   355  				overload: ov,
   356  			})
   357  		}
   358  	}
   359  	return m
   360  }()