github.com/dolthub/go-mysql-server@v0.18.0/sql/information_schema/columns_table.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package information_schema 16 17 import ( 18 "bytes" 19 "encoding/hex" 20 "fmt" 21 "sort" 22 "strconv" 23 "strings" 24 "time" 25 26 "github.com/dolthub/vitess/go/sqltypes" 27 "github.com/dolthub/vitess/go/vt/proto/query" 28 29 "github.com/dolthub/go-mysql-server/sql" 30 "github.com/dolthub/go-mysql-server/sql/mysql_db" 31 "github.com/dolthub/go-mysql-server/sql/transform" 32 "github.com/dolthub/go-mysql-server/sql/types" 33 ) 34 35 const defaultColumnsTableRowCount = 1000 36 37 var typeToNumericPrecision = map[query.Type]int{ 38 sqltypes.Int8: 3, 39 sqltypes.Uint8: 3, 40 sqltypes.Int16: 5, 41 sqltypes.Uint16: 5, 42 sqltypes.Int24: 7, 43 sqltypes.Uint24: 7, 44 sqltypes.Int32: 10, 45 sqltypes.Uint32: 10, 46 sqltypes.Int64: 19, 47 sqltypes.Uint64: 20, 48 sqltypes.Float32: 12, 49 sqltypes.Float64: 22, 50 } 51 52 // ColumnsTable describes the information_schema.columns table. It implements both sql.Node and sql.Table 53 // as way to handle resolving column defaults. 54 type ColumnsTable struct { 55 name string 56 schema sql.Schema 57 catalog sql.Catalog 58 // allColsWithDefaultValue is the full schema of all tables in all databases. We need this during analysis in order 59 // to resolve the default values of some columns, so we pre-compute it. 60 allColsWithDefaultValue sql.Schema 61 62 rowIter func(*sql.Context, sql.Catalog, sql.Schema) (sql.RowIter, error) 63 } 64 65 var _ sql.Table = (*ColumnsTable)(nil) 66 var _ sql.StatisticsTable = (*ColumnsTable)(nil) 67 var _ sql.Databaseable = (*ColumnsTable)(nil) 68 var _ sql.DynamicColumnsTable = (*ColumnsTable)(nil) 69 70 // String implements the sql.Table interface. 71 func (c *ColumnsTable) String() string { 72 return printTable(ColumnsTableName, columnsSchema) 73 } 74 75 // Schema implements the sql.Table interface. 76 func (c *ColumnsTable) Schema() sql.Schema { 77 return columnsSchema 78 } 79 80 // Collation implements the sql.Table interface. 81 func (c *ColumnsTable) Collation() sql.CollationID { 82 return sql.Collation_Information_Schema_Default 83 } 84 85 // Name implements the sql.Table interface. 86 func (c *ColumnsTable) Name() string { 87 return ColumnsTableName 88 } 89 90 // Database implements the sql.Databaseable interface. 91 func (c *ColumnsTable) Database() string { 92 return sql.InformationSchemaDatabaseName 93 } 94 95 func (c *ColumnsTable) DataLength(_ *sql.Context) (uint64, error) { 96 return uint64(len(c.Schema()) * int(types.Text.MaxByteLength()) * defaultColumnsTableRowCount), nil 97 } 98 99 func (c *ColumnsTable) RowCount(ctx *sql.Context) (uint64, bool, error) { 100 return defaultColumnsTableRowCount, false, nil 101 } 102 103 func (c *ColumnsTable) AssignCatalog(cat sql.Catalog) sql.Table { 104 c.catalog = cat 105 return c 106 } 107 108 // Partitions implements the sql.Table interface. 109 func (c *ColumnsTable) Partitions(context *sql.Context) (sql.PartitionIter, error) { 110 return &informationSchemaPartitionIter{informationSchemaPartition: informationSchemaPartition{partitionKey(c.Name())}}, nil 111 } 112 113 // PartitionRows implements the sql.Table interface. 114 func (c *ColumnsTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) { 115 if !bytes.Equal(partition.Key(), partitionKey(c.Name())) { 116 return nil, sql.ErrPartitionNotFound.New(partition.Key()) 117 } 118 119 if c.catalog == nil { 120 return nil, fmt.Errorf("nil catalog for info schema table %s", c.Name()) 121 } 122 123 return columnsRowIter(context, c.catalog, c.allColsWithDefaultValue) 124 } 125 func (c *ColumnsTable) HasDynamicColumns() bool { 126 return true 127 } 128 129 // AllColumns returns all columns in the catalog, renamed to reflect their database and table names 130 func (c *ColumnsTable) AllColumns(ctx *sql.Context) (sql.Schema, error) { 131 if len(c.allColsWithDefaultValue) > 0 { 132 return c.allColsWithDefaultValue, nil 133 } 134 135 if c.catalog == nil { 136 return nil, fmt.Errorf("nil catalog for info schema table %s", c.Name()) 137 } 138 139 var allColumns sql.Schema 140 141 for _, db := range c.catalog.AllDatabases(ctx) { 142 err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) { 143 tableSch := t.Schema() 144 for i := range tableSch { 145 newCol := tableSch[i].Copy() 146 newCol.DatabaseSource = db.Name() 147 allColumns = append(allColumns, newCol) 148 } 149 return true, nil 150 }) 151 152 if err != nil { 153 return nil, err 154 } 155 } 156 157 c.allColsWithDefaultValue = allColumns 158 return c.allColsWithDefaultValue, nil 159 } 160 161 func (c ColumnsTable) WithColumnDefaults(columnDefaults []sql.Expression) (sql.Table, error) { 162 if c.allColsWithDefaultValue == nil { 163 return nil, fmt.Errorf("WithColumnDefaults called with nil columns for table %s", c.Name()) 164 } 165 166 if len(columnDefaults) != len(c.allColsWithDefaultValue) { 167 return nil, sql.ErrInvalidChildrenNumber.New(c, len(columnDefaults), len(c.allColsWithDefaultValue)) 168 } 169 170 sch, err := transform.SchemaWithDefaults(c.allColsWithDefaultValue, columnDefaults) 171 if err != nil { 172 return nil, err 173 } 174 175 c.allColsWithDefaultValue = sch 176 return &c, nil 177 } 178 179 func (c ColumnsTable) WithDefaultsSchema(sch sql.Schema) (sql.Table, error) { 180 if c.allColsWithDefaultValue == nil { 181 return nil, fmt.Errorf("WithColumnDefaults called with nil columns for table %s", c.Name()) 182 } 183 184 if len(sch) != len(c.allColsWithDefaultValue) { 185 return nil, sql.ErrInvalidChildrenNumber.New(c, len(sch), len(c.allColsWithDefaultValue)) 186 } 187 188 // TODO: generated values 189 for i, col := range sch { 190 c.allColsWithDefaultValue[i].Default = col.Default 191 } 192 return &c, nil 193 } 194 195 // columnsRowIter implements the custom sql.RowIter for the information_schema.columns table. 196 func columnsRowIter(ctx *sql.Context, catalog sql.Catalog, allColsWithDefaultValue sql.Schema) (sql.RowIter, error) { 197 var ( 198 rows []sql.Row 199 globalPrivSetMap = make(map[string]struct{}) 200 ) 201 202 privSet, _ := ctx.GetPrivilegeSet() 203 if privSet == nil { 204 privSet = mysql_db.NewPrivilegeSet() 205 } 206 globalPrivSetMap = getCurrentPrivSetMapForColumn(privSet.ToSlice(), globalPrivSetMap) 207 208 for _, db := range catalog.AllDatabases(ctx) { 209 rs, err := getRowsFromDatabase(ctx, db, privSet, globalPrivSetMap, allColsWithDefaultValue) 210 if err != nil { 211 return nil, err 212 } 213 rows = append(rows, rs...) 214 215 rs, err = getRowsFromViews(ctx, db) 216 if err != nil { 217 return nil, err 218 } 219 rows = append(rows, rs...) 220 } 221 return sql.RowsToRowIter(rows...), nil 222 } 223 224 // getRowFromColumn returns a single row for given column. The arguments passed are used to define all row values. 225 // These include the current ordinal position, so this column will get the next position number, sql.Column object, 226 // database name, table name, column key and column privileges information through privileges set for the table. 227 func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName, tblName, columnKey string, privSetTbl sql.PrivilegeSetTable, privSetMap map[string]struct{}) sql.Row { 228 var ( 229 ordinalPos = uint32(curOrdPos + 1) 230 nullable = "NO" 231 datetimePrecision interface{} 232 srsId interface{} 233 ) 234 235 colType, dataType := getDtdIdAndDataType(col.Type) 236 237 if col.Nullable { 238 nullable = "YES" 239 } 240 241 if s, ok := col.Type.(sql.SpatialColumnType); ok { 242 if srid, d := s.GetSpatialTypeSRID(); d { 243 srsId = srid 244 } 245 } 246 247 charName, collName, charMaxLen, charOctetLen := getCharAndCollNamesAndCharMaxAndOctetLens(ctx, col.Type) 248 249 numericPrecision, numericScale := getColumnPrecisionAndScale(col.Type) 250 if types.IsDatetimeType(col.Type) || types.IsTimestampType(col.Type) { 251 datetimePrecision = 0 252 } else if types.IsTimespan(col.Type) { 253 // TODO: TIME length not yet supported 254 datetimePrecision = 6 255 } 256 257 columnDefault := getColumnDefault(ctx, col.Default) 258 259 extra := col.Extra 260 // If extra is not defined, fill it here. 261 if extra == "" && !col.Default.IsLiteral() { 262 extra = "DEFAULT_GENERATED" 263 } 264 265 var curColPrivStr []string 266 for p := range privSetMap { 267 curColPrivStr = append(curColPrivStr, p) 268 } 269 270 privSetCol := privSetTbl.Column(col.Name) 271 for _, pt := range privSetCol.ToSlice() { 272 priv := strings.ToLower(pt.String()) 273 if _, ok := privSetMap[priv]; !ok { 274 curColPrivStr = append(curColPrivStr, priv) 275 } 276 } 277 278 sort.Strings(curColPrivStr) 279 privileges := strings.Join(curColPrivStr, ",") 280 281 return sql.Row{ 282 "def", // table_catalog 283 dbName, // table_schema 284 tblName, // table_name 285 col.Name, // column_name 286 ordinalPos, // ordinal_position 287 columnDefault, // column_default 288 nullable, // is_nullable 289 dataType, // data_type 290 charMaxLen, // character_maximum_length 291 charOctetLen, // character_octet_length 292 numericPrecision, // numeric_precision 293 numericScale, // numeric_scale 294 datetimePrecision, // datetime_precision 295 charName, // character_set_name 296 collName, // collation_name 297 colType, // column_type 298 columnKey, // column_key 299 extra, // extra 300 privileges, // privileges 301 col.Comment, // column_comment 302 "", // generation_expression 303 srsId, // srs_id 304 } 305 } 306 307 // getRowsFromTable returns array of rows for all accessible columns of the given table. 308 func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb sql.PrivilegeSetDatabase, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) { 309 var rows []sql.Row 310 311 privSetTbl := privSetDb.Table(t.Name()) 312 curPrivSetMap := getCurrentPrivSetMapForColumn(privSetTbl.ToSlice(), privSetMap) 313 314 columnKeyMap, hasPK, err := getIndexKeyInfo(ctx, t) 315 if err != nil { 316 return nil, err 317 } 318 319 tblName := t.Name() 320 for i, col := range schemaForTable(t, db, allColsWithDefaultValue) { 321 var columnKey string 322 // Check column PK here first because there are PKs from table implementations that don't implement sql.IndexedTable 323 if col.PrimaryKey { 324 columnKey = "PRI" 325 } else if val, ok := columnKeyMap[col.Name]; ok { 326 columnKey = val 327 // A UNIQUE index may be displayed as PRI if it cannot contain NULL values and there is no PRIMARY KEY in the table 328 if !col.Nullable && !hasPK && columnKey == "UNI" { 329 columnKey = "PRI" 330 hasPK = true 331 } 332 } 333 334 r := getRowFromColumn(ctx, i, col, db.Name(), tblName, columnKey, privSetTbl, curPrivSetMap) 335 if r != nil { 336 rows = append(rows, r) 337 } 338 } 339 340 return rows, nil 341 } 342 343 // getRowsFromViews returns array or rows for columns for all views for given database. 344 func getRowsFromViews(ctx *sql.Context, db sql.Database) ([]sql.Row, error) { 345 var rows []sql.Row 346 // TODO: View Definition is lacking information to properly fill out these table 347 // TODO: Should somehow get reference to table(s) view is referencing 348 // TODO: Each column that view references should also show up as unique entries as well 349 views, err := viewsInDatabase(ctx, db) 350 if err != nil { 351 return nil, err 352 } 353 354 for _, view := range views { 355 rows = append(rows, sql.Row{ 356 "def", // table_catalog 357 db.Name(), // table_schema 358 view.Name, // table_name 359 "", // column_name 360 uint32(0), // ordinal_position 361 nil, // column_default 362 "", // is_nullable 363 nil, // data_type 364 nil, // character_maximum_length 365 nil, // character_octet_length 366 nil, // numeric_precision 367 nil, // numeric_scale 368 nil, // datetime_precision 369 "", // character_set_name 370 "", // collation_name 371 "", // column_type 372 "", // column_key 373 "", // extra 374 "select", // privileges 375 "", // column_comment 376 "", // generation_expression 377 nil, // srs_id 378 }) 379 } 380 381 return rows, nil 382 } 383 384 // getRowsFromDatabase returns array of rows for all accessible columns of accessible table of the given database. 385 func getRowsFromDatabase(ctx *sql.Context, db sql.Database, privSet sql.PrivilegeSet, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) { 386 var rows []sql.Row 387 dbName := db.Name() 388 389 privSetDb := privSet.Database(dbName) 390 curPrivSetMap := getCurrentPrivSetMapForColumn(privSetDb.ToSlice(), privSetMap) 391 if dbName == sql.InformationSchemaDatabaseName { 392 curPrivSetMap["select"] = struct{}{} 393 } 394 395 err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) { 396 rs, err := getRowsFromTable(ctx, db, t, privSetDb, curPrivSetMap, allColsWithDefaultValue) 397 if err != nil { 398 return false, err 399 } 400 rows = append(rows, rs...) 401 return true, nil 402 }) 403 if err != nil { 404 return nil, err 405 } 406 407 return rows, nil 408 } 409 410 // getCurrentPrivSetMapForColumn returns a new privilege set map that contains what the given privilege set map has, 411 // and it adds any available privileges from given array of privilege type. For example, the given privilege set map 412 // may contain general privilege types for the database only, and the given array of privilege type will contain all 413 // privilege types defined for the table specifically. This function only add `select`, `insert`, `update` and 414 // `references` privileges to the new privilege set map if available. These are column level privileges only. 415 func getCurrentPrivSetMapForColumn(privs []sql.PrivilegeType, privSetMap map[string]struct{}) map[string]struct{} { 416 curPrivSetMap := make(map[string]struct{}) 417 for p := range privSetMap { 418 curPrivSetMap[p] = struct{}{} 419 } 420 for _, pt := range privs { 421 switch pt { 422 // columns can have 'select', 'insert', 'update', 'references' privileges only. 423 case sql.PrivilegeType_Select, sql.PrivilegeType_Insert, sql.PrivilegeType_Update, sql.PrivilegeType_References: 424 curPrivSetMap[strings.ToLower(pt.String())] = struct{}{} 425 } 426 } 427 return curPrivSetMap 428 } 429 430 // getIndexKeyInfo returns map of columns and its index information whether this column is PK or unique index, etc. 431 func getIndexKeyInfo(ctx *sql.Context, t sql.Table) (map[string]string, bool, error) { 432 var columnKeyMap = make(map[string]string) 433 // Get UNIQUEs, PRIMARY KEYs 434 hasPK := false 435 if indexTable, ok := t.(sql.IndexAddressable); ok { 436 indexes, iErr := indexTable.GetIndexes(ctx) 437 if iErr != nil { 438 return columnKeyMap, hasPK, iErr 439 } 440 441 for _, index := range indexes { 442 idx := "" 443 if index.ID() == "PRIMARY" { 444 idx = "PRI" 445 hasPK = true 446 } else if index.IsUnique() { 447 idx = "UNI" 448 } else { 449 idx = "MUL" 450 } 451 452 colNames := getColumnNamesFromIndex(index, t) 453 // A UNIQUE index may display as MUL if several columns form a composite UNIQUE index 454 if idx == "UNI" && len(colNames) > 1 { 455 idx = "MUL" 456 columnKeyMap[colNames[0]] = idx 457 } else { 458 for _, colName := range colNames { 459 columnKeyMap[colName] = idx 460 } 461 } 462 } 463 } 464 465 return columnKeyMap, hasPK, nil 466 } 467 468 // getColumnDefault returns the column default value for given sql.ColumnDefaultValue 469 func getColumnDefault(ctx *sql.Context, cd *sql.ColumnDefaultValue) interface{} { 470 if cd == nil { 471 return nil 472 } 473 474 defStr := cd.String() 475 if defStr == "NULL" { 476 return nil 477 } 478 479 if !cd.IsLiteral() { 480 if strings.HasPrefix(defStr, "(") && strings.HasSuffix(defStr, ")") { 481 defStr = strings.TrimSuffix(strings.TrimPrefix(defStr, "("), ")") 482 } 483 if types.IsTime(cd.Type()) && (strings.HasPrefix(defStr, "NOW") || strings.HasPrefix(defStr, "CURRENT_TIMESTAMP")) { 484 defStr = strings.Replace(defStr, "NOW", "CURRENT_TIMESTAMP", -1) 485 defStr = strings.TrimSuffix(defStr, "()") 486 } 487 return fmt.Sprint(defStr) 488 } 489 490 if types.IsEnum(cd.Type()) || types.IsSet(cd.Type()) { 491 return strings.Trim(defStr, "'") 492 } 493 494 v, err := cd.Eval(ctx, nil) 495 if err != nil { 496 return "" 497 } 498 499 switch l := v.(type) { 500 case time.Time: 501 v = l.Format("2006-01-02 15:04:05") 502 case []uint8: 503 hexStr := hex.EncodeToString(l) 504 v = fmt.Sprintf("0x%s", hexStr) 505 } 506 507 if types.IsBit(cd.Type()) { 508 if i, ok := v.(uint64); ok { 509 bitStr := strconv.FormatUint(i, 2) 510 v = fmt.Sprintf("b'%s'", bitStr) 511 } 512 } 513 514 return fmt.Sprint(v) 515 } 516 517 func schemaForTable(t sql.Table, db sql.Database, allColsWithDefaultValue sql.Schema) sql.Schema { 518 start, end := -1, -1 519 tableName := strings.ToLower(t.Name()) 520 521 for i, col := range allColsWithDefaultValue { 522 dbName := strings.ToLower(db.Name()) 523 if start < 0 && strings.ToLower(col.Source) == tableName && strings.ToLower(col.DatabaseSource) == dbName { 524 start = i 525 } else if start >= 0 && (strings.ToLower(col.Source) != tableName || strings.ToLower(col.DatabaseSource) != dbName) { 526 end = i 527 break 528 } 529 } 530 531 if start < 0 { 532 return nil 533 } 534 535 if end < 0 { 536 end = len(allColsWithDefaultValue) 537 } 538 539 return allColsWithDefaultValue[start:end] 540 } 541 542 // get DtdIdAndDataType returns data types for given sql.Type but in two different ways. 543 // The DTD_IDENTIFIER value contains the type name and possibly other information such as the precision or length. 544 // The DATA_TYPE value is the type name only with no other information. 545 func getDtdIdAndDataType(colType sql.Type) (string, string) { 546 dtdId := strings.Split(strings.Split(colType.String(), " COLLATE")[0], " CHARACTER SET")[0] 547 548 // The DATA_TYPE value is the type name only with no other information 549 dataType := strings.Split(dtdId, "(")[0] 550 dataType = strings.Split(dataType, " ")[0] 551 552 return dtdId, dataType 553 } 554 555 // getColumnPrecisionAndScale returns the precision or a number of mysql type. For non-numeric or decimal types this 556 // function should return nil,nil. 557 func getColumnPrecisionAndScale(colType sql.Type) (interface{}, interface{}) { 558 var numericScale interface{} 559 switch t := colType.(type) { 560 case types.BitType: 561 return int(t.NumberOfBits()), numericScale 562 case sql.DecimalType: 563 return int(t.Precision()), int(t.Scale()) 564 case sql.NumberType: 565 switch colType.Type() { 566 case sqltypes.Float32, sqltypes.Float64: 567 numericScale = nil 568 default: 569 numericScale = 0 570 } 571 return typeToNumericPrecision[colType.Type()], numericScale 572 default: 573 return nil, nil 574 } 575 } 576 577 func getCharAndCollNamesAndCharMaxAndOctetLens(ctx *sql.Context, colType sql.Type) (interface{}, interface{}, interface{}, interface{}) { 578 var ( 579 charName interface{} 580 collName interface{} 581 charMaxLen interface{} 582 charOctetLen interface{} 583 ) 584 if twc, ok := colType.(sql.TypeWithCollation); ok && !types.IsBinaryType(colType) { 585 colColl := twc.Collation() 586 collName = colColl.Name() 587 charName = colColl.CharacterSet().String() 588 if types.IsEnum(colType) || types.IsSet(colType) { 589 charOctetLen = int64(colType.MaxTextResponseByteLength(ctx)) 590 charMaxLen = int64(colType.MaxTextResponseByteLength(ctx)) / colColl.CharacterSet().MaxLength() 591 } 592 } 593 if st, ok := colType.(sql.StringType); ok { 594 charMaxLen = st.MaxCharacterLength() 595 charOctetLen = st.MaxByteLength() 596 } 597 598 return charName, collName, charMaxLen, charOctetLen 599 }