github.com/XiaoMi/Gaea@v1.2.5/proxy/server/executor_stmt.go (about) 1 // Copyright 2016 The kingshard Authors. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"): you may 4 // not use this file except in compliance with the License. You may obtain 5 // a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations 13 // under the License. 14 15 // Copyright 2019 The Gaea Authors. All Rights Reserved. 16 // 17 // Licensed under the Apache License, Version 2.0 (the "License"); 18 // you may not use this file except in compliance with the License. 19 // You may obtain a copy of the License at 20 // 21 // http://www.apache.org/licenses/LICENSE-2.0 22 // 23 // Unless required by applicable law or agreed to in writing, software 24 // distributed under the License is distributed on an "AS IS" BASIS, 25 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 // See the License for the specific language governing permissions and 27 // limitations under the License. 28 29 package server 30 31 import ( 32 "encoding/binary" 33 "errors" 34 "fmt" 35 "math" 36 "strconv" 37 38 "github.com/XiaoMi/Gaea/mysql" 39 "github.com/XiaoMi/Gaea/util" 40 ) 41 42 var p = &mysql.Field{Name: []byte("?")} 43 var c = &mysql.Field{} 44 45 func calcParams(sql string) (paramCount int, offsets []int, err error) { 46 count := 0 47 quoteChar := "" 48 paramCount = 0 49 offsets = make([]int, 0) 50 51 for i, elem := range []byte(sql) { 52 if elem == '\\' { 53 continue 54 } else if elem == '"' || elem == '\'' { 55 if quoteChar == "" { 56 quoteChar = string(elem) 57 } else if quoteChar == string(elem) { 58 quoteChar = "" 59 } 60 } else if quoteChar == "" && elem == '?' { 61 count++ 62 offsets = append(offsets, i) 63 } 64 65 } 66 if quoteChar != "" { 67 err = fmt.Errorf("fatal situation") 68 return 69 } 70 71 paramCount = count 72 73 return 74 } 75 76 func escapeSQL(sql string) string { 77 t := make([]byte, 0, len(sql)) 78 for _, elem := range []byte(sql) { 79 if elem == '\\' || elem == '\'' { 80 t = append(t, '\\') 81 } 82 t = append(t, elem) 83 } 84 return string(t) 85 } 86 87 // Stmt prepare statement struct 88 type Stmt struct { 89 id uint32 90 sql string 91 args []interface{} 92 columnCount int 93 paramCount int 94 paramTypes []byte 95 offsets []int 96 } 97 98 // ResetParams reset args 99 func (s *Stmt) ResetParams() { 100 s.args = make([]interface{}, s.paramCount) 101 } 102 103 func (s *Stmt) SetParamTypes(paramTypes []byte) { 104 s.paramTypes = paramTypes 105 } 106 107 func (s *Stmt) GetParamTypes() []byte { 108 return s.paramTypes 109 } 110 111 // GetRewriteSQL get rewrite sql 112 func (s *Stmt) GetRewriteSQL() (string, error) { 113 sql := s.sql 114 var tmp = "" 115 var pos = 0 116 var offset = 0 117 var quote = false 118 for i := 0; i < s.paramCount; i++ { 119 quote, tmp = util.ItoString(s.args[i]) 120 tmp = escapeSQL(tmp) 121 pos = s.offsets[i] 122 if quote { 123 sql = util.Concat(sql[:pos+offset], "'", tmp, "'", sql[pos+offset+1:]) 124 offset = offset + len(tmp) - 1 + 2 125 } else { 126 sql = util.Concat(sql[:pos+offset], tmp, sql[pos+offset+1:]) 127 offset = offset + len(tmp) - 1 128 } 129 } 130 return sql, nil 131 } 132 133 func (se *SessionExecutor) handleStmtExecute(data []byte) (*mysql.Result, error) { 134 if len(data) < 9 { 135 return nil, mysql.ErrMalformPacket 136 } 137 138 pos := 0 139 id := binary.LittleEndian.Uint32(data[0:4]) 140 pos += 4 141 142 s, ok := se.stmts[id] 143 if !ok { 144 return nil, mysql.NewDefaultError(mysql.ErrUnknownStmtHandler, 145 strconv.FormatUint(uint64(id), 10), "stmt_execute") 146 } 147 148 flag := data[pos] & mysql.CursorTypeReadOnly 149 pos++ 150 //now we only support CURSOR_TYPE_NO_CURSOR flag 151 if flag != 0 { 152 return nil, mysql.NewError(mysql.ErrUnknown, fmt.Sprintf("unsupported flag %d", flag)) 153 } 154 155 //skip iteration-count, always 1 156 pos += 4 157 158 var nullBitmaps []byte 159 var paramTypes []byte 160 var paramValues []byte 161 162 paramNum := s.paramCount 163 164 var executeSQL string 165 var err error 166 if paramNum > 0 { 167 nullBitmapLen := (s.paramCount + 7) >> 3 168 if len(data) < (pos + nullBitmapLen + 1) { 169 return nil, mysql.ErrMalformPacket 170 } 171 nullBitmaps = data[pos : pos+nullBitmapLen] 172 pos += nullBitmapLen 173 174 //new param bound flag 175 if data[pos] == 1 { 176 pos++ 177 if len(data) < (pos + (paramNum << 1)) { 178 return nil, mysql.ErrMalformPacket 179 } 180 181 paramTypes = data[pos : pos+(paramNum<<1)] 182 pos += (paramNum << 1) 183 184 paramValues = data[pos:] 185 s.SetParamTypes(paramTypes) 186 } else { 187 paramValues = data[pos+1:] 188 } 189 190 if err := se.bindStmtArgs(s, nullBitmaps, s.GetParamTypes(), paramValues); err != nil { 191 return nil, err 192 } 193 194 executeSQL, err = s.GetRewriteSQL() 195 if err != nil { 196 return nil, err 197 } 198 } else { 199 executeSQL = s.sql 200 } 201 202 defer s.ResetParams() 203 204 // execute sql using ComQuery 205 r, err := se.handleQuery(executeSQL) 206 if err != nil { 207 return nil, err 208 } 209 210 // build binary result set 211 if r != nil && r.Resultset != nil { 212 resultSet, err := mysql.BuildBinaryResultset(r.Fields, r.Values) 213 if err != nil { 214 return nil, err 215 } 216 r.Resultset = resultSet 217 } 218 219 return r, nil 220 } 221 222 // long data and generic args are all in s.args 223 func (se *SessionExecutor) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) error { 224 args := s.args 225 226 pos := 0 227 228 var v []byte 229 var isNull bool 230 231 for i := 0; i < s.paramCount; i++ { 232 if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 { 233 args[i] = nil 234 continue 235 } 236 237 if (i<<1)+1 >= len(paramTypes) { 238 return mysql.ErrMalformPacket 239 } 240 241 tp := paramTypes[i<<1] 242 isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 243 244 if s.args[i] != nil { 245 continue 246 } 247 switch tp { 248 case mysql.TypeNull: 249 args[i] = nil 250 continue 251 252 case mysql.TypeTiny: 253 if len(paramValues) < (pos + 1) { 254 return mysql.ErrMalformPacket 255 } 256 257 if isUnsigned { 258 args[i] = uint8(paramValues[pos]) 259 } else { 260 args[i] = int8(paramValues[pos]) 261 } 262 263 pos++ 264 continue 265 266 case mysql.TypeShort, mysql.TypeYear: 267 if len(paramValues) < (pos + 2) { 268 return mysql.ErrMalformPacket 269 } 270 271 if isUnsigned { 272 args[i] = uint16(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) 273 } else { 274 args[i] = int16((binary.LittleEndian.Uint16(paramValues[pos : pos+2]))) 275 } 276 pos += 2 277 continue 278 279 case mysql.TypeInt24, mysql.TypeLong: 280 if len(paramValues) < (pos + 4) { 281 return mysql.ErrMalformPacket 282 } 283 284 if isUnsigned { 285 args[i] = uint32(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) 286 } else { 287 args[i] = int32(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) 288 } 289 pos += 4 290 continue 291 292 case mysql.TypeLonglong: 293 if len(paramValues) < (pos + 8) { 294 return mysql.ErrMalformPacket 295 } 296 297 if isUnsigned { 298 args[i] = binary.LittleEndian.Uint64(paramValues[pos : pos+8]) 299 } else { 300 args[i] = int64(binary.LittleEndian.Uint64(paramValues[pos : pos+8])) 301 } 302 pos += 8 303 continue 304 305 case mysql.TypeFloat: 306 if len(paramValues) < (pos + 4) { 307 return mysql.ErrMalformPacket 308 } 309 310 args[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4]))) 311 pos += 4 312 continue 313 314 case mysql.TypeDouble: 315 if len(paramValues) < (pos + 8) { 316 return mysql.ErrMalformPacket 317 } 318 319 args[i] = math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8])) 320 pos += 8 321 continue 322 323 case mysql.TypeDecimal, mysql.TypeNewDecimal, mysql.TypeVarchar, 324 mysql.TypeBit, mysql.TypeEnum, mysql.TypeSet, mysql.TypeTinyBlob, 325 mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob, 326 mysql.TypeVarString, mysql.TypeString, mysql.TypeGeometry, 327 mysql.TypeDate, mysql.TypeNewDate, 328 mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeJSON: 329 if len(paramValues) < (pos + 1) { 330 return mysql.ErrMalformPacket 331 } 332 333 var ok = false 334 v, pos, isNull, ok = mysql.ReadLenEncStringAsBytes(paramValues, pos) 335 if !ok { 336 return errors.New("ReadLenEncStringAsBytes in bindStmtArgs failed") 337 } 338 339 if !isNull { 340 args[i] = v 341 continue 342 } else { 343 args[i] = nil 344 continue 345 } 346 default: 347 return fmt.Errorf("Stmt Unknown FieldType %d", tp) 348 } 349 } 350 return nil 351 } 352 353 func (se *SessionExecutor) handleStmtSendLongData(data []byte) error { 354 if len(data) < 6 { 355 return mysql.ErrMalformPacket 356 } 357 358 id := binary.LittleEndian.Uint32(data[0:4]) 359 360 s, ok := se.stmts[id] 361 if !ok { 362 return mysql.NewDefaultError(mysql.ErrUnknownStmtHandler, 363 strconv.FormatUint(uint64(id), 10), "stmt_send_longdata") 364 } 365 366 paramID := binary.LittleEndian.Uint16(data[4:6]) 367 if paramID >= uint16(s.paramCount) { 368 return mysql.NewDefaultError(mysql.ErrWrongArguments, "stmt_send_longdata") 369 } 370 371 if s.args[paramID] == nil { 372 tmpSlice := make([]byte, len(data)-6) 373 copy(tmpSlice, data[6:]) 374 s.args[paramID] = tmpSlice 375 } else { 376 if b, ok := s.args[paramID].([]byte); ok { 377 b = append(b, data[6:]...) 378 s.args[paramID] = b 379 } else { 380 return fmt.Errorf("invalid param long data type %T", s.args[paramID]) 381 } 382 } 383 384 return nil 385 } 386 387 func (se *SessionExecutor) handleStmtReset(data []byte) error { 388 if len(data) < 4 { 389 return mysql.ErrMalformPacket 390 } 391 392 id := binary.LittleEndian.Uint32(data[0:4]) 393 394 s, ok := se.stmts[id] 395 if !ok { 396 return mysql.NewDefaultError(mysql.ErrUnknownStmtHandler, 397 strconv.FormatUint(uint64(id), 10), "stmt_reset") 398 } 399 400 s.ResetParams() 401 return nil 402 }