github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqlparse/tidbparser/dependency/util/chunk/compare.go (about) 1 // Copyright 2017 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package chunk 15 16 import ( 17 "sort" 18 19 "github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/mysql" 20 "github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/terror" 21 "github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/types" 22 "github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/types/json" 23 ) 24 25 // CompareFunc is a function to compare the two values in Row, the two columns must have the same type. 26 type CompareFunc = func(l Row, lCol int, r Row, rCol int) int 27 28 // GetCompareFunc gets a compare function for the field type. 29 func GetCompareFunc(tp *types.FieldType) CompareFunc { 30 switch tp.Tp { 31 case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: 32 if mysql.HasUnsignedFlag(tp.Flag) { 33 return cmpUint64 34 } 35 return cmpInt64 36 case mysql.TypeFloat: 37 return cmpFloat32 38 case mysql.TypeDouble: 39 return cmpFloat64 40 case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, 41 mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: 42 return cmpString 43 case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: 44 return cmpTime 45 case mysql.TypeDuration: 46 return cmpDuration 47 case mysql.TypeNewDecimal: 48 return cmpMyDecimal 49 case mysql.TypeSet, mysql.TypeEnum: 50 return cmpNameValue 51 case mysql.TypeBit: 52 return cmpBit 53 case mysql.TypeJSON: 54 return cmpJSON 55 } 56 return nil 57 } 58 59 func cmpNull(lNull, rNull bool) int { 60 if lNull && rNull { 61 return 0 62 } 63 if lNull { 64 return -1 65 } 66 return 1 67 } 68 69 func cmpInt64(l Row, lCol int, r Row, rCol int) int { 70 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 71 if lNull || rNull { 72 return cmpNull(lNull, rNull) 73 } 74 return types.CompareInt64(l.GetInt64(lCol), r.GetInt64(rCol)) 75 } 76 77 func cmpUint64(l Row, lCol int, r Row, rCol int) int { 78 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 79 if lNull || rNull { 80 return cmpNull(lNull, rNull) 81 } 82 return types.CompareUint64(l.GetUint64(lCol), r.GetUint64(rCol)) 83 } 84 85 func cmpString(l Row, lCol int, r Row, rCol int) int { 86 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 87 if lNull || rNull { 88 return cmpNull(lNull, rNull) 89 } 90 return types.CompareString(l.GetString(lCol), r.GetString(rCol)) 91 } 92 93 func cmpFloat32(l Row, lCol int, r Row, rCol int) int { 94 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 95 if lNull || rNull { 96 return cmpNull(lNull, rNull) 97 } 98 return types.CompareFloat64(float64(l.GetFloat32(lCol)), float64(r.GetFloat32(rCol))) 99 } 100 101 func cmpFloat64(l Row, lCol int, r Row, rCol int) int { 102 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 103 if lNull || rNull { 104 return cmpNull(lNull, rNull) 105 } 106 return types.CompareFloat64(l.GetFloat64(lCol), r.GetFloat64(rCol)) 107 } 108 109 func cmpMyDecimal(l Row, lCol int, r Row, rCol int) int { 110 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 111 if lNull || rNull { 112 return cmpNull(lNull, rNull) 113 } 114 lDec, rDec := l.GetMyDecimal(lCol), r.GetMyDecimal(rCol) 115 return lDec.Compare(rDec) 116 } 117 118 func cmpTime(l Row, lCol int, r Row, rCol int) int { 119 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 120 if lNull || rNull { 121 return cmpNull(lNull, rNull) 122 } 123 lTime, rTime := l.GetTime(lCol), r.GetTime(rCol) 124 return lTime.Compare(rTime) 125 } 126 127 func cmpDuration(l Row, lCol int, r Row, rCol int) int { 128 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 129 if lNull || rNull { 130 return cmpNull(lNull, rNull) 131 } 132 lDur, rDur := l.GetDuration(lCol), r.GetDuration(rCol) 133 return types.CompareInt64(int64(lDur.Duration), int64(rDur.Duration)) 134 } 135 136 func cmpNameValue(l Row, lCol int, r Row, rCol int) int { 137 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 138 if lNull || rNull { 139 return cmpNull(lNull, rNull) 140 } 141 _, lVal := l.getNameValue(lCol) 142 _, rVal := r.getNameValue(rCol) 143 return types.CompareUint64(lVal, rVal) 144 } 145 146 func cmpBit(l Row, lCol int, r Row, rCol int) int { 147 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 148 if lNull || rNull { 149 return cmpNull(lNull, rNull) 150 } 151 lBit := types.BinaryLiteral(l.GetBytes(lCol)) 152 rBit := types.BinaryLiteral(r.GetBytes(rCol)) 153 lUint, err := lBit.ToInt() 154 terror.Log(err) 155 rUint, err := rBit.ToInt() 156 terror.Log(err) 157 return types.CompareUint64(lUint, rUint) 158 } 159 160 func cmpJSON(l Row, lCol int, r Row, rCol int) int { 161 lNull, rNull := l.IsNull(lCol), r.IsNull(rCol) 162 if lNull || rNull { 163 return cmpNull(lNull, rNull) 164 } 165 lJ, rJ := l.GetJSON(lCol), r.GetJSON(rCol) 166 return json.CompareBinary(lJ, rJ) 167 } 168 169 func compare(row Row, colIdx int, ad *types.Datum) int { 170 switch ad.Kind() { 171 case types.KindNull: 172 if row.IsNull(colIdx) { 173 return 0 174 } 175 return 1 176 case types.KindMinNotNull: 177 if row.IsNull(colIdx) { 178 return -1 179 } 180 return 1 181 case types.KindMaxValue: 182 return -1 183 case types.KindInt64: 184 return types.CompareInt64(row.GetInt64(colIdx), ad.GetInt64()) 185 case types.KindUint64: 186 return types.CompareUint64(row.GetUint64(colIdx), ad.GetUint64()) 187 case types.KindFloat32: 188 return types.CompareFloat64(float64(row.GetFloat32(colIdx)), float64(ad.GetFloat32())) 189 case types.KindFloat64: 190 return types.CompareFloat64(row.GetFloat64(colIdx), ad.GetFloat64()) 191 case types.KindString, types.KindBytes, types.KindBinaryLiteral, types.KindMysqlBit: 192 return types.CompareString(row.GetString(colIdx), ad.GetString()) 193 case types.KindMysqlDecimal: 194 l, r := row.GetMyDecimal(colIdx), ad.GetMysqlDecimal() 195 return l.Compare(r) 196 case types.KindMysqlDuration: 197 l, r := row.GetDuration(colIdx), ad.GetMysqlDuration() 198 return types.CompareInt64(int64(l.Duration), int64(r.Duration)) 199 case types.KindMysqlEnum: 200 l, r := row.GetEnum(colIdx).Value, ad.GetMysqlEnum().Value 201 return types.CompareUint64(l, r) 202 case types.KindMysqlSet: 203 l, r := row.GetSet(colIdx).Value, ad.GetMysqlSet().Value 204 return types.CompareUint64(l, r) 205 case types.KindMysqlJSON: 206 l, r := row.GetJSON(colIdx), ad.GetMysqlJSON() 207 return json.CompareBinary(l, r) 208 case types.KindMysqlTime: 209 l, r := row.GetTime(colIdx), ad.GetMysqlTime() 210 return l.Compare(r) 211 default: 212 return 0 213 } 214 } 215 216 // LowerBound searches on the non-decreasing column colIdx, 217 // returns the smallest index i such that the value at row i is not less than `ad`. 218 func (c *Chunk) LowerBound(colIdx int, ad *types.Datum) (index int, match bool) { 219 index = sort.Search(c.NumRows(), func(i int) bool { 220 cmp := compare(c.GetRow(i), colIdx, ad) 221 if cmp == 0 { 222 match = true 223 } 224 return cmp >= 0 225 }) 226 return 227 }