github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/set.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  func (b *Builder) buildSet(inScope *scope, n *ast.Set) (outScope *scope) {
    30  	var setVarExprs []*ast.SetVarExpr
    31  	for _, setExpr := range n.Exprs {
    32  		switch strings.ToLower(setExpr.Name.String()) {
    33  		case "names":
    34  			// Special case: SET NAMES expands to 3 different system variables.
    35  			setVarExprs = append(setVarExprs, getSetVarExprsFromSetNamesExpr(setExpr)...)
    36  		case "charset":
    37  			// Special case: SET CHARACTER SET (CHARSET) expands to 3 different system variables.
    38  			csd, err := b.ctx.GetSessionVariable(b.ctx, "character_set_database")
    39  			if err != nil {
    40  				b.handleErr(err)
    41  			}
    42  			setVarExprs = append(setVarExprs, getSetVarExprsFromSetCharsetExpr(setExpr, []byte(csd.(string)))...)
    43  		default:
    44  			setVarExprs = append(setVarExprs, setExpr)
    45  		}
    46  	}
    47  
    48  	exprs := b.setExprsToExpressions(inScope, setVarExprs)
    49  
    50  	outScope = inScope.push()
    51  	outScope.node = plan.NewSet(exprs)
    52  	return outScope
    53  }
    54  
    55  func getSetVarExprsFromSetNamesExpr(expr *ast.SetVarExpr) []*ast.SetVarExpr {
    56  	return []*ast.SetVarExpr{
    57  		{
    58  			Name: ast.NewColName("character_set_client"),
    59  			Expr: expr.Expr,
    60  		},
    61  		{
    62  			Name: ast.NewColName("character_set_connection"),
    63  			Expr: expr.Expr,
    64  		},
    65  		{
    66  			Name: ast.NewColName("character_set_results"),
    67  			Expr: expr.Expr,
    68  		},
    69  		// TODO (9/24/20 Zach): this should also set the collation_connection to the default collation for the character set named
    70  	}
    71  }
    72  
    73  func getSetVarExprsFromSetCharsetExpr(expr *ast.SetVarExpr, csd []byte) []*ast.SetVarExpr {
    74  	return []*ast.SetVarExpr{
    75  		{
    76  			Name: ast.NewColName("character_set_client"),
    77  			Expr: expr.Expr,
    78  		},
    79  		{
    80  			Name: ast.NewColName("character_set_results"),
    81  			Expr: expr.Expr,
    82  		},
    83  		{
    84  			Name: ast.NewColName("character_set_connection"),
    85  			Expr: &ast.SQLVal{Type: ast.StrVal, Val: csd},
    86  		},
    87  	}
    88  }
    89  
    90  func (b *Builder) setExprsToExpressions(inScope *scope, e ast.SetVarExprs) []sql.Expression {
    91  	res := make([]sql.Expression, len(e))
    92  	for i, setExpr := range e {
    93  		if expr, ok := setExpr.Expr.(*ast.SQLVal); ok && strings.ToLower(setExpr.Name.String()) == "transaction" &&
    94  			(setExpr.Scope == ast.SetScope_Global || setExpr.Scope == ast.SetScope_Session || string(setExpr.Scope) == "") {
    95  			scope := sql.SystemVariableScope_Session
    96  			if setExpr.Scope == ast.SetScope_Global {
    97  				scope = sql.SystemVariableScope_Global
    98  			}
    99  			switch strings.ToLower(expr.String()) {
   100  			case "'isolation level repeatable read'":
   101  				varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope))
   102  				res[i] = expression.NewSetField(varToSet, expression.NewLiteral("REPEATABLE-READ", types.LongText))
   103  				continue
   104  			case "'isolation level read committed'":
   105  				varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope))
   106  				res[i] = expression.NewSetField(varToSet, expression.NewLiteral("READ-COMMITTED", types.LongText))
   107  				continue
   108  			case "'isolation level read uncommitted'":
   109  				varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope))
   110  				res[i] = expression.NewSetField(varToSet, expression.NewLiteral("READ-UNCOMMITTED", types.LongText))
   111  				continue
   112  			case "'isolation level serializable'":
   113  				varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope))
   114  				res[i] = expression.NewSetField(varToSet, expression.NewLiteral("SERIALIZABLE", types.LongText))
   115  				continue
   116  			case "'read write'":
   117  				varToSet := expression.NewSystemVar("transaction_read_only", scope, string(scope))
   118  				res[i] = expression.NewSetField(varToSet, expression.NewLiteral(false, types.Boolean))
   119  				continue
   120  			case "'read only'":
   121  				varToSet := expression.NewSystemVar("transaction_read_only", scope, string(scope))
   122  				res[i] = expression.NewSetField(varToSet, expression.NewLiteral(true, types.Boolean))
   123  				continue
   124  			}
   125  		}
   126  
   127  		// left => convert to user var or system var expression, validate system var
   128  		// right => getSetExpr, not adapted for defaults yet, special keywords need to be converted, variables replaced
   129  		var setScope ast.SetScope
   130  
   131  		tblName := strings.ToLower(setExpr.Name.Qualifier.String())
   132  		c, ok := inScope.resolveColumn("", tblName, strings.ToLower(setExpr.Name.Name.String()), true, false)
   133  		var setVar sql.Expression
   134  		if ok {
   135  			setVar = c.scalarGf()
   136  		} else {
   137  			setVar, setScope, ok = b.buildSysVar(setExpr.Name, setExpr.Scope)
   138  			if !ok {
   139  				switch setScope {
   140  				case ast.SetScope_None:
   141  					if tblName != "" && !inScope.hasTable(tblName) {
   142  						b.handleErr(sql.ErrTableNotFound.New(tblName))
   143  					}
   144  					b.handleErr(sql.ErrColumnNotFound.New(setExpr.Name.String()))
   145  				case ast.SetScope_User:
   146  					b.handleErr(sql.ErrUnknownUserVariable.New(setExpr.Name.String()))
   147  				default:
   148  					b.handleErr(sql.ErrUnknownSystemVariable.New(setExpr.Name.String()))
   149  				}
   150  			}
   151  		}
   152  
   153  		sysVarType, _ := setVar.Type().(sql.SystemVariableType)
   154  		innerExpr, ok := b.simplifySetExpr(setExpr.Name, setScope, setExpr.Expr, sysVarType)
   155  		if !ok {
   156  			innerExpr = b.buildScalar(inScope, setExpr.Expr)
   157  		}
   158  
   159  		res[i] = expression.NewSetField(setVar, innerExpr)
   160  	}
   161  	return res
   162  }
   163  
   164  func (b *Builder) buildSysVar(colName *ast.ColName, scopeHint ast.SetScope) (sql.Expression, ast.SetScope, bool) {
   165  	// convert to system or user var, validate system var
   166  	table := colName.Qualifier.String()
   167  	col := colName.Name.String()
   168  	var varName string
   169  	var scope ast.SetScope
   170  	var err error
   171  	var specifiedScope string
   172  
   173  	if table == "" {
   174  		varName, scope, specifiedScope, err = ast.VarScope(col)
   175  	} else {
   176  		varName, scope, specifiedScope, err = ast.VarScope(table, col)
   177  	}
   178  	if err != nil {
   179  		b.handleErr(err)
   180  	}
   181  
   182  	if scope == "" {
   183  		scope = scopeHint
   184  	}
   185  
   186  	switch scope {
   187  	case ast.SetScope_Global:
   188  		_, _, ok := sql.SystemVariables.GetGlobal(varName)
   189  		if !ok {
   190  			return nil, scope, false
   191  		}
   192  		return expression.NewSystemVar(varName, sql.SystemVariableScope_Global, specifiedScope), scope, true
   193  	case ast.SetScope_None, ast.SetScope_Session:
   194  		switch strings.ToLower(varName) {
   195  		case "character_set_database", "collation_database":
   196  			sysVar := expression.NewSystemVar(varName, sql.SystemVariableScope_Session, specifiedScope)
   197  			sysVar.Collation = sql.Collation_Default
   198  			if db, err := b.cat.Database(b.ctx, b.ctx.GetCurrentDatabase()); err == nil {
   199  				sysVar.Collation = plan.GetDatabaseCollation(b.ctx, db)
   200  			}
   201  			return sysVar, scope, true
   202  		default:
   203  			_, err = b.ctx.GetSessionVariable(b.ctx, varName)
   204  			if err != nil {
   205  				return nil, scope, false
   206  			}
   207  			return expression.NewSystemVar(varName, sql.SystemVariableScope_Session, specifiedScope), scope, true
   208  		}
   209  	case ast.SetScope_User:
   210  		t, _, err := b.ctx.GetUserVariable(b.ctx, varName)
   211  		if err != nil {
   212  			b.handleErr(err)
   213  		}
   214  		if t != nil {
   215  			return expression.NewUserVarWithType(varName, t), scope, true
   216  		}
   217  		return expression.NewUserVar(varName), scope, true
   218  	case ast.SetScope_Persist:
   219  		return expression.NewSystemVar(varName, sql.SystemVariableScope_Persist, specifiedScope), scope, true
   220  	case ast.SetScope_PersistOnly:
   221  		return expression.NewSystemVar(varName, sql.SystemVariableScope_PersistOnly, specifiedScope), scope, true
   222  	default: // shouldn't happen
   223  		err := fmt.Errorf("unknown set scope %v", scope)
   224  		b.handleErr(err)
   225  	}
   226  	return nil, scope, false
   227  }
   228  
   229  func (b *Builder) simplifySetExpr(name *ast.ColName, varScope ast.SetScope, val ast.Expr, sysVarType sql.Type) (sql.Expression, bool) {
   230  	// can |val| be nested?
   231  	switch val := val.(type) {
   232  	case *ast.SQLVal:
   233  		if val.Type != ast.StrVal {
   234  			return nil, false
   235  		}
   236  		e := expression.NewLiteral(string(val.Val), types.Text)
   237  		res, err := e.Eval(b.ctx, nil)
   238  		if err != nil {
   239  			b.handleErr(err)
   240  		}
   241  		setVal, ok := res.(string)
   242  		if !ok {
   243  			return nil, false
   244  		}
   245  
   246  		switch strings.ToLower(setVal) {
   247  		case ast.KeywordString(ast.ON):
   248  			return expression.NewLiteral(true, types.Boolean), true
   249  		case ast.KeywordString(ast.TRUE):
   250  			return expression.NewLiteral(true, types.Boolean), true
   251  		case ast.KeywordString(ast.OFF):
   252  			return expression.NewLiteral(false, types.Boolean), true
   253  		case ast.KeywordString(ast.FALSE):
   254  			return expression.NewLiteral(false, types.Boolean), true
   255  		default:
   256  		}
   257  
   258  		if sysVarType == nil {
   259  			return nil, false
   260  		}
   261  
   262  		enum, _, err := sysVarType.Convert(setVal)
   263  		if err != nil {
   264  			b.handleErr(err)
   265  		}
   266  		return expression.NewLiteral(enum, sysVarType), true
   267  	case *ast.ColName:
   268  		// convert and eval
   269  		// todo check whether right side needs variable replacement
   270  		sysVar, _, ok := b.buildSysVar(val, ast.SetScope_None)
   271  		if ok {
   272  			return sysVar, true
   273  		}
   274  		e := expression.NewLiteral(val.Name.String(), types.Text)
   275  		res, err := e.Eval(b.ctx, nil)
   276  		if err != nil {
   277  			b.handleErr(err)
   278  		}
   279  		setVal, ok := res.(string)
   280  		if !ok {
   281  			return nil, false
   282  		}
   283  
   284  		switch strings.ToLower(setVal) {
   285  		case ast.KeywordString(ast.ON):
   286  			return expression.NewLiteral(true, types.Boolean), true
   287  		case ast.KeywordString(ast.TRUE):
   288  			return expression.NewLiteral(true, types.Boolean), true
   289  		case ast.KeywordString(ast.OFF):
   290  			return expression.NewLiteral(false, types.Boolean), true
   291  		case ast.KeywordString(ast.FALSE):
   292  			return expression.NewLiteral(false, types.Boolean), true
   293  		default:
   294  		}
   295  
   296  		if sysVarType == nil {
   297  			return nil, false
   298  		}
   299  
   300  		enum, _, err := sysVarType.Convert(setVal)
   301  		if err != nil {
   302  			b.handleErr(err)
   303  		}
   304  		return expression.NewLiteral(enum, sysVarType), true
   305  	case *ast.BoolVal:
   306  		// conv
   307  		e := expression.NewLiteral(val, types.Text)
   308  		res, err := e.Eval(b.ctx, nil)
   309  		if err != nil {
   310  			b.handleErr(err)
   311  		}
   312  		setVal, ok := res.(bool)
   313  		if !ok {
   314  			err := fmt.Errorf("expected *ast.BoolVal to evaluate to bool type, found: %T", val)
   315  			b.handleErr(err)
   316  		}
   317  
   318  		if setVal {
   319  			return expression.NewLiteral(1, types.Boolean), true
   320  		} else {
   321  			return expression.NewLiteral(0, types.Boolean), true
   322  		}
   323  	case *ast.Default:
   324  		// set back to default value
   325  		var err error
   326  		var varName string
   327  		table := name.Qualifier.String()
   328  		col := name.Name.Lowered()
   329  		if table != "" {
   330  			varName, _, _, err = ast.VarScope(table, col)
   331  		} else {
   332  			varName, _, _, err = ast.VarScope(col)
   333  		}
   334  		if err != nil {
   335  			b.handleErr(err)
   336  		}
   337  
   338  		switch varScope {
   339  		case ast.SetScope_None, ast.SetScope_Session, ast.SetScope_Global:
   340  			_, value, ok := sql.SystemVariables.GetGlobal(varName)
   341  			if ok {
   342  				return expression.NewLiteral(value, types.ApproximateTypeFromValue(value)), true
   343  			}
   344  			err = sql.ErrUnknownSystemVariable.New(varName)
   345  		case ast.SetScope_Persist, ast.SetScope_PersistOnly:
   346  			err = fmt.Errorf("%wsetting default for '%s'", sql.ErrUnsupportedFeature.New(), varScope)
   347  		case ast.SetScope_User:
   348  			err = sql.ErrUserVariableNoDefault.New(varName)
   349  		default: // shouldn't happen
   350  			err = fmt.Errorf("unknown set scope %v", varScope)
   351  		}
   352  		b.handleErr(err)
   353  	}
   354  	return nil, false
   355  }