github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/allegrosql/server/conn_stmt.go (about)

     1  // Copyright 2020 The Go-MyALLEGROSQL-Driver Authors. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     5  // You can obtain one at http://mozilla.org/MPL/2.0/.
     6  
     7  // The MIT License (MIT)
     8  //
     9  // Copyright (c) 2020 wandoulabs
    10  // Copyright (c) 2020 siddontang
    11  //
    12  // Permission is hereby granted, free of charge, to any person obtaining a copy of
    13  // this software and associated documentation files (the "Software"), to deal in
    14  // the Software without restriction, including without limitation the rights to
    15  // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
    16  // the Software, and to permit persons to whom the Software is furnished to do so,
    17  // subject to the following conditions:
    18  //
    19  // The above copyright notice and this permission notice shall be included in all
    20  // copies or substantial portions of the Software.
    21  
    22  // Copyright 2020 WHTCORPS INC, Inc.
    23  //
    24  // Licensed under the Apache License, Version 2.0 (the "License");
    25  // you may not use this file except in compliance with the License.
    26  // You may obtain a copy of the License at
    27  //
    28  //     http://www.apache.org/licenses/LICENSE-2.0
    29  //
    30  // Unless required by applicable law or agreed to in writing, software
    31  // distributed under the License is distributed on an "AS IS" BASIS,
    32  // See the License for the specific language governing permissions and
    33  // limitations under the License.
    34  
    35  package server
    36  
    37  import (
    38  	"context"
    39  	"encoding/binary"
    40  	"fmt"
    41  	"math"
    42  	"runtime/trace"
    43  	"strconv"
    44  	"time"
    45  
    46  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    47  	"github.com/whtcorpsinc/BerolinaSQL/terror"
    48  	"github.com/whtcorpsinc/errors"
    49  	causetembedded "github.com/whtcorpsinc/milevadb/causet/embedded"
    50  	"github.com/whtcorpsinc/milevadb/config"
    51  	"github.com/whtcorpsinc/milevadb/soliton/execdetails"
    52  	"github.com/whtcorpsinc/milevadb/soliton/replog"
    53  	"github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx"
    54  	"github.com/whtcorpsinc/milevadb/types"
    55  )
    56  
    57  func (cc *clientConn) handleStmtPrepare(ctx context.Context, allegrosql string) error {
    58  	stmt, defCausumns, params, err := cc.ctx.Prepare(allegrosql)
    59  	if err != nil {
    60  		return err
    61  	}
    62  	data := make([]byte, 4, 128)
    63  
    64  	//status ok
    65  	data = append(data, 0)
    66  	//stmt id
    67  	data = dumpUint32(data, uint32(stmt.ID()))
    68  	//number defCausumns
    69  	data = dumpUint16(data, uint16(len(defCausumns)))
    70  	//number params
    71  	data = dumpUint16(data, uint16(len(params)))
    72  	//filter [00]
    73  	data = append(data, 0)
    74  	//warning count
    75  	data = append(data, 0, 0) //TODO support warning count
    76  
    77  	if err := cc.writePacket(data); err != nil {
    78  		return err
    79  	}
    80  
    81  	if len(params) > 0 {
    82  		for i := 0; i < len(params); i++ {
    83  			data = data[0:4]
    84  			data = params[i].Dump(data)
    85  
    86  			if err := cc.writePacket(data); err != nil {
    87  				return err
    88  			}
    89  		}
    90  
    91  		if err := cc.writeEOF(0); err != nil {
    92  			return err
    93  		}
    94  	}
    95  
    96  	if len(defCausumns) > 0 {
    97  		for i := 0; i < len(defCausumns); i++ {
    98  			data = data[0:4]
    99  			data = defCausumns[i].Dump(data)
   100  
   101  			if err := cc.writePacket(data); err != nil {
   102  				return err
   103  			}
   104  		}
   105  
   106  		if err := cc.writeEOF(0); err != nil {
   107  			return err
   108  		}
   109  
   110  	}
   111  	return cc.flush(ctx)
   112  }
   113  
   114  func (cc *clientConn) handleStmtInterDircute(ctx context.Context, data []byte) (err error) {
   115  	defer trace.StartRegion(ctx, "HandleStmtInterDircute").End()
   116  	if len(data) < 9 {
   117  		return allegrosql.ErrMalformPacket
   118  	}
   119  	pos := 0
   120  	stmtID := binary.LittleEndian.Uint32(data[0:4])
   121  	pos += 4
   122  
   123  	stmt := cc.ctx.GetStatement(int(stmtID))
   124  	if stmt == nil {
   125  		return allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler,
   126  			strconv.FormatUint(uint64(stmtID), 10), "stmt_execute")
   127  	}
   128  
   129  	flag := data[pos]
   130  	pos++
   131  	// Please refer to https://dev.allegrosql.com/doc/internals/en/com-stmt-execute.html
   132  	// The client indicates that it wants to use cursor by setting this flag.
   133  	// 0x00 CURSOR_TYPE_NO_CURSOR
   134  	// 0x01 CURSOR_TYPE_READ_ONLY
   135  	// 0x02 CURSOR_TYPE_FOR_UFIDelATE
   136  	// 0x04 CURSOR_TYPE_SCROLLABLE
   137  	// Now we only support forward-only, read-only cursor.
   138  	var useCursor bool
   139  	switch flag {
   140  	case 0:
   141  		useCursor = false
   142  	case 1:
   143  		useCursor = true
   144  	default:
   145  		return allegrosql.NewErrf(allegrosql.ErrUnknown, "unsupported flag %d", flag)
   146  	}
   147  
   148  	// skip iteration-count, always 1
   149  	pos += 4
   150  
   151  	var (
   152  		nullBitmaps []byte
   153  		paramTypes  []byte
   154  		paramValues []byte
   155  	)
   156  	numParams := stmt.NumParams()
   157  	args := make([]types.Causet, numParams)
   158  	if numParams > 0 {
   159  		nullBitmapLen := (numParams + 7) >> 3
   160  		if len(data) < (pos + nullBitmapLen + 1) {
   161  			return allegrosql.ErrMalformPacket
   162  		}
   163  		nullBitmaps = data[pos : pos+nullBitmapLen]
   164  		pos += nullBitmapLen
   165  
   166  		// new param bound flag
   167  		if data[pos] == 1 {
   168  			pos++
   169  			if len(data) < (pos + (numParams << 1)) {
   170  				return allegrosql.ErrMalformPacket
   171  			}
   172  
   173  			paramTypes = data[pos : pos+(numParams<<1)]
   174  			pos += numParams << 1
   175  			paramValues = data[pos:]
   176  			// Just the first StmtInterDircute packet contain parameters type,
   177  			// we need save it for further use.
   178  			stmt.SetParamsType(paramTypes)
   179  		} else {
   180  			paramValues = data[pos+1:]
   181  		}
   182  
   183  		err = parseInterDircArgs(cc.ctx.GetStochastikVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues)
   184  		stmt.Reset()
   185  		if err != nil {
   186  			return errors.Annotate(err, cc.preparedStmt2String(stmtID))
   187  		}
   188  	}
   189  	ctx = context.WithValue(ctx, execdetails.StmtInterDircDetailKey, &execdetails.StmtInterDircDetails{})
   190  	rs, err := stmt.InterDircute(ctx, args)
   191  	if err != nil {
   192  		return errors.Annotate(err, cc.preparedStmt2String(stmtID))
   193  	}
   194  	if rs == nil {
   195  		return cc.writeOK(ctx)
   196  	}
   197  
   198  	// if the client wants to use cursor
   199  	// we should hold the ResultSet in PreparedStatement for next stmt_fetch, and only send back DeferredCausetInfo.
   200  	// Tell the client cursor exists in server by setting proper serverStatus.
   201  	if useCursor {
   202  		stmt.StoreResultSet(rs)
   203  		err = cc.writeDeferredCausetInfo(rs.DeferredCausets(), allegrosql.ServerStatusCursorExists)
   204  		if err != nil {
   205  			return err
   206  		}
   207  		if cl, ok := rs.(fetchNotifier); ok {
   208  			cl.OnFetchReturned()
   209  		}
   210  		// explicitly flush defCausumnInfo to client.
   211  		return cc.flush(ctx)
   212  	}
   213  	defer terror.Call(rs.Close)
   214  	err = cc.writeResultset(ctx, rs, true, 0, 0)
   215  	if err != nil {
   216  		return errors.Annotate(err, cc.preparedStmt2String(stmtID))
   217  	}
   218  	return nil
   219  }
   220  
   221  // maxFetchSize constants
   222  const (
   223  	maxFetchSize = 1024
   224  )
   225  
   226  func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) {
   227  	cc.ctx.GetStochastikVars().StartTime = time.Now()
   228  
   229  	stmtID, fetchSize, err := parseStmtFetchCmd(data)
   230  	if err != nil {
   231  		return err
   232  	}
   233  
   234  	stmt := cc.ctx.GetStatement(int(stmtID))
   235  	if stmt == nil {
   236  		return errors.Annotate(allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler,
   237  			strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch"), cc.preparedStmt2String(stmtID))
   238  	}
   239  	allegrosql := ""
   240  	if prepared, ok := cc.ctx.GetStatement(int(stmtID)).(*MilevaDBStatement); ok {
   241  		allegrosql = prepared.allegrosql
   242  	}
   243  	cc.ctx.SetProcessInfo(allegrosql, time.Now(), allegrosql.ComStmtInterDircute, 0)
   244  	rs := stmt.GetResultSet()
   245  	if rs == nil {
   246  		return errors.Annotate(allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler,
   247  			strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs"), cc.preparedStmt2String(stmtID))
   248  	}
   249  
   250  	err = cc.writeResultset(ctx, rs, true, allegrosql.ServerStatusCursorExists, int(fetchSize))
   251  	if err != nil {
   252  		return errors.Annotate(err, cc.preparedStmt2String(stmtID))
   253  	}
   254  	return nil
   255  }
   256  
   257  func parseStmtFetchCmd(data []byte) (uint32, uint32, error) {
   258  	if len(data) != 8 {
   259  		return 0, 0, allegrosql.ErrMalformPacket
   260  	}
   261  	// Please refer to https://dev.allegrosql.com/doc/internals/en/com-stmt-fetch.html
   262  	stmtID := binary.LittleEndian.Uint32(data[0:4])
   263  	fetchSize := binary.LittleEndian.Uint32(data[4:8])
   264  	if fetchSize > maxFetchSize {
   265  		fetchSize = maxFetchSize
   266  	}
   267  	return stmtID, fetchSize, nil
   268  }
   269  
   270  func parseInterDircArgs(sc *stmtctx.StatementContext, args []types.Causet, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte) (err error) {
   271  	pos := 0
   272  	var (
   273  		tmp    interface{}
   274  		v      []byte
   275  		n      int
   276  		isNull bool
   277  	)
   278  
   279  	for i := 0; i < len(args); i++ {
   280  		// if params had received via ComStmtSendLongData, use them directly.
   281  		// ref https://dev.allegrosql.com/doc/internals/en/com-stmt-send-long-data.html
   282  		// see clientConn#handleStmtSendLongData
   283  		if boundParams[i] != nil {
   284  			args[i] = types.NewBytesCauset(boundParams[i])
   285  			continue
   286  		}
   287  
   288  		// check nullBitMap to determine the NULL arguments.
   289  		// ref https://dev.allegrosql.com/doc/internals/en/com-stmt-execute.html
   290  		// notice: some client(e.g. mariadb) will set nullBitMap even if data had be sent via ComStmtSendLongData,
   291  		// so this check need place after boundParam's check.
   292  		if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
   293  			var nilCauset types.Causet
   294  			nilCauset.SetNull()
   295  			args[i] = nilCauset
   296  			continue
   297  		}
   298  
   299  		if (i<<1)+1 >= len(paramTypes) {
   300  			return allegrosql.ErrMalformPacket
   301  		}
   302  
   303  		tp := paramTypes[i<<1]
   304  		isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
   305  
   306  		switch tp {
   307  		case allegrosql.TypeNull:
   308  			var nilCauset types.Causet
   309  			nilCauset.SetNull()
   310  			args[i] = nilCauset
   311  			continue
   312  
   313  		case allegrosql.TypeTiny:
   314  			if len(paramValues) < (pos + 1) {
   315  				err = allegrosql.ErrMalformPacket
   316  				return
   317  			}
   318  
   319  			if isUnsigned {
   320  				args[i] = types.NewUintCauset(uint64(paramValues[pos]))
   321  			} else {
   322  				args[i] = types.NewIntCauset(int64(int8(paramValues[pos])))
   323  			}
   324  
   325  			pos++
   326  			continue
   327  
   328  		case allegrosql.TypeShort, allegrosql.TypeYear:
   329  			if len(paramValues) < (pos + 2) {
   330  				err = allegrosql.ErrMalformPacket
   331  				return
   332  			}
   333  			valU16 := binary.LittleEndian.Uint16(paramValues[pos : pos+2])
   334  			if isUnsigned {
   335  				args[i] = types.NewUintCauset(uint64(valU16))
   336  			} else {
   337  				args[i] = types.NewIntCauset(int64(int16(valU16)))
   338  			}
   339  			pos += 2
   340  			continue
   341  
   342  		case allegrosql.TypeInt24, allegrosql.TypeLong:
   343  			if len(paramValues) < (pos + 4) {
   344  				err = allegrosql.ErrMalformPacket
   345  				return
   346  			}
   347  			valU32 := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
   348  			if isUnsigned {
   349  				args[i] = types.NewUintCauset(uint64(valU32))
   350  			} else {
   351  				args[i] = types.NewIntCauset(int64(int32(valU32)))
   352  			}
   353  			pos += 4
   354  			continue
   355  
   356  		case allegrosql.TypeLonglong:
   357  			if len(paramValues) < (pos + 8) {
   358  				err = allegrosql.ErrMalformPacket
   359  				return
   360  			}
   361  			valU64 := binary.LittleEndian.Uint64(paramValues[pos : pos+8])
   362  			if isUnsigned {
   363  				args[i] = types.NewUintCauset(valU64)
   364  			} else {
   365  				args[i] = types.NewIntCauset(int64(valU64))
   366  			}
   367  			pos += 8
   368  			continue
   369  
   370  		case allegrosql.TypeFloat:
   371  			if len(paramValues) < (pos + 4) {
   372  				err = allegrosql.ErrMalformPacket
   373  				return
   374  			}
   375  
   376  			args[i] = types.NewFloat32Causet(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4])))
   377  			pos += 4
   378  			continue
   379  
   380  		case allegrosql.TypeDouble:
   381  			if len(paramValues) < (pos + 8) {
   382  				err = allegrosql.ErrMalformPacket
   383  				return
   384  			}
   385  
   386  			args[i] = types.NewFloat64Causet(math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8])))
   387  			pos += 8
   388  			continue
   389  
   390  		case allegrosql.TypeDate, allegrosql.TypeTimestamp, allegrosql.TypeDatetime:
   391  			if len(paramValues) < (pos + 1) {
   392  				err = allegrosql.ErrMalformPacket
   393  				return
   394  			}
   395  			// See https://dev.allegrosql.com/doc/internals/en/binary-protodefCaus-value.html
   396  			// for more details.
   397  			length := paramValues[pos]
   398  			pos++
   399  			switch length {
   400  			case 0:
   401  				tmp = types.ZeroDatetimeStr
   402  			case 4:
   403  				pos, tmp = parseBinaryDate(pos, paramValues)
   404  			case 7:
   405  				pos, tmp = parseBinaryDateTime(pos, paramValues)
   406  			case 11:
   407  				pos, tmp = parseBinaryTimestamp(pos, paramValues)
   408  			default:
   409  				err = allegrosql.ErrMalformPacket
   410  				return
   411  			}
   412  			args[i] = types.NewCauset(tmp) // FIXME: After check works!!!!!!
   413  			continue
   414  
   415  		case allegrosql.TypeDuration:
   416  			if len(paramValues) < (pos + 1) {
   417  				err = allegrosql.ErrMalformPacket
   418  				return
   419  			}
   420  			// See https://dev.allegrosql.com/doc/internals/en/binary-protodefCaus-value.html
   421  			// for more details.
   422  			length := paramValues[pos]
   423  			pos++
   424  			switch length {
   425  			case 0:
   426  				tmp = "0"
   427  			case 8:
   428  				isNegative := paramValues[pos]
   429  				if isNegative > 1 {
   430  					err = allegrosql.ErrMalformPacket
   431  					return
   432  				}
   433  				pos++
   434  				pos, tmp = parseBinaryDuration(pos, paramValues, isNegative)
   435  			case 12:
   436  				isNegative := paramValues[pos]
   437  				if isNegative > 1 {
   438  					err = allegrosql.ErrMalformPacket
   439  					return
   440  				}
   441  				pos++
   442  				pos, tmp = parseBinaryDurationWithMS(pos, paramValues, isNegative)
   443  			default:
   444  				err = allegrosql.ErrMalformPacket
   445  				return
   446  			}
   447  			args[i] = types.NewCauset(tmp)
   448  			continue
   449  		case allegrosql.TypeNewDecimal:
   450  			if len(paramValues) < (pos + 1) {
   451  				err = allegrosql.ErrMalformPacket
   452  				return
   453  			}
   454  
   455  			v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:])
   456  			pos += n
   457  			if err != nil {
   458  				return
   459  			}
   460  
   461  			if isNull {
   462  				args[i] = types.NewDecimalCauset(nil)
   463  			} else {
   464  				var dec types.MyDecimal
   465  				err = sc.HandleTruncate(dec.FromString(v))
   466  				if err != nil {
   467  					return err
   468  				}
   469  				args[i] = types.NewDecimalCauset(&dec)
   470  			}
   471  			continue
   472  		case allegrosql.TypeBlob, allegrosql.TypeTinyBlob, allegrosql.TypeMediumBlob, allegrosql.TypeLongBlob:
   473  			if len(paramValues) < (pos + 1) {
   474  				err = allegrosql.ErrMalformPacket
   475  				return
   476  			}
   477  			v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:])
   478  			pos += n
   479  			if err != nil {
   480  				return
   481  			}
   482  
   483  			if isNull {
   484  				args[i] = types.NewBytesCauset(nil)
   485  			} else {
   486  				args[i] = types.NewBytesCauset(v)
   487  			}
   488  			continue
   489  		case allegrosql.TypeUnspecified, allegrosql.TypeVarchar, allegrosql.TypeVarString, allegrosql.TypeString,
   490  			allegrosql.TypeEnum, allegrosql.TypeSet, allegrosql.TypeGeometry, allegrosql.TypeBit:
   491  			if len(paramValues) < (pos + 1) {
   492  				err = allegrosql.ErrMalformPacket
   493  				return
   494  			}
   495  
   496  			v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:])
   497  			pos += n
   498  			if err != nil {
   499  				return
   500  			}
   501  
   502  			if !isNull {
   503  				tmp = string(replog.String(v))
   504  			} else {
   505  				tmp = nil
   506  			}
   507  			args[i] = types.NewCauset(tmp)
   508  			continue
   509  		default:
   510  			err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp)
   511  			return
   512  		}
   513  	}
   514  	return
   515  }
   516  
   517  func parseBinaryDate(pos int, paramValues []byte) (int, string) {
   518  	year := binary.LittleEndian.Uint16(paramValues[pos : pos+2])
   519  	pos += 2
   520  	month := paramValues[pos]
   521  	pos++
   522  	day := paramValues[pos]
   523  	pos++
   524  	return pos, fmt.Sprintf("%04d-%02d-%02d", year, month, day)
   525  }
   526  
   527  func parseBinaryDateTime(pos int, paramValues []byte) (int, string) {
   528  	pos, date := parseBinaryDate(pos, paramValues)
   529  	hour := paramValues[pos]
   530  	pos++
   531  	minute := paramValues[pos]
   532  	pos++
   533  	second := paramValues[pos]
   534  	pos++
   535  	return pos, fmt.Sprintf("%s %02d:%02d:%02d", date, hour, minute, second)
   536  }
   537  
   538  func parseBinaryTimestamp(pos int, paramValues []byte) (int, string) {
   539  	pos, dateTime := parseBinaryDateTime(pos, paramValues)
   540  	microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
   541  	pos += 4
   542  	return pos, fmt.Sprintf("%s.%06d", dateTime, microSecond)
   543  }
   544  
   545  func parseBinaryDuration(pos int, paramValues []byte, isNegative uint8) (int, string) {
   546  	sign := ""
   547  	if isNegative == 1 {
   548  		sign = "-"
   549  	}
   550  	days := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
   551  	pos += 4
   552  	hours := paramValues[pos]
   553  	pos++
   554  	minutes := paramValues[pos]
   555  	pos++
   556  	seconds := paramValues[pos]
   557  	pos++
   558  	return pos, fmt.Sprintf("%s%d %02d:%02d:%02d", sign, days, hours, minutes, seconds)
   559  }
   560  
   561  func parseBinaryDurationWithMS(pos int, paramValues []byte,
   562  	isNegative uint8) (int, string) {
   563  	pos, dur := parseBinaryDuration(pos, paramValues, isNegative)
   564  	microSecond := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
   565  	pos += 4
   566  	return pos, fmt.Sprintf("%s.%06d", dur, microSecond)
   567  }
   568  
   569  func (cc *clientConn) handleStmtClose(data []byte) (err error) {
   570  	if len(data) < 4 {
   571  		return
   572  	}
   573  
   574  	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
   575  	stmt := cc.ctx.GetStatement(stmtID)
   576  	if stmt != nil {
   577  		return stmt.Close()
   578  	}
   579  	return
   580  }
   581  
   582  func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) {
   583  	if len(data) < 6 {
   584  		return allegrosql.ErrMalformPacket
   585  	}
   586  
   587  	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
   588  
   589  	stmt := cc.ctx.GetStatement(stmtID)
   590  	if stmt == nil {
   591  		return allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler,
   592  			strconv.Itoa(stmtID), "stmt_send_longdata")
   593  	}
   594  
   595  	paramID := int(binary.LittleEndian.Uint16(data[4:6]))
   596  	return stmt.AppendParam(paramID, data[6:])
   597  }
   598  
   599  func (cc *clientConn) handleStmtReset(ctx context.Context, data []byte) (err error) {
   600  	if len(data) < 4 {
   601  		return allegrosql.ErrMalformPacket
   602  	}
   603  
   604  	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
   605  	stmt := cc.ctx.GetStatement(stmtID)
   606  	if stmt == nil {
   607  		return allegrosql.NewErr(allegrosql.ErrUnknownStmtHandler,
   608  			strconv.Itoa(stmtID), "stmt_reset")
   609  	}
   610  	stmt.Reset()
   611  	stmt.StoreResultSet(nil)
   612  	return cc.writeOK(ctx)
   613  }
   614  
   615  // handleSetOption refer to https://dev.allegrosql.com/doc/internals/en/com-set-option.html
   616  func (cc *clientConn) handleSetOption(ctx context.Context, data []byte) (err error) {
   617  	if len(data) < 2 {
   618  		return allegrosql.ErrMalformPacket
   619  	}
   620  
   621  	switch binary.LittleEndian.Uint16(data[:2]) {
   622  	case 0:
   623  		cc.capability |= allegrosql.ClientMultiStatements
   624  		cc.ctx.SetClientCapability(cc.capability)
   625  	case 1:
   626  		cc.capability &^= allegrosql.ClientMultiStatements
   627  		cc.ctx.SetClientCapability(cc.capability)
   628  	default:
   629  		return allegrosql.ErrMalformPacket
   630  	}
   631  	if err = cc.writeEOF(0); err != nil {
   632  		return err
   633  	}
   634  
   635  	return cc.flush(ctx)
   636  }
   637  
   638  func (cc *clientConn) preparedStmt2String(stmtID uint32) string {
   639  	sv := cc.ctx.GetStochastikVars()
   640  	if sv == nil {
   641  		return ""
   642  	}
   643  	if config.RedactLogEnabled() {
   644  		return cc.preparedStmt2StringNoArgs(stmtID)
   645  	}
   646  	return cc.preparedStmt2StringNoArgs(stmtID) + sv.PreparedParams.String()
   647  }
   648  
   649  func (cc *clientConn) preparedStmt2StringNoArgs(stmtID uint32) string {
   650  	sv := cc.ctx.GetStochastikVars()
   651  	if sv == nil {
   652  		return ""
   653  	}
   654  	preparedPointer, ok := sv.PreparedStmts[stmtID]
   655  	if !ok {
   656  		return "prepared memex not found, ID: " + strconv.FormatUint(uint64(stmtID), 10)
   657  	}
   658  	preparedObj, ok := preparedPointer.(*causetembedded.CachedPrepareStmt)
   659  	if !ok {
   660  		return "invalidate CachedPrepareStmt type, ID: " + strconv.FormatUint(uint64(stmtID), 10)
   661  	}
   662  	preparedAst := preparedObj.PreparedAst
   663  	return preparedAst.Stmt.Text()
   664  }