github.com/XiaoMi/Gaea@v1.2.5/proxy/server/executor_handle.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package server 16 17 import ( 18 "bytes" 19 "encoding/binary" 20 "fmt" 21 "runtime" 22 "strings" 23 "time" 24 25 "github.com/XiaoMi/Gaea/backend" 26 "github.com/XiaoMi/Gaea/core/errors" 27 "github.com/XiaoMi/Gaea/log" 28 "github.com/XiaoMi/Gaea/mysql" 29 "github.com/XiaoMi/Gaea/parser" 30 "github.com/XiaoMi/Gaea/parser/ast" 31 "github.com/XiaoMi/Gaea/proxy/plan" 32 "github.com/XiaoMi/Gaea/util" 33 ) 34 35 // Parse parse sql 36 func (se *SessionExecutor) Parse(sql string) (ast.StmtNode, error) { 37 return se.parser.ParseOneStmt(sql, "", "") 38 } 39 40 // 处理query语句 41 func (se *SessionExecutor) handleQuery(sql string) (r *mysql.Result, err error) { 42 defer func() { 43 if e := recover(); e != nil { 44 log.Warn("handle query command failed, error: %v, sql: %s", e, sql) 45 46 if err, ok := e.(error); ok { 47 const size = 4096 48 buf := make([]byte, size) 49 buf = buf[:runtime.Stack(buf, false)] 50 51 log.Warn("handle query command catch panic error, sql: %s, error: %s, stack: %s", 52 sql, err.Error(), string(buf)) 53 } 54 55 err = errors.ErrInternalServer 56 return 57 } 58 }() 59 60 sql = strings.TrimRight(sql, ";") //删除sql语句最后的分号 61 62 reqCtx := util.NewRequestContext() 63 // check black sql 64 ns := se.GetNamespace() 65 if !ns.IsSQLAllowed(reqCtx, sql) { 66 fingerprint := mysql.GetFingerprint(sql) 67 log.Warn("catch black sql, sql: %s", sql) 68 se.manager.GetStatisticManager().RecordSQLForbidden(fingerprint, se.GetNamespace().GetName()) 69 err := mysql.NewError(mysql.ErrUnknown, "sql in blacklist") 70 return nil, err 71 } 72 73 startTime := time.Now() 74 stmtType := parser.Preview(sql) 75 reqCtx.Set(util.StmtType, stmtType) 76 77 r, err = se.doQuery(reqCtx, sql) 78 se.manager.RecordSessionSQLMetrics(reqCtx, se, sql, startTime, err) 79 return r, err 80 } 81 82 func (se *SessionExecutor) doQuery(reqCtx *util.RequestContext, sql string) (*mysql.Result, error) { 83 stmtType := reqCtx.Get("stmtType").(int) 84 85 if isSQLNotAllowedByUser(se, stmtType) { 86 return nil, fmt.Errorf("write DML is now allowed by read user") 87 } 88 89 if canHandleWithoutPlan(stmtType) { 90 return se.handleQueryWithoutPlan(reqCtx, sql) 91 } 92 93 db := se.db 94 95 p, err := se.getPlan(se.GetNamespace(), db, sql) 96 if err != nil { 97 return nil, fmt.Errorf("get plan error, db: %s, sql: %s, err: %v", db, sql, err) 98 } 99 100 if canExecuteFromSlave(se, sql) { 101 reqCtx.Set(util.FromSlave, 1) 102 } 103 104 reqCtx.Set(util.DefaultSlice, se.GetNamespace().GetDefaultSlice()) 105 r, err := p.ExecuteIn(reqCtx, se) 106 if err != nil { 107 log.Warn("execute select: %s", err.Error()) 108 return nil, err 109 } 110 111 modifyResultStatus(r, se) 112 113 return r, nil 114 } 115 116 // 处理逻辑较简单的SQL, 不走执行计划部分 117 func (se *SessionExecutor) handleQueryWithoutPlan(reqCtx *util.RequestContext, sql string) (*mysql.Result, error) { 118 n, err := se.Parse(sql) 119 if err != nil { 120 return nil, fmt.Errorf("parse sql error, sql: %s, err: %v", sql, err) 121 } 122 123 switch stmt := n.(type) { 124 case *ast.ShowStmt: 125 return se.handleShow(reqCtx, sql, stmt) 126 case *ast.SetStmt: 127 return se.handleSet(reqCtx, sql, stmt) 128 case *ast.BeginStmt: 129 return nil, se.handleBegin() 130 case *ast.CommitStmt: 131 return nil, se.handleCommit() 132 case *ast.RollbackStmt: 133 return nil, se.handleRollback(stmt) 134 case *ast.SavepointStmt: 135 return nil, se.handleSavepoint(stmt) 136 case *ast.UseStmt: 137 return nil, se.handleUseDB(stmt.DBName) 138 default: 139 return nil, fmt.Errorf("cannot handle sql without plan, ns: %s, sql: %s", se.namespace, sql) 140 } 141 } 142 143 func (se *SessionExecutor) handleUseDB(dbName string) error { 144 if len(dbName) == 0 { 145 return fmt.Errorf("must have database, the length of dbName is zero") 146 } 147 148 if se.GetNamespace().IsAllowedDB(dbName) { 149 se.db = dbName 150 return nil 151 } 152 153 return mysql.NewDefaultError(mysql.ErrNoDB) 154 } 155 156 func (se *SessionExecutor) getPlan(ns *Namespace, db string, sql string) (plan.Plan, error) { 157 n, err := se.Parse(sql) 158 if err != nil { 159 return nil, fmt.Errorf("parse sql error, sql: %s, err: %v", sql, err) 160 } 161 162 rt := ns.GetRouter() 163 seq := ns.GetSequences() 164 phyDBs := ns.GetPhysicalDBs() 165 p, err := plan.BuildPlan(n, phyDBs, db, sql, rt, seq) 166 if err != nil { 167 return nil, fmt.Errorf("create select plan error: %v", err) 168 } 169 170 return p, nil 171 } 172 173 func (se *SessionExecutor) handleShow(reqCtx *util.RequestContext, sql string, stmt *ast.ShowStmt) (*mysql.Result, error) { 174 switch stmt.Tp { 175 case ast.ShowDatabases: 176 dbs := se.GetNamespace().GetAllowedDBs() 177 return createShowDatabaseResult(dbs), nil 178 case ast.ShowVariables: 179 if strings.Contains(sql, gaeaGeneralLogVariable) { 180 return createShowGeneralLogResult(), nil 181 } 182 fallthrough 183 default: 184 r, err := se.ExecuteSQL(reqCtx, se.GetNamespace().GetDefaultSlice(), se.db, sql) 185 if err != nil { 186 return nil, fmt.Errorf("execute sql error, sql: %s, err: %v", sql, err) 187 } 188 modifyResultStatus(r, se) 189 return r, nil 190 } 191 } 192 193 func (se *SessionExecutor) handleSet(reqCtx *util.RequestContext, sql string, stmt *ast.SetStmt) (*mysql.Result, error) { 194 for _, v := range stmt.Variables { 195 if err := se.handleSetVariable(v); err != nil { 196 return nil, err 197 } 198 } 199 200 return nil, nil 201 } 202 203 func (se *SessionExecutor) handleSetVariable(v *ast.VariableAssignment) error { 204 if v.IsGlobal { 205 return fmt.Errorf("does not support set variable in global scope") 206 } 207 name := strings.ToLower(v.Name) 208 switch name { 209 case "character_set_results", "character_set_client", "character_set_connection": 210 charset := getVariableExprResult(v.Value) 211 if charset == "null" { // character_set_results允许设置成null, character_set_client和character_set_connection不允许 212 return nil 213 } 214 if charset == mysql.KeywordDefault { 215 se.charset = se.GetNamespace().GetDefaultCharset() 216 se.collation = se.GetNamespace().GetDefaultCollationID() 217 return nil 218 } 219 cid, ok := mysql.CharsetIds[charset] 220 if !ok { 221 return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset) 222 } 223 se.charset = charset 224 se.collation = cid 225 return nil 226 case "autocommit": 227 value := getVariableExprResult(v.Value) 228 if value == mysql.KeywordDefault || value == "on" || value == "1" { 229 return se.handleSetAutoCommit(true) // default set autocommit = 1 230 } else if value == "off" || value == "0" { 231 return se.handleSetAutoCommit(false) 232 } else { 233 return mysql.NewDefaultError(mysql.ErrWrongValueForVar, name, value) 234 } 235 case "setnames": // SetNAMES represents SET NAMES 'xxx' COLLATE 'xxx' 236 charset := getVariableExprResult(v.Value) 237 if charset == mysql.KeywordDefault { 238 charset = se.GetNamespace().GetDefaultCharset() 239 } 240 241 var collationID mysql.CollationID 242 // if SET NAMES 'xxx' COLLATE DEFAULT, the parser treats it like SET NAMES 'xxx', and the ExtendValue is nil 243 if v.ExtendValue != nil { 244 collationName := getVariableExprResult(v.ExtendValue) 245 cid, ok := mysql.CollationNames[collationName] 246 if !ok { 247 return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset) 248 } 249 toCharset, ok := mysql.CollationNameToCharset[collationName] 250 if !ok { 251 return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset) 252 } 253 if toCharset != charset { // collation与charset不匹配 254 return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset) 255 } 256 collationID = cid 257 } else { 258 // if only set charset but not set collation, the collation is set to charset default collation implicitly. 259 cid, ok := mysql.CharsetIds[charset] 260 if !ok { 261 return mysql.NewDefaultError(mysql.ErrUnknownCharacterSet, charset) 262 } 263 collationID = cid 264 } 265 266 se.charset = charset 267 se.collation = collationID 268 return nil 269 case "sql_mode": 270 sqlMode := getVariableExprResult(v.Value) 271 return se.setStringSessionVariable(mysql.SQLModeStr, sqlMode) 272 case "sql_safe_updates": 273 value := getVariableExprResult(v.Value) 274 onOffValue, err := getOnOffVariable(value) 275 if err != nil { 276 return mysql.NewDefaultError(mysql.ErrWrongValueForVar, name, value) 277 } 278 return se.setIntSessionVariable(mysql.SQLSafeUpdates, onOffValue) 279 case "time_zone": 280 value := getVariableExprResult(v.Value) 281 return se.setStringSessionVariable(mysql.TimeZone, value) 282 case "max_allowed_packet": 283 return mysql.NewDefaultError(mysql.ErrVariableIsReadonly, "SESSION", mysql.MaxAllowedPacket, "GLOBAL") 284 285 // do nothing 286 case "wait_timeout", "interactive_timeout", "net_write_timeout", "net_read_timeout": 287 return nil 288 case "sql_select_limit": 289 return nil 290 // unsupported 291 case "transaction": 292 return fmt.Errorf("does not support set transaction in gaea") 293 case gaeaGeneralLogVariable: 294 value := getVariableExprResult(v.Value) 295 onOffValue, err := getOnOffVariable(value) 296 if err != nil { 297 return mysql.NewDefaultError(mysql.ErrWrongValueForVar, name, value) 298 } 299 return se.setGeneralLogVariable(onOffValue) 300 default: 301 return nil 302 } 303 } 304 305 func (se *SessionExecutor) handleSetAutoCommit(autocommit bool) (err error) { 306 se.txLock.Lock() 307 defer se.txLock.Unlock() 308 309 if autocommit { 310 se.status |= mysql.ServerStatusAutocommit 311 if se.status&mysql.ServerStatusInTrans > 0 { 312 se.status &= ^mysql.ServerStatusInTrans 313 } 314 for _, pc := range se.txConns { 315 if e := pc.SetAutoCommit(1); e != nil { 316 err = fmt.Errorf("set autocommit error, %v", e) 317 } 318 pc.Recycle() 319 } 320 se.txConns = make(map[string]backend.PooledConnect) 321 return 322 } 323 324 se.status &= ^mysql.ServerStatusAutocommit 325 return 326 } 327 328 func (se *SessionExecutor) handleStmtPrepare(sql string) (*Stmt, error) { 329 log.Debug("namespace: %s use prepare, sql: %s", se.GetNamespace().GetName(), sql) 330 331 stmt := new(Stmt) 332 333 sql = strings.TrimRight(sql, ";") 334 stmt.sql = sql 335 336 paramCount, offsets, err := calcParams(stmt.sql) 337 if err != nil { 338 log.Warn("prepare calc params failed, namespace: %s, sql: %s", se.GetNamespace().GetName(), sql) 339 return nil, err 340 } 341 342 stmt.paramCount = paramCount 343 stmt.offsets = offsets 344 stmt.id = se.stmtID 345 stmt.columnCount = 0 346 se.stmtID++ 347 348 stmt.ResetParams() 349 se.stmts[stmt.id] = stmt 350 351 return stmt, nil 352 } 353 354 func (se *SessionExecutor) handleStmtClose(data []byte) error { 355 if len(data) < 4 { 356 return nil 357 } 358 359 id := binary.LittleEndian.Uint32(data[0:4]) 360 361 delete(se.stmts, id) 362 363 return nil 364 } 365 366 func (se *SessionExecutor) handleFieldList(data []byte) ([]*mysql.Field, error) { 367 index := bytes.IndexByte(data, 0x00) 368 table := string(data[0:index]) 369 wildcard := string(data[index+1:]) 370 371 sliceName := se.GetNamespace().GetRouter().GetRule(se.GetDatabase(), table).GetSlice(0) 372 373 pc, err := se.getBackendConn(sliceName, se.GetNamespace().IsRWSplit(se.user)) 374 if err != nil { 375 return nil, err 376 } 377 defer se.recycleBackendConn(pc, false) 378 379 phyDB, err := se.GetNamespace().GetDefaultPhyDB(se.GetDatabase()) 380 if err != nil { 381 return nil, err 382 } 383 384 if err = initBackendConn(pc, phyDB, se.GetCharset(), se.GetCollationID(), se.GetVariables()); err != nil { 385 return nil, err 386 } 387 388 fs, err := pc.FieldList(table, wildcard) 389 if err != nil { 390 return nil, err 391 } 392 393 return fs, nil 394 }