github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/internal/sqlsmith/schema.go (about) 1 // Copyright 2019 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package sqlsmith 12 13 import ( 14 gosql "database/sql" 15 "fmt" 16 "strings" 17 18 // Import builtins so they are reflected in tree.FunDefs. 19 _ "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins" 20 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 21 "github.com/cockroachdb/cockroach/pkg/sql/types" 22 "github.com/lib/pq/oid" 23 ) 24 25 // tableRef represents a table and its columns. 26 type tableRef struct { 27 TableName *tree.TableName 28 Columns []*tree.ColumnTableDef 29 } 30 31 type aliasedTableRef struct { 32 *tableRef 33 indexFlags *tree.IndexFlags 34 } 35 36 type tableRefs []*tableRef 37 38 // ReloadSchemas loads tables from the database. 39 func (s *Smither) ReloadSchemas() error { 40 if s.db == nil { 41 return nil 42 } 43 s.lock.Lock() 44 defer s.lock.Unlock() 45 var err error 46 s.tables, err = extractTables(s.db) 47 if err != nil { 48 return err 49 } 50 s.indexes, err = extractIndexes(s.db, s.tables) 51 s.columns = make(map[tree.TableName]map[tree.Name]*tree.ColumnTableDef) 52 for _, ref := range s.tables { 53 s.columns[*ref.TableName] = make(map[tree.Name]*tree.ColumnTableDef) 54 for _, col := range ref.Columns { 55 s.columns[*ref.TableName][col.Name] = col 56 } 57 } 58 return err 59 } 60 61 func (s *Smither) getRandTable() (*aliasedTableRef, bool) { 62 s.lock.RLock() 63 defer s.lock.RUnlock() 64 if len(s.tables) == 0 { 65 return nil, false 66 } 67 table := s.tables[s.rnd.Intn(len(s.tables))] 68 indexes := s.indexes[*table.TableName] 69 var indexFlags tree.IndexFlags 70 if s.coin() { 71 indexNames := make([]tree.Name, 0, len(indexes)) 72 for _, index := range indexes { 73 if !index.Inverted { 74 indexNames = append(indexNames, index.Name) 75 } 76 } 77 if len(indexNames) > 0 { 78 indexFlags.Index = tree.UnrestrictedName(indexNames[s.rnd.Intn(len(indexNames))]) 79 } 80 } 81 aliased := &aliasedTableRef{ 82 tableRef: table, 83 indexFlags: &indexFlags, 84 } 85 return aliased, true 86 } 87 88 func (s *Smither) getRandTableIndex( 89 table, alias tree.TableName, 90 ) (*tree.TableIndexName, *tree.CreateIndex, colRefs, bool) { 91 s.lock.RLock() 92 indexes := s.indexes[table] 93 s.lock.RUnlock() 94 if len(indexes) == 0 { 95 return nil, nil, nil, false 96 } 97 names := make([]tree.Name, 0, len(indexes)) 98 for n := range indexes { 99 names = append(names, n) 100 } 101 idx := indexes[names[s.rnd.Intn(len(names))]] 102 var refs colRefs 103 s.lock.RLock() 104 defer s.lock.RUnlock() 105 for _, col := range idx.Columns { 106 refs = append(refs, &colRef{ 107 typ: tree.MustBeStaticallyKnownType(s.columns[table][col.Column].Type), 108 item: tree.NewColumnItem(&alias, col.Column), 109 }) 110 } 111 return &tree.TableIndexName{ 112 Table: alias, 113 Index: tree.UnrestrictedName(idx.Name), 114 }, idx, refs, true 115 } 116 117 func (s *Smither) getRandIndex() (*tree.TableIndexName, *tree.CreateIndex, colRefs, bool) { 118 tableRef, ok := s.getRandTable() 119 if !ok { 120 return nil, nil, nil, false 121 } 122 name := *tableRef.TableName 123 return s.getRandTableIndex(name, name) 124 } 125 126 func extractTables(db *gosql.DB) ([]*tableRef, error) { 127 rows, err := db.Query(` 128 SELECT 129 table_catalog, 130 table_schema, 131 table_name, 132 column_name, 133 crdb_sql_type, 134 generation_expression != '' AS computed, 135 is_nullable = 'YES' AS nullable, 136 is_hidden = 'YES' AS hidden 137 FROM 138 information_schema.columns 139 WHERE 140 table_schema = 'public' 141 ORDER BY 142 table_catalog, table_schema, table_name 143 `) 144 // TODO(justin): have a flag that includes system tables? 145 if err != nil { 146 return nil, err 147 } 148 defer rows.Close() 149 150 // This is a little gross: we want to operate on each segment of the results 151 // that corresponds to a single table. We could maybe json_agg the results 152 // or something for a cleaner processing step? 153 154 firstTime := true 155 var lastCatalog, lastSchema, lastName tree.Name 156 var tables []*tableRef 157 var currentCols []*tree.ColumnTableDef 158 emit := func() error { 159 if lastSchema != "public" { 160 return nil 161 } 162 if len(currentCols) == 0 { 163 return fmt.Errorf("zero columns for %s.%s", lastCatalog, lastName) 164 } 165 tables = append(tables, &tableRef{ 166 TableName: tree.NewTableName(lastCatalog, lastName), 167 Columns: currentCols, 168 }) 169 return nil 170 } 171 for rows.Next() { 172 var catalog, schema, name, col tree.Name 173 var typ string 174 var computed, nullable, hidden bool 175 if err := rows.Scan(&catalog, &schema, &name, &col, &typ, &computed, &nullable, &hidden); err != nil { 176 return nil, err 177 } 178 if hidden { 179 continue 180 } 181 182 if firstTime { 183 lastCatalog = catalog 184 lastSchema = schema 185 lastName = name 186 } 187 firstTime = false 188 189 if lastCatalog != catalog || lastSchema != schema || lastName != name { 190 if err := emit(); err != nil { 191 return nil, err 192 } 193 currentCols = nil 194 } 195 196 coltyp := typeFromName(typ) 197 column := tree.ColumnTableDef{ 198 Name: col, 199 Type: coltyp, 200 } 201 if nullable { 202 column.Nullable.Nullability = tree.Null 203 } 204 if computed { 205 column.Computed.Computed = true 206 } 207 currentCols = append(currentCols, &column) 208 lastCatalog = catalog 209 lastSchema = schema 210 lastName = name 211 } 212 if !firstTime { 213 if err := emit(); err != nil { 214 return nil, err 215 } 216 } 217 return tables, rows.Err() 218 } 219 220 func extractIndexes( 221 db *gosql.DB, tables tableRefs, 222 ) (map[tree.TableName]map[tree.Name]*tree.CreateIndex, error) { 223 ret := map[tree.TableName]map[tree.Name]*tree.CreateIndex{} 224 225 for _, t := range tables { 226 indexes := map[tree.Name]*tree.CreateIndex{} 227 // Ignore rowid indexes since those columns aren't known to 228 // sqlsmith. 229 rows, err := db.Query(fmt.Sprintf(` 230 SELECT 231 index_name, column_name, storing, direction = 'ASC' 232 FROM 233 [SHOW INDEXES FROM %s] 234 WHERE 235 column_name != 'rowid' 236 `, t.TableName)) 237 if err != nil { 238 return nil, err 239 } 240 for rows.Next() { 241 var idx, col tree.Name 242 var storing, ascending bool 243 if err := rows.Scan(&idx, &col, &storing, &ascending); err != nil { 244 rows.Close() 245 return nil, err 246 } 247 if _, ok := indexes[idx]; !ok { 248 indexes[idx] = &tree.CreateIndex{ 249 Name: idx, 250 Table: *t.TableName, 251 } 252 } 253 create := indexes[idx] 254 if storing { 255 create.Storing = append(create.Storing, col) 256 } else { 257 dir := tree.Ascending 258 if !ascending { 259 dir = tree.Descending 260 } 261 create.Columns = append(create.Columns, tree.IndexElem{ 262 Column: col, 263 Direction: dir, 264 }) 265 } 266 row := db.QueryRow(fmt.Sprintf(` 267 SELECT 268 is_inverted 269 FROM 270 crdb_internal.table_indexes 271 WHERE 272 descriptor_name = '%s' AND index_name = '%s' 273 `, t.TableName.Table(), idx)) 274 var isInverted bool 275 if err = row.Scan(&isInverted); err != nil { 276 // We got an error which likely indicates that 'is_inverted' column is 277 // not present in crdb_internal.table_indexes vtable (probably because 278 // we're running 19.2 version). We will use a heuristic to determine 279 // whether the index is inverted. 280 isInverted = strings.Contains(strings.ToLower(idx.String()), "jsonb") 281 } 282 indexes[idx].Inverted = isInverted 283 } 284 rows.Close() 285 if err := rows.Err(); err != nil { 286 return nil, err 287 } 288 ret[*t.TableName] = indexes 289 } 290 return ret, nil 291 } 292 293 type operator struct { 294 *tree.BinOp 295 Operator tree.BinaryOperator 296 } 297 298 var operators = func() map[oid.Oid][]operator { 299 m := map[oid.Oid][]operator{} 300 for BinaryOperator, overload := range tree.BinOps { 301 for _, ov := range overload { 302 bo := ov.(*tree.BinOp) 303 m[bo.ReturnType.Oid()] = append(m[bo.ReturnType.Oid()], operator{ 304 BinOp: bo, 305 Operator: BinaryOperator, 306 }) 307 } 308 } 309 return m 310 }() 311 312 type function struct { 313 def *tree.FunctionDefinition 314 overload *tree.Overload 315 } 316 317 var functions = func() map[tree.FunctionClass]map[oid.Oid][]function { 318 m := map[tree.FunctionClass]map[oid.Oid][]function{} 319 for _, def := range tree.FunDefs { 320 switch def.Name { 321 case "pg_sleep": 322 continue 323 } 324 if strings.Contains(def.Name, "crdb_internal.force_") { 325 continue 326 } 327 if _, ok := m[def.Class]; !ok { 328 m[def.Class] = map[oid.Oid][]function{} 329 } 330 // Ignore pg compat functions since many are unimplemented. 331 if def.Category == "Compatibility" { 332 continue 333 } 334 if def.Private { 335 continue 336 } 337 for _, ov := range def.Definition { 338 ov := ov.(*tree.Overload) 339 // Ignore documented unusable functions. 340 if strings.Contains(ov.Info, "Not usable") { 341 continue 342 } 343 typ := ov.FixedReturnType() 344 found := false 345 for _, scalarTyp := range types.Scalar { 346 if typ.Family() == scalarTyp.Family() { 347 found = true 348 } 349 } 350 if !found { 351 continue 352 } 353 m[def.Class][typ.Oid()] = append(m[def.Class][typ.Oid()], function{ 354 def: def, 355 overload: ov, 356 }) 357 } 358 } 359 return m 360 }()