github.com/matrixorigin/matrixone@v1.2.0/pkg/frontend/plsql_interpreter.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package frontend
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strconv"
    21  	"strings"
    22  
    23  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    24  	"github.com/matrixorigin/matrixone/pkg/defines"
    25  	"github.com/matrixorigin/matrixone/pkg/logutil"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/parsers"
    27  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect"
    28  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/tree"
    29  )
    30  
    31  type SpStatus int
    32  
    33  const (
    34  	SpOk        SpStatus = 0
    35  	SpNotOk     SpStatus = 1
    36  	SpBranchHit SpStatus = 2
    37  	SpLeaveLoop SpStatus = 3
    38  	SpIterLoop  SpStatus = 4
    39  )
    40  
    41  type Interpreter struct {
    42  	ctx         context.Context
    43  	ses         *Session
    44  	bh          BackgroundExec
    45  	varScope    *[]map[string]interface{}
    46  	fmtctx      *tree.FmtCtx
    47  	result      []ExecResult
    48  	argsAttr    map[string]tree.InOutArgType // used for IN, OUT, IN/OUT check
    49  	argsMap     map[string]tree.Expr         // used for argument to parameter mapping
    50  	outParamMap map[string]interface{}       // used for storing and updating OUT type arg
    51  }
    52  
    53  func (interpreter *Interpreter) GetResult() []ExecResult {
    54  	return interpreter.result
    55  }
    56  
    57  func (interpreter *Interpreter) GetExprString(input tree.Expr) string {
    58  	interpreter.fmtctx.Reset()
    59  	input.Format(interpreter.fmtctx)
    60  	return interpreter.fmtctx.String()
    61  }
    62  
    63  func (interpreter *Interpreter) GetStatementString(input tree.Statement) string {
    64  	interpreter.fmtctx.Reset()
    65  	input.Format(interpreter.fmtctx)
    66  	return interpreter.fmtctx.String()
    67  }
    68  
    69  func (interpreter *Interpreter) GetSpVar(varName string) (interface{}, error) {
    70  	for i := len(*interpreter.varScope) - 1; i >= 0; i-- {
    71  		curScope := (*interpreter.varScope)[i]
    72  		val, ok := curScope[strings.ToLower(varName)]
    73  		if ok {
    74  			return val, nil
    75  		}
    76  	}
    77  	return "", nil
    78  }
    79  
    80  // Return error if variable is not declared yet. PARAM is an exception!
    81  func (interpreter *Interpreter) SetSpVar(name string, value interface{}) error {
    82  	for i := len(*interpreter.varScope) - 1; i >= 0; i-- {
    83  		curScope := (*interpreter.varScope)[i]
    84  		if _, ok := curScope[strings.ToLower(name)]; ok {
    85  			curScope[strings.ToLower(name)] = value
    86  			return nil
    87  		}
    88  	}
    89  	// loop up OUT param and SET in-place
    90  	if _, ok := interpreter.outParamMap[name]; ok {
    91  		// save at local
    92  		interpreter.outParamMap[name] = value
    93  		return nil
    94  	}
    95  	return moerr.NewNotSupported(interpreter.ctx, fmt.Sprintf("variable %s has to be declared using DECLARE.", name))
    96  }
    97  
    98  func (interpreter *Interpreter) FlushParam() error {
    99  	for k, v := range (*interpreter.varScope)[0] {
   100  		if _, ok := interpreter.argsMap[k]; ok && interpreter.argsAttr[k] == tree.TYPE_INOUT {
   101  			// save INOUT at session
   102  			interpreter.bh.ClearExecResultSet()
   103  			// system setvar execution
   104  			err := interpreter.ses.SetUserDefinedVar(interpreter.argsMap[k].(*tree.VarExpr).Name, v, "")
   105  			if err != nil {
   106  				return err
   107  			}
   108  		}
   109  	}
   110  
   111  	for k, v := range interpreter.outParamMap {
   112  		// save at session
   113  		interpreter.bh.ClearExecResultSet()
   114  		// system setvar execution
   115  		err := interpreter.ses.SetUserDefinedVar(interpreter.argsMap[k].(*tree.VarExpr).Name, v, "")
   116  		if err != nil {
   117  			return err
   118  		}
   119  	}
   120  
   121  	return nil
   122  }
   123  
   124  func (interpreter *Interpreter) GetSimpleExprValueWithSpVar(e tree.Expr) (interface{}, error) {
   125  	newExpr, err := interpreter.MatchExpr(e)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	retStmt, err := parsers.ParseOne(interpreter.ctx, dialect.MYSQL, "select "+interpreter.GetExprString(newExpr), 1, 0)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	retExpr := retStmt.(*tree.Select).Select.(*tree.SelectClause).Exprs[0].Expr
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	return GetSimpleExprValue(interpreter.ctx, retExpr, interpreter.ses)
   138  }
   139  
   140  // Currently we support only binary, unary and comparison expression.
   141  func (interpreter *Interpreter) MatchExpr(expr tree.Expr) (tree.Expr, error) {
   142  	switch e := expr.(type) {
   143  	case *tree.BinaryExpr:
   144  		leftExpr, err := interpreter.MatchExpr(e.Left)
   145  		if err != nil {
   146  			return nil, err
   147  		}
   148  		rightExpr, err := interpreter.MatchExpr(e.Right)
   149  		if err != nil {
   150  			return nil, err
   151  		}
   152  		return &tree.BinaryExpr{
   153  			Op:    e.Op,
   154  			Left:  leftExpr,
   155  			Right: rightExpr,
   156  		}, nil
   157  	case *tree.UnaryExpr:
   158  	case *tree.ComparisonExpr:
   159  		leftExpr, err := interpreter.MatchExpr(e.Left)
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  		rightExpr, err := interpreter.MatchExpr(e.Right)
   164  		if err != nil {
   165  			return nil, err
   166  		}
   167  		return &tree.ComparisonExpr{
   168  			Op:     e.Op,
   169  			SubOp:  e.SubOp,
   170  			Left:   leftExpr,
   171  			Right:  rightExpr,
   172  			Escape: e.Escape,
   173  		}, nil
   174  	case *tree.AndExpr:
   175  	case *tree.XorExpr:
   176  	case *tree.OrExpr:
   177  	case *tree.NotExpr:
   178  	case *tree.IsNullExpr:
   179  	case *tree.IsNotNullExpr:
   180  	case *tree.IsUnknownExpr:
   181  	case *tree.IsNotUnknownExpr:
   182  	case *tree.IsTrueExpr:
   183  	case *tree.IsNotTrueExpr:
   184  	case *tree.IsFalseExpr:
   185  	case *tree.IsNotFalseExpr:
   186  	case *tree.FuncExpr:
   187  	case *tree.UnresolvedName:
   188  		// change column name to var name
   189  		val, err := interpreter.GetSpVar(e.Parts[0])
   190  		if err != nil {
   191  			return nil, err
   192  		}
   193  		retName := &tree.UnresolvedName{
   194  			NumParts: e.NumParts,
   195  			Star:     e.Star,
   196  			Parts:    e.Parts,
   197  		}
   198  		retName.Parts[0] = fmt.Sprintf("%v", val)
   199  		return retName, nil
   200  	default:
   201  		return e, nil
   202  	}
   203  	return nil, nil
   204  }
   205  
   206  // Evaluate condition by sending it to bh with a select
   207  func (interpreter *Interpreter) EvalCond(cond string) (int, error) {
   208  	interpreter.bh.ClearExecResultSet()
   209  	interpreter.ctx = context.WithValue(interpreter.ctx, defines.VarScopeKey{}, interpreter.varScope)
   210  	interpreter.ctx = context.WithValue(interpreter.ctx, defines.InSp{}, true)
   211  	err := interpreter.bh.Exec(interpreter.ctx, "select "+cond)
   212  	if err != nil {
   213  		return 0, err
   214  	}
   215  	erArray, err := getResultSet(interpreter.ctx, interpreter.bh)
   216  	if err != nil {
   217  		return 0, err
   218  	}
   219  
   220  	if execResultArrayHasData(erArray) {
   221  		cond, err := erArray[0].GetInt64(interpreter.ctx, 0, 0)
   222  		if err != nil {
   223  			return 0, err
   224  		}
   225  		return int(cond), nil
   226  	}
   227  	return 0, nil
   228  }
   229  
   230  func (interpreter *Interpreter) ExecuteSp(stmt tree.Statement, dbName string) (err error) {
   231  	curScope := make(map[string]interface{})
   232  	interpreter.bh.ClearExecResultSet()
   233  
   234  	// use current database as default
   235  	err = interpreter.bh.Exec(interpreter.ctx, "use "+dbName)
   236  	if err != nil {
   237  		return err
   238  	}
   239  
   240  	// make sure the entire sp is in a single transaction
   241  	err = interpreter.bh.Exec(interpreter.ctx, "begin;")
   242  	defer func() {
   243  		err = finishTxn(interpreter.ctx, interpreter.bh, err)
   244  	}()
   245  	if err != nil {
   246  		return err
   247  	}
   248  
   249  	// save parameters as local variables
   250  	*interpreter.varScope = append(*interpreter.varScope, curScope)
   251  	for k, v := range interpreter.argsMap {
   252  		var value interface{}
   253  		if varParam, ok := v.(*tree.VarExpr); ok {
   254  			// For OUT type, store it in a separate map only for SET to update it and flush at the end
   255  			if interpreter.argsAttr[k] == tree.TYPE_OUT {
   256  				interpreter.outParamMap[k] = 0
   257  			} else { // For INOUT and IN type, fetch store its previous value
   258  				interpreter.bh.ClearExecResultSet()
   259  				_, value, _ := interpreter.ses.GetUserDefinedVar(varParam.Name)
   260  				if value == nil {
   261  					// raise an error as INOUT / IN type param has to have a value
   262  					return moerr.NewNotSupported(interpreter.ctx, fmt.Sprintf("parameter %s with type INOUT or IN has to have a specified value.", k))
   263  				}
   264  				// save param to local var scope
   265  				(*interpreter.varScope)[len(*interpreter.varScope)-1][strings.ToLower(k)] = value.Value
   266  			}
   267  		} else {
   268  			// if param type is INOUT or OUT and the param is not provided with variable expr, raise an error
   269  			if interpreter.argsAttr[k] == tree.TYPE_INOUT || interpreter.argsAttr[k] == tree.TYPE_OUT {
   270  				return moerr.NewNotSupported(interpreter.ctx, fmt.Sprintf("parameter %s with type INOUT or OUT has to be passed in using @.", k))
   271  			}
   272  			// evaluate the param
   273  			value, err = interpreter.GetSimpleExprValueWithSpVar(v)
   274  			if err != nil {
   275  				return err
   276  			}
   277  			// save param to local var scope
   278  			(*interpreter.varScope)[len(*interpreter.varScope)-1][strings.ToLower(k)] = value
   279  		}
   280  	}
   281  
   282  	_, err = interpreter.interpret(stmt)
   283  
   284  	if err != nil {
   285  		return err
   286  	}
   287  
   288  	// // commit the param flush part of sp
   289  	// err = interpreter.bh.Exec(interpreter.ctx, "begin;")
   290  	// if err != nil {
   291  	// 	return err
   292  	// }
   293  
   294  	err = interpreter.FlushParam()
   295  	if err != nil {
   296  		return err
   297  	}
   298  
   299  	// err = interpreter.bh.Exec(interpreter.ctx, "commit;")
   300  	// if err != nil {
   301  	// 	return err
   302  	// }
   303  
   304  	return nil
   305  }
   306  
   307  func (interpreter *Interpreter) interpret(stmt tree.Statement) (SpStatus, error) {
   308  	if stmt == nil {
   309  		return SpOk, nil
   310  	}
   311  	switch st := stmt.(type) {
   312  	case *tree.CompoundStmt:
   313  		// create new variable scope and push it
   314  		curScope := make(map[string]interface{})
   315  		*interpreter.varScope = append(*interpreter.varScope, curScope)
   316  		logutil.Info("current scope level: " + strconv.Itoa(len(*interpreter.varScope)))
   317  		// recursively execute
   318  		for _, innerSt := range st.Stmts {
   319  			_, err := interpreter.interpret(innerSt)
   320  			if err != nil {
   321  				return SpNotOk, err
   322  			}
   323  		}
   324  		// pop current scope
   325  		*interpreter.varScope = (*interpreter.varScope)[:len(*interpreter.varScope)-1]
   326  		return SpOk, nil
   327  	case *tree.RepeatStmt:
   328  		for {
   329  			// first execute body
   330  			for _, stmt := range st.Body {
   331  				_, err := interpreter.interpret(stmt)
   332  				if err != nil {
   333  					return SpNotOk, err
   334  				}
   335  			}
   336  			// then evaluate condition
   337  			condStr := interpreter.GetExprString(st.Cond)
   338  			condVal, err := interpreter.EvalCond(condStr)
   339  			if err != nil {
   340  				return SpNotOk, err
   341  			}
   342  			if condVal == 1 {
   343  				break
   344  			}
   345  		}
   346  	case *tree.WhileStmt:
   347  		for {
   348  			// first evaluate
   349  			condStr := interpreter.GetExprString(st.Cond)
   350  			condVal, err := interpreter.EvalCond(condStr)
   351  			if err != nil {
   352  				return SpNotOk, err
   353  			}
   354  			if condVal == 0 {
   355  				break
   356  			}
   357  			// then execute body
   358  			for _, stmt := range st.Body {
   359  				_, err := interpreter.interpret(stmt)
   360  				if err != nil {
   361  					return SpNotOk, err
   362  				}
   363  			}
   364  		}
   365  	case *tree.LoopStmt:
   366  	start:
   367  		for {
   368  			for _, stmt := range st.Body {
   369  				status, err := interpreter.interpret(stmt)
   370  				if err != nil {
   371  					return SpNotOk, err
   372  				}
   373  				if status == SpLeaveLoop {
   374  					// check label here using stmt
   375  					goto exit
   376  				}
   377  				if status == SpIterLoop {
   378  					// check label here using stmt
   379  					goto start
   380  				}
   381  			}
   382  		}
   383  	exit:
   384  		return SpOk, nil
   385  	case *tree.IterateStmt:
   386  		return SpIterLoop, nil
   387  	case *tree.LeaveStmt:
   388  		return SpLeaveLoop, nil
   389  	case *tree.ElseIfStmt:
   390  		// evaluate condition
   391  		condStr := interpreter.GetExprString(st.Cond)
   392  		condVal, err := interpreter.EvalCond(condStr)
   393  		if err != nil {
   394  			return SpNotOk, err
   395  		}
   396  		if condVal == 1 {
   397  			// execute current else-if branch, remember to terminate other else-if
   398  			for _, bodyStmt := range st.Body {
   399  				status, err := interpreter.interpret(bodyStmt)
   400  				if err != nil {
   401  					return SpNotOk, err
   402  				}
   403  				if status == SpBranchHit || status == SpIterLoop || status == SpLeaveLoop {
   404  					return status, nil
   405  				}
   406  			}
   407  			return SpBranchHit, nil
   408  		} else {
   409  			return SpOk, nil
   410  		}
   411  	case *tree.IfStmt:
   412  		// evaluate condition
   413  		condStr := interpreter.GetExprString(st.Cond)
   414  		condVal, err := interpreter.EvalCond(condStr)
   415  		if err != nil {
   416  			return SpNotOk, err
   417  		}
   418  		if condVal == 1 {
   419  			// execute current branch
   420  			for _, bodyStmt := range st.Body {
   421  				status, err := interpreter.interpret(bodyStmt)
   422  				if err != nil {
   423  					return SpNotOk, err
   424  				}
   425  				if status == SpBranchHit || status == SpIterLoop || status == SpLeaveLoop {
   426  					return status, nil
   427  				}
   428  			}
   429  		} else {
   430  			if len(st.Elifs) != 0 {
   431  				// bunch of elif branch
   432  				for _, elifStmt := range st.Elifs {
   433  					status, err := interpreter.interpret(elifStmt)
   434  					if err != nil {
   435  						return SpNotOk, err
   436  					}
   437  					if status == SpBranchHit {
   438  						// this means this else-if branch gets executed, no need to execute the rest elseif and else.
   439  						goto end
   440  					}
   441  					if status == SpIterLoop || status == SpLeaveLoop {
   442  						return status, nil
   443  					}
   444  				}
   445  			}
   446  			// else branch
   447  			for _, elseStmt := range st.Else {
   448  				status, err := interpreter.interpret(elseStmt)
   449  				if err != nil {
   450  					return SpNotOk, err
   451  				}
   452  				if status == SpBranchHit || status == SpIterLoop || status == SpLeaveLoop {
   453  					return status, nil
   454  				}
   455  			}
   456  		end:
   457  			break
   458  		}
   459  	case *tree.WhenStmt:
   460  		// any whenstmt that comes here will get executed, as we've already evaluated the condition in casestmt
   461  		for _, stmt := range st.Body {
   462  			// we use this branch
   463  			_, err := interpreter.interpret(stmt)
   464  			if err != nil {
   465  				return SpNotOk, err
   466  			}
   467  		}
   468  	case *tree.CaseStmt:
   469  		// match case expression with all of its whens
   470  		for _, whenStmt := range st.Whens {
   471  			// build equality checker
   472  			equalityExpr := &tree.ComparisonExpr{
   473  				Op:    tree.EQUAL,
   474  				Left:  st.Expr,
   475  				Right: whenStmt.Cond,
   476  			}
   477  			condVal, err := interpreter.EvalCond(interpreter.GetExprString(equalityExpr))
   478  			if err != nil {
   479  				return SpNotOk, nil
   480  			}
   481  			if condVal == 1 {
   482  				// we use this branch
   483  				_, err := interpreter.interpret(whenStmt)
   484  				if err != nil {
   485  					return SpNotOk, err
   486  				}
   487  				return SpOk, nil
   488  			}
   489  		}
   490  
   491  		// none of the WHEN branch hit, we execute ELSE
   492  		for _, stmt := range st.Else {
   493  			_, err := interpreter.interpret(stmt)
   494  			if err != nil {
   495  				return SpNotOk, err
   496  			}
   497  		}
   498  		return SpOk, nil
   499  	case *tree.Declare:
   500  		var err error
   501  		var value interface{}
   502  		// store variables into current scope
   503  		if st.DefaultVal != nil {
   504  			value, err = GetSimpleExprValue(interpreter.ctx, st.DefaultVal, interpreter.ses)
   505  			if err != nil {
   506  				return SpNotOk, nil
   507  			}
   508  		}
   509  		for _, v := range st.Variables {
   510  			(*interpreter.varScope)[len(*interpreter.varScope)-1][v] = value
   511  		}
   512  		return SpOk, nil
   513  	case *tree.SetVar:
   514  		for _, assign := range st.Assignments {
   515  			name := assign.Name
   516  
   517  			// if this is a system set, ignore if it's not a INOUT/OUT arg
   518  			if strings.Contains(interpreter.GetExprString(st), "@") {
   519  				str := interpreter.GetExprString(st)
   520  				interpreter.bh.ClearExecResultSet()
   521  				// system setvar execution
   522  				err := interpreter.bh.Exec(interpreter.ctx, str)
   523  				if err != nil {
   524  					return SpNotOk, err
   525  				}
   526  			} else {
   527  				// custom defined variable
   528  				var value interface{}
   529  				// get updated value
   530  				value, err := interpreter.GetSimpleExprValueWithSpVar(assign.Value)
   531  				if err != nil {
   532  					return SpNotOk, err
   533  				}
   534  
   535  				// update local value
   536  				err = interpreter.SetSpVar(name, value)
   537  				if err != nil {
   538  					return SpNotOk, err
   539  				}
   540  			}
   541  		}
   542  	default: // normal sql. Since we don't support SELECT INTO for now, we don't have to worry about updating variables
   543  		str := interpreter.GetStatementString(st)
   544  		interpreter.bh.ClearExecResultSet()
   545  		// For sp variable replacement
   546  		interpreter.ctx = context.WithValue(interpreter.ctx, defines.VarScopeKey{}, interpreter.varScope)
   547  		interpreter.ctx = context.WithValue(interpreter.ctx, defines.InSp{}, true)
   548  		err := interpreter.bh.Exec(interpreter.ctx, str)
   549  		if err != nil {
   550  			return SpNotOk, err
   551  		}
   552  		erArray, err := getResultSet(interpreter.ctx, interpreter.bh)
   553  		if err != nil {
   554  			return SpNotOk, err
   555  		}
   556  		if execResultArrayHasData(erArray) {
   557  			interpreter.result = append(interpreter.result, erArray[0])
   558  		}
   559  		return SpOk, nil
   560  	}
   561  	return SpOk, nil
   562  }