github.com/matrixorigin/matrixone@v1.2.0/pkg/frontend/util.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frontend 16 17 import ( 18 "bytes" 19 "context" 20 "fmt" 21 "math/rand" 22 "os" 23 "runtime" 24 "strconv" 25 "strings" 26 "sync" 27 "sync/atomic" 28 "time" 29 30 "github.com/google/uuid" 31 "go.uber.org/zap" 32 33 "github.com/matrixorigin/matrixone/pkg/common/log" 34 moruntime "github.com/matrixorigin/matrixone/pkg/common/runtime" 35 "github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect" 36 db_holder "github.com/matrixorigin/matrixone/pkg/util/export/etl/db" 37 "github.com/matrixorigin/matrixone/pkg/util/trace" 38 "github.com/matrixorigin/matrixone/pkg/util/trace/impl/motrace" 39 "github.com/matrixorigin/matrixone/pkg/vm/engine/memoryengine" 40 41 "github.com/matrixorigin/matrixone/pkg/frontend/constant" 42 43 "github.com/matrixorigin/matrixone/pkg/container/batch" 44 "github.com/matrixorigin/matrixone/pkg/pb/plan" 45 "github.com/matrixorigin/matrixone/pkg/sql/colexec" 46 47 "github.com/matrixorigin/matrixone/pkg/defines" 48 49 "github.com/BurntSushi/toml" 50 51 "github.com/matrixorigin/matrixone/pkg/common/moerr" 52 mo_config "github.com/matrixorigin/matrixone/pkg/config" 53 "github.com/matrixorigin/matrixone/pkg/container/types" 54 "github.com/matrixorigin/matrixone/pkg/container/vector" 55 "github.com/matrixorigin/matrixone/pkg/logutil" 56 "github.com/matrixorigin/matrixone/pkg/sql/parsers/tree" 57 plan2 "github.com/matrixorigin/matrixone/pkg/sql/plan" 58 "github.com/matrixorigin/matrixone/pkg/vm/engine" 59 ) 60 61 type CloseFlag struct { 62 //closed flag 63 closed uint32 64 } 65 66 // 1 for closed 67 // 0 for others 68 func (cf *CloseFlag) setClosed(value uint32) { 69 atomic.StoreUint32(&cf.closed, value) 70 } 71 72 func (cf *CloseFlag) Open() { 73 cf.setClosed(0) 74 } 75 76 func (cf *CloseFlag) Close() { 77 cf.setClosed(1) 78 } 79 80 func (cf *CloseFlag) IsClosed() bool { 81 return atomic.LoadUint32(&cf.closed) != 0 82 } 83 84 func (cf *CloseFlag) IsOpened() bool { 85 return atomic.LoadUint32(&cf.closed) == 0 86 } 87 88 func Min(a int, b int) int { 89 if a < b { 90 return a 91 } else { 92 return b 93 } 94 } 95 96 func Max(a int, b int) int { 97 if a < b { 98 return b 99 } else { 100 return a 101 } 102 } 103 104 // GetRoutineId gets the routine id 105 func GetRoutineId() uint64 { 106 data := make([]byte, 64) 107 data = data[:runtime.Stack(data, false)] 108 data = bytes.TrimPrefix(data, []byte("goroutine ")) 109 data = data[:bytes.IndexByte(data, ' ')] 110 id, _ := strconv.ParseUint(string(data), 10, 64) 111 return id 112 } 113 114 type Timeout struct { 115 //last record of the time 116 lastTime atomic.Value //time.Time 117 118 //period 119 timeGap time.Duration 120 121 //auto update 122 autoUpdate bool 123 } 124 125 func NewTimeout(tg time.Duration, autoUpdateWhenChecked bool) *Timeout { 126 ret := &Timeout{ 127 timeGap: tg, 128 autoUpdate: autoUpdateWhenChecked, 129 } 130 ret.lastTime.Store(time.Now()) 131 return ret 132 } 133 134 func (t *Timeout) UpdateTime(tn time.Time) { 135 t.lastTime.Store(tn) 136 } 137 138 /* 139 ----------+---------+------------------+-------- 140 141 lastTime Now lastTime + timeGap 142 143 return true : is timeout. the lastTime has been updated. 144 return false : is not timeout. the lastTime has not been updated. 145 */ 146 func (t *Timeout) isTimeout() bool { 147 if time.Since(t.lastTime.Load().(time.Time)) <= t.timeGap { 148 return false 149 } 150 151 if t.autoUpdate { 152 t.lastTime.Store(time.Now()) 153 } 154 return true 155 } 156 157 /* 158 length: 159 -1, complete string. 160 0, empty string 161 >0 , length of characters at the header of the string. 162 */ 163 func SubStringFromBegin(str string, length int) string { 164 if length == 0 || length < -1 { 165 return "" 166 } 167 168 if length == -1 { 169 return str 170 } 171 172 l := Min(len(str), length) 173 if l != len(str) { 174 return str[:l] + "..." 175 } 176 return str[:l] 177 } 178 179 /* 180 path exists in the system 181 return: 182 true/false - exists or not. 183 true/false - file or directory 184 error 185 */ 186 var PathExists = func(path string) (bool, bool, error) { 187 fi, err := os.Stat(path) 188 if err == nil { 189 return true, !fi.IsDir(), nil 190 } 191 if os.IsNotExist(err) { 192 return false, false, err 193 } 194 195 return false, false, err 196 } 197 198 func getSystemVariables(configFile string) (*mo_config.FrontendParameters, error) { 199 sv := &mo_config.FrontendParameters{} 200 var err error 201 _, err = toml.DecodeFile(configFile, sv) 202 if err != nil { 203 return nil, err 204 } 205 return sv, err 206 } 207 208 func getParameterUnit(configFile string, eng engine.Engine, txnClient TxnClient) (*mo_config.ParameterUnit, error) { 209 sv, err := getSystemVariables(configFile) 210 if err != nil { 211 return nil, err 212 } 213 pu := mo_config.NewParameterUnit(sv, eng, txnClient, engine.Nodes{}) 214 215 return pu, nil 216 } 217 218 // WildcardMatch implements wildcard pattern match algorithm. 219 // pattern and target are ascii characters 220 // TODO: add \_ and \% 221 func WildcardMatch(pattern, target string) bool { 222 var p = 0 223 var t = 0 224 var positionOfPercentPlusOne int = -1 225 var positionOfTargetEncounterPercent int = -1 226 plen := len(pattern) 227 tlen := len(target) 228 for t < tlen { 229 //% 230 if p < plen && pattern[p] == '%' { 231 p++ 232 positionOfPercentPlusOne = p 233 if p >= plen { 234 //pattern end with % 235 return true 236 } 237 //means % matches empty 238 positionOfTargetEncounterPercent = t 239 } else if p < plen && (pattern[p] == '_' || pattern[p] == target[t]) { //match or _ 240 p++ 241 t++ 242 } else { 243 if positionOfPercentPlusOne == -1 { 244 //have not matched a % 245 return false 246 } 247 if positionOfTargetEncounterPercent == -1 { 248 return false 249 } 250 //backtrace to last % position + 1 251 p = positionOfPercentPlusOne 252 //means % matches multiple characters 253 positionOfTargetEncounterPercent++ 254 t = positionOfTargetEncounterPercent 255 } 256 } 257 //skip % 258 for p < plen && pattern[p] == '%' { 259 p++ 260 } 261 return p >= plen 262 } 263 264 // getExprValue executes the expression and returns the value. 265 func getExprValue(e tree.Expr, ses *Session, execCtx *ExecCtx) (interface{}, error) { 266 /* 267 CORNER CASE: 268 SET character_set_results = utf8; // e = tree.UnresolvedName{'utf8'}. 269 270 tree.UnresolvedName{'utf8'} can not be resolved as the column of some table. 271 */ 272 switch v := e.(type) { 273 case *tree.UnresolvedName: 274 // set @a = on, type of a is bool. 275 return v.Parts[0], nil 276 } 277 278 var err error 279 280 table := &tree.TableName{} 281 table.ObjectName = "dual" 282 283 //1.composite the 'select (expr) from dual' 284 compositedSelect := &tree.Select{ 285 Select: &tree.SelectClause{ 286 Exprs: tree.SelectExprs{ 287 tree.SelectExpr{ 288 Expr: e, 289 }, 290 }, 291 From: &tree.From{ 292 Tables: tree.TableExprs{ 293 &tree.JoinTableExpr{ 294 JoinType: tree.JOIN_TYPE_CROSS, 295 Left: &tree.AliasedTableExpr{ 296 Expr: table, 297 }, 298 }, 299 }, 300 }, 301 }, 302 } 303 304 //2.run the select 305 306 //run the statement in the same session 307 ses.ClearResultBatches() 308 //!!!different ExecCtx 309 tempExecCtx := ExecCtx{ 310 reqCtx: execCtx.reqCtx, 311 ses: ses, 312 } 313 err = executeStmtInSameSession(tempExecCtx.reqCtx, ses, &tempExecCtx, compositedSelect) 314 if err != nil { 315 return nil, err 316 } 317 318 batches := ses.GetResultBatches() 319 if len(batches) == 0 { 320 return nil, moerr.NewInternalError(execCtx.reqCtx, "the expr %s does not generate a value", e.String()) 321 } 322 323 if batches[0].VectorCount() > 1 { 324 return nil, moerr.NewInternalError(execCtx.reqCtx, "the expr %s generates multi columns value", e.String()) 325 } 326 327 //evaluate the count of rows, the count of columns 328 count := 0 329 var resultVec *vector.Vector 330 for _, b := range batches { 331 if b.RowCount() == 0 { 332 continue 333 } 334 count += b.RowCount() 335 if count > 1 { 336 return nil, moerr.NewInternalError(execCtx.reqCtx, "the expr %s generates multi rows value", e.String()) 337 } 338 if resultVec == nil && b.GetVector(0).Length() != 0 { 339 resultVec = b.GetVector(0) 340 } 341 } 342 343 if resultVec == nil { 344 return nil, moerr.NewInternalError(execCtx.reqCtx, "the expr %s does not generate a value", e.String()) 345 } 346 347 // for the decimal type, we need the type of expr 348 //!!!NOTE: the type here may be different from the one in the result vector. 349 var planExpr *plan.Expr 350 oid := resultVec.GetType().Oid 351 if oid == types.T_decimal64 || oid == types.T_decimal128 { 352 builder := plan2.NewQueryBuilder(plan.Query_SELECT, ses.GetTxnCompileCtx(), false) 353 bindContext := plan2.NewBindContext(builder, nil) 354 binder := plan2.NewSetVarBinder(builder, bindContext) 355 planExpr, err = binder.BindExpr(e, 0, false) 356 if err != nil { 357 return nil, err 358 } 359 } 360 361 return getValueFromVector(execCtx.reqCtx, resultVec, ses, planExpr) 362 } 363 364 // only support single value and unary minus 365 func GetSimpleExprValue(ctx context.Context, e tree.Expr, ses *Session) (interface{}, error) { 366 switch v := e.(type) { 367 case *tree.UnresolvedName: 368 // set @a = on, type of a is bool. 369 return v.Parts[0], nil 370 default: 371 builder := plan2.NewQueryBuilder(plan.Query_SELECT, ses.GetTxnCompileCtx(), false) 372 bindContext := plan2.NewBindContext(builder, nil) 373 binder := plan2.NewSetVarBinder(builder, bindContext) 374 planExpr, err := binder.BindExpr(e, 0, false) 375 if err != nil { 376 return nil, err 377 } 378 // set @a = 'on', type of a is bool. And mo cast rule does not fit set variable rule so delay to convert type. 379 // Here the evalExpr may execute some function that needs engine.Engine. 380 ses.txnCompileCtx.GetProcess().Ctx = attachValue(ses.txnCompileCtx.GetProcess().Ctx, 381 defines.EngineKey{}, 382 ses.GetTxnHandler().GetStorage()) 383 384 vec, err := colexec.EvalExpressionOnce(ses.txnCompileCtx.GetProcess(), planExpr, []*batch.Batch{batch.EmptyForConstFoldBatch}) 385 if err != nil { 386 return nil, err 387 } 388 389 value, err := getValueFromVector(ctx, vec, ses, planExpr) 390 vec.Free(ses.txnCompileCtx.GetProcess().Mp()) 391 return value, err 392 } 393 } 394 395 func getValueFromVector(ctx context.Context, vec *vector.Vector, ses *Session, expr *plan2.Expr) (interface{}, error) { 396 if vec.IsConstNull() || vec.GetNulls().Contains(0) { 397 return nil, nil 398 } 399 switch vec.GetType().Oid { 400 case types.T_bool: 401 return vector.MustFixedCol[bool](vec)[0], nil 402 case types.T_bit: 403 return vector.MustFixedCol[uint64](vec)[0], nil 404 case types.T_int8: 405 return vector.MustFixedCol[int8](vec)[0], nil 406 case types.T_int16: 407 return vector.MustFixedCol[int16](vec)[0], nil 408 case types.T_int32: 409 return vector.MustFixedCol[int32](vec)[0], nil 410 case types.T_int64: 411 return vector.MustFixedCol[int64](vec)[0], nil 412 case types.T_uint8: 413 return vector.MustFixedCol[uint8](vec)[0], nil 414 case types.T_uint16: 415 return vector.MustFixedCol[uint16](vec)[0], nil 416 case types.T_uint32: 417 return vector.MustFixedCol[uint32](vec)[0], nil 418 case types.T_uint64: 419 return vector.MustFixedCol[uint64](vec)[0], nil 420 case types.T_float32: 421 return vector.MustFixedCol[float32](vec)[0], nil 422 case types.T_float64: 423 return vector.MustFixedCol[float64](vec)[0], nil 424 case types.T_char, types.T_varchar, types.T_binary, types.T_varbinary, types.T_text, types.T_blob: 425 return vec.GetStringAt(0), nil 426 case types.T_array_float32: 427 return vector.GetArrayAt[float32](vec, 0), nil 428 case types.T_array_float64: 429 return vector.GetArrayAt[float64](vec, 0), nil 430 case types.T_decimal64: 431 val := vector.GetFixedAt[types.Decimal64](vec, 0) 432 return val.Format(expr.Typ.Scale), nil 433 case types.T_decimal128: 434 val := vector.GetFixedAt[types.Decimal128](vec, 0) 435 return val.Format(expr.Typ.Scale), nil 436 case types.T_json: 437 val := vec.GetBytesAt(0) 438 byteJson := types.DecodeJson(val) 439 return byteJson.String(), nil 440 case types.T_uuid: 441 val := vector.MustFixedCol[types.Uuid](vec)[0] 442 return val.ToString(), nil 443 case types.T_date: 444 val := vector.MustFixedCol[types.Date](vec)[0] 445 return val.String(), nil 446 case types.T_time: 447 val := vector.MustFixedCol[types.Time](vec)[0] 448 return val.String(), nil 449 case types.T_datetime: 450 val := vector.MustFixedCol[types.Datetime](vec)[0] 451 return val.String(), nil 452 case types.T_timestamp: 453 val := vector.MustFixedCol[types.Timestamp](vec)[0] 454 return val.String2(ses.GetTimeZone(), vec.GetType().Scale), nil 455 case types.T_enum: 456 return vector.MustFixedCol[types.Enum](vec)[0], nil 457 default: 458 return nil, moerr.NewInvalidArg(ctx, "variable type", vec.GetType().Oid.String()) 459 } 460 } 461 462 type statementStatus int 463 464 const ( 465 success statementStatus = iota 466 fail 467 sessionId = "session_id" 468 469 txnId = "txn_id" 470 statementId = "statement_id" 471 ) 472 473 func (s statementStatus) String() string { 474 switch s { 475 case success: 476 return "success" 477 case fail: 478 return "fail" 479 } 480 return "running" 481 } 482 483 // logStatementStatus prints the status of the statement into the log. 484 func logStatementStatus(ctx context.Context, ses FeSession, stmt tree.Statement, status statementStatus, err error) { 485 var stmtStr string 486 stm := motrace.StatementFromContext(ctx) 487 if stm == nil { 488 fmtCtx := tree.NewFmtCtx(dialect.MYSQL) 489 stmt.Format(fmtCtx) 490 stmtStr = fmtCtx.String() 491 } else { 492 stmtStr = stm.Statement 493 } 494 logStatementStringStatus(ctx, ses, stmtStr, status, err) 495 } 496 497 func logStatementStringStatus(ctx context.Context, ses FeSession, stmtStr string, status statementStatus, err error) { 498 str := SubStringFromBegin(stmtStr, int(getGlobalPu().SV.LengthOfQueryPrinted)) 499 outBytes, outPacket := ses.GetMysqlProtocol().CalculateOutTrafficBytes(true) 500 if status == success { 501 logDebug(ses, ses.GetDebugString(), "query trace status", logutil.ConnectionIdField(ses.GetConnectionID()), logutil.StatementField(str), logutil.StatusField(status.String()), trace.ContextField(ctx)) 502 err = nil // make sure: it is nil for EndStatement 503 } else { 504 logError(ses, ses.GetDebugString(), "query trace status", logutil.ConnectionIdField(ses.GetConnectionID()), logutil.StatementField(str), logutil.StatusField(status.String()), logutil.ErrorField(err), trace.ContextField(ctx)) 505 } 506 507 // pls make sure: NO ONE use the ses.tStmt after EndStatement 508 if !ses.IsBackgroundSession() { 509 motrace.EndStatement(ctx, err, ses.SendRows(), outBytes, outPacket) 510 } 511 512 // need just below EndStatement 513 ses.SetTStmt(nil) 514 } 515 516 var logger *log.MOLogger 517 var loggerOnce sync.Once 518 519 func getLogger() *log.MOLogger { 520 loggerOnce.Do(initLogger) 521 return logger 522 } 523 524 func initLogger() { 525 rt := moruntime.ProcessLevelRuntime() 526 if rt == nil { 527 rt = moruntime.DefaultRuntime() 528 } 529 logger = rt.Logger().Named("frontend") 530 } 531 532 func appendSessionField(fields []zap.Field, ses FeSession) []zap.Field { 533 if ses != nil { 534 if ses.GetStmtInfo() != nil { 535 fields = append(fields, zap.String(sessionId, uuid.UUID(ses.GetStmtInfo().SessionID).String())) 536 fields = append(fields, zap.String(statementId, uuid.UUID(ses.GetStmtInfo().StatementID).String())) 537 txnInfo := ses.GetTxnInfo() 538 if txnInfo != "" { 539 fields = append(fields, zap.String(txnId, txnInfo)) 540 } 541 } else { 542 fields = append(fields, zap.String(sessionId, uuid.UUID(ses.GetUUID()).String())) 543 } 544 } 545 return fields 546 } 547 548 func logInfo(ses FeSession, info string, msg string, fields ...zap.Field) { 549 if ses != nil && ses.GetTenantInfo() != nil && ses.GetTenantInfo().GetUser() == db_holder.MOLoggerUser { 550 return 551 } 552 fields = append(fields, zap.String("session_info", info)) 553 fields = appendSessionField(fields, ses) 554 getLogger().Log(msg, log.DefaultLogOptions().WithLevel(zap.InfoLevel).AddCallerSkip(1), fields...) 555 } 556 557 func logInfof(info string, msg string, fields ...zap.Field) { 558 if logutil.GetSkip1Logger().Core().Enabled(zap.InfoLevel) { 559 fields = append(fields, zap.String("session_info", info)) 560 getLogger().Log(msg, log.DefaultLogOptions().WithLevel(zap.InfoLevel).AddCallerSkip(1), fields...) 561 } 562 } 563 564 func logDebug(ses FeSession, info string, msg string, fields ...zap.Field) { 565 if ses != nil && ses.GetTenantInfo() != nil && ses.GetTenantInfo().GetUser() == db_holder.MOLoggerUser { 566 return 567 } 568 fields = append(fields, zap.String("session_info", info)) 569 fields = appendSessionField(fields, ses) 570 getLogger().Log(msg, log.DefaultLogOptions().WithLevel(zap.DebugLevel).AddCallerSkip(1), fields...) 571 } 572 573 func logError(ses FeSession, info string, msg string, fields ...zap.Field) { 574 if ses != nil && ses.GetTenantInfo() != nil && ses.GetTenantInfo().GetUser() == db_holder.MOLoggerUser { 575 return 576 } 577 fields = append(fields, zap.String("session_info", info)) 578 fields = appendSessionField(fields, ses) 579 getLogger().Log(msg, log.DefaultLogOptions().WithLevel(zap.ErrorLevel).AddCallerSkip(1), fields...) 580 } 581 582 // todo: remove this function after all the logDebugf are replaced by logDebug 583 func logDebugf(info string, msg string, fields ...interface{}) { 584 if logutil.GetSkip1Logger().Core().Enabled(zap.DebugLevel) { 585 fields = append(fields, info) 586 logutil.Debugf(msg+" %s", fields...) 587 } 588 } 589 590 // isCmdFieldListSql checks the sql is the cmdFieldListSql or not. 591 func isCmdFieldListSql(sql string) bool { 592 if len(sql) < cmdFieldListSqlLen { 593 return false 594 } 595 prefix := sql[:cmdFieldListSqlLen] 596 return strings.Compare(strings.ToLower(prefix), cmdFieldListSql) == 0 597 } 598 599 // makeCmdFieldListSql makes the internal CMD_FIELD_LIST sql 600 func makeCmdFieldListSql(query string) string { 601 nullIdx := strings.IndexRune(query, rune(0)) 602 if nullIdx != -1 { 603 query = query[:nullIdx] 604 } 605 return cmdFieldListSql + " " + query 606 } 607 608 // parseCmdFieldList parses the internal cmd field list 609 func parseCmdFieldList(ctx context.Context, sql string) (*InternalCmdFieldList, error) { 610 if !isCmdFieldListSql(sql) { 611 return nil, moerr.NewInternalError(ctx, "it is not the CMD_FIELD_LIST") 612 } 613 tableName := strings.TrimSpace(sql[len(cmdFieldListSql):]) 614 return &InternalCmdFieldList{tableName: tableName}, nil 615 } 616 617 func getVariableValue(varDefault interface{}) string { 618 switch val := varDefault.(type) { 619 case int64: 620 return fmt.Sprintf("%d", val) 621 case uint64: 622 return fmt.Sprintf("%d", val) 623 case int8: 624 return fmt.Sprintf("%d", val) 625 case float64: 626 // 0.1 => 0.100000 627 // 0.0000001 -> 1.000000e-7 628 if val >= 1e-6 { 629 return fmt.Sprintf("%.6f", val) 630 } else { 631 return fmt.Sprintf("%.6e", val) 632 } 633 case string: 634 return val 635 default: 636 return "" 637 } 638 } 639 640 func makeServerVersion(pu *mo_config.ParameterUnit, version string) string { 641 return pu.SV.ServerVersionPrefix + version 642 } 643 644 func copyBytes(src []byte, needCopy bool) []byte { 645 if needCopy { 646 if len(src) > 0 { 647 dst := make([]byte, len(src)) 648 copy(dst, src) 649 return dst 650 } else { 651 return []byte{} 652 } 653 } 654 return src 655 } 656 657 // getUserProfile returns the account, user, role of the account 658 func getUserProfile(account *TenantInfo) (string, string, string) { 659 var ( 660 accountName string 661 userName string 662 roleName string 663 ) 664 665 if account != nil { 666 accountName = account.GetTenant() 667 userName = account.GetUser() 668 roleName = account.GetDefaultRole() 669 } else { 670 accountName = sysAccountName 671 userName = rootName 672 roleName = moAdminRoleName 673 } 674 return accountName, userName, roleName 675 } 676 677 // RewriteError rewrites the error info 678 func RewriteError(err error, username string) (uint16, string, string) { 679 if err == nil { 680 return moerr.ER_INTERNAL_ERROR, "", "" 681 } 682 var errorCode uint16 683 var sqlState string 684 var msg string 685 686 errMsg := strings.ToLower(err.Error()) 687 if needConvertedToAccessDeniedError(errMsg) { 688 failed := moerr.MysqlErrorMsgRefer[moerr.ER_ACCESS_DENIED_ERROR] 689 if len(username) > 0 { 690 tipsFormat := "Access denied for user %s. %s" 691 msg = fmt.Sprintf(tipsFormat, getUserPart(username), err.Error()) 692 } else { 693 msg = err.Error() 694 } 695 errorCode = failed.ErrorCode 696 sqlState = failed.SqlStates[0] 697 } else { 698 //Reference To : https://github.com/matrixorigin/matrixone/pull/12396/files#r1374443578 699 switch errImpl := err.(type) { 700 case *moerr.Error: 701 if errImpl.MySQLCode() != moerr.ER_UNKNOWN_ERROR { 702 errorCode = errImpl.MySQLCode() 703 } else { 704 errorCode = errImpl.ErrorCode() 705 } 706 msg = err.Error() 707 sqlState = errImpl.SqlState() 708 default: 709 failed := moerr.MysqlErrorMsgRefer[moerr.ER_INTERNAL_ERROR] 710 msg = err.Error() 711 errorCode = failed.ErrorCode 712 sqlState = failed.SqlStates[0] 713 } 714 715 } 716 return errorCode, sqlState, msg 717 } 718 719 func needConvertedToAccessDeniedError(errMsg string) bool { 720 if strings.Contains(errMsg, "check password failed") || 721 /* 722 following two cases are suggested by the peers from the mo cloud team. 723 we keep the consensus with them. 724 */ 725 strings.Contains(errMsg, "suspended") || 726 strings.Contains(errMsg, "source address") && 727 strings.Contains(errMsg, "is not authorized") { 728 return true 729 } 730 return false 731 } 732 733 const ( 734 quitStr = "MysqlClientQuit" 735 ) 736 737 // makeExecuteSql appends the PREPARE sql and its values of parameters for the EXECUTE statement. 738 // Format 1: execute ... using ... 739 // execute.... // prepare stmt1 from .... ; set var1 = val1 ; set var2 = val2 ; 740 // Format 2: COM_STMT_EXECUTE 741 // execute.... // prepare stmt1 from .... ; param0 ; param1 ... 742 func makeExecuteSql(ctx context.Context, ses *Session, stmt tree.Statement) string { 743 if ses == nil || stmt == nil { 744 return "" 745 } 746 preSql := "" 747 bb := &strings.Builder{} 748 //fill prepare parameters 749 switch t := stmt.(type) { 750 case *tree.Execute: 751 name := string(t.Name) 752 prepareStmt, err := ses.GetPrepareStmt(ctx, name) 753 if err != nil || prepareStmt == nil { 754 break 755 } 756 preSql = strings.TrimSpace(prepareStmt.Sql) 757 bb.WriteString(preSql) 758 bb.WriteString(" ; ") 759 if len(t.Variables) != 0 { 760 //for EXECUTE ... USING statement. append variables if there is. 761 //get SET VAR sql 762 setVarSqls := make([]string, len(t.Variables)) 763 for i, v := range t.Variables { 764 _, userVal, err := ses.GetUserDefinedVar(v.Name) 765 if err == nil && userVal != nil && len(userVal.Sql) != 0 { 766 setVarSqls[i] = userVal.Sql 767 } 768 } 769 bb.WriteString(strings.Join(setVarSqls, " ; ")) 770 } else if prepareStmt.params != nil { 771 //for COM_STMT_EXECUTE 772 //get value of parameters 773 paramCnt := prepareStmt.params.Length() 774 paramValues := make([]string, paramCnt) 775 vs := vector.MustFixedCol[types.Varlena](prepareStmt.params) 776 for i := 0; i < paramCnt; i++ { 777 isNull := prepareStmt.params.GetNulls().Contains(uint64(i)) 778 if isNull { 779 paramValues[i] = "NULL" 780 } else { 781 paramValues[i] = vs[i].GetString(prepareStmt.params.GetArea()) 782 } 783 } 784 bb.WriteString(strings.Join(paramValues, " ; ")) 785 } 786 default: 787 return "" 788 } 789 return bb.String() 790 } 791 792 func mysqlColDef2PlanResultColDef(mr *MysqlResultSet) *plan.ResultColDef { 793 if mr == nil { 794 return nil 795 } 796 797 resultCols := make([]*plan.ColDef, len(mr.Columns)) 798 for i, col := range mr.Columns { 799 resultCols[i] = &plan.ColDef{ 800 Name: col.Name(), 801 } 802 switch col.ColumnType() { 803 case defines.MYSQL_TYPE_VAR_STRING: 804 resultCols[i].Typ = plan.Type{ 805 Id: int32(types.T_varchar), 806 } 807 case defines.MYSQL_TYPE_LONG: 808 resultCols[i].Typ = plan.Type{ 809 Id: int32(types.T_int32), 810 } 811 case defines.MYSQL_TYPE_LONGLONG: 812 resultCols[i].Typ = plan.Type{ 813 Id: int32(types.T_int64), 814 } 815 case defines.MYSQL_TYPE_DOUBLE: 816 resultCols[i].Typ = plan.Type{ 817 Id: int32(types.T_float64), 818 } 819 case defines.MYSQL_TYPE_FLOAT: 820 resultCols[i].Typ = plan.Type{ 821 Id: int32(types.T_float32), 822 } 823 case defines.MYSQL_TYPE_DATE: 824 resultCols[i].Typ = plan.Type{ 825 Id: int32(types.T_date), 826 } 827 case defines.MYSQL_TYPE_TIME: 828 resultCols[i].Typ = plan.Type{ 829 Id: int32(types.T_time), 830 } 831 case defines.MYSQL_TYPE_DATETIME: 832 resultCols[i].Typ = plan.Type{ 833 Id: int32(types.T_datetime), 834 } 835 case defines.MYSQL_TYPE_TIMESTAMP: 836 resultCols[i].Typ = plan.Type{ 837 Id: int32(types.T_timestamp), 838 } 839 default: 840 panic(fmt.Sprintf("unsupported mysql type %d", col.ColumnType())) 841 } 842 } 843 return &plan.ResultColDef{ 844 ResultCols: resultCols, 845 } 846 } 847 848 // errCodeRollbackWholeTxn denotes that the error code 849 // that should rollback the whole txn 850 var errCodeRollbackWholeTxn = map[uint16]bool{ 851 moerr.ErrDeadLockDetected: false, 852 moerr.ErrLockTableBindChanged: false, 853 moerr.ErrLockTableNotFound: false, 854 moerr.ErrDeadlockCheckBusy: false, 855 moerr.ErrLockConflict: false, 856 } 857 858 func isErrorRollbackWholeTxn(inputErr error) bool { 859 if inputErr == nil { 860 return false 861 } 862 me, ok := inputErr.(*moerr.Error) 863 if !ok { 864 // This is not a moerr 865 return false 866 } 867 if _, has := errCodeRollbackWholeTxn[me.ErrorCode()]; has { 868 return true 869 } 870 return false 871 } 872 873 func getRandomErrorRollbackWholeTxn() error { 874 rand.NewSource(time.Now().UnixNano()) 875 x := rand.Intn(len(errCodeRollbackWholeTxn)) 876 arr := make([]uint16, 0, len(errCodeRollbackWholeTxn)) 877 for k := range errCodeRollbackWholeTxn { 878 arr = append(arr, k) 879 } 880 switch arr[x] { 881 case moerr.ErrDeadLockDetected: 882 return moerr.NewDeadLockDetectedNoCtx() 883 case moerr.ErrLockTableBindChanged: 884 return moerr.NewLockTableBindChangedNoCtx() 885 case moerr.ErrLockTableNotFound: 886 return moerr.NewLockTableNotFoundNoCtx() 887 case moerr.ErrDeadlockCheckBusy: 888 return moerr.NewDeadlockCheckBusyNoCtx() 889 case moerr.ErrLockConflict: 890 return moerr.NewLockConflictNoCtx() 891 default: 892 panic(fmt.Sprintf("usp error code %d", arr[x])) 893 } 894 } 895 896 func skipClientQuit(info string) bool { 897 return strings.Contains(info, quitStr) 898 } 899 900 // UserInput 901 // normally, just use the sql. 902 // for some special statement, like 'set_var', we need to use the stmt. 903 // if the stmt is not nil, we neglect the sql. 904 type UserInput struct { 905 sql string 906 stmt tree.Statement 907 sqlSourceType []string 908 isRestore bool 909 // operator account, the account executes restoration 910 // e.g. sys takes a snapshot sn1 for acc1, then restores acc1 from snapshot sn1. In this scenario, sys is the operator account 911 opAccount uint32 912 toAccount uint32 913 } 914 915 func (ui *UserInput) getSql() string { 916 return ui.sql 917 } 918 919 // getStmt if the stmt is not nil, we neglect the sql. 920 func (ui *UserInput) getStmt() tree.Statement { 921 return ui.stmt 922 } 923 924 func (ui *UserInput) getSqlSourceTypes() []string { 925 return ui.sqlSourceType 926 } 927 928 // isInternal return true if the stmt is not nil. 929 // it means the statement is not from any client. 930 // currently, we use it to handle the 'set_var' statement. 931 func (ui *UserInput) isInternal() bool { 932 return ui.getStmt() != nil 933 } 934 935 func (ui *UserInput) genSqlSourceType(ses FeSession) { 936 sql := ui.getSql() 937 ui.sqlSourceType = nil 938 if ui.getStmt() != nil { 939 ui.sqlSourceType = append(ui.sqlSourceType, constant.InternalSql) 940 return 941 } 942 tenant := ses.GetTenantInfo() 943 if tenant == nil || strings.HasPrefix(sql, cmdFieldListSql) { 944 ui.sqlSourceType = append(ui.sqlSourceType, constant.InternalSql) 945 return 946 } 947 flag, _, _ := isSpecialUser(tenant.GetUser()) 948 if flag { 949 ui.sqlSourceType = append(ui.sqlSourceType, constant.InternalSql) 950 return 951 } 952 if tenant.GetTenant() == sysAccountName && tenant.GetUser() == "internal" { 953 ui.sqlSourceType = append(ui.sqlSourceType, constant.InternalSql) 954 return 955 } 956 for len(sql) > 0 { 957 p1 := strings.Index(sql, "/*") 958 p2 := strings.Index(sql, "*/") 959 if p1 < 0 || p2 < 0 || p2 <= p1+1 { 960 ui.sqlSourceType = append(ui.sqlSourceType, constant.ExternSql) 961 return 962 } 963 source := strings.TrimSpace(sql[p1+2 : p2]) 964 if source == cloudUserTag { 965 ui.sqlSourceType = append(ui.sqlSourceType, constant.CloudUserSql) 966 } else if source == cloudNoUserTag { 967 ui.sqlSourceType = append(ui.sqlSourceType, constant.CloudNoUserSql) 968 } else if source == saveResultTag { 969 ui.sqlSourceType = append(ui.sqlSourceType, constant.CloudUserSql) 970 } else { 971 ui.sqlSourceType = append(ui.sqlSourceType, constant.ExternSql) 972 } 973 sql = sql[p2+2:] 974 } 975 } 976 977 func (ui *UserInput) getSqlSourceType(i int) string { 978 sqlType := constant.ExternSql 979 if i < len(ui.sqlSourceType) { 980 sqlType = ui.sqlSourceType[i] 981 } 982 return sqlType 983 } 984 985 func unboxExprStr(ctx context.Context, expr tree.Expr) (string, error) { 986 if e, ok := expr.(*tree.NumVal); ok && e.ValType == tree.P_char { 987 return e.OrigString(), nil 988 } 989 return "", moerr.NewInternalError(ctx, "invalid expr type") 990 } 991 992 type strParamBinder struct { 993 ctx context.Context 994 params *vector.Vector 995 err error 996 } 997 998 func (b *strParamBinder) bind(e tree.Expr) string { 999 if b.err != nil { 1000 return "" 1001 } 1002 1003 switch val := e.(type) { 1004 case *tree.NumVal: 1005 return val.OrigString() 1006 case *tree.ParamExpr: 1007 return b.params.GetStringAt(val.Offset - 1) 1008 default: 1009 b.err = moerr.NewInternalError(b.ctx, "invalid params type %T", e) 1010 return "" 1011 } 1012 } 1013 1014 func (b *strParamBinder) bindIdentStr(ident *tree.AccountIdentified) string { 1015 if b.err != nil { 1016 return "" 1017 } 1018 1019 switch ident.Typ { 1020 case tree.AccountIdentifiedByPassword, 1021 tree.AccountIdentifiedWithSSL: 1022 return b.bind(ident.Str) 1023 default: 1024 return "" 1025 } 1026 } 1027 1028 func resetBits(t *uint32, val uint32) { 1029 if t == nil { 1030 return 1031 } 1032 *t = val 1033 } 1034 1035 func setBits(t *uint32, bit uint32) { 1036 if t == nil { 1037 return 1038 } 1039 *t |= bit 1040 } 1041 1042 func clearBits(t *uint32, bit uint32) { 1043 if t == nil { 1044 return 1045 } 1046 *t &= ^bit 1047 } 1048 1049 func bitsIsSet(t uint32, bit uint32) bool { 1050 return t&bit != 0 1051 } 1052 1053 func attachValue(ctx context.Context, key, val any) context.Context { 1054 if ctx == nil { 1055 panic("context is nil") 1056 } 1057 1058 return context.WithValue(ctx, key, val) 1059 } 1060 1061 func updateTempEngine(storage engine.Engine, te *memoryengine.Engine) { 1062 if ee, ok := storage.(*engine.EntireEngine); ok && ee != nil { 1063 ee.TempEngine = te 1064 } 1065 } 1066 1067 func genKey(dbName, tblName string) string { 1068 return fmt.Sprintf("%s#%s", dbName, tblName) 1069 } 1070 1071 type topsort struct { 1072 next map[string][]string 1073 } 1074 1075 func (g *topsort) addVertex(v string) { 1076 g.next[v] = make([]string, 0) 1077 } 1078 1079 func (g *topsort) addEdge(from, to string) { 1080 g.next[from] = append(g.next[from], to) 1081 } 1082 1083 func (g *topsort) sort() (ans []string, ok bool) { 1084 inDegree := make(map[string]uint) 1085 for u := range g.next { 1086 inDegree[u] = 0 1087 } 1088 for _, nextVertices := range g.next { 1089 for _, v := range nextVertices { 1090 inDegree[v] += 1 1091 } 1092 } 1093 1094 var noPreVertices []string 1095 for v, deg := range inDegree { 1096 if deg == 0 { 1097 noPreVertices = append(noPreVertices, v) 1098 } 1099 } 1100 1101 for len(noPreVertices) > 0 { 1102 // find vertex whose inDegree = 0 1103 v := noPreVertices[0] 1104 noPreVertices = noPreVertices[1:] 1105 ans = append(ans, v) 1106 1107 // update the next vertices from v 1108 for _, to := range g.next[v] { 1109 inDegree[to] -= 1 1110 if inDegree[to] == 0 { 1111 noPreVertices = append(noPreVertices, to) 1112 } 1113 } 1114 } 1115 1116 if len(ans) == len(inDegree) { 1117 ok = true 1118 } 1119 return 1120 }