
     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package information_schema
    17  import (
    18  	"bytes"
    19  	"encoding/hex"
    20  	"fmt"
    21  	"sort"
    22  	"strconv"
    23  	"strings"
    24  	"time"
    26  	""
    27  	""
    29  	""
    30  	""
    31  	""
    32  	""
    33  )
    35  const defaultColumnsTableRowCount = 1000
    37  var typeToNumericPrecision = map[query.Type]int{
    38  	sqltypes.Int8:    3,
    39  	sqltypes.Uint8:   3,
    40  	sqltypes.Int16:   5,
    41  	sqltypes.Uint16:  5,
    42  	sqltypes.Int24:   7,
    43  	sqltypes.Uint24:  7,
    44  	sqltypes.Int32:   10,
    45  	sqltypes.Uint32:  10,
    46  	sqltypes.Int64:   19,
    47  	sqltypes.Uint64:  20,
    48  	sqltypes.Float32: 12,
    49  	sqltypes.Float64: 22,
    50  }
    52  // ColumnsTable describes the information_schema.columns table. It implements both sql.Node and sql.Table
    53  // as way to handle resolving column defaults.
    54  type ColumnsTable struct {
    55  	name    string
    56  	schema  sql.Schema
    57  	catalog sql.Catalog
    58  	// allColsWithDefaultValue is the full schema of all tables in all databases. We need this during analysis in order
    59  	// to resolve the default values of some columns, so we pre-compute it.
    60  	allColsWithDefaultValue sql.Schema
    62  	rowIter func(*sql.Context, sql.Catalog, sql.Schema) (sql.RowIter, error)
    63  }
    65  var _ sql.Table = (*ColumnsTable)(nil)
    66  var _ sql.StatisticsTable = (*ColumnsTable)(nil)
    67  var _ sql.Databaseable = (*ColumnsTable)(nil)
    68  var _ sql.DynamicColumnsTable = (*ColumnsTable)(nil)
    70  // String implements the sql.Table interface.
    71  func (c *ColumnsTable) String() string {
    72  	return printTable(ColumnsTableName, columnsSchema)
    73  }
    75  // Schema implements the sql.Table interface.
    76  func (c *ColumnsTable) Schema() sql.Schema {
    77  	return columnsSchema
    78  }
    80  // Collation implements the sql.Table interface.
    81  func (c *ColumnsTable) Collation() sql.CollationID {
    82  	return sql.Collation_Information_Schema_Default
    83  }
    85  // Name implements the sql.Table interface.
    86  func (c *ColumnsTable) Name() string {
    87  	return ColumnsTableName
    88  }
    90  // Database implements the sql.Databaseable interface.
    91  func (c *ColumnsTable) Database() string {
    92  	return sql.InformationSchemaDatabaseName
    93  }
    95  func (c *ColumnsTable) DataLength(_ *sql.Context) (uint64, error) {
    96  	return uint64(len(c.Schema()) * int(types.Text.MaxByteLength()) * defaultColumnsTableRowCount), nil
    97  }
    99  func (c *ColumnsTable) RowCount(ctx *sql.Context) (uint64, bool, error) {
   100  	return defaultColumnsTableRowCount, false, nil
   101  }
   103  func (c *ColumnsTable) AssignCatalog(cat sql.Catalog) sql.Table {
   104  	c.catalog = cat
   105  	return c
   106  }
   108  // Partitions implements the sql.Table interface.
   109  func (c *ColumnsTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
   110  	return &informationSchemaPartitionIter{informationSchemaPartition: informationSchemaPartition{partitionKey(c.Name())}}, nil
   111  }
   113  // PartitionRows implements the sql.Table interface.
   114  func (c *ColumnsTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
   115  	if !bytes.Equal(partition.Key(), partitionKey(c.Name())) {
   116  		return nil, sql.ErrPartitionNotFound.New(partition.Key())
   117  	}
   119  	if c.catalog == nil {
   120  		return nil, fmt.Errorf("nil catalog for info schema table %s", c.Name())
   121  	}
   123  	return columnsRowIter(context, c.catalog, c.allColsWithDefaultValue)
   124  }
   125  func (c *ColumnsTable) HasDynamicColumns() bool {
   126  	return true
   127  }
   129  // AllColumns returns all columns in the catalog, renamed to reflect their database and table names
   130  func (c *ColumnsTable) AllColumns(ctx *sql.Context) (sql.Schema, error) {
   131  	if len(c.allColsWithDefaultValue) > 0 {
   132  		return c.allColsWithDefaultValue, nil
   133  	}
   135  	if c.catalog == nil {
   136  		return nil, fmt.Errorf("nil catalog for info schema table %s", c.Name())
   137  	}
   139  	var allColumns sql.Schema
   141  	for _, db := range c.catalog.AllDatabases(ctx) {
   142  		err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) {
   143  			tableSch := t.Schema()
   144  			for i := range tableSch {
   145  				newCol := tableSch[i].Copy()
   146  				newCol.DatabaseSource = db.Name()
   147  				allColumns = append(allColumns, newCol)
   148  			}
   149  			return true, nil
   150  		})
   152  		if err != nil {
   153  			return nil, err
   154  		}
   155  	}
   157  	c.allColsWithDefaultValue = allColumns
   158  	return c.allColsWithDefaultValue, nil
   159  }
   161  func (c ColumnsTable) WithColumnDefaults(columnDefaults []sql.Expression) (sql.Table, error) {
   162  	if c.allColsWithDefaultValue == nil {
   163  		return nil, fmt.Errorf("WithColumnDefaults called with nil columns for table %s", c.Name())
   164  	}
   166  	if len(columnDefaults) != len(c.allColsWithDefaultValue) {
   167  		return nil, sql.ErrInvalidChildrenNumber.New(c, len(columnDefaults), len(c.allColsWithDefaultValue))
   168  	}
   170  	sch, err := transform.SchemaWithDefaults(c.allColsWithDefaultValue, columnDefaults)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   175  	c.allColsWithDefaultValue = sch
   176  	return &c, nil
   177  }
   179  func (c ColumnsTable) WithDefaultsSchema(sch sql.Schema) (sql.Table, error) {
   180  	if c.allColsWithDefaultValue == nil {
   181  		return nil, fmt.Errorf("WithColumnDefaults called with nil columns for table %s", c.Name())
   182  	}
   184  	if len(sch) != len(c.allColsWithDefaultValue) {
   185  		return nil, sql.ErrInvalidChildrenNumber.New(c, len(sch), len(c.allColsWithDefaultValue))
   186  	}
   188  	// TODO: generated values
   189  	for i, col := range sch {
   190  		c.allColsWithDefaultValue[i].Default = col.Default
   191  	}
   192  	return &c, nil
   193  }
   195  // columnsRowIter implements the custom sql.RowIter for the information_schema.columns table.
   196  func columnsRowIter(ctx *sql.Context, catalog sql.Catalog, allColsWithDefaultValue sql.Schema) (sql.RowIter, error) {
   197  	var (
   198  		rows             []sql.Row
   199  		globalPrivSetMap = make(map[string]struct{})
   200  	)
   202  	privSet, _ := ctx.GetPrivilegeSet()
   203  	if privSet == nil {
   204  		privSet = mysql_db.NewPrivilegeSet()
   205  	}
   206  	globalPrivSetMap = getCurrentPrivSetMapForColumn(privSet.ToSlice(), globalPrivSetMap)
   208  	for _, db := range catalog.AllDatabases(ctx) {
   209  		rs, err := getRowsFromDatabase(ctx, db, privSet, globalPrivSetMap, allColsWithDefaultValue)
   210  		if err != nil {
   211  			return nil, err
   212  		}
   213  		rows = append(rows, rs...)
   215  		rs, err = getRowsFromViews(ctx, db)
   216  		if err != nil {
   217  			return nil, err
   218  		}
   219  		rows = append(rows, rs...)
   220  	}
   221  	return sql.RowsToRowIter(rows...), nil
   222  }
   224  // getRowFromColumn returns a single row for given column. The arguments passed are used to define all row values.
   225  // These include the current ordinal position, so this column will get the next position number, sql.Column object,
   226  // database name, table name, column key and column privileges information through privileges set for the table.
   227  func getRowFromColumn(ctx *sql.Context, curOrdPos int, col *sql.Column, dbName, tblName, columnKey string, privSetTbl sql.PrivilegeSetTable, privSetMap map[string]struct{}) sql.Row {
   228  	var (
   229  		ordinalPos        = uint32(curOrdPos + 1)
   230  		nullable          = "NO"
   231  		datetimePrecision interface{}
   232  		srsId             interface{}
   233  	)
   235  	colType, dataType := getDtdIdAndDataType(col.Type)
   237  	if col.Nullable {
   238  		nullable = "YES"
   239  	}
   241  	if s, ok := col.Type.(sql.SpatialColumnType); ok {
   242  		if srid, d := s.GetSpatialTypeSRID(); d {
   243  			srsId = srid
   244  		}
   245  	}
   247  	charName, collName, charMaxLen, charOctetLen := getCharAndCollNamesAndCharMaxAndOctetLens(ctx, col.Type)
   249  	numericPrecision, numericScale := getColumnPrecisionAndScale(col.Type)
   250  	if types.IsDatetimeType(col.Type) || types.IsTimestampType(col.Type) {
   251  		datetimePrecision = 0
   252  	} else if types.IsTimespan(col.Type) {
   253  		// TODO: TIME length not yet supported
   254  		datetimePrecision = 6
   255  	}
   257  	columnDefault := getColumnDefault(ctx, col.Default)
   259  	extra := col.Extra
   260  	// If extra is not defined, fill it here.
   261  	if extra == "" && !col.Default.IsLiteral() {
   262  		extra = "DEFAULT_GENERATED"
   263  	}
   265  	var curColPrivStr []string
   266  	for p := range privSetMap {
   267  		curColPrivStr = append(curColPrivStr, p)
   268  	}
   270  	privSetCol := privSetTbl.Column(col.Name)
   271  	for _, pt := range privSetCol.ToSlice() {
   272  		priv := strings.ToLower(pt.String())
   273  		if _, ok := privSetMap[priv]; !ok {
   274  			curColPrivStr = append(curColPrivStr, priv)
   275  		}
   276  	}
   278  	sort.Strings(curColPrivStr)
   279  	privileges := strings.Join(curColPrivStr, ",")
   281  	return sql.Row{
   282  		"def",             // table_catalog
   283  		dbName,            // table_schema
   284  		tblName,           // table_name
   285  		col.Name,          // column_name
   286  		ordinalPos,        // ordinal_position
   287  		columnDefault,     // column_default
   288  		nullable,          // is_nullable
   289  		dataType,          // data_type
   290  		charMaxLen,        // character_maximum_length
   291  		charOctetLen,      // character_octet_length
   292  		numericPrecision,  // numeric_precision
   293  		numericScale,      // numeric_scale
   294  		datetimePrecision, // datetime_precision
   295  		charName,          // character_set_name
   296  		collName,          // collation_name
   297  		colType,           // column_type
   298  		columnKey,         // column_key
   299  		extra,             // extra
   300  		privileges,        // privileges
   301  		col.Comment,       // column_comment
   302  		"",                // generation_expression
   303  		srsId,             // srs_id
   304  	}
   305  }
   307  // getRowsFromTable returns array of rows for all accessible columns of the given table.
   308  func getRowsFromTable(ctx *sql.Context, db sql.Database, t sql.Table, privSetDb sql.PrivilegeSetDatabase, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
   309  	var rows []sql.Row
   311  	privSetTbl := privSetDb.Table(t.Name())
   312  	curPrivSetMap := getCurrentPrivSetMapForColumn(privSetTbl.ToSlice(), privSetMap)
   314  	columnKeyMap, hasPK, err := getIndexKeyInfo(ctx, t)
   315  	if err != nil {
   316  		return nil, err
   317  	}
   319  	tblName := t.Name()
   320  	for i, col := range schemaForTable(t, db, allColsWithDefaultValue) {
   321  		var columnKey string
   322  		// Check column PK here first because there are PKs from table implementations that don't implement sql.IndexedTable
   323  		if col.PrimaryKey {
   324  			columnKey = "PRI"
   325  		} else if val, ok := columnKeyMap[col.Name]; ok {
   326  			columnKey = val
   327  			// A UNIQUE index may be displayed as PRI if it cannot contain NULL values and there is no PRIMARY KEY in the table
   328  			if !col.Nullable && !hasPK && columnKey == "UNI" {
   329  				columnKey = "PRI"
   330  				hasPK = true
   331  			}
   332  		}
   334  		r := getRowFromColumn(ctx, i, col, db.Name(), tblName, columnKey, privSetTbl, curPrivSetMap)
   335  		if r != nil {
   336  			rows = append(rows, r)
   337  		}
   338  	}
   340  	return rows, nil
   341  }
   343  // getRowsFromViews returns array or rows for columns for all views for given database.
   344  func getRowsFromViews(ctx *sql.Context, db sql.Database) ([]sql.Row, error) {
   345  	var rows []sql.Row
   346  	// TODO: View Definition is lacking information to properly fill out these table
   347  	// TODO: Should somehow get reference to table(s) view is referencing
   348  	// TODO: Each column that view references should also show up as unique entries as well
   349  	views, err := viewsInDatabase(ctx, db)
   350  	if err != nil {
   351  		return nil, err
   352  	}
   354  	for _, view := range views {
   355  		rows = append(rows, sql.Row{
   356  			"def",     // table_catalog
   357  			db.Name(), // table_schema
   358  			view.Name, // table_name
   359  			"",        // column_name
   360  			uint32(0), // ordinal_position
   361  			nil,       // column_default
   362  			"",        // is_nullable
   363  			nil,       // data_type
   364  			nil,       // character_maximum_length
   365  			nil,       // character_octet_length
   366  			nil,       // numeric_precision
   367  			nil,       // numeric_scale
   368  			nil,       // datetime_precision
   369  			"",        // character_set_name
   370  			"",        // collation_name
   371  			"",        // column_type
   372  			"",        // column_key
   373  			"",        // extra
   374  			"select",  // privileges
   375  			"",        // column_comment
   376  			"",        // generation_expression
   377  			nil,       // srs_id
   378  		})
   379  	}
   381  	return rows, nil
   382  }
   384  // getRowsFromDatabase returns array of rows for all accessible columns of accessible table of the given database.
   385  func getRowsFromDatabase(ctx *sql.Context, db sql.Database, privSet sql.PrivilegeSet, privSetMap map[string]struct{}, allColsWithDefaultValue sql.Schema) ([]sql.Row, error) {
   386  	var rows []sql.Row
   387  	dbName := db.Name()
   389  	privSetDb := privSet.Database(dbName)
   390  	curPrivSetMap := getCurrentPrivSetMapForColumn(privSetDb.ToSlice(), privSetMap)
   391  	if dbName == sql.InformationSchemaDatabaseName {
   392  		curPrivSetMap["select"] = struct{}{}
   393  	}
   395  	err := sql.DBTableIter(ctx, db, func(t sql.Table) (cont bool, err error) {
   396  		rs, err := getRowsFromTable(ctx, db, t, privSetDb, curPrivSetMap, allColsWithDefaultValue)
   397  		if err != nil {
   398  			return false, err
   399  		}
   400  		rows = append(rows, rs...)
   401  		return true, nil
   402  	})
   403  	if err != nil {
   404  		return nil, err
   405  	}
   407  	return rows, nil
   408  }
   410  // getCurrentPrivSetMapForColumn returns a new privilege set map that contains what the given privilege set map has,
   411  // and it adds any available privileges from given array of privilege type. For example, the given privilege set map
   412  // may contain general privilege types for the database only, and the given array of privilege type will contain all
   413  // privilege types defined for the table specifically. This function only add `select`, `insert`, `update` and
   414  // `references` privileges to the new privilege set map if available. These are column level privileges only.
   415  func getCurrentPrivSetMapForColumn(privs []sql.PrivilegeType, privSetMap map[string]struct{}) map[string]struct{} {
   416  	curPrivSetMap := make(map[string]struct{})
   417  	for p := range privSetMap {
   418  		curPrivSetMap[p] = struct{}{}
   419  	}
   420  	for _, pt := range privs {
   421  		switch pt {
   422  		// columns can have 'select', 'insert', 'update', 'references' privileges only.
   423  		case sql.PrivilegeType_Select, sql.PrivilegeType_Insert, sql.PrivilegeType_Update, sql.PrivilegeType_References:
   424  			curPrivSetMap[strings.ToLower(pt.String())] = struct{}{}
   425  		}
   426  	}
   427  	return curPrivSetMap
   428  }
   430  // getIndexKeyInfo returns map of columns and its index information whether this column is PK or unique index, etc.
   431  func getIndexKeyInfo(ctx *sql.Context, t sql.Table) (map[string]string, bool, error) {
   432  	var columnKeyMap = make(map[string]string)
   433  	// Get UNIQUEs, PRIMARY KEYs
   434  	hasPK := false
   435  	if indexTable, ok := t.(sql.IndexAddressable); ok {
   436  		indexes, iErr := indexTable.GetIndexes(ctx)
   437  		if iErr != nil {
   438  			return columnKeyMap, hasPK, iErr
   439  		}
   441  		for _, index := range indexes {
   442  			idx := ""
   443  			if index.ID() == "PRIMARY" {
   444  				idx = "PRI"
   445  				hasPK = true
   446  			} else if index.IsUnique() {
   447  				idx = "UNI"
   448  			} else {
   449  				idx = "MUL"
   450  			}
   452  			colNames := getColumnNamesFromIndex(index, t)
   453  			// A UNIQUE index may display as MUL if several columns form a composite UNIQUE index
   454  			if idx == "UNI" && len(colNames) > 1 {
   455  				idx = "MUL"
   456  				columnKeyMap[colNames[0]] = idx
   457  			} else {
   458  				for _, colName := range colNames {
   459  					columnKeyMap[colName] = idx
   460  				}
   461  			}
   462  		}
   463  	}
   465  	return columnKeyMap, hasPK, nil
   466  }
   468  // getColumnDefault returns the column default value for given sql.ColumnDefaultValue
   469  func getColumnDefault(ctx *sql.Context, cd *sql.ColumnDefaultValue) interface{} {
   470  	if cd == nil {
   471  		return nil
   472  	}
   474  	defStr := cd.String()
   475  	if defStr == "NULL" {
   476  		return nil
   477  	}
   479  	if !cd.IsLiteral() {
   480  		if strings.HasPrefix(defStr, "(") && strings.HasSuffix(defStr, ")") {
   481  			defStr = strings.TrimSuffix(strings.TrimPrefix(defStr, "("), ")")
   482  		}
   483  		if types.IsTime(cd.Type()) && (strings.HasPrefix(defStr, "NOW") || strings.HasPrefix(defStr, "CURRENT_TIMESTAMP")) {
   484  			defStr = strings.Replace(defStr, "NOW", "CURRENT_TIMESTAMP", -1)
   485  			defStr = strings.TrimSuffix(defStr, "()")
   486  		}
   487  		return fmt.Sprint(defStr)
   488  	}
   490  	if types.IsEnum(cd.Type()) || types.IsSet(cd.Type()) {
   491  		return strings.Trim(defStr, "'")
   492  	}
   494  	v, err := cd.Eval(ctx, nil)
   495  	if err != nil {
   496  		return ""
   497  	}
   499  	switch l := v.(type) {
   500  	case time.Time:
   501  		v = l.Format("2006-01-02 15:04:05")
   502  	case []uint8:
   503  		hexStr := hex.EncodeToString(l)
   504  		v = fmt.Sprintf("0x%s", hexStr)
   505  	}
   507  	if types.IsBit(cd.Type()) {
   508  		if i, ok := v.(uint64); ok {
   509  			bitStr := strconv.FormatUint(i, 2)
   510  			v = fmt.Sprintf("b'%s'", bitStr)
   511  		}
   512  	}
   514  	return fmt.Sprint(v)
   515  }
   517  func schemaForTable(t sql.Table, db sql.Database, allColsWithDefaultValue sql.Schema) sql.Schema {
   518  	start, end := -1, -1
   519  	tableName := strings.ToLower(t.Name())
   521  	for i, col := range allColsWithDefaultValue {
   522  		dbName := strings.ToLower(db.Name())
   523  		if start < 0 && strings.ToLower(col.Source) == tableName && strings.ToLower(col.DatabaseSource) == dbName {
   524  			start = i
   525  		} else if start >= 0 && (strings.ToLower(col.Source) != tableName || strings.ToLower(col.DatabaseSource) != dbName) {
   526  			end = i
   527  			break
   528  		}
   529  	}
   531  	if start < 0 {
   532  		return nil
   533  	}
   535  	if end < 0 {
   536  		end = len(allColsWithDefaultValue)
   537  	}
   539  	return allColsWithDefaultValue[start:end]
   540  }
   542  // get DtdIdAndDataType returns data types for given sql.Type but in two different ways.
   543  // The DTD_IDENTIFIER value contains the type name and possibly other information such as the precision or length.
   544  // The DATA_TYPE value is the type name only with no other information.
   545  func getDtdIdAndDataType(colType sql.Type) (string, string) {
   546  	dtdId := strings.Split(strings.Split(colType.String(), " COLLATE")[0], " CHARACTER SET")[0]
   548  	// The DATA_TYPE value is the type name only with no other information
   549  	dataType := strings.Split(dtdId, "(")[0]
   550  	dataType = strings.Split(dataType, " ")[0]
   552  	return dtdId, dataType
   553  }
   555  // getColumnPrecisionAndScale returns the precision or a number of mysql type. For non-numeric or decimal types this
   556  // function should return nil,nil.
   557  func getColumnPrecisionAndScale(colType sql.Type) (interface{}, interface{}) {
   558  	var numericScale interface{}
   559  	switch t := colType.(type) {
   560  	case types.BitType:
   561  		return int(t.NumberOfBits()), numericScale
   562  	case sql.DecimalType:
   563  		return int(t.Precision()), int(t.Scale())
   564  	case sql.NumberType:
   565  		switch colType.Type() {
   566  		case sqltypes.Float32, sqltypes.Float64:
   567  			numericScale = nil
   568  		default:
   569  			numericScale = 0
   570  		}
   571  		return typeToNumericPrecision[colType.Type()], numericScale
   572  	default:
   573  		return nil, nil
   574  	}
   575  }
   577  func getCharAndCollNamesAndCharMaxAndOctetLens(ctx *sql.Context, colType sql.Type) (interface{}, interface{}, interface{}, interface{}) {
   578  	var (
   579  		charName     interface{}
   580  		collName     interface{}
   581  		charMaxLen   interface{}
   582  		charOctetLen interface{}
   583  	)
   584  	if twc, ok := colType.(sql.TypeWithCollation); ok && !types.IsBinaryType(colType) {
   585  		colColl := twc.Collation()
   586  		collName = colColl.Name()
   587  		charName = colColl.CharacterSet().String()
   588  		if types.IsEnum(colType) || types.IsSet(colType) {
   589  			charOctetLen = int64(colType.MaxTextResponseByteLength(ctx))
   590  			charMaxLen = int64(colType.MaxTextResponseByteLength(ctx)) / colColl.CharacterSet().MaxLength()
   591  		}
   592  	}
   593  	if st, ok := colType.(sql.StringType); ok {
   594  		charMaxLen = st.MaxCharacterLength()
   595  		charOctetLen = st.MaxByteLength()
   596  	}
   598  	return charName, collName, charMaxLen, charOctetLen
   599  }