github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/allegrosql/server/conn_stmt.go (about) 1 // Copyright 2020 The Go-MyALLEGROSQL-Driver Authors. All rights reserved. 2 // 3 // This Source Code Form is subject to the terms of the Mozilla Public 4 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 // You can obtain one at http://mozilla.org/MPL/2.0/. 6 7 // The MIT License (MIT) 8 // 9 // Copyright (c) 2020 wandoulabs 10 // Copyright (c) 2020 siddontang 11 // 12 // Permission is hereby granted, free of charge, to any person obtaining a copy of 13 // this software and associated documentation files (the "Software"), to deal in 14 // the Software without restriction, including without limitation the rights to 15 // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 16 // the Software, and to permit persons to whom the Software is furnished to do so, 17 // subject to the following conditions: 18 // 19 // The above copyright notice and this permission notice shall be included in all 20 // copies or substantial portions of the Software. 21 22 // Copyright 2020 WHTCORPS INC, Inc. 23 // 24 // Licensed under the Apache License, Version 2.0 (the "License"); 25 // you may not use this file except in compliance with the License. 26 // You may obtain a copy of the License at 27 // 28 // http://www.apache.org/licenses/LICENSE-2.0 29 // 30 // Unless required by applicable law or agreed to in writing, software 31 // distributed under the License is distributed on an "AS IS" BASIS, 32 // See the License for the specific language governing permissions and 33 // limitations under the License. 34 35 package server 36 37 import ( 38 "context" 39 "encoding/binary" 40 "fmt" 41 "math" 42 "runtime/trace" 43 "strconv" 44 "time" 45 46 "github.com/whtcorpsinc/BerolinaSQL/allegrosql" 47 "github.com/whtcorpsinc/BerolinaSQL/terror" 48 "github.com/whtcorpsinc/errors" 49 causetembedded "github.com/whtcorpsinc/milevadb/causet/embedded" 50 "github.com/whtcorpsinc/milevadb/config" 51 "github.com/whtcorpsinc/milevadb/soliton/execdetails" 52 "github.com/whtcorpsinc/milevadb/soliton/replog" 53 "github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx" 54 "github.com/whtcorpsinc/milevadb/types" 55 ) 56 57 func (cc *clientConn) handleStmtPrepare(ctx context.Context, allegrosql string) error { 58 stmt, defCausumns, params, err := cc.ctx.Prepare(allegrosql) 59 if err != nil { 60 return err 61 } 62 data := make([]byte, 4, 128) 63 64 //status ok 65 data = append(data, 0) 66 //stmt id 67 data = dumpUint32(data, uint32(stmt.ID())) 68 //number defCausumns 69 data = dumpUint16(data, uint16(len(defCausumns))) 70 //number params 71 data = dumpUint16(data, uint16(len(params))) 72 //filter [00] 73 data = append(data, 0) 74 //warning count 75 data = append(data, 0, 0) //TODO support warning count 76 77 if err := cc.writePacket(data); err != nil { 78 return err 79 } 80 81 if len(params) > 0 { 82 for i := 0; i < len(params); i++ { 83 data = data[0:4] 84 data = params[i].Dump(data) 85 86 if err := cc.writePacket(data); err != nil { 87 return err 88 } 89 } 90 91 if err := cc.writeEOF(0); err != nil { 92 return err 93 } 94 } 95 96 if len(defCausumns) > 0 { 97 for i := 0; i < len(defCausumns); i++ { 98 data = data[0:4] 99 data = defCausumns[i].Dump(data) 100 101 if err := cc.writePacket(data); err != nil { 102 return err 103 } 104 } 105 106 if err := cc.writeEOF(0); err != nil { 107 return err 108 } 109 110 } 111 return cc.flush(ctx) 112 } 113 114 func (cc *clientConn) handleStmtInterDircute(ctx context.Context, data []byte) (err error) { 115 defer trace.StartRegion(ctx, "HandleStmtInterDircute").End() 116 if len(data) < 9 { 117 return allegrosql.ErrMalformPacket 118 } 119 pos := 0 120 stmtID := binary.LittleEndian.Uint32(data[0:4]) 121 pos += 4 122 123 stmt := cc.ctx.GetStatement(int(stmtID)) 124 if stmt == nil { 125 return allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler, 126 strconv.FormatUint(uint64(stmtID), 10), "stmt_execute") 127 } 128 129 flag := data[pos] 130 pos++ 131 // Please refer to https://dev.allegrosql.com/doc/internals/en/com-stmt-execute.html 132 // The client indicates that it wants to use cursor by setting this flag. 133 // 0x00 CURSOR_TYPE_NO_CURSOR 134 // 0x01 CURSOR_TYPE_READ_ONLY 135 // 0x02 CURSOR_TYPE_FOR_UFIDelATE 136 // 0x04 CURSOR_TYPE_SCROLLABLE 137 // Now we only support forward-only, read-only cursor. 138 var useCursor bool 139 switch flag { 140 case 0: 141 useCursor = false 142 case 1: 143 useCursor = true 144 default: 145 return allegrosql.NewErrf(allegrosql.ErrUnknown, "unsupported flag %d", flag) 146 } 147 148 // skip iteration-count, always 1 149 pos += 4 150 151 var ( 152 nullBitmaps []byte 153 paramTypes []byte 154 paramValues []byte 155 ) 156 numParams := stmt.NumParams() 157 args := make([]types.Causet, numParams) 158 if numParams > 0 { 159 nullBitmapLen := (numParams + 7) >> 3 160 if len(data) < (pos + nullBitmapLen + 1) { 161 return allegrosql.ErrMalformPacket 162 } 163 nullBitmaps = data[pos : pos+nullBitmapLen] 164 pos += nullBitmapLen 165 166 // new param bound flag 167 if data[pos] == 1 { 168 pos++ 169 if len(data) < (pos + (numParams << 1)) { 170 return allegrosql.ErrMalformPacket 171 } 172 173 paramTypes = data[pos : pos+(numParams<<1)] 174 pos += numParams << 1 175 paramValues = data[pos:] 176 // Just the first StmtInterDircute packet contain parameters type, 177 // we need save it for further use. 178 stmt.SetParamsType(paramTypes) 179 } else { 180 paramValues = data[pos+1:] 181 } 182 183 err = parseInterDircArgs(cc.ctx.GetStochastikVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues) 184 stmt.Reset() 185 if err != nil { 186 return errors.Annotate(err, cc.preparedStmt2String(stmtID)) 187 } 188 } 189 ctx = context.WithValue(ctx, execdetails.StmtInterDircDetailKey, &execdetails.StmtInterDircDetails{}) 190 rs, err := stmt.InterDircute(ctx, args) 191 if err != nil { 192 return errors.Annotate(err, cc.preparedStmt2String(stmtID)) 193 } 194 if rs == nil { 195 return cc.writeOK(ctx) 196 } 197 198 // if the client wants to use cursor 199 // we should hold the ResultSet in PreparedStatement for next stmt_fetch, and only send back DeferredCausetInfo. 200 // Tell the client cursor exists in server by setting proper serverStatus. 201 if useCursor { 202 stmt.StoreResultSet(rs) 203 err = cc.writeDeferredCausetInfo(rs.DeferredCausets(), allegrosql.ServerStatusCursorExists) 204 if err != nil { 205 return err 206 } 207 if cl, ok := rs.(fetchNotifier); ok { 208 cl.OnFetchReturned() 209 } 210 // explicitly flush defCausumnInfo to client. 211 return cc.flush(ctx) 212 } 213 defer terror.Call(rs.Close) 214 err = cc.writeResultset(ctx, rs, true, 0, 0) 215 if err != nil { 216 return errors.Annotate(err, cc.preparedStmt2String(stmtID)) 217 } 218 return nil 219 } 220 221 // maxFetchSize constants 222 const ( 223 maxFetchSize = 1024 224 ) 225 226 func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) { 227 cc.ctx.GetStochastikVars().StartTime = time.Now() 228 229 stmtID, fetchSize, err := parseStmtFetchCmd(data) 230 if err != nil { 231 return err 232 } 233 234 stmt := cc.ctx.GetStatement(int(stmtID)) 235 if stmt == nil { 236 return errors.Annotate(allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler, 237 strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch"), cc.preparedStmt2String(stmtID)) 238 } 239 allegrosql := "" 240 if prepared, ok := cc.ctx.GetStatement(int(stmtID)).(*MilevaDBStatement); ok { 241 allegrosql = prepared.allegrosql 242 } 243 cc.ctx.SetProcessInfo(allegrosql, time.Now(), allegrosql.ComStmtInterDircute, 0) 244 rs := stmt.GetResultSet() 245 if rs == nil { 246 return errors.Annotate(allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler, 247 strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs"), cc.preparedStmt2String(stmtID)) 248 } 249 250 err = cc.writeResultset(ctx, rs, true, allegrosql.ServerStatusCursorExists, int(fetchSize)) 251 if err != nil { 252 return errors.Annotate(err, cc.preparedStmt2String(stmtID)) 253 } 254 return nil 255 } 256 257 func parseStmtFetchCmd(data []byte) (uint32, uint32, error) { 258 if len(data) != 8 { 259 return 0, 0, allegrosql.ErrMalformPacket 260 } 261 // Please refer to https://dev.allegrosql.com/doc/internals/en/com-stmt-fetch.html 262 stmtID := binary.LittleEndian.Uint32(data[0:4]) 263 fetchSize := binary.LittleEndian.Uint32(data[4:8]) 264 if fetchSize > maxFetchSize { 265 fetchSize = maxFetchSize 266 } 267 return stmtID, fetchSize, nil 268 } 269 270 func parseInterDircArgs(sc *stmtctx.StatementContext, args []types.Causet, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte) (err error) { 271 pos := 0 272 var ( 273 tmp interface{} 274 v []byte 275 n int 276 isNull bool 277 ) 278 279 for i := 0; i < len(args); i++ { 280 // if params had received via ComStmtSendLongData, use them directly. 281 // ref https://dev.allegrosql.com/doc/internals/en/com-stmt-send-long-data.html 282 // see clientConn#handleStmtSendLongData 283 if boundParams[i] != nil { 284 args[i] = types.NewBytesCauset(boundParams[i]) 285 continue 286 } 287 288 // check nullBitMap to determine the NULL arguments. 289 // ref https://dev.allegrosql.com/doc/internals/en/com-stmt-execute.html 290 // notice: some client(e.g. mariadb) will set nullBitMap even if data had be sent via ComStmtSendLongData, 291 // so this check need place after boundParam's check. 292 if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 { 293 var nilCauset types.Causet 294 nilCauset.SetNull() 295 args[i] = nilCauset 296 continue 297 } 298 299 if (i<<1)+1 >= len(paramTypes) { 300 return allegrosql.ErrMalformPacket 301 } 302 303 tp := paramTypes[i<<1] 304 isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 305 306 switch tp { 307 case allegrosql.TypeNull: 308 var nilCauset types.Causet 309 nilCauset.SetNull() 310 args[i] = nilCauset 311 continue 312 313 case allegrosql.TypeTiny: 314 if len(paramValues) < (pos + 1) { 315 err = allegrosql.ErrMalformPacket 316 return 317 } 318 319 if isUnsigned { 320 args[i] = types.NewUintCauset(uint64(paramValues[pos])) 321 } else { 322 args[i] = types.NewIntCauset(int64(int8(paramValues[pos]))) 323 } 324 325 pos++ 326 continue 327 328 case allegrosql.TypeShort, allegrosql.TypeYear: 329 if len(paramValues) < (pos + 2) { 330 err = allegrosql.ErrMalformPacket 331 return 332 } 333 valU16 := binary.LittleEndian.Uint16(paramValues[pos : pos+2]) 334 if isUnsigned { 335 args[i] = types.NewUintCauset(uint64(valU16)) 336 } else { 337 args[i] = types.NewIntCauset(int64(int16(valU16))) 338 } 339 pos += 2 340 continue 341 342 case allegrosql.TypeInt24, allegrosql.TypeLong: 343 if len(paramValues) < (pos + 4) { 344 err = allegrosql.ErrMalformPacket 345 return 346 } 347 valU32 := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) 348 if isUnsigned { 349 args[i] = types.NewUintCauset(uint64(valU32)) 350 } else { 351 args[i] = types.NewIntCauset(int64(int32(valU32))) 352 } 353 pos += 4 354 continue 355 356 case allegrosql.TypeLonglong: 357 if len(paramValues) < (pos + 8) { 358 err = allegrosql.ErrMalformPacket 359 return 360 } 361 valU64 := binary.LittleEndian.Uint64(paramValues[pos : pos+8]) 362 if isUnsigned { 363 args[i] = types.NewUintCauset(valU64) 364 } else { 365 args[i] = types.NewIntCauset(int64(valU64)) 366 } 367 pos += 8 368 continue 369 370 case allegrosql.TypeFloat: 371 if len(paramValues) < (pos + 4) { 372 err = allegrosql.ErrMalformPacket 373 return 374 } 375 376 args[i] = types.NewFloat32Causet(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))) 377 pos += 4 378 continue 379 380 case allegrosql.TypeDouble: 381 if len(paramValues) < (pos + 8) { 382 err = allegrosql.ErrMalformPacket 383 return 384 } 385 386 args[i] = types.NewFloat64Causet(math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))) 387 pos += 8 388 continue 389 390 case allegrosql.TypeDate, allegrosql.TypeTimestamp, allegrosql.TypeDatetime: 391 if len(paramValues) < (pos + 1) { 392 err = allegrosql.ErrMalformPacket 393 return 394 } 395 // See https://dev.allegrosql.com/doc/internals/en/binary-protodefCaus-value.html 396 // for more details. 397 length := paramValues[pos] 398 pos++ 399 switch length { 400 case 0: 401 tmp = types.ZeroDatetimeStr 402 case 4: 403 pos, tmp = parseBinaryDate(pos, paramValues) 404 case 7: 405 pos, tmp = parseBinaryDateTime(pos, paramValues) 406 case 11: 407 pos, tmp = parseBinaryTimestamp(pos, paramValues) 408 default: 409 err = allegrosql.ErrMalformPacket 410 return 411 } 412 args[i] = types.NewCauset(tmp) // FIXME: After check works!!!!!! 413 continue 414 415 case allegrosql.TypeDuration: 416 if len(paramValues) < (pos + 1) { 417 err = allegrosql.ErrMalformPacket 418 return 419 } 420 // See https://dev.allegrosql.com/doc/internals/en/binary-protodefCaus-value.html 421 // for more details. 422 length := paramValues[pos] 423 pos++ 424 switch length { 425 case 0: 426 tmp = "0" 427 case 8: 428 isNegative := paramValues[pos] 429 if isNegative > 1 { 430 err = allegrosql.ErrMalformPacket 431 return 432 } 433 pos++ 434 pos, tmp = parseBinaryDuration(pos, paramValues, isNegative) 435 case 12: 436 isNegative := paramValues[pos] 437 if isNegative > 1 { 438 err = allegrosql.ErrMalformPacket 439 return 440 } 441 pos++ 442 pos, tmp = parseBinaryDurationWithMS(pos, paramValues, isNegative) 443 default: 444 err = allegrosql.ErrMalformPacket 445 return 446 } 447 args[i] = types.NewCauset(tmp) 448 continue 449 case allegrosql.TypeNewDecimal: 450 if len(paramValues) < (pos + 1) { 451 err = allegrosql.ErrMalformPacket 452 return 453 } 454 455 v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:]) 456 pos += n 457 if err != nil { 458 return 459 } 460 461 if isNull { 462 args[i] = types.NewDecimalCauset(nil) 463 } else { 464 var dec types.MyDecimal 465 err = sc.HandleTruncate(dec.FromString(v)) 466 if err != nil { 467 return err 468 } 469 args[i] = types.NewDecimalCauset(&dec) 470 } 471 continue 472 case allegrosql.TypeBlob, allegrosql.TypeTinyBlob, allegrosql.TypeMediumBlob, allegrosql.TypeLongBlob: 473 if len(paramValues) < (pos + 1) { 474 err = allegrosql.ErrMalformPacket 475 return 476 } 477 v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:]) 478 pos += n 479 if err != nil { 480 return 481 } 482 483 if isNull { 484 args[i] = types.NewBytesCauset(nil) 485 } else { 486 args[i] = types.NewBytesCauset(v) 487 } 488 continue 489 case allegrosql.TypeUnspecified, allegrosql.TypeVarchar, allegrosql.TypeVarString, allegrosql.TypeString, 490 allegrosql.TypeEnum, allegrosql.TypeSet, allegrosql.TypeGeometry, allegrosql.TypeBit: 491 if len(paramValues) < (pos + 1) { 492 err = allegrosql.ErrMalformPacket 493 return 494 } 495 496 v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:]) 497 pos += n 498 if err != nil { 499 return 500 } 501 502 if !isNull { 503 tmp = string(replog.String(v)) 504 } else { 505 tmp = nil 506 } 507 args[i] = types.NewCauset(tmp) 508 continue 509 default: 510 err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp) 511 return 512 } 513 } 514 return 515 } 516 517 func parseBinaryDate(pos int, paramValues []byte) (int, string) { 518 year := binary.LittleEndian.Uint16(paramValues[pos : pos+2]) 519 pos += 2 520 month := paramValues[pos] 521 pos++ 522 day := paramValues[pos] 523 pos++ 524 return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day) 525 } 526 527 func parseBinaryDateTime(pos int, paramValues []byte) (int, string) { 528 pos, date := parseBinaryDate(pos, paramValues) 529 hour := paramValues[pos] 530 pos++ 531 minute := paramValues[pos] 532 pos++ 533 second := paramValues[pos] 534 pos++ 535 return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second) 536 } 537 538 func parseBinaryTimestamp(pos int, paramValues []byte) (int, string) { 539 pos, dateTime := parseBinaryDateTime(pos, paramValues) 540 microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) 541 pos += 4 542 return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond) 543 } 544 545 func parseBinaryDuration(pos int, paramValues []byte, isNegative uint8) (int, string) { 546 sign := "" 547 if isNegative == 1 { 548 sign = "-" 549 } 550 days := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) 551 pos += 4 552 hours := paramValues[pos] 553 pos++ 554 minutes := paramValues[pos] 555 pos++ 556 seconds := paramValues[pos] 557 pos++ 558 return pos, fmt.Sprintf("%s%d %02d:%02d:%02d", sign, days, hours, minutes, seconds) 559 } 560 561 func parseBinaryDurationWithMS(pos int, paramValues []byte, 562 isNegative uint8) (int, string) { 563 pos, dur := parseBinaryDuration(pos, paramValues, isNegative) 564 microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) 565 pos += 4 566 return pos, fmt.Sprintf("%s.%06d", dur, microSecond) 567 } 568 569 func (cc *clientConn) handleStmtClose(data []byte) (err error) { 570 if len(data) < 4 { 571 return 572 } 573 574 stmtID := int(binary.LittleEndian.Uint32(data[0:4])) 575 stmt := cc.ctx.GetStatement(stmtID) 576 if stmt != nil { 577 return stmt.Close() 578 } 579 return 580 } 581 582 func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) { 583 if len(data) < 6 { 584 return allegrosql.ErrMalformPacket 585 } 586 587 stmtID := int(binary.LittleEndian.Uint32(data[0:4])) 588 589 stmt := cc.ctx.GetStatement(stmtID) 590 if stmt == nil { 591 return allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler, 592 strconv.Itoa(stmtID), "stmt_send_longdata") 593 } 594 595 paramID := int(binary.LittleEndian.Uint16(data[4:6])) 596 return stmt.AppendParam(paramID, data[6:]) 597 } 598 599 func (cc *clientConn) handleStmtReset(ctx context.Context, data []byte) (err error) { 600 if len(data) < 4 { 601 return allegrosql.ErrMalformPacket 602 } 603 604 stmtID := int(binary.LittleEndian.Uint32(data[0:4])) 605 stmt := cc.ctx.GetStatement(stmtID) 606 if stmt == nil { 607 return allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler, 608 strconv.Itoa(stmtID), "stmt_reset") 609 } 610 stmt.Reset() 611 stmt.StoreResultSet(nil) 612 return cc.writeOK(ctx) 613 } 614 615 // handleSetOption refer to https://dev.allegrosql.com/doc/internals/en/com-set-option.html 616 func (cc *clientConn) handleSetOption(ctx context.Context, data []byte) (err error) { 617 if len(data) < 2 { 618 return allegrosql.ErrMalformPacket 619 } 620 621 switch binary.LittleEndian.Uint16(data[:2]) { 622 case 0: 623 cc.capability |= allegrosql.ClientMultiStatements 624 cc.ctx.SetClientCapability(cc.capability) 625 case 1: 626 cc.capability &^= allegrosql.ClientMultiStatements 627 cc.ctx.SetClientCapability(cc.capability) 628 default: 629 return allegrosql.ErrMalformPacket 630 } 631 if err = cc.writeEOF(0); err != nil { 632 return err 633 } 634 635 return cc.flush(ctx) 636 } 637 638 func (cc *clientConn) preparedStmt2String(stmtID uint32) string { 639 sv := cc.ctx.GetStochastikVars() 640 if sv == nil { 641 return "" 642 } 643 if config.RedactLogEnabled() { 644 return cc.preparedStmt2StringNoArgs(stmtID) 645 } 646 return cc.preparedStmt2StringNoArgs(stmtID) + sv.PreparedParams.String() 647 } 648 649 func (cc *clientConn) preparedStmt2StringNoArgs(stmtID uint32) string { 650 sv := cc.ctx.GetStochastikVars() 651 if sv == nil { 652 return "" 653 } 654 preparedPointer, ok := sv.PreparedStmts[stmtID] 655 if !ok { 656 return "prepared memex not found, ID: " + strconv.FormatUint(uint64(stmtID), 10) 657 } 658 preparedObj, ok := preparedPointer.(*causetembedded.CachedPrepareStmt) 659 if !ok { 660 return "invalidate CachedPrepareStmt type, ID: " + strconv.FormatUint(uint64(stmtID), 10) 661 } 662 preparedAst := preparedObj.PreparedAst 663 return preparedAst.Stmt.Text() 664 }