github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/schema_override.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/plan"
    22  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    23  
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
    29  )
    30  
    31  // resolveOverriddenNonexistentTable checks if there is an overridden schema commit set for this session, and if so
    32  // returns an empty table with that schema if |tblName| exists in the overridden schema commit. If no schema override
    33  // is set, this function returns a nil sql.Table and a false boolean return parameter.
    34  func resolveOverriddenNonexistentTable(ctx *sql.Context, tblName string, db Database) (sql.Table, bool, error) {
    35  	// Check to see if table schemas have been overridden
    36  	schemaRoot, err := resolveOverriddenSchemaRoot(ctx, db)
    37  	if err != nil {
    38  		return nil, false, err
    39  	}
    40  	if schemaRoot == nil {
    41  		return nil, false, nil
    42  	}
    43  
    44  	// If schema overrides are in place, see if the table exists in the overridden schema
    45  	t, _, ok, err := doltdb.GetTableInsensitive(ctx, schemaRoot, tblName)
    46  	if err != nil {
    47  		return nil, false, err
    48  	}
    49  	if !ok {
    50  		return nil, false, nil
    51  	}
    52  
    53  	// Load the overridden schema and convert it to a sql.Schema
    54  	// TODO: Loading the schema is an expensive operation, so it would be more
    55  	//       efficient to use the same schema cache from getTable() here. The
    56  	//       schemas are cached by root value, so it's safe to use the cache.
    57  	overriddenSchema, err := t.GetSchema(ctx)
    58  	if err != nil {
    59  		return nil, false, err
    60  	}
    61  	overriddenSqlSchema, err := sqlutil.FromDoltSchema(db.Name(), tblName, overriddenSchema)
    62  	if err != nil {
    63  		return nil, false, err
    64  	}
    65  
    66  	// Return an empty table with the overridden schema
    67  	emptyTable := plan.NewEmptyTableWithSchema(overriddenSqlSchema.Schema)
    68  	return emptyTable.(sql.Table), true, nil
    69  }
    70  
    71  // overrideSchemaForTable loads the schema from |overriddenSchemaRoot| for the table named |tableName| and sets the
    72  // override on |tbl|. If there are any problems loading the overridden schema, this function returns an error.
    73  func overrideSchemaForTable(ctx *sql.Context, tableName string, tbl *doltdb.Table, overriddenSchemaRoot doltdb.RootValue) error {
    74  	overriddenTable, _, ok, err := doltdb.GetTableInsensitive(ctx, overriddenSchemaRoot, tableName)
    75  	if err != nil {
    76  		return fmt.Errorf("unable to find table '%s' at overridden schema root: %s", tableName, err.Error())
    77  	}
    78  	if !ok {
    79  		return fmt.Errorf("unable to find table '%s' at overridden schema root", tableName)
    80  	}
    81  
    82  	// TODO: Loading the schema is an expensive operation, so it would be more
    83  	//       efficient to use the same schema cache from getTable() here. The
    84  	//       schemas are cached by root value, so it's safe to use the cache.
    85  	overriddenSchema, err := overriddenTable.GetSchema(ctx)
    86  	if err != nil {
    87  		return fmt.Errorf("unable to load overridden schema for table '%s': %s", tableName, err.Error())
    88  	}
    89  
    90  	tbl.OverrideSchema(overriddenSchema)
    91  	return nil
    92  }
    93  
    94  // getOverriddenSchemaValue returns a string value of the Dolt schema override session variable. If the
    95  // variable is not set (i.e. NULL or empty string) then this function returns an empty string.
    96  func getOverriddenSchemaValue(ctx *sql.Context) (string, error) {
    97  	doltSession := dsess.DSessFromSess(ctx.Session)
    98  	// TODO: Session variable lookups can be surprisingly expensive as well.
    99  	//       Check out DoltSession.dbSessionVarsStale() to see an example of how
   100  	//       we can use caching to make this more efficient.
   101  	varValue, err := doltSession.GetSessionVariable(ctx, dsess.DoltOverrideSchema)
   102  	if err != nil {
   103  		return "", err
   104  	}
   105  
   106  	if varValue == nil {
   107  		return "", nil
   108  	}
   109  
   110  	varString, ok := varValue.(string)
   111  	if !ok {
   112  		return "", fmt.Errorf("value of %s session variable is not a string", dsess.DoltOverrideSchema)
   113  	}
   114  	return varString, nil
   115  }
   116  
   117  // resolveOverriddenSchemaRoot loads the Dolt schema override session variable, resolves the commit reference, and
   118  // loads the RootValue for that commit. If the session variable is not set, this function returns nil. If there are
   119  // any problems resolving the commit or loading the root value, this function returns an error.
   120  func resolveOverriddenSchemaRoot(ctx *sql.Context, db Database) (doltdb.RootValue, error) {
   121  	overriddenSchemaValue, err := getOverriddenSchemaValue(ctx)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	if overriddenSchemaValue == "" {
   127  		return nil, nil
   128  	}
   129  
   130  	commitSpec, err := doltdb.NewCommitSpec(overriddenSchemaValue)
   131  	if err != nil {
   132  		return nil, fmt.Errorf("invalid commit spec specified in %s: %s", dsess.DoltOverrideSchema, err.Error())
   133  	}
   134  
   135  	// Attempt to get a head ref if we can, but don't error out, if we don't. Commit and tag
   136  	// revision databases won't have a head ref, so it's okay to pass in nil for the head ref.
   137  	doltSession := dsess.DSessFromSess(ctx.Session)
   138  	headRef, _ := doltSession.CWBHeadRef(ctx, db.Name())
   139  
   140  	optionalCommit, err := db.GetDoltDB().Resolve(ctx, commitSpec, headRef)
   141  	if err != nil {
   142  		return nil, fmt.Errorf("unable to resolve schema override value: " + err.Error())
   143  	}
   144  
   145  	commit, ok := optionalCommit.ToCommit()
   146  	if !ok {
   147  		return nil, fmt.Errorf("unable to resolve schema override: "+
   148  			"commit '%s' is not present locally in the commit graph", optionalCommit.Addr.String())
   149  	}
   150  
   151  	rootValue, err := commit.GetRootValue(ctx)
   152  	if err != nil {
   153  		return nil, fmt.Errorf("unable to load root value for schema override commit: " + err.Error())
   154  	}
   155  
   156  	return rootValue, nil
   157  }
   158  
   159  // rowConverterByColTagAndName returns a function that converts a row from |srcSchema| to |targetSchema| using the
   160  // specified |projectedTags| and |projectedColNames|. Projected tags and projected column names are both
   161  // provided so that if a tag changes (such as when a column's type is changed) the mapping can fall back to
   162  // matching by column name.
   163  //
   164  // NOTE: This was forked from the dolt_history system table's rowConverter function, which has slightly different
   165  // behavior. It would be nice to resolve the differences and standardize on how we convert rows between schemas.
   166  // The main differences are:
   167  //  1. The dolt_history_ system tables only maps columns by name and doesn't take into account tags. This
   168  //     implementation prefers mapping by column tags, but will fall back to column names if a column with a specified
   169  //     tag can't be found. This behavior is similar to what we do in the diff system tables. Related to this, the
   170  //     columns to include in the projection are also only specified by name in the dolt_history system tables, but
   171  //     here they need to be specified by tag and then fallback to column name matching if a tag isn't found.
   172  //  2. The dolt_history_ system tables will not map columns unless their types are exactly identical. This is too
   173  //     strict for schema override mapping, so this implementation attempts to convert column values to the target
   174  //     type. If a column value is not compatible with the mapped column type, then an error is returned while mapping
   175  //     the schema. String types are currently the only exception: they will be truncated to fit into narrower types
   176  //     if necessary, and a warning will be logged in the session. This is similar to the behavior of the diff tables
   177  //     but instead of returning an error, they log a warning and return a NULL value.
   178  func rowConverterByColTagAndName(srcSchema, targetSchema schema.Schema, projectedTags []uint64, projectedColNames []string) func(ctx *sql.Context, row sql.Row) (sql.Row, error) {
   179  	srcIndexToTargetIndex := make(map[int]int)
   180  	srcIndexToTargetType := make(map[int]typeinfo.TypeInfo)
   181  	for i, targetColumn := range targetSchema.GetAllCols().GetColumns() {
   182  		sourceColumn, found := srcSchema.GetAllCols().GetByTag(targetColumn.Tag)
   183  		if !found {
   184  			sourceColumn, found = srcSchema.GetAllCols().GetByName(targetColumn.Name)
   185  		}
   186  
   187  		if found {
   188  			srcIndex := srcSchema.GetAllCols().IndexOf(sourceColumn.Name)
   189  			srcIndexToTargetIndex[srcIndex] = i
   190  			srcIndexToTargetType[srcIndex] = targetColumn.TypeInfo
   191  		}
   192  	}
   193  
   194  	return func(ctx *sql.Context, row sql.Row) (sql.Row, error) {
   195  		r := make(sql.Row, len(projectedColNames))
   196  		for i, tag := range projectedTags {
   197  			// First try to find the column in the src schema with the matching tag
   198  			// then fallback to a name match, since type changes will change the tag
   199  			srcColumn, found := srcSchema.GetAllCols().GetByTag(tag)
   200  			if !found {
   201  				srcColumn, found = srcSchema.GetAllCols().GetByName(projectedColNames[i])
   202  			}
   203  
   204  			if found {
   205  				srcIndex := srcSchema.GetAllCols().IndexOf(srcColumn.Name)
   206  				temp := row[srcIndex]
   207  
   208  				conversionType := srcIndexToTargetType[srcIndex]
   209  
   210  				convertedValue, err := convertWithTruncation(ctx, temp, conversionType)
   211  				if err != nil {
   212  					return nil, err
   213  				}
   214  
   215  				r[i] = convertedValue
   216  			}
   217  		}
   218  		return r, nil
   219  	}
   220  }
   221  
   222  // convertWithTruncation attempts to convert |value| to |typ| and returns the converted value. If the value is a string
   223  // and the type is a VARCHAR, CHAR, or TEXT type and the length of |value| is greater than the allowed lenght of |typ|,
   224  // then the value is truncated to the allowed length and a warning is logged in the session.
   225  // If the value is not compatible with |typ|, then an error is
   226  func convertWithTruncation(ctx *sql.Context, value any, typ typeinfo.TypeInfo) (any, error) {
   227  	if s, ok := value.(string); ok && gmstypes.IsTextOnly(typ.ToSqlType()) {
   228  		// For char/varchar/text values, we are more lenient with conversion and truncate the value
   229  		// if it is too long to fit into the target type.
   230  		stringType := typ.ToSqlType().(gmstypes.StringType)
   231  		if int64(len(s)) > stringType.MaxCharacterLength() {
   232  			value = s[:stringType.MaxCharacterLength()]
   233  			ctx.Warn(1246, "Value '%s' truncated to fit column of type %s", s, typ.String())
   234  		}
   235  	}
   236  
   237  	convertedValue, _, err := typ.ToSqlType().Convert(value)
   238  	if err != nil {
   239  		return nil, fmt.Errorf("unable to convert value to overridden schema: %s", err.Error())
   240  	}
   241  	return convertedValue, nil
   242  }
   243  
   244  // newMappingRowIter returns a RowIter that maps results from |wrappedIter| to the overridden schema on |t|.
   245  func newMappingRowIter(ctx *sql.Context, t *DoltTable, wrappedIter sql.RowIter) (sql.RowIter, error) {
   246  	rowConvFunc, err := newRowConverterForDoltTable(ctx, t)
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  
   251  	newRowIter := mappingRowIter{
   252  		child:       wrappedIter,
   253  		rowConvFunc: rowConvFunc,
   254  	}
   255  	return &newRowIter, nil
   256  }
   257  
   258  // newRowConverterForDoltTable returns a function that converts rows from the original schema of |t| to the overridden
   259  // schema of |t|.
   260  func newRowConverterForDoltTable(ctx *sql.Context, t *DoltTable) (func(ctx *sql.Context, row sql.Row) (sql.Row, error), error) {
   261  	// If there is a schema override, then we need to map the results
   262  	// from the old schema to the new schema
   263  	doltSession := dsess.DSessFromSess(ctx.Session)
   264  	roots, ok := doltSession.GetRoots(ctx, t.db.Name())
   265  	if !ok {
   266  		return nil, fmt.Errorf("unable to get roots for database '%s'", t.db.Name())
   267  	}
   268  
   269  	doltSchema, err := sqlutil.ToDoltSchema(ctx, roots.Working, t.Name(), t.sqlSch, roots.Head, t.Collation())
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  
   274  	var projectedColNames []string
   275  	for _, tag := range t.projectedCols {
   276  		column, ok := t.overriddenSchema.GetAllCols().GetByTag(tag)
   277  		if !ok {
   278  			return nil, fmt.Errorf("unable to find column with tag %d in overridden schema", tag)
   279  		}
   280  		projectedColNames = append(projectedColNames, column.Name)
   281  	}
   282  
   283  	rowConvFunc := rowConverterByColTagAndName(doltSchema, t.overriddenSchema, t.projectedCols, projectedColNames)
   284  	return rowConvFunc, nil
   285  }
   286  
   287  // mappingRowIter is a RowIter that maps rows from a child RowIter to a new schema using a row conversion function.
   288  type mappingRowIter struct {
   289  	child       sql.RowIter
   290  	rowConvFunc func(ctx *sql.Context, row sql.Row) (sql.Row, error)
   291  }
   292  
   293  var _ sql.RowIter = (*mappingRowIter)(nil)
   294  
   295  // Next implements the sql.RowIter interface
   296  func (m *mappingRowIter) Next(ctx *sql.Context) (sql.Row, error) {
   297  	next, err := m.child.Next(ctx)
   298  	if err != nil {
   299  		return next, err
   300  	}
   301  
   302  	if m.rowConvFunc == nil {
   303  		return next, nil
   304  	} else {
   305  		return m.rowConvFunc(ctx, next)
   306  	}
   307  }
   308  
   309  // Close implements the sql.RowIter interface
   310  func (m *mappingRowIter) Close(ctx *sql.Context) error {
   311  	return m.child.Close(ctx)
   312  }