github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/soliton/chunk/compare.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package chunk
    15  
    16  import (
    17  	"bytes"
    18  	"sort"
    19  
    20  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    21  	"github.com/whtcorpsinc/milevadb/types"
    22  	"github.com/whtcorpsinc/milevadb/types/json"
    23  )
    24  
    25  // CompareFunc is a function to compare the two values in Row, the two defCausumns must have the same type.
    26  type CompareFunc = func(l Row, lDefCaus int, r Row, rDefCaus int) int
    27  
    28  // GetCompareFunc gets a compare function for the field type.
    29  func GetCompareFunc(tp *types.FieldType) CompareFunc {
    30  	switch tp.Tp {
    31  	case allegrosql.TypeTiny, allegrosql.TypeShort, allegrosql.TypeInt24, allegrosql.TypeLong, allegrosql.TypeLonglong, allegrosql.TypeYear:
    32  		if allegrosql.HasUnsignedFlag(tp.Flag) {
    33  			return cmpUint64
    34  		}
    35  		return cmpInt64
    36  	case allegrosql.TypeFloat:
    37  		return cmpFloat32
    38  	case allegrosql.TypeDouble:
    39  		return cmpFloat64
    40  	case allegrosql.TypeString, allegrosql.TypeVarString, allegrosql.TypeVarchar,
    41  		allegrosql.TypeBlob, allegrosql.TypeTinyBlob, allegrosql.TypeMediumBlob, allegrosql.TypeLongBlob:
    42  		return genCmpStringFunc(tp.DefCauslate)
    43  	case allegrosql.TypeDate, allegrosql.TypeDatetime, allegrosql.TypeTimestamp:
    44  		return cmpTime
    45  	case allegrosql.TypeDuration:
    46  		return cmFIDeluration
    47  	case allegrosql.TypeNewDecimal:
    48  		return cmpMyDecimal
    49  	case allegrosql.TypeSet, allegrosql.TypeEnum:
    50  		return cmpNameValue
    51  	case allegrosql.TypeBit:
    52  		return cmpBit
    53  	case allegrosql.TypeJSON:
    54  		return cmpJSON
    55  	}
    56  	return nil
    57  }
    58  
    59  func cmpNull(lNull, rNull bool) int {
    60  	if lNull && rNull {
    61  		return 0
    62  	}
    63  	if lNull {
    64  		return -1
    65  	}
    66  	return 1
    67  }
    68  
    69  func cmpInt64(l Row, lDefCaus int, r Row, rDefCaus int) int {
    70  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
    71  	if lNull || rNull {
    72  		return cmpNull(lNull, rNull)
    73  	}
    74  	return types.CompareInt64(l.GetInt64(lDefCaus), r.GetInt64(rDefCaus))
    75  }
    76  
    77  func cmpUint64(l Row, lDefCaus int, r Row, rDefCaus int) int {
    78  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
    79  	if lNull || rNull {
    80  		return cmpNull(lNull, rNull)
    81  	}
    82  	return types.CompareUint64(l.GetUint64(lDefCaus), r.GetUint64(rDefCaus))
    83  }
    84  
    85  func genCmpStringFunc(defCauslation string) func(l Row, lDefCaus int, r Row, rDefCaus int) int {
    86  	return func(l Row, lDefCaus int, r Row, rDefCaus int) int {
    87  		return cmpStringWithDefCauslationInfo(l, lDefCaus, r, rDefCaus, defCauslation)
    88  	}
    89  }
    90  
    91  func cmpStringWithDefCauslationInfo(l Row, lDefCaus int, r Row, rDefCaus int, defCauslation string) int {
    92  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
    93  	if lNull || rNull {
    94  		return cmpNull(lNull, rNull)
    95  	}
    96  	return types.CompareString(l.GetString(lDefCaus), r.GetString(rDefCaus), defCauslation)
    97  }
    98  
    99  func cmpFloat32(l Row, lDefCaus int, r Row, rDefCaus int) int {
   100  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
   101  	if lNull || rNull {
   102  		return cmpNull(lNull, rNull)
   103  	}
   104  	return types.CompareFloat64(float64(l.GetFloat32(lDefCaus)), float64(r.GetFloat32(rDefCaus)))
   105  }
   106  
   107  func cmpFloat64(l Row, lDefCaus int, r Row, rDefCaus int) int {
   108  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
   109  	if lNull || rNull {
   110  		return cmpNull(lNull, rNull)
   111  	}
   112  	return types.CompareFloat64(l.GetFloat64(lDefCaus), r.GetFloat64(rDefCaus))
   113  }
   114  
   115  func cmpMyDecimal(l Row, lDefCaus int, r Row, rDefCaus int) int {
   116  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
   117  	if lNull || rNull {
   118  		return cmpNull(lNull, rNull)
   119  	}
   120  	lDec, rDec := l.GetMyDecimal(lDefCaus), r.GetMyDecimal(rDefCaus)
   121  	return lDec.Compare(rDec)
   122  }
   123  
   124  func cmpTime(l Row, lDefCaus int, r Row, rDefCaus int) int {
   125  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
   126  	if lNull || rNull {
   127  		return cmpNull(lNull, rNull)
   128  	}
   129  	lTime, rTime := l.GetTime(lDefCaus), r.GetTime(rDefCaus)
   130  	return lTime.Compare(rTime)
   131  }
   132  
   133  func cmFIDeluration(l Row, lDefCaus int, r Row, rDefCaus int) int {
   134  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
   135  	if lNull || rNull {
   136  		return cmpNull(lNull, rNull)
   137  	}
   138  	lDur, rDur := l.GetDuration(lDefCaus, 0).Duration, r.GetDuration(rDefCaus, 0).Duration
   139  	return types.CompareInt64(int64(lDur), int64(rDur))
   140  }
   141  
   142  func cmpNameValue(l Row, lDefCaus int, r Row, rDefCaus int) int {
   143  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
   144  	if lNull || rNull {
   145  		return cmpNull(lNull, rNull)
   146  	}
   147  	_, lVal := l.getNameValue(lDefCaus)
   148  	_, rVal := r.getNameValue(rDefCaus)
   149  	return types.CompareUint64(lVal, rVal)
   150  }
   151  
   152  func cmpBit(l Row, lDefCaus int, r Row, rDefCaus int) int {
   153  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
   154  	if lNull || rNull {
   155  		return cmpNull(lNull, rNull)
   156  	}
   157  	lBit := types.BinaryLiteral(l.GetBytes(lDefCaus))
   158  	rBit := types.BinaryLiteral(r.GetBytes(rDefCaus))
   159  	return lBit.Compare(rBit)
   160  }
   161  
   162  func cmpJSON(l Row, lDefCaus int, r Row, rDefCaus int) int {
   163  	lNull, rNull := l.IsNull(lDefCaus), r.IsNull(rDefCaus)
   164  	if lNull || rNull {
   165  		return cmpNull(lNull, rNull)
   166  	}
   167  	lJ, rJ := l.GetJSON(lDefCaus), r.GetJSON(rDefCaus)
   168  	return json.CompareBinary(lJ, rJ)
   169  }
   170  
   171  // Compare compares the value with ad.
   172  // We assume that the defCauslation information of the defCausumn is the same with the causet.
   173  func Compare(event Row, defCausIdx int, ad *types.Causet) int {
   174  	switch ad.HoTT() {
   175  	case types.HoTTNull:
   176  		if event.IsNull(defCausIdx) {
   177  			return 0
   178  		}
   179  		return 1
   180  	case types.HoTTMinNotNull:
   181  		if event.IsNull(defCausIdx) {
   182  			return -1
   183  		}
   184  		return 1
   185  	case types.HoTTMaxValue:
   186  		return -1
   187  	case types.HoTTInt64:
   188  		return types.CompareInt64(event.GetInt64(defCausIdx), ad.GetInt64())
   189  	case types.HoTTUint64:
   190  		return types.CompareUint64(event.GetUint64(defCausIdx), ad.GetUint64())
   191  	case types.HoTTFloat32:
   192  		return types.CompareFloat64(float64(event.GetFloat32(defCausIdx)), float64(ad.GetFloat32()))
   193  	case types.HoTTFloat64:
   194  		return types.CompareFloat64(event.GetFloat64(defCausIdx), ad.GetFloat64())
   195  	case types.HoTTString:
   196  		return types.CompareString(event.GetString(defCausIdx), ad.GetString(), ad.DefCauslation())
   197  	case types.HoTTBytes, types.HoTTBinaryLiteral, types.HoTTMysqlBit:
   198  		return bytes.Compare(event.GetBytes(defCausIdx), ad.GetBytes())
   199  	case types.HoTTMysqlDecimal:
   200  		l, r := event.GetMyDecimal(defCausIdx), ad.GetMysqlDecimal()
   201  		return l.Compare(r)
   202  	case types.HoTTMysqlDuration:
   203  		l, r := event.GetDuration(defCausIdx, 0).Duration, ad.GetMysqlDuration().Duration
   204  		return types.CompareInt64(int64(l), int64(r))
   205  	case types.HoTTMysqlEnum:
   206  		l, r := event.GetEnum(defCausIdx).Value, ad.GetMysqlEnum().Value
   207  		return types.CompareUint64(l, r)
   208  	case types.HoTTMysqlSet:
   209  		l, r := event.GetSet(defCausIdx).Value, ad.GetMysqlSet().Value
   210  		return types.CompareUint64(l, r)
   211  	case types.HoTTMysqlJSON:
   212  		l, r := event.GetJSON(defCausIdx), ad.GetMysqlJSON()
   213  		return json.CompareBinary(l, r)
   214  	case types.HoTTMysqlTime:
   215  		l, r := event.GetTime(defCausIdx), ad.GetMysqlTime()
   216  		return l.Compare(r)
   217  	default:
   218  		return 0
   219  	}
   220  }
   221  
   222  // LowerBound searches on the non-decreasing DeferredCauset defCausIdx,
   223  // returns the smallest index i such that the value at event i is not less than `d`.
   224  func (c *Chunk) LowerBound(defCausIdx int, d *types.Causet) (index int, match bool) {
   225  	index = sort.Search(c.NumRows(), func(i int) bool {
   226  		cmp := Compare(c.GetRow(i), defCausIdx, d)
   227  		if cmp == 0 {
   228  			match = true
   229  		}
   230  		return cmp >= 0
   231  	})
   232  	return
   233  }
   234  
   235  // UpperBound searches on the non-decreasing DeferredCauset defCausIdx,
   236  // returns the smallest index i such that the value at event i is larger than `d`.
   237  func (c *Chunk) UpperBound(defCausIdx int, d *types.Causet) int {
   238  	return sort.Search(c.NumRows(), func(i int) bool {
   239  		return Compare(c.GetRow(i), defCausIdx, d) > 0
   240  	})
   241  }