github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/store/tikv/mock-tikv/cop_handler.go (about)

     1  // Copyright 2016 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package mocktikv
    15  
    16  import (
    17  	"bytes"
    18  	"encoding/binary"
    19  
    20  	"github.com/insionng/yougam/libraries/golang/protobuf/proto"
    21  	"github.com/insionng/yougam/libraries/juju/errors"
    22  	"github.com/insionng/yougam/libraries/pingcap/kvproto/pkg/coprocessor"
    23  	"github.com/insionng/yougam/libraries/pingcap/tidb/kv"
    24  	"github.com/insionng/yougam/libraries/pingcap/tidb/mysql"
    25  	"github.com/insionng/yougam/libraries/pingcap/tidb/terror"
    26  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/codec"
    27  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/types"
    28  	"github.com/insionng/yougam/libraries/pingcap/tidb/xapi/tablecodec"
    29  	"github.com/insionng/yougam/libraries/pingcap/tidb/xapi/xeval"
    30  	"github.com/insionng/yougam/libraries/pingcap/tipb/go-tipb"
    31  )
    32  
    33  type selectContext struct {
    34  	sel          *tipb.SelectRequest
    35  	eval         *xeval.Evaluator
    36  	whereColumns map[int64]*tipb.ColumnInfo
    37  }
    38  
    39  func (h *rpcHandler) handleCopRequest(req *coprocessor.Request) (*coprocessor.Response, error) {
    40  	resp := &coprocessor.Response{}
    41  	if err := h.checkContext(req.GetContext()); err != nil {
    42  		resp.RegionError = err
    43  		return resp, nil
    44  	}
    45  	if len(req.Ranges) == 0 {
    46  		return resp, nil
    47  	}
    48  	if req.GetTp() == kv.ReqTypeSelect || req.GetTp() == kv.ReqTypeIndex {
    49  		sel := new(tipb.SelectRequest)
    50  		err := proto.Unmarshal(req.Data, sel)
    51  		if err != nil {
    52  			return nil, errors.Trace(err)
    53  		}
    54  		ctx := &selectContext{
    55  			sel: sel,
    56  		}
    57  		if sel.Where != nil {
    58  			ctx.eval = &xeval.Evaluator{Row: make(map[int64]types.Datum)}
    59  			ctx.whereColumns = make(map[int64]*tipb.ColumnInfo)
    60  			collectColumnsInWhere(sel.Where, ctx)
    61  		}
    62  		var rows []*tipb.Row
    63  		if req.GetTp() == kv.ReqTypeSelect {
    64  			rows, err = h.getRowsFromSelectReq(ctx)
    65  		} else {
    66  			rows, err = h.getRowsFromIndexReq(ctx)
    67  		}
    68  		selResp := new(tipb.SelectResponse)
    69  		selResp.Error = toPBError(err)
    70  		selResp.Rows = rows
    71  		if err != nil {
    72  			resp.OtherError = proto.String(err.Error())
    73  		}
    74  		data, err := proto.Marshal(selResp)
    75  		if err != nil {
    76  			return nil, errors.Trace(err)
    77  		}
    78  		resp.Data = data
    79  	}
    80  	return resp, nil
    81  }
    82  
    83  func collectColumnsInWhere(expr *tipb.Expr, ctx *selectContext) error {
    84  	if expr == nil {
    85  		return nil
    86  	}
    87  	if expr.GetTp() == tipb.ExprType_ColumnRef {
    88  		_, i, err := codec.DecodeInt(expr.Val)
    89  		if err != nil {
    90  			return errors.Trace(err)
    91  		}
    92  		var columns []*tipb.ColumnInfo
    93  		if ctx.sel.TableInfo != nil {
    94  			columns = ctx.sel.TableInfo.Columns
    95  		} else {
    96  			columns = ctx.sel.IndexInfo.Columns
    97  		}
    98  		for _, c := range columns {
    99  			if c.GetColumnId() == i {
   100  				ctx.whereColumns[i] = c
   101  				return nil
   102  			}
   103  		}
   104  		return xeval.ErrInvalid.Gen("column %d not found", i)
   105  	}
   106  	for _, child := range expr.Children {
   107  		err := collectColumnsInWhere(child, ctx)
   108  		if err != nil {
   109  			return errors.Trace(err)
   110  		}
   111  	}
   112  	return nil
   113  }
   114  
   115  func toPBError(err error) *tipb.Error {
   116  	if err == nil {
   117  		return nil
   118  	}
   119  	perr := new(tipb.Error)
   120  	code := int32(1)
   121  	perr.Code = &code
   122  	errStr := err.Error()
   123  	perr.Msg = &errStr
   124  	return perr
   125  }
   126  
   127  func (h *rpcHandler) getRowsFromSelectReq(ctx *selectContext) ([]*tipb.Row, error) {
   128  	kvRanges, desc := h.extractKVRanges(ctx.sel)
   129  	var rows []*tipb.Row
   130  	limit := int64(-1)
   131  	if ctx.sel.Limit != nil {
   132  		limit = ctx.sel.GetLimit()
   133  	}
   134  	for _, ran := range kvRanges {
   135  		if limit == 0 {
   136  			break
   137  		}
   138  		ranRows, err := h.getRowsFromRange(ctx, ran, limit, desc)
   139  		if err != nil {
   140  			return nil, errors.Trace(err)
   141  		}
   142  		rows = append(rows, ranRows...)
   143  		limit -= int64(len(ranRows))
   144  	}
   145  	return rows, nil
   146  }
   147  
   148  // extractKVRanges extracts kv.KeyRanges slice from a SelectRequest, and also returns if it is in descending order.
   149  func (h *rpcHandler) extractKVRanges(sel *tipb.SelectRequest) (kvRanges []kv.KeyRange, desc bool) {
   150  	var (
   151  		tid   int64
   152  		idxID int64
   153  	)
   154  	if sel.IndexInfo != nil {
   155  		tid = sel.IndexInfo.GetTableId()
   156  		idxID = sel.IndexInfo.GetIndexId()
   157  	} else {
   158  		tid = sel.TableInfo.GetTableId()
   159  	}
   160  	for _, kran := range sel.Ranges {
   161  		var upperKey, lowerKey kv.Key
   162  		if idxID == 0 {
   163  			upperKey = tablecodec.EncodeRowKey(tid, kran.GetHigh())
   164  			if bytes.Compare(upperKey, h.startKey) <= 0 {
   165  				continue
   166  			}
   167  			lowerKey = tablecodec.EncodeRowKey(tid, kran.GetLow())
   168  		} else {
   169  			upperKey = tablecodec.EncodeIndexSeekKey(tid, idxID, kran.GetHigh())
   170  			if bytes.Compare(upperKey, h.startKey) <= 0 {
   171  				continue
   172  			}
   173  			lowerKey = tablecodec.EncodeIndexSeekKey(tid, idxID, kran.GetLow())
   174  		}
   175  		if len(h.endKey) != 0 && bytes.Compare([]byte(lowerKey), h.endKey) >= 0 {
   176  			break
   177  		}
   178  		var kvr kv.KeyRange
   179  		kvr.StartKey = kv.Key(maxStartKey(lowerKey, h.startKey))
   180  		kvr.EndKey = kv.Key(minEndKey(upperKey, h.endKey))
   181  		kvRanges = append(kvRanges, kvr)
   182  	}
   183  	if sel.OrderBy != nil {
   184  		desc = *sel.OrderBy[0].Desc
   185  	}
   186  	if desc {
   187  		reverseKVRanges(kvRanges)
   188  	}
   189  	return
   190  }
   191  
   192  func reverseKVRanges(kvRanges []kv.KeyRange) {
   193  	for i := 0; i < len(kvRanges)/2; i++ {
   194  		j := len(kvRanges) - i - 1
   195  		kvRanges[i], kvRanges[j] = kvRanges[j], kvRanges[i]
   196  	}
   197  }
   198  
   199  func (h *rpcHandler) getRowsFromRange(ctx *selectContext, ran kv.KeyRange, limit int64, desc bool) ([]*tipb.Row, error) {
   200  	startKey := maxStartKey(ran.StartKey, h.startKey)
   201  	endKey := minEndKey(ran.EndKey, h.endKey)
   202  	if limit == 0 || bytes.Compare(startKey, endKey) >= 0 {
   203  		return nil, nil
   204  	}
   205  	var rows []*tipb.Row
   206  	if ran.IsPoint() {
   207  		val, err := h.mvccStore.Get(startKey, ctx.sel.GetStartTs())
   208  		if len(val) == 0 {
   209  			return nil, nil
   210  		} else if err != nil {
   211  			return nil, errors.Trace(err)
   212  		}
   213  		handle, err := tablecodec.DecodeRowKey(kv.Key(startKey))
   214  		if err != nil {
   215  			return nil, errors.Trace(err)
   216  		}
   217  		match, err := h.evalWhereForRow(ctx, handle)
   218  		if err != nil {
   219  			return nil, errors.Trace(err)
   220  		}
   221  		if !match {
   222  			return nil, nil
   223  		}
   224  		row, err := h.getRowByHandle(ctx, handle)
   225  		if err != nil {
   226  			return nil, errors.Trace(err)
   227  		}
   228  		if row != nil {
   229  			rows = append(rows, row)
   230  		}
   231  		return rows, nil
   232  	}
   233  	var seekKey []byte
   234  	if desc {
   235  		seekKey = endKey
   236  	} else {
   237  		seekKey = startKey
   238  	}
   239  	for {
   240  		if limit == 0 {
   241  			break
   242  		}
   243  		var (
   244  			pairs []Pair
   245  			pair  Pair
   246  			err   error
   247  		)
   248  		if desc {
   249  			pairs = h.mvccStore.ReverseScan(startKey, seekKey, 1, ctx.sel.GetStartTs())
   250  		} else {
   251  			pairs = h.mvccStore.Scan(seekKey, endKey, 1, ctx.sel.GetStartTs())
   252  		}
   253  		if len(pairs) > 0 {
   254  			pair = pairs[0]
   255  		}
   256  		if pair.Err != nil {
   257  			// TODO: handle lock error.
   258  			return nil, errors.Trace(pair.Err)
   259  		}
   260  		if pair.Key == nil {
   261  			break
   262  		}
   263  		if desc {
   264  			if bytes.Compare(pair.Key, startKey) < 0 {
   265  				break
   266  			}
   267  			seekKey = pair.Key
   268  		} else {
   269  			if bytes.Compare(pair.Key, endKey) >= 0 {
   270  				break
   271  			}
   272  			seekKey = []byte(kv.Key(pair.Key).PrefixNext())
   273  		}
   274  		handle, err := tablecodec.DecodeRowKey(pair.Key)
   275  		if err != nil {
   276  			return nil, errors.Trace(err)
   277  		}
   278  		match, err := h.evalWhereForRow(ctx, handle)
   279  		if err != nil {
   280  			return nil, errors.Trace(err)
   281  		}
   282  		if !match {
   283  			continue
   284  		}
   285  		row, err := h.getRowByHandle(ctx, handle)
   286  		if err != nil {
   287  			return nil, errors.Trace(err)
   288  		}
   289  		if row != nil {
   290  			rows = append(rows, row)
   291  			limit--
   292  		}
   293  	}
   294  	return rows, nil
   295  }
   296  
   297  func (h *rpcHandler) getRowByHandle(ctx *selectContext, handle int64) (*tipb.Row, error) {
   298  	tid := ctx.sel.TableInfo.GetTableId()
   299  	columns := ctx.sel.TableInfo.Columns
   300  	row := new(tipb.Row)
   301  	var d types.Datum
   302  	d.SetInt64(handle)
   303  	var err error
   304  	row.Handle, err = codec.EncodeValue(nil, d)
   305  	if err != nil {
   306  		return nil, errors.Trace(err)
   307  	}
   308  	for _, col := range columns {
   309  		if col.GetPkHandle() {
   310  			if mysql.HasUnsignedFlag(uint(col.GetFlag())) {
   311  				row.Data, err = codec.EncodeValue(row.Data, types.NewUintDatum(uint64(handle)))
   312  				if err != nil {
   313  					return nil, errors.Trace(err)
   314  				}
   315  			} else {
   316  				row.Data = append(row.Data, row.Handle...)
   317  			}
   318  		} else {
   319  			colID := col.GetColumnId()
   320  			if ctx.whereColumns[colID] != nil {
   321  				// The column is saved in evaluator, use it directly.
   322  				datum := ctx.eval.Row[colID]
   323  				row.Data, err = codec.EncodeValue(row.Data, datum)
   324  				if err != nil {
   325  					return nil, errors.Trace(err)
   326  				}
   327  			} else {
   328  				key := tablecodec.EncodeColumnKey(tid, handle, colID)
   329  				data, err1 := h.mvccStore.Get(key, ctx.sel.GetStartTs())
   330  				if err1 != nil {
   331  					return nil, errors.Trace(err1)
   332  				}
   333  				if data == nil {
   334  					if mysql.HasNotNullFlag(uint(col.GetFlag())) {
   335  						return nil, errors.Trace(kv.ErrNotExist)
   336  					}
   337  					row.Data = append(row.Data, codec.NilFlag)
   338  				} else {
   339  					row.Data = append(row.Data, data...)
   340  				}
   341  			}
   342  		}
   343  	}
   344  	return row, nil
   345  }
   346  
   347  func (h *rpcHandler) evalWhereForRow(ctx *selectContext, handle int64) (bool, error) {
   348  	if ctx.sel.Where == nil {
   349  		return true, nil
   350  	}
   351  	tid := ctx.sel.TableInfo.GetTableId()
   352  	for colID, col := range ctx.whereColumns {
   353  		if col.GetPkHandle() {
   354  			if mysql.HasUnsignedFlag(uint(col.GetFlag())) {
   355  				ctx.eval.Row[colID] = types.NewUintDatum(uint64(handle))
   356  			} else {
   357  				ctx.eval.Row[colID] = types.NewIntDatum(handle)
   358  			}
   359  		} else {
   360  			key := tablecodec.EncodeColumnKey(tid, handle, colID)
   361  			data, err := h.mvccStore.Get(key, ctx.sel.GetStartTs())
   362  			if err != nil {
   363  				return false, errors.Trace(err)
   364  			}
   365  			if data == nil {
   366  				if mysql.HasNotNullFlag(uint(col.GetFlag())) {
   367  					return false, errors.Trace(kv.ErrNotExist)
   368  				}
   369  				ctx.eval.Row[colID] = types.Datum{}
   370  			} else {
   371  				_, datum, err := codec.DecodeOne(data)
   372  				if err != nil {
   373  					return false, errors.Trace(err)
   374  				}
   375  				ctx.eval.Row[colID] = datum
   376  			}
   377  		}
   378  	}
   379  	result, err := ctx.eval.Eval(ctx.sel.Where)
   380  	if err != nil {
   381  		return false, errors.Trace(err)
   382  	}
   383  	if result.Kind() == types.KindNull {
   384  		return false, nil
   385  	}
   386  	boolResult, err := result.ToBool()
   387  	if err != nil {
   388  		return false, errors.Trace(err)
   389  	}
   390  	return boolResult == 1, nil
   391  }
   392  
   393  func (h *rpcHandler) getRowsFromIndexReq(ctx *selectContext) ([]*tipb.Row, error) {
   394  	kvRanges, desc := h.extractKVRanges(ctx.sel)
   395  	var rows []*tipb.Row
   396  	limit := int64(-1)
   397  	if ctx.sel.Limit != nil {
   398  		limit = ctx.sel.GetLimit()
   399  	}
   400  	for _, ran := range kvRanges {
   401  		if limit == 0 {
   402  			break
   403  		}
   404  		ranRows, err := h.getIndexRowFromRange(ctx.sel, ran, desc, limit)
   405  		if err != nil {
   406  			return nil, errors.Trace(err)
   407  		}
   408  		rows = append(rows, ranRows...)
   409  		limit -= int64(len(ranRows))
   410  	}
   411  	return rows, nil
   412  }
   413  
   414  func (h *rpcHandler) getIndexRowFromRange(sel *tipb.SelectRequest, ran kv.KeyRange, desc bool, limit int64) ([]*tipb.Row, error) {
   415  	startKey := maxStartKey(ran.StartKey, h.startKey)
   416  	endKey := minEndKey(ran.EndKey, h.endKey)
   417  	if limit == 0 || bytes.Compare(startKey, endKey) >= 0 {
   418  		return nil, nil
   419  	}
   420  	var rows []*tipb.Row
   421  	var seekKey kv.Key
   422  	if desc {
   423  		seekKey = endKey
   424  	} else {
   425  		seekKey = startKey
   426  	}
   427  	for {
   428  		if limit == 0 {
   429  			break
   430  		}
   431  		var (
   432  			pairs []Pair
   433  			pair  Pair
   434  			err   error
   435  		)
   436  		if desc {
   437  			pairs = h.mvccStore.ReverseScan(startKey, seekKey, 1, sel.GetStartTs())
   438  		} else {
   439  			pairs = h.mvccStore.Scan(seekKey, endKey, 1, sel.GetStartTs())
   440  		}
   441  		if len(pairs) > 0 {
   442  			pair = pairs[0]
   443  		}
   444  		if pair.Err != nil {
   445  			// TODO: handle lock error.
   446  			return nil, errors.Trace(pair.Err)
   447  		}
   448  		if pair.Key == nil {
   449  			break
   450  		}
   451  		if desc {
   452  			if bytes.Compare(pair.Key, startKey) < 0 {
   453  				break
   454  			}
   455  			seekKey = pair.Key
   456  		} else {
   457  			if bytes.Compare(pair.Key, endKey) >= 0 {
   458  				break
   459  			}
   460  			seekKey = []byte(kv.Key(pair.Key).PrefixNext())
   461  		}
   462  
   463  		datums, err := tablecodec.DecodeIndexKey(pair.Key)
   464  		if err != nil {
   465  			return nil, errors.Trace(err)
   466  		}
   467  		var handle types.Datum
   468  		columns := sel.IndexInfo.Columns
   469  		if len(datums) > len(columns) {
   470  			handle = datums[len(columns)]
   471  			datums = datums[:len(columns)]
   472  		} else {
   473  			var intHandle int64
   474  			intHandle, err = decodeHandle(pair.Value)
   475  			if err != nil {
   476  				return nil, errors.Trace(err)
   477  			}
   478  			handle.SetInt64(intHandle)
   479  		}
   480  		data, err := codec.EncodeValue(nil, datums...)
   481  		if err != nil {
   482  			return nil, errors.Trace(err)
   483  		}
   484  		handleData, err := codec.EncodeValue(nil, handle)
   485  		if err != nil {
   486  			return nil, errors.Trace(err)
   487  		}
   488  		row := &tipb.Row{Handle: handleData, Data: data}
   489  		rows = append(rows, row)
   490  		limit--
   491  	}
   492  	return rows, nil
   493  }
   494  
   495  func maxStartKey(rangeStartKey kv.Key, regionStartKey []byte) []byte {
   496  	if bytes.Compare([]byte(rangeStartKey), regionStartKey) > 0 {
   497  		return []byte(rangeStartKey)
   498  	}
   499  	return regionStartKey
   500  }
   501  
   502  func minEndKey(rangeEndKey kv.Key, regionEndKey []byte) []byte {
   503  	if len(regionEndKey) == 0 || bytes.Compare([]byte(rangeEndKey), regionEndKey) < 0 {
   504  		return []byte(rangeEndKey)
   505  	}
   506  	return regionEndKey
   507  }
   508  
   509  func decodeHandle(data []byte) (int64, error) {
   510  	var h int64
   511  	buf := bytes.NewBuffer(data)
   512  	err := binary.Read(buf, binary.BigEndian, &h)
   513  	return h, errors.Trace(err)
   514  }
   515  
   516  func isDefaultNull(err error, col *tipb.ColumnInfo) bool {
   517  	return terror.ErrorEqual(err, kv.ErrNotExist) && !mysql.HasNotNullFlag(uint(col.GetFlag()))
   518  }