github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/set.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "fmt" 19 "strings" 20 21 ast "github.com/dolthub/vitess/go/vt/sqlparser" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/plan" 26 "github.com/dolthub/go-mysql-server/sql/types" 27 ) 28 29 func (b *Builder) buildSet(inScope *scope, n *ast.Set) (outScope *scope) { 30 var setVarExprs []*ast.SetVarExpr 31 for _, setExpr := range n.Exprs { 32 switch strings.ToLower(setExpr.Name.String()) { 33 case "names": 34 // Special case: SET NAMES expands to 3 different system variables. 35 setVarExprs = append(setVarExprs, getSetVarExprsFromSetNamesExpr(setExpr)...) 36 case "charset": 37 // Special case: SET CHARACTER SET (CHARSET) expands to 3 different system variables. 38 csd, err := b.ctx.GetSessionVariable(b.ctx, "character_set_database") 39 if err != nil { 40 b.handleErr(err) 41 } 42 setVarExprs = append(setVarExprs, getSetVarExprsFromSetCharsetExpr(setExpr, []byte(csd.(string)))...) 43 default: 44 setVarExprs = append(setVarExprs, setExpr) 45 } 46 } 47 48 exprs := b.setExprsToExpressions(inScope, setVarExprs) 49 50 outScope = inScope.push() 51 outScope.node = plan.NewSet(exprs) 52 return outScope 53 } 54 55 func getSetVarExprsFromSetNamesExpr(expr *ast.SetVarExpr) []*ast.SetVarExpr { 56 return []*ast.SetVarExpr{ 57 { 58 Name: ast.NewColName("character_set_client"), 59 Expr: expr.Expr, 60 }, 61 { 62 Name: ast.NewColName("character_set_connection"), 63 Expr: expr.Expr, 64 }, 65 { 66 Name: ast.NewColName("character_set_results"), 67 Expr: expr.Expr, 68 }, 69 // TODO (9/24/20 Zach): this should also set the collation_connection to the default collation for the character set named 70 } 71 } 72 73 func getSetVarExprsFromSetCharsetExpr(expr *ast.SetVarExpr, csd []byte) []*ast.SetVarExpr { 74 return []*ast.SetVarExpr{ 75 { 76 Name: ast.NewColName("character_set_client"), 77 Expr: expr.Expr, 78 }, 79 { 80 Name: ast.NewColName("character_set_results"), 81 Expr: expr.Expr, 82 }, 83 { 84 Name: ast.NewColName("character_set_connection"), 85 Expr: &ast.SQLVal{Type: ast.StrVal, Val: csd}, 86 }, 87 } 88 } 89 90 func (b *Builder) setExprsToExpressions(inScope *scope, e ast.SetVarExprs) []sql.Expression { 91 res := make([]sql.Expression, len(e)) 92 for i, setExpr := range e { 93 if expr, ok := setExpr.Expr.(*ast.SQLVal); ok && strings.ToLower(setExpr.Name.String()) == "transaction" && 94 (setExpr.Scope == ast.SetScope_Global || setExpr.Scope == ast.SetScope_Session || string(setExpr.Scope) == "") { 95 scope := sql.SystemVariableScope_Session 96 if setExpr.Scope == ast.SetScope_Global { 97 scope = sql.SystemVariableScope_Global 98 } 99 switch strings.ToLower(expr.String()) { 100 case "'isolation level repeatable read'": 101 varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope)) 102 res[i] = expression.NewSetField(varToSet, expression.NewLiteral("REPEATABLE-READ", types.LongText)) 103 continue 104 case "'isolation level read committed'": 105 varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope)) 106 res[i] = expression.NewSetField(varToSet, expression.NewLiteral("READ-COMMITTED", types.LongText)) 107 continue 108 case "'isolation level read uncommitted'": 109 varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope)) 110 res[i] = expression.NewSetField(varToSet, expression.NewLiteral("READ-UNCOMMITTED", types.LongText)) 111 continue 112 case "'isolation level serializable'": 113 varToSet := expression.NewSystemVar("transaction_isolation", scope, string(scope)) 114 res[i] = expression.NewSetField(varToSet, expression.NewLiteral("SERIALIZABLE", types.LongText)) 115 continue 116 case "'read write'": 117 varToSet := expression.NewSystemVar("transaction_read_only", scope, string(scope)) 118 res[i] = expression.NewSetField(varToSet, expression.NewLiteral(false, types.Boolean)) 119 continue 120 case "'read only'": 121 varToSet := expression.NewSystemVar("transaction_read_only", scope, string(scope)) 122 res[i] = expression.NewSetField(varToSet, expression.NewLiteral(true, types.Boolean)) 123 continue 124 } 125 } 126 127 // left => convert to user var or system var expression, validate system var 128 // right => getSetExpr, not adapted for defaults yet, special keywords need to be converted, variables replaced 129 var setScope ast.SetScope 130 131 tblName := strings.ToLower(setExpr.Name.Qualifier.String()) 132 c, ok := inScope.resolveColumn("", tblName, strings.ToLower(setExpr.Name.Name.String()), true, false) 133 var setVar sql.Expression 134 if ok { 135 setVar = c.scalarGf() 136 } else { 137 setVar, setScope, ok = b.buildSysVar(setExpr.Name, setExpr.Scope) 138 if !ok { 139 switch setScope { 140 case ast.SetScope_None: 141 if tblName != "" && !inScope.hasTable(tblName) { 142 b.handleErr(sql.ErrTableNotFound.New(tblName)) 143 } 144 b.handleErr(sql.ErrColumnNotFound.New(setExpr.Name.String())) 145 case ast.SetScope_User: 146 b.handleErr(sql.ErrUnknownUserVariable.New(setExpr.Name.String())) 147 default: 148 b.handleErr(sql.ErrUnknownSystemVariable.New(setExpr.Name.String())) 149 } 150 } 151 } 152 153 sysVarType, _ := setVar.Type().(sql.SystemVariableType) 154 innerExpr, ok := b.simplifySetExpr(setExpr.Name, setScope, setExpr.Expr, sysVarType) 155 if !ok { 156 innerExpr = b.buildScalar(inScope, setExpr.Expr) 157 } 158 159 res[i] = expression.NewSetField(setVar, innerExpr) 160 } 161 return res 162 } 163 164 func (b *Builder) buildSysVar(colName *ast.ColName, scopeHint ast.SetScope) (sql.Expression, ast.SetScope, bool) { 165 // convert to system or user var, validate system var 166 table := colName.Qualifier.String() 167 col := colName.Name.String() 168 var varName string 169 var scope ast.SetScope 170 var err error 171 var specifiedScope string 172 173 if table == "" { 174 varName, scope, specifiedScope, err = ast.VarScope(col) 175 } else { 176 varName, scope, specifiedScope, err = ast.VarScope(table, col) 177 } 178 if err != nil { 179 b.handleErr(err) 180 } 181 182 if scope == "" { 183 scope = scopeHint 184 } 185 186 switch scope { 187 case ast.SetScope_Global: 188 _, _, ok := sql.SystemVariables.GetGlobal(varName) 189 if !ok { 190 return nil, scope, false 191 } 192 return expression.NewSystemVar(varName, sql.SystemVariableScope_Global, specifiedScope), scope, true 193 case ast.SetScope_None, ast.SetScope_Session: 194 switch strings.ToLower(varName) { 195 case "character_set_database", "collation_database": 196 sysVar := expression.NewSystemVar(varName, sql.SystemVariableScope_Session, specifiedScope) 197 sysVar.Collation = sql.Collation_Default 198 if db, err := b.cat.Database(b.ctx, b.ctx.GetCurrentDatabase()); err == nil { 199 sysVar.Collation = plan.GetDatabaseCollation(b.ctx, db) 200 } 201 return sysVar, scope, true 202 default: 203 _, err = b.ctx.GetSessionVariable(b.ctx, varName) 204 if err != nil { 205 return nil, scope, false 206 } 207 return expression.NewSystemVar(varName, sql.SystemVariableScope_Session, specifiedScope), scope, true 208 } 209 case ast.SetScope_User: 210 t, _, err := b.ctx.GetUserVariable(b.ctx, varName) 211 if err != nil { 212 b.handleErr(err) 213 } 214 if t != nil { 215 return expression.NewUserVarWithType(varName, t), scope, true 216 } 217 return expression.NewUserVar(varName), scope, true 218 case ast.SetScope_Persist: 219 return expression.NewSystemVar(varName, sql.SystemVariableScope_Persist, specifiedScope), scope, true 220 case ast.SetScope_PersistOnly: 221 return expression.NewSystemVar(varName, sql.SystemVariableScope_PersistOnly, specifiedScope), scope, true 222 default: // shouldn't happen 223 err := fmt.Errorf("unknown set scope %v", scope) 224 b.handleErr(err) 225 } 226 return nil, scope, false 227 } 228 229 func (b *Builder) simplifySetExpr(name *ast.ColName, varScope ast.SetScope, val ast.Expr, sysVarType sql.Type) (sql.Expression, bool) { 230 // can |val| be nested? 231 switch val := val.(type) { 232 case *ast.SQLVal: 233 if val.Type != ast.StrVal { 234 return nil, false 235 } 236 e := expression.NewLiteral(string(val.Val), types.Text) 237 res, err := e.Eval(b.ctx, nil) 238 if err != nil { 239 b.handleErr(err) 240 } 241 setVal, ok := res.(string) 242 if !ok { 243 return nil, false 244 } 245 246 switch strings.ToLower(setVal) { 247 case ast.KeywordString(ast.ON): 248 return expression.NewLiteral(true, types.Boolean), true 249 case ast.KeywordString(ast.TRUE): 250 return expression.NewLiteral(true, types.Boolean), true 251 case ast.KeywordString(ast.OFF): 252 return expression.NewLiteral(false, types.Boolean), true 253 case ast.KeywordString(ast.FALSE): 254 return expression.NewLiteral(false, types.Boolean), true 255 default: 256 } 257 258 if sysVarType == nil { 259 return nil, false 260 } 261 262 enum, _, err := sysVarType.Convert(setVal) 263 if err != nil { 264 b.handleErr(err) 265 } 266 return expression.NewLiteral(enum, sysVarType), true 267 case *ast.ColName: 268 // convert and eval 269 // todo check whether right side needs variable replacement 270 sysVar, _, ok := b.buildSysVar(val, ast.SetScope_None) 271 if ok { 272 return sysVar, true 273 } 274 e := expression.NewLiteral(val.Name.String(), types.Text) 275 res, err := e.Eval(b.ctx, nil) 276 if err != nil { 277 b.handleErr(err) 278 } 279 setVal, ok := res.(string) 280 if !ok { 281 return nil, false 282 } 283 284 switch strings.ToLower(setVal) { 285 case ast.KeywordString(ast.ON): 286 return expression.NewLiteral(true, types.Boolean), true 287 case ast.KeywordString(ast.TRUE): 288 return expression.NewLiteral(true, types.Boolean), true 289 case ast.KeywordString(ast.OFF): 290 return expression.NewLiteral(false, types.Boolean), true 291 case ast.KeywordString(ast.FALSE): 292 return expression.NewLiteral(false, types.Boolean), true 293 default: 294 } 295 296 if sysVarType == nil { 297 return nil, false 298 } 299 300 enum, _, err := sysVarType.Convert(setVal) 301 if err != nil { 302 b.handleErr(err) 303 } 304 return expression.NewLiteral(enum, sysVarType), true 305 case *ast.BoolVal: 306 // conv 307 e := expression.NewLiteral(val, types.Text) 308 res, err := e.Eval(b.ctx, nil) 309 if err != nil { 310 b.handleErr(err) 311 } 312 setVal, ok := res.(bool) 313 if !ok { 314 err := fmt.Errorf("expected *ast.BoolVal to evaluate to bool type, found: %T", val) 315 b.handleErr(err) 316 } 317 318 if setVal { 319 return expression.NewLiteral(1, types.Boolean), true 320 } else { 321 return expression.NewLiteral(0, types.Boolean), true 322 } 323 case *ast.Default: 324 // set back to default value 325 var err error 326 var varName string 327 table := name.Qualifier.String() 328 col := name.Name.Lowered() 329 if table != "" { 330 varName, _, _, err = ast.VarScope(table, col) 331 } else { 332 varName, _, _, err = ast.VarScope(col) 333 } 334 if err != nil { 335 b.handleErr(err) 336 } 337 338 switch varScope { 339 case ast.SetScope_None, ast.SetScope_Session, ast.SetScope_Global: 340 _, value, ok := sql.SystemVariables.GetGlobal(varName) 341 if ok { 342 return expression.NewLiteral(value, types.ApproximateTypeFromValue(value)), true 343 } 344 err = sql.ErrUnknownSystemVariable.New(varName) 345 case ast.SetScope_Persist, ast.SetScope_PersistOnly: 346 err = fmt.Errorf("%wsetting default for '%s'", sql.ErrUnsupportedFeature.New(), varScope) 347 case ast.SetScope_User: 348 err = sql.ErrUserVariableNoDefault.New(varName) 349 default: // shouldn't happen 350 err = fmt.Errorf("unknown set scope %v", varScope) 351 } 352 b.handleErr(err) 353 } 354 return nil, false 355 }