github.com/XiaoMi/Gaea@v1.2.5/proxy/server/executor_handle.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package server
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/binary"
    20  	"fmt"
    21  	"runtime"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/XiaoMi/Gaea/backend"
    26  	"github.com/XiaoMi/Gaea/core/errors"
    27  	"github.com/XiaoMi/Gaea/log"
    28  	"github.com/XiaoMi/Gaea/mysql"
    29  	"github.com/XiaoMi/Gaea/parser"
    30  	"github.com/XiaoMi/Gaea/parser/ast"
    31  	"github.com/XiaoMi/Gaea/proxy/plan"
    32  	"github.com/XiaoMi/Gaea/util"
    33  )
    34  
    35  // Parse parse sql
    36  func (se *SessionExecutor) Parse(sql string) (ast.StmtNode, error) {
    37  	return se.parser.ParseOneStmt(sql, "", "")
    38  }
    39  
    40  // 处理query语句
    41  func (se *SessionExecutor) handleQuery(sql string) (r *mysql.Result, err error) {
    42  	defer func() {
    43  		if e := recover(); e != nil {
    44  			log.Warn("handle query command failed, error: %v, sql: %s", e, sql)
    45  
    46  			if err, ok := e.(error); ok {
    47  				const size = 4096
    48  				buf := make([]byte, size)
    49  				buf = buf[:runtime.Stack(buf, false)]
    50  
    51  				log.Warn("handle query command catch panic error, sql: %s, error: %s, stack: %s",
    52  					sql, err.Error(), string(buf))
    53  			}
    54  
    55  			err = errors.ErrInternalServer
    56  			return
    57  		}
    58  	}()
    59  
    60  	sql = strings.TrimRight(sql, ";") //删除sql语句最后的分号
    61  
    62  	reqCtx := util.NewRequestContext()
    63  	// check black sql
    64  	ns := se.GetNamespace()
    65  	if !ns.IsSQLAllowed(reqCtx, sql) {
    66  		fingerprint := mysql.GetFingerprint(sql)
    67  		log.Warn("catch black sql, sql: %s", sql)
    68  		se.manager.GetStatisticManager().RecordSQLForbidden(fingerprint, se.GetNamespace().GetName())
    69  		err := mysql.NewError(mysql.ErrUnknown, "sql in blacklist")
    70  		return nil, err
    71  	}
    72  
    73  	startTime := time.Now()
    74  	stmtType := parser.Preview(sql)
    75  	reqCtx.Set(util.StmtType, stmtType)
    76  
    77  	r, err = se.doQuery(reqCtx, sql)
    78  	se.manager.RecordSessionSQLMetrics(reqCtx, se, sql, startTime, err)
    79  	return r, err
    80  }
    81  
    82  func (se *SessionExecutor) doQuery(reqCtx *util.RequestContext, sql string) (*mysql.Result, error) {
    83  	stmtType := reqCtx.Get("stmtType").(int)
    84  
    85  	if isSQLNotAllowedByUser(se, stmtType) {
    86  		return nil, fmt.Errorf("write DML is now allowed by read user")
    87  	}
    88  
    89  	if canHandleWithoutPlan(stmtType) {
    90  		return se.handleQueryWithoutPlan(reqCtx, sql)
    91  	}
    92  
    93  	db := se.db
    94  
    95  	p, err := se.getPlan(se.GetNamespace(), db, sql)
    96  	if err != nil {
    97  		return nil, fmt.Errorf("get plan error, db: %s, sql: %s, err: %v", db, sql, err)
    98  	}
    99  
   100  	if canExecuteFromSlave(se, sql) {
   101  		reqCtx.Set(util.FromSlave, 1)
   102  	}
   103  
   104  	reqCtx.Set(util.DefaultSlice, se.GetNamespace().GetDefaultSlice())
   105  	r, err := p.ExecuteIn(reqCtx, se)
   106  	if err != nil {
   107  		log.Warn("execute select: %s", err.Error())
   108  		return nil, err
   109  	}
   110  
   111  	modifyResultStatus(r, se)
   112  
   113  	return r, nil
   114  }
   115  
   116  // 处理逻辑较简单的SQL, 不走执行计划部分
   117  func (se *SessionExecutor) handleQueryWithoutPlan(reqCtx *util.RequestContext, sql string) (*mysql.Result, error) {
   118  	n, err := se.Parse(sql)
   119  	if err != nil {
   120  		return nil, fmt.Errorf("parse sql error, sql: %s, err: %v", sql, err)
   121  	}
   122  
   123  	switch stmt := n.(type) {
   124  	case *ast.ShowStmt:
   125  		return se.handleShow(reqCtx, sql, stmt)
   126  	case *ast.SetStmt:
   127  		return se.handleSet(reqCtx, sql, stmt)
   128  	case *ast.BeginStmt:
   129  		return nil, se.handleBegin()
   130  	case *ast.CommitStmt:
   131  		return nil, se.handleCommit()
   132  	case *ast.RollbackStmt:
   133  		return nil, se.handleRollback(stmt)
   134  	case *ast.SavepointStmt:
   135  		return nil, se.handleSavepoint(stmt)
   136  	case *ast.UseStmt:
   137  		return nil, se.handleUseDB(stmt.DBName)
   138  	default:
   139  		return nil, fmt.Errorf("cannot handle sql without plan, ns: %s, sql: %s", se.namespace, sql)
   140  	}
   141  }
   142  
   143  func (se *SessionExecutor) handleUseDB(dbName string) error {
   144  	if len(dbName) == 0 {
   145  		return fmt.Errorf("must have database, the length of dbName is zero")
   146  	}
   147  
   148  	if se.GetNamespace().IsAllowedDB(dbName) {
   149  		se.db = dbName
   150  		return nil
   151  	}
   152  
   153  	return mysql.NewDefaultError(mysql.ErrNoDB)
   154  }
   155  
   156  func (se *SessionExecutor) getPlan(ns *Namespace, db string, sql string) (plan.Plan, error) {
   157  	n, err := se.Parse(sql)
   158  	if err != nil {
   159  		return nil, fmt.Errorf("parse sql error, sql: %s, err: %v", sql, err)
   160  	}
   161  
   162  	rt := ns.GetRouter()
   163  	seq := ns.GetSequences()
   164  	phyDBs := ns.GetPhysicalDBs()
   165  	p, err := plan.BuildPlan(n, phyDBs, db, sql, rt, seq)
   166  	if err != nil {
   167  		return nil, fmt.Errorf("create select plan error: %v", err)
   168  	}
   169  
   170  	return p, nil
   171  }
   172  
   173  func (se *SessionExecutor) handleShow(reqCtx *util.RequestContext, sql string, stmt *ast.ShowStmt) (*mysql.Result, error) {
   174  	switch stmt.Tp {
   175  	case ast.ShowDatabases:
   176  		dbs := se.GetNamespace().GetAllowedDBs()
   177  		return createShowDatabaseResult(dbs), nil
   178  	case ast.ShowVariables:
   179  		if strings.Contains(sql, gaeaGeneralLogVariable) {
   180  			return createShowGeneralLogResult(), nil
   181  		}
   182  		fallthrough
   183  	default:
   184  		r, err := se.ExecuteSQL(reqCtx, se.GetNamespace().GetDefaultSlice(), se.db, sql)
   185  		if err != nil {
   186  			return nil, fmt.Errorf("execute sql error, sql: %s, err: %v", sql, err)
   187  		}
   188  		modifyResultStatus(r, se)
   189  		return r, nil
   190  	}
   191  }
   192  
   193  func (se *SessionExecutor) handleSet(reqCtx *util.RequestContext, sql string, stmt *ast.SetStmt) (*mysql.Result, error) {
   194  	for _, v := range stmt.Variables {
   195  		if err := se.handleSetVariable(v); err != nil {
   196  			return nil, err
   197  		}
   198  	}
   199  
   200  	return nil, nil
   201  }
   202  
   203  func (se *SessionExecutor) handleSetVariable(v *ast.VariableAssignment) error {
   204  	if v.IsGlobal {
   205  		return fmt.Errorf("does not support set variable in global scope")
   206  	}
   207  	name := strings.ToLower(v.Name)
   208  	switch name {
   209  	case "character_set_results", "character_set_client", "character_set_connection":
   210  		charset := getVariableExprResult(v.Value)
   211  		if charset == "null" { // character_set_results允许设置成null, character_set_client和character_set_connection不允许
   212  			return nil
   213  		}
   214  		if charset == mysql.KeywordDefault {
   215  			se.charset = se.GetNamespace().GetDefaultCharset()
   216  			se.collation = se.GetNamespace().GetDefaultCollationID()
   217  			return nil
   218  		}
   219  		cid, ok := mysql.CharsetIds[charset]
   220  		if !ok {
   221  			return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset)
   222  		}
   223  		se.charset = charset
   224  		se.collation = cid
   225  		return nil
   226  	case "autocommit":
   227  		value := getVariableExprResult(v.Value)
   228  		if value == mysql.KeywordDefault || value == "on" || value == "1" {
   229  			return se.handleSetAutoCommit(true) // default set autocommit = 1
   230  		} else if value == "off" || value == "0" {
   231  			return se.handleSetAutoCommit(false)
   232  		} else {
   233  			return mysql.NewDefaultError(mysql.ErrWrongValueForVar, name, value)
   234  		}
   235  	case "setnames": // SetNAMES represents SET NAMES 'xxx' COLLATE 'xxx'
   236  		charset := getVariableExprResult(v.Value)
   237  		if charset == mysql.KeywordDefault {
   238  			charset = se.GetNamespace().GetDefaultCharset()
   239  		}
   240  
   241  		var collationID mysql.CollationID
   242  		// if SET NAMES 'xxx' COLLATE DEFAULT, the parser treats it like SET NAMES 'xxx', and the ExtendValue is nil
   243  		if v.ExtendValue != nil {
   244  			collationName := getVariableExprResult(v.ExtendValue)
   245  			cid, ok := mysql.CollationNames[collationName]
   246  			if !ok {
   247  				return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset)
   248  			}
   249  			toCharset, ok := mysql.CollationNameToCharset[collationName]
   250  			if !ok {
   251  				return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset)
   252  			}
   253  			if toCharset != charset { // collation与charset不匹配
   254  				return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset)
   255  			}
   256  			collationID = cid
   257  		} else {
   258  			// if only set charset but not set collation, the collation is set to charset default collation implicitly.
   259  			cid, ok := mysql.CharsetIds[charset]
   260  			if !ok {
   261  				return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset)
   262  			}
   263  			collationID = cid
   264  		}
   265  
   266  		se.charset = charset
   267  		se.collation = collationID
   268  		return nil
   269  	case "sql_mode":
   270  		sqlMode := getVariableExprResult(v.Value)
   271  		return se.setStringSessionVariable(mysql.SQLModeStr, sqlMode)
   272  	case "sql_safe_updates":
   273  		value := getVariableExprResult(v.Value)
   274  		onOffValue, err := getOnOffVariable(value)
   275  		if err != nil {
   276  			return mysql.NewDefaultError(mysql.ErrWrongValueForVar, name, value)
   277  		}
   278  		return se.setIntSessionVariable(mysql.SQLSafeUpdates, onOffValue)
   279  	case "time_zone":
   280  		value := getVariableExprResult(v.Value)
   281  		return se.setStringSessionVariable(mysql.TimeZone, value)
   282  	case "max_allowed_packet":
   283  		return mysql.NewDefaultError(mysql.ErrVariableIsReadonly, "SESSION", mysql.MaxAllowedPacket, "GLOBAL")
   284  
   285  		// do nothing
   286  	case "wait_timeout", "interactive_timeout", "net_write_timeout", "net_read_timeout":
   287  		return nil
   288  	case "sql_select_limit":
   289  		return nil
   290  		// unsupported
   291  	case "transaction":
   292  		return fmt.Errorf("does not support set transaction in gaea")
   293  	case gaeaGeneralLogVariable:
   294  		value := getVariableExprResult(v.Value)
   295  		onOffValue, err := getOnOffVariable(value)
   296  		if err != nil {
   297  			return mysql.NewDefaultError(mysql.ErrWrongValueForVar, name, value)
   298  		}
   299  		return se.setGeneralLogVariable(onOffValue)
   300  	default:
   301  		return nil
   302  	}
   303  }
   304  
   305  func (se *SessionExecutor) handleSetAutoCommit(autocommit bool) (err error) {
   306  	se.txLock.Lock()
   307  	defer se.txLock.Unlock()
   308  
   309  	if autocommit {
   310  		se.status |= mysql.ServerStatusAutocommit
   311  		if se.status&mysql.ServerStatusInTrans > 0 {
   312  			se.status &= ^mysql.ServerStatusInTrans
   313  		}
   314  		for _, pc := range se.txConns {
   315  			if e := pc.SetAutoCommit(1); e != nil {
   316  				err = fmt.Errorf("set autocommit error, %v", e)
   317  			}
   318  			pc.Recycle()
   319  		}
   320  		se.txConns = make(map[string]backend.PooledConnect)
   321  		return
   322  	}
   323  
   324  	se.status &= ^mysql.ServerStatusAutocommit
   325  	return
   326  }
   327  
   328  func (se *SessionExecutor) handleStmtPrepare(sql string) (*Stmt, error) {
   329  	log.Debug("namespace: %s use prepare, sql: %s", se.GetNamespace().GetName(), sql)
   330  
   331  	stmt := new(Stmt)
   332  
   333  	sql = strings.TrimRight(sql, ";")
   334  	stmt.sql = sql
   335  
   336  	paramCount, offsets, err := calcParams(stmt.sql)
   337  	if err != nil {
   338  		log.Warn("prepare calc params failed, namespace: %s, sql: %s", se.GetNamespace().GetName(), sql)
   339  		return nil, err
   340  	}
   341  
   342  	stmt.paramCount = paramCount
   343  	stmt.offsets = offsets
   344  	stmt.id = se.stmtID
   345  	stmt.columnCount = 0
   346  	se.stmtID++
   347  
   348  	stmt.ResetParams()
   349  	se.stmts[stmt.id] = stmt
   350  
   351  	return stmt, nil
   352  }
   353  
   354  func (se *SessionExecutor) handleStmtClose(data []byte) error {
   355  	if len(data) < 4 {
   356  		return nil
   357  	}
   358  
   359  	id := binary.LittleEndian.Uint32(data[0:4])
   360  
   361  	delete(se.stmts, id)
   362  
   363  	return nil
   364  }
   365  
   366  func (se *SessionExecutor) handleFieldList(data []byte) ([]*mysql.Field, error) {
   367  	index := bytes.IndexByte(data, 0x00)
   368  	table := string(data[0:index])
   369  	wildcard := string(data[index+1:])
   370  
   371  	sliceName := se.GetNamespace().GetRouter().GetRule(se.GetDatabase(), table).GetSlice(0)
   372  
   373  	pc, err := se.getBackendConn(sliceName, se.GetNamespace().IsRWSplit(se.user))
   374  	if err != nil {
   375  		return nil, err
   376  	}
   377  	defer se.recycleBackendConn(pc, false)
   378  
   379  	phyDB, err := se.GetNamespace().GetDefaultPhyDB(se.GetDatabase())
   380  	if err != nil {
   381  		return nil, err
   382  	}
   383  
   384  	if err = initBackendConn(pc, phyDB, se.GetCharset(), se.GetCollationID(), se.GetVariables()); err != nil {
   385  		return nil, err
   386  	}
   387  
   388  	fs, err := pc.FieldList(table, wildcard)
   389  	if err != nil {
   390  		return nil, err
   391  	}
   392  
   393  	return fs, nil
   394  }