github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/executor/executor_simple.go (about)

     1  // Copyright 2016 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package executor
    15  
    16  import (
    17  	"fmt"
    18  	"strings"
    19  
    20  	"github.com/insionng/yougam/libraries/juju/errors"
    21  	"github.com/insionng/yougam/libraries/pingcap/tidb/ast"
    22  	"github.com/insionng/yougam/libraries/pingcap/tidb/context"
    23  	"github.com/insionng/yougam/libraries/pingcap/tidb/evaluator"
    24  	"github.com/insionng/yougam/libraries/pingcap/tidb/infoschema"
    25  	"github.com/insionng/yougam/libraries/pingcap/tidb/model"
    26  	"github.com/insionng/yougam/libraries/pingcap/tidb/mysql"
    27  	"github.com/insionng/yougam/libraries/pingcap/tidb/sessionctx"
    28  	"github.com/insionng/yougam/libraries/pingcap/tidb/sessionctx/db"
    29  	"github.com/insionng/yougam/libraries/pingcap/tidb/sessionctx/variable"
    30  	"github.com/insionng/yougam/libraries/pingcap/tidb/util"
    31  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/charset"
    32  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/sqlexec"
    33  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/types"
    34  )
    35  
    36  // SimpleExec represents simple statement executor.
    37  // For statements do simple execution.
    38  // includes `UseStmt`, 'SetStmt`, `SetCharsetStmt`.
    39  // `DoStmt`, `BeginStmt`, `CommitStmt`, `RollbackStmt`.
    40  // TODO: list all simple statements.
    41  type SimpleExec struct {
    42  	Statement ast.StmtNode
    43  	ctx       context.Context
    44  	done      bool
    45  }
    46  
    47  // Fields implements Executor Fields interface.
    48  func (e *SimpleExec) Fields() []*ast.ResultField {
    49  	return nil
    50  }
    51  
    52  // Next implements Execution Next interface.
    53  func (e *SimpleExec) Next() (*Row, error) {
    54  	if e.done {
    55  		return nil, nil
    56  	}
    57  	var err error
    58  	switch x := e.Statement.(type) {
    59  	case *ast.UseStmt:
    60  		err = e.executeUse(x)
    61  	case *ast.SetStmt:
    62  		err = e.executeSet(x)
    63  	case *ast.SetCharsetStmt:
    64  		err = e.executeSetCharset(x)
    65  	case *ast.DoStmt:
    66  		err = e.executeDo(x)
    67  	case *ast.BeginStmt:
    68  		err = e.executeBegin(x)
    69  	case *ast.CommitStmt:
    70  		err = e.executeCommit(x)
    71  	case *ast.RollbackStmt:
    72  		err = e.executeRollback(x)
    73  	case *ast.CreateUserStmt:
    74  		err = e.executeCreateUser(x)
    75  	case *ast.SetPwdStmt:
    76  		err = e.executeSetPwd(x)
    77  	}
    78  	if err != nil {
    79  		return nil, errors.Trace(err)
    80  	}
    81  	e.done = true
    82  	return nil, nil
    83  }
    84  
    85  // Close implements Executor Close interface.
    86  func (e *SimpleExec) Close() error {
    87  	return nil
    88  }
    89  
    90  func (e *SimpleExec) executeUse(s *ast.UseStmt) error {
    91  	dbname := model.NewCIStr(s.DBName)
    92  	dbinfo, exists := sessionctx.GetDomain(e.ctx).InfoSchema().SchemaByName(dbname)
    93  	if !exists {
    94  		return infoschema.ErrDatabaseNotExists.Gen("database %s not exists", dbname)
    95  	}
    96  	db.BindCurrentSchema(e.ctx, dbname.O)
    97  	// character_set_database is the character set used by the default database.
    98  	// The server sets this variable whenever the default database changes.
    99  	// See: http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_character_set_database
   100  	sessionVars := variable.GetSessionVars(e.ctx)
   101  	sessionVars.Systems[variable.CharsetDatabase] = dbinfo.Charset
   102  	sessionVars.Systems[variable.CollationDatabase] = dbinfo.Collate
   103  	return nil
   104  }
   105  
   106  func (e *SimpleExec) executeSet(s *ast.SetStmt) error {
   107  	sessionVars := variable.GetSessionVars(e.ctx)
   108  	globalVars := variable.GetGlobalVarAccessor(e.ctx)
   109  	for _, v := range s.Variables {
   110  		// Variable is case insensitive, we use lower case.
   111  		name := strings.ToLower(v.Name)
   112  		if !v.IsSystem {
   113  			// Set user variable.
   114  			value, err := evaluator.Eval(e.ctx, v.Value)
   115  			if err != nil {
   116  				return errors.Trace(err)
   117  			}
   118  
   119  			if value.Kind() == types.KindNull {
   120  				delete(sessionVars.Users, name)
   121  			} else {
   122  				svalue, err1 := value.ToString()
   123  				if err1 != nil {
   124  					return errors.Trace(err1)
   125  				}
   126  				sessionVars.Users[name] = fmt.Sprintf("%v", svalue)
   127  			}
   128  			continue
   129  		}
   130  
   131  		// Set system variable
   132  		sysVar := variable.GetSysVar(name)
   133  		if sysVar == nil {
   134  			return variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name)
   135  		}
   136  		if sysVar.Scope == variable.ScopeNone {
   137  			return errors.Errorf("Variable '%s' is a read only variable", name)
   138  		}
   139  		if v.IsGlobal {
   140  			// Set global scope system variable.
   141  			if sysVar.Scope&variable.ScopeGlobal == 0 {
   142  				return errors.Errorf("Variable '%s' is a SESSION variable and can't be used with SET GLOBAL", name)
   143  			}
   144  			value, err := evaluator.Eval(e.ctx, v.Value)
   145  			if err != nil {
   146  				return errors.Trace(err)
   147  			}
   148  			if value.Kind() == types.KindNull {
   149  				value.SetString("")
   150  			}
   151  			svalue, err := value.ToString()
   152  			if err != nil {
   153  				return errors.Trace(err)
   154  			}
   155  			err = globalVars.SetGlobalSysVar(e.ctx, name, svalue)
   156  			if err != nil {
   157  				return errors.Trace(err)
   158  			}
   159  		} else {
   160  			// Set session scope system variable.
   161  			if sysVar.Scope&variable.ScopeSession == 0 {
   162  				return errors.Errorf("Variable '%s' is a GLOBAL variable and should be set with SET GLOBAL", name)
   163  			}
   164  			if value, err := evaluator.Eval(e.ctx, v.Value); err != nil {
   165  				return errors.Trace(err)
   166  			} else if value.Kind() == types.KindNull {
   167  				sessionVars.Systems[name] = ""
   168  			} else {
   169  				svalue, err := value.ToString()
   170  				if err != nil {
   171  					return errors.Trace(err)
   172  				}
   173  				sessionVars.Systems[name] = fmt.Sprintf("%v", svalue)
   174  			}
   175  		}
   176  	}
   177  	return nil
   178  }
   179  
   180  func (e *SimpleExec) executeSetCharset(s *ast.SetCharsetStmt) error {
   181  	collation := s.Collate
   182  	if len(collation) == 0 {
   183  		var err error
   184  		collation, err = charset.GetDefaultCollation(s.Charset)
   185  		if err != nil {
   186  			return errors.Trace(err)
   187  		}
   188  	}
   189  	sessionVars := variable.GetSessionVars(e.ctx)
   190  	for _, v := range variable.SetNamesVariables {
   191  		sessionVars.Systems[v] = s.Charset
   192  	}
   193  	sessionVars.Systems[variable.CollationConnection] = collation
   194  	return nil
   195  }
   196  
   197  func (e *SimpleExec) executeDo(s *ast.DoStmt) error {
   198  	for _, expr := range s.Exprs {
   199  		_, err := evaluator.Eval(e.ctx, expr)
   200  		if err != nil {
   201  			return errors.Trace(err)
   202  		}
   203  	}
   204  	return nil
   205  }
   206  
   207  func (e *SimpleExec) executeBegin(s *ast.BeginStmt) error {
   208  	_, err := e.ctx.GetTxn(true)
   209  	if err != nil {
   210  		return errors.Trace(err)
   211  	}
   212  	// With START TRANSACTION, autocommit remains disabled until you end
   213  	// the transaction with COMMIT or ROLLBACK. The autocommit mode then
   214  	// reverts to its previous state.
   215  	variable.GetSessionVars(e.ctx).SetStatusFlag(mysql.ServerStatusInTrans, true)
   216  	return nil
   217  }
   218  
   219  func (e *SimpleExec) executeCommit(s *ast.CommitStmt) error {
   220  	err := e.ctx.FinishTxn(false)
   221  	variable.GetSessionVars(e.ctx).SetStatusFlag(mysql.ServerStatusInTrans, false)
   222  	return errors.Trace(err)
   223  }
   224  
   225  func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error {
   226  	err := e.ctx.FinishTxn(true)
   227  	variable.GetSessionVars(e.ctx).SetStatusFlag(mysql.ServerStatusInTrans, false)
   228  	return errors.Trace(err)
   229  }
   230  
   231  func (e *SimpleExec) executeCreateUser(s *ast.CreateUserStmt) error {
   232  	users := make([]string, 0, len(s.Specs))
   233  	for _, spec := range s.Specs {
   234  		userName, host := parseUser(spec.User)
   235  		exists, err1 := userExists(e.ctx, userName, host)
   236  		if err1 != nil {
   237  			return errors.Trace(err1)
   238  		}
   239  		if exists {
   240  			if !s.IfNotExists {
   241  				return errors.New("Duplicate user")
   242  			}
   243  			continue
   244  		}
   245  		pwd := ""
   246  		if spec.AuthOpt.ByAuthString {
   247  			pwd = util.EncodePassword(spec.AuthOpt.AuthString)
   248  		} else {
   249  			pwd = util.EncodePassword(spec.AuthOpt.HashString)
   250  		}
   251  		user := fmt.Sprintf(`("%s", "%s", "%s")`, host, userName, pwd)
   252  		users = append(users, user)
   253  	}
   254  	if len(users) == 0 {
   255  		return nil
   256  	}
   257  	sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", "))
   258  	_, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
   259  	if err != nil {
   260  		return errors.Trace(err)
   261  	}
   262  	return nil
   263  }
   264  
   265  // parse user string into username and host
   266  // root@localhost -> roor, localhost
   267  func parseUser(user string) (string, string) {
   268  	strs := strings.Split(user, "@")
   269  	return strs[0], strs[1]
   270  }
   271  
   272  func userExists(ctx context.Context, name string, host string) (bool, error) {
   273  	sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User="%s" AND Host="%s";`, mysql.SystemDB, mysql.UserTable, name, host)
   274  	rs, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql)
   275  	if err != nil {
   276  		return false, errors.Trace(err)
   277  	}
   278  	defer rs.Close()
   279  	row, err := rs.Next()
   280  	if err != nil {
   281  		return false, errors.Trace(err)
   282  	}
   283  	return row != nil, nil
   284  }
   285  
   286  func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error {
   287  	// TODO: If len(s.User) == 0, use CURRENT_USER()
   288  	userName, host := parseUser(s.User)
   289  	// Update mysql.user
   290  	sql := fmt.Sprintf(`UPDATE %s.%s SET password="%s" WHERE User="%s" AND Host="%s";`, mysql.SystemDB, mysql.UserTable, util.EncodePassword(s.Password), userName, host)
   291  	_, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql)
   292  	return errors.Trace(err)
   293  }