github.com/XiaoMi/Gaea@v1.2.5/proxy/server/executor_stmt.go (about)

     1  // Copyright 2016 The kingshard Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"): you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    11  // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    12  // License for the specific language governing permissions and limitations
    13  // under the License.
    14  
    15  // Copyright 2019 The Gaea Authors. All Rights Reserved.
    16  //
    17  // Licensed under the Apache License, Version 2.0 (the "License");
    18  // you may not use this file except in compliance with the License.
    19  // You may obtain a copy of the License at
    20  //
    21  //     http://www.apache.org/licenses/LICENSE-2.0
    22  //
    23  // Unless required by applicable law or agreed to in writing, software
    24  // distributed under the License is distributed on an "AS IS" BASIS,
    25  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    26  // See the License for the specific language governing permissions and
    27  // limitations under the License.
    28  
    29  package server
    30  
    31  import (
    32  	"encoding/binary"
    33  	"errors"
    34  	"fmt"
    35  	"math"
    36  	"strconv"
    37  
    38  	"github.com/XiaoMi/Gaea/mysql"
    39  	"github.com/XiaoMi/Gaea/util"
    40  )
    41  
    42  var p = &mysql.Field{Name: []byte("?")}
    43  var c = &mysql.Field{}
    44  
    45  func calcParams(sql string) (paramCount int, offsets []int, err error) {
    46  	count := 0
    47  	quoteChar := ""
    48  	paramCount = 0
    49  	offsets = make([]int, 0)
    50  
    51  	for i, elem := range []byte(sql) {
    52  		if elem == '\\' {
    53  			continue
    54  		} else if elem == '"' || elem == '\'' {
    55  			if quoteChar == "" {
    56  				quoteChar = string(elem)
    57  			} else if quoteChar == string(elem) {
    58  				quoteChar = ""
    59  			}
    60  		} else if quoteChar == "" && elem == '?' {
    61  			count++
    62  			offsets = append(offsets, i)
    63  		}
    64  
    65  	}
    66  	if quoteChar != "" {
    67  		err = fmt.Errorf("fatal situation")
    68  		return
    69  	}
    70  
    71  	paramCount = count
    72  
    73  	return
    74  }
    75  
    76  func escapeSQL(sql string) string {
    77  	t := make([]byte, 0, len(sql))
    78  	for _, elem := range []byte(sql) {
    79  		if elem == '\\' || elem == '\'' {
    80  			t = append(t, '\\')
    81  		}
    82  		t = append(t, elem)
    83  	}
    84  	return string(t)
    85  }
    86  
    87  // Stmt prepare statement struct
    88  type Stmt struct {
    89  	id          uint32
    90  	sql         string
    91  	args        []interface{}
    92  	columnCount int
    93  	paramCount  int
    94  	paramTypes  []byte
    95  	offsets     []int
    96  }
    97  
    98  // ResetParams reset args
    99  func (s *Stmt) ResetParams() {
   100  	s.args = make([]interface{}, s.paramCount)
   101  }
   102  
   103  func (s *Stmt) SetParamTypes(paramTypes []byte) {
   104  	s.paramTypes = paramTypes
   105  }
   106  
   107  func (s *Stmt) GetParamTypes() []byte {
   108  	return s.paramTypes
   109  }
   110  
   111  // GetRewriteSQL get rewrite sql
   112  func (s *Stmt) GetRewriteSQL() (string, error) {
   113  	sql := s.sql
   114  	var tmp = ""
   115  	var pos = 0
   116  	var offset = 0
   117  	var quote = false
   118  	for i := 0; i < s.paramCount; i++ {
   119  		quote, tmp = util.ItoString(s.args[i])
   120  		tmp = escapeSQL(tmp)
   121  		pos = s.offsets[i]
   122  		if quote {
   123  			sql = util.Concat(sql[:pos+offset], "'", tmp, "'", sql[pos+offset+1:])
   124  			offset = offset + len(tmp) - 1 + 2
   125  		} else {
   126  			sql = util.Concat(sql[:pos+offset], tmp, sql[pos+offset+1:])
   127  			offset = offset + len(tmp) - 1
   128  		}
   129  	}
   130  	return sql, nil
   131  }
   132  
   133  func (se *SessionExecutor) handleStmtExecute(data []byte) (*mysql.Result, error) {
   134  	if len(data) < 9 {
   135  		return nil, mysql.ErrMalformPacket
   136  	}
   137  
   138  	pos := 0
   139  	id := binary.LittleEndian.Uint32(data[0:4])
   140  	pos += 4
   141  
   142  	s, ok := se.stmts[id]
   143  	if !ok {
   144  		return nil, mysql.NewDefaultError(mysql.ErrUnknownStmtHandler,
   145  			strconv.FormatUint(uint64(id), 10), "stmt_execute")
   146  	}
   147  
   148  	flag := data[pos] & mysql.CursorTypeReadOnly
   149  	pos++
   150  	//now we only support CURSOR_TYPE_NO_CURSOR flag
   151  	if flag != 0 {
   152  		return nil, mysql.NewError(mysql.ErrUnknown, fmt.Sprintf("unsupported flag %d", flag))
   153  	}
   154  
   155  	//skip iteration-count, always 1
   156  	pos += 4
   157  
   158  	var nullBitmaps []byte
   159  	var paramTypes []byte
   160  	var paramValues []byte
   161  
   162  	paramNum := s.paramCount
   163  
   164  	var executeSQL string
   165  	var err error
   166  	if paramNum > 0 {
   167  		nullBitmapLen := (s.paramCount + 7) >> 3
   168  		if len(data) < (pos + nullBitmapLen + 1) {
   169  			return nil, mysql.ErrMalformPacket
   170  		}
   171  		nullBitmaps = data[pos : pos+nullBitmapLen]
   172  		pos += nullBitmapLen
   173  
   174  		//new param bound flag
   175  		if data[pos] == 1 {
   176  			pos++
   177  			if len(data) < (pos + (paramNum << 1)) {
   178  				return nil, mysql.ErrMalformPacket
   179  			}
   180  
   181  			paramTypes = data[pos : pos+(paramNum<<1)]
   182  			pos += (paramNum << 1)
   183  
   184  			paramValues = data[pos:]
   185  			s.SetParamTypes(paramTypes)
   186  		} else {
   187  			paramValues = data[pos+1:]
   188  		}
   189  
   190  		if err := se.bindStmtArgs(s, nullBitmaps, s.GetParamTypes(), paramValues); err != nil {
   191  			return nil, err
   192  		}
   193  
   194  		executeSQL, err = s.GetRewriteSQL()
   195  		if err != nil {
   196  			return nil, err
   197  		}
   198  	} else {
   199  		executeSQL = s.sql
   200  	}
   201  
   202  	defer s.ResetParams()
   203  
   204  	// execute sql using ComQuery
   205  	r, err := se.handleQuery(executeSQL)
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	// build binary result set
   211  	if r != nil && r.Resultset != nil {
   212  		resultSet, err := mysql.BuildBinaryResultset(r.Fields, r.Values)
   213  		if err != nil {
   214  			return nil, err
   215  		}
   216  		r.Resultset = resultSet
   217  	}
   218  
   219  	return r, nil
   220  }
   221  
   222  // long data and generic args are all in s.args
   223  func (se *SessionExecutor) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) error {
   224  	args := s.args
   225  
   226  	pos := 0
   227  
   228  	var v []byte
   229  	var isNull bool
   230  
   231  	for i := 0; i < s.paramCount; i++ {
   232  		if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
   233  			args[i] = nil
   234  			continue
   235  		}
   236  
   237  		if (i<<1)+1 >= len(paramTypes) {
   238  			return mysql.ErrMalformPacket
   239  		}
   240  
   241  		tp := paramTypes[i<<1]
   242  		isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
   243  
   244  		if s.args[i] != nil {
   245  			continue
   246  		}
   247  		switch tp {
   248  		case mysql.TypeNull:
   249  			args[i] = nil
   250  			continue
   251  
   252  		case mysql.TypeTiny:
   253  			if len(paramValues) < (pos + 1) {
   254  				return mysql.ErrMalformPacket
   255  			}
   256  
   257  			if isUnsigned {
   258  				args[i] = uint8(paramValues[pos])
   259  			} else {
   260  				args[i] = int8(paramValues[pos])
   261  			}
   262  
   263  			pos++
   264  			continue
   265  
   266  		case mysql.TypeShort, mysql.TypeYear:
   267  			if len(paramValues) < (pos + 2) {
   268  				return mysql.ErrMalformPacket
   269  			}
   270  
   271  			if isUnsigned {
   272  				args[i] = uint16(binary.LittleEndian.Uint16(paramValues[pos : pos+2]))
   273  			} else {
   274  				args[i] = int16((binary.LittleEndian.Uint16(paramValues[pos : pos+2])))
   275  			}
   276  			pos += 2
   277  			continue
   278  
   279  		case mysql.TypeInt24, mysql.TypeLong:
   280  			if len(paramValues) < (pos + 4) {
   281  				return mysql.ErrMalformPacket
   282  			}
   283  
   284  			if isUnsigned {
   285  				args[i] = uint32(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
   286  			} else {
   287  				args[i] = int32(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))
   288  			}
   289  			pos += 4
   290  			continue
   291  
   292  		case mysql.TypeLonglong:
   293  			if len(paramValues) < (pos + 8) {
   294  				return mysql.ErrMalformPacket
   295  			}
   296  
   297  			if isUnsigned {
   298  				args[i] = binary.LittleEndian.Uint64(paramValues[pos : pos+8])
   299  			} else {
   300  				args[i] = int64(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
   301  			}
   302  			pos += 8
   303  			continue
   304  
   305  		case mysql.TypeFloat:
   306  			if len(paramValues) < (pos + 4) {
   307  				return mysql.ErrMalformPacket
   308  			}
   309  
   310  			args[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4])))
   311  			pos += 4
   312  			continue
   313  
   314  		case mysql.TypeDouble:
   315  			if len(paramValues) < (pos + 8) {
   316  				return mysql.ErrMalformPacket
   317  			}
   318  
   319  			args[i] = math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
   320  			pos += 8
   321  			continue
   322  
   323  		case mysql.TypeDecimal, mysql.TypeNewDecimal, mysql.TypeVarchar,
   324  			mysql.TypeBit, mysql.TypeEnum, mysql.TypeSet, mysql.TypeTinyBlob,
   325  			mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob,
   326  			mysql.TypeVarString, mysql.TypeString, mysql.TypeGeometry,
   327  			mysql.TypeDate, mysql.TypeNewDate,
   328  			mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeJSON:
   329  			if len(paramValues) < (pos + 1) {
   330  				return mysql.ErrMalformPacket
   331  			}
   332  
   333  			var ok = false
   334  			v, pos, isNull, ok = mysql.ReadLenEncStringAsBytes(paramValues, pos)
   335  			if !ok {
   336  				return errors.New("ReadLenEncStringAsBytes in bindStmtArgs failed")
   337  			}
   338  
   339  			if !isNull {
   340  				args[i] = v
   341  				continue
   342  			} else {
   343  				args[i] = nil
   344  				continue
   345  			}
   346  		default:
   347  			return fmt.Errorf("Stmt Unknown FieldType %d", tp)
   348  		}
   349  	}
   350  	return nil
   351  }
   352  
   353  func (se *SessionExecutor) handleStmtSendLongData(data []byte) error {
   354  	if len(data) < 6 {
   355  		return mysql.ErrMalformPacket
   356  	}
   357  
   358  	id := binary.LittleEndian.Uint32(data[0:4])
   359  
   360  	s, ok := se.stmts[id]
   361  	if !ok {
   362  		return mysql.NewDefaultError(mysql.ErrUnknownStmtHandler,
   363  			strconv.FormatUint(uint64(id), 10), "stmt_send_longdata")
   364  	}
   365  
   366  	paramID := binary.LittleEndian.Uint16(data[4:6])
   367  	if paramID >= uint16(s.paramCount) {
   368  		return mysql.NewDefaultError(mysql.ErrWrongArguments, "stmt_send_longdata")
   369  	}
   370  
   371  	if s.args[paramID] == nil {
   372  		tmpSlice := make([]byte, len(data)-6)
   373  		copy(tmpSlice, data[6:])
   374  		s.args[paramID] = tmpSlice
   375  	} else {
   376  		if b, ok := s.args[paramID].([]byte); ok {
   377  			b = append(b, data[6:]...)
   378  			s.args[paramID] = b
   379  		} else {
   380  			return fmt.Errorf("invalid param long data type %T", s.args[paramID])
   381  		}
   382  	}
   383  
   384  	return nil
   385  }
   386  
   387  func (se *SessionExecutor) handleStmtReset(data []byte) error {
   388  	if len(data) < 4 {
   389  		return mysql.ErrMalformPacket
   390  	}
   391  
   392  	id := binary.LittleEndian.Uint32(data[0:4])
   393  
   394  	s, ok := se.stmts[id]
   395  	if !ok {
   396  		return mysql.NewDefaultError(mysql.ErrUnknownStmtHandler,
   397  			strconv.FormatUint(uint64(id), 10), "stmt_reset")
   398  	}
   399  
   400  	s.ResetParams()
   401  	return nil
   402  }