github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/server/conn_stmt.go (about)

     1  // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     5  // You can obtain one at http://mozilla.org/MPL/2.0/.
     6  
     7  // The MIT License (MIT)
     8  //
     9  // Copyright (c) 2014 wandoulabs
    10  // Copyright (c) 2014 siddontang
    11  //
    12  // Permission is hereby granted, free of charge, to any person obtaining a copy of
    13  // this software and associated documentation files (the "Software"), to deal in
    14  // the Software without restriction, including without limitation the rights to
    15  // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
    16  // the Software, and to permit persons to whom the Software is furnished to do so,
    17  // subject to the following conditions:
    18  //
    19  // The above copyright notice and this permission notice shall be included in all
    20  // copies or substantial portions of the Software.
    21  
    22  // Copyright 2015 PingCAP, Inc.
    23  //
    24  // Licensed under the Apache License, Version 2.0 (the "License");
    25  // you may not use this file except in compliance with the License.
    26  // You may obtain a copy of the License at
    27  //
    28  //     http://www.apache.org/licenses/LICENSE-2.0
    29  //
    30  // Unless required by applicable law or agreed to in writing, software
    31  // distributed under the License is distributed on an "AS IS" BASIS,
    32  // See the License for the specific language governing permissions and
    33  // limitations under the License.
    34  
    35  package server
    36  
    37  import (
    38  	"encoding/binary"
    39  	"math"
    40  	"strconv"
    41  
    42  	"github.com/insionng/yougam/libraries/juju/errors"
    43  	"github.com/insionng/yougam/libraries/pingcap/tidb/mysql"
    44  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/hack"
    45  )
    46  
    47  func (cc *clientConn) handleStmtPrepare(sql string) error {
    48  	stmt, columns, params, err := cc.ctx.Prepare(sql)
    49  	if err != nil {
    50  		return errors.Trace(err)
    51  	}
    52  	data := make([]byte, 4, 128)
    53  
    54  	//status ok
    55  	data = append(data, 0)
    56  	//stmt id
    57  	data = append(data, dumpUint32(uint32(stmt.ID()))...)
    58  	//number columns
    59  	data = append(data, dumpUint16(uint16(len(columns)))...)
    60  	//number params
    61  	data = append(data, dumpUint16(uint16(len(params)))...)
    62  	//filter [00]
    63  	data = append(data, 0)
    64  	//warning count
    65  	data = append(data, 0, 0) //TODO support warning count
    66  
    67  	if err := cc.writePacket(data); err != nil {
    68  		return errors.Trace(err)
    69  	}
    70  
    71  	if len(params) > 0 {
    72  		for i := 0; i < len(params); i++ {
    73  			data = data[0:4]
    74  			data = append(data, params[i].Dump(cc.alloc)...)
    75  
    76  			if err := cc.writePacket(data); err != nil {
    77  				return errors.Trace(err)
    78  			}
    79  		}
    80  
    81  		if err := cc.writeEOF(); err != nil {
    82  			return errors.Trace(err)
    83  		}
    84  	}
    85  
    86  	if len(columns) > 0 {
    87  		for i := 0; i < len(columns); i++ {
    88  			data = data[0:4]
    89  			data = append(data, columns[i].Dump(cc.alloc)...)
    90  
    91  			if err := cc.writePacket(data); err != nil {
    92  				return errors.Trace(err)
    93  			}
    94  		}
    95  
    96  		if err := cc.writeEOF(); err != nil {
    97  			return errors.Trace(err)
    98  		}
    99  
   100  	}
   101  	return errors.Trace(cc.flush())
   102  }
   103  
   104  func (cc *clientConn) handleStmtExecute(data []byte) (err error) {
   105  	if len(data) < 9 {
   106  		return mysql.ErrMalformPacket
   107  	}
   108  
   109  	pos := 0
   110  	stmtID := binary.LittleEndian.Uint32(data[0:4])
   111  	pos += 4
   112  
   113  	stmt := cc.ctx.GetStatement(int(stmtID))
   114  	if stmt == nil {
   115  		return mysql.NewErr(mysql.ErrUnknownStmtHandler,
   116  			strconv.FormatUint(uint64(stmtID), 10), "stmt_execute")
   117  	}
   118  
   119  	flag := data[pos]
   120  	pos++
   121  	//now we only support CURSOR_TYPE_NO_CURSOR flag
   122  	if flag != 0 {
   123  		return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag %d", flag)
   124  	}
   125  
   126  	//skip iteration-count, always 1
   127  	pos += 4
   128  
   129  	var (
   130  		nullBitmaps []byte
   131  		paramTypes  []byte
   132  		paramValues []byte
   133  	)
   134  	numParams := stmt.NumParams()
   135  	args := make([]interface{}, numParams)
   136  	if numParams > 0 {
   137  		nullBitmapLen := (numParams + 7) >> 3
   138  		if len(data) < (pos + nullBitmapLen + 1) {
   139  			return mysql.ErrMalformPacket
   140  		}
   141  		nullBitmaps = data[pos : pos+nullBitmapLen]
   142  		pos += nullBitmapLen
   143  
   144  		//new param bound flag
   145  		if data[pos] == 1 {
   146  			pos++
   147  			if len(data) < (pos + (numParams << 1)) {
   148  				return mysql.ErrMalformPacket
   149  			}
   150  
   151  			paramTypes = data[pos : pos+(numParams<<1)]
   152  			pos += (numParams << 1)
   153  			paramValues = data[pos:]
   154  		}
   155  
   156  		err = parseStmtArgs(args, stmt.BoundParams(), nullBitmaps, paramTypes, paramValues)
   157  		if err != nil {
   158  			return errors.Trace(err)
   159  		}
   160  	}
   161  	rs, err := stmt.Execute(args...)
   162  	if err != nil {
   163  		return errors.Trace(err)
   164  	}
   165  	if rs == nil {
   166  		return errors.Trace(cc.writeOK())
   167  	}
   168  
   169  	return errors.Trace(cc.writeResultset(rs, true))
   170  }
   171  
   172  func parseStmtArgs(args []interface{}, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte) (err error) {
   173  	pos := 0
   174  	var v []byte
   175  	var n int
   176  	var isNull bool
   177  
   178  	for i := 0; i < len(args); i++ {
   179  		if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
   180  			args[i] = nil
   181  			continue
   182  		}
   183  		if boundParams[i] != nil {
   184  			args[i] = boundParams[i]
   185  			continue
   186  		}
   187  
   188  		tp := paramTypes[i<<1]
   189  		isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
   190  
   191  		switch tp {
   192  		case mysql.TypeNull:
   193  			args[i] = nil
   194  			continue
   195  
   196  		case mysql.TypeTiny:
   197  			if len(paramValues) < (pos + 1) {
   198  				err = mysql.ErrMalformPacket
   199  				return
   200  			}
   201  
   202  			if isUnsigned {
   203  				args[i] = uint64(paramValues[pos])
   204  			} else {
   205  				args[i] = int64(paramValues[pos])
   206  			}
   207  
   208  			pos++
   209  			continue
   210  
   211  		case mysql.TypeShort, mysql.TypeYear:
   212  			if len(paramValues) < (pos + 2) {
   213  				err = mysql.ErrMalformPacket
   214  				return
   215  			}
   216  			valU16 := binary.LittleEndian.Uint16(paramValues[pos : pos+2])
   217  			if isUnsigned {
   218  				args[i] = uint64(valU16)
   219  			} else {
   220  				args[i] = int64(valU16)
   221  			}
   222  			pos += 2
   223  			continue
   224  
   225  		case mysql.TypeInt24, mysql.TypeLong:
   226  			if len(paramValues) < (pos + 4) {
   227  				err = mysql.ErrMalformPacket
   228  				return
   229  			}
   230  			valU32 := binary.LittleEndian.Uint32(paramValues[pos : pos+4])
   231  			if isUnsigned {
   232  				args[i] = uint64(valU32)
   233  			} else {
   234  				args[i] = int64(valU32)
   235  			}
   236  			pos += 4
   237  			continue
   238  
   239  		case mysql.TypeLonglong:
   240  			if len(paramValues) < (pos + 8) {
   241  				err = mysql.ErrMalformPacket
   242  				return
   243  			}
   244  			valU64 := binary.LittleEndian.Uint64(paramValues[pos : pos+8])
   245  			if isUnsigned {
   246  				args[i] = valU64
   247  			} else {
   248  				args[i] = int64(valU64)
   249  			}
   250  			pos += 8
   251  			continue
   252  
   253  		case mysql.TypeFloat:
   254  			if len(paramValues) < (pos + 4) {
   255  				err = mysql.ErrMalformPacket
   256  				return
   257  			}
   258  
   259  			args[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4])))
   260  			pos += 4
   261  			continue
   262  
   263  		case mysql.TypeDouble:
   264  			if len(paramValues) < (pos + 8) {
   265  				err = mysql.ErrMalformPacket
   266  				return
   267  			}
   268  
   269  			args[i] = math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8]))
   270  			pos += 8
   271  			continue
   272  
   273  		case mysql.TypeDecimal, mysql.TypeNewDecimal, mysql.TypeVarchar,
   274  			mysql.TypeBit, mysql.TypeEnum, mysql.TypeSet, mysql.TypeTinyBlob,
   275  			mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob,
   276  			mysql.TypeVarString, mysql.TypeString, mysql.TypeGeometry,
   277  			mysql.TypeDate, mysql.TypeNewDate,
   278  			mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration:
   279  			if len(paramValues) < (pos + 1) {
   280  				err = mysql.ErrMalformPacket
   281  				return
   282  			}
   283  
   284  			v, isNull, n, err = parseLengthEncodedBytes(paramValues[pos:])
   285  			pos += n
   286  			if err != nil {
   287  				return
   288  			}
   289  
   290  			if !isNull {
   291  				args[i] = hack.String(v)
   292  			} else {
   293  				args[i] = nil
   294  			}
   295  			continue
   296  		default:
   297  			err = errors.Errorf("Stmt Unknown FieldType %d", tp)
   298  			return
   299  		}
   300  	}
   301  	return
   302  }
   303  
   304  func (cc *clientConn) handleStmtClose(data []byte) (err error) {
   305  	if len(data) < 4 {
   306  		return
   307  	}
   308  
   309  	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
   310  	stmt := cc.ctx.GetStatement(stmtID)
   311  	if stmt != nil {
   312  		return errors.Trace(stmt.Close())
   313  	}
   314  	return
   315  }
   316  
   317  func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) {
   318  	if len(data) < 6 {
   319  		return mysql.ErrMalformPacket
   320  	}
   321  
   322  	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
   323  
   324  	stmt := cc.ctx.GetStatement(stmtID)
   325  	if stmt == nil {
   326  		return mysql.NewErr(mysql.ErrUnknownStmtHandler,
   327  			strconv.Itoa(stmtID), "stmt_send_longdata")
   328  	}
   329  
   330  	paramID := int(binary.LittleEndian.Uint16(data[4:6]))
   331  	return stmt.AppendParam(paramID, data[6:])
   332  }
   333  
   334  func (cc *clientConn) handleStmtReset(data []byte) (err error) {
   335  	if len(data) < 4 {
   336  		return mysql.ErrMalformPacket
   337  	}
   338  
   339  	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
   340  	stmt := cc.ctx.GetStatement(stmtID)
   341  	if stmt == nil {
   342  		return mysql.NewErr(mysql.ErrUnknownStmtHandler,
   343  			strconv.Itoa(stmtID), "stmt_reset")
   344  	}
   345  	stmt.Reset()
   346  	return cc.writeOK()
   347  }