github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqlparse/tidbparser/dependency/util/chunk/compare.go (about)

     1  // Copyright 2017 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package chunk
    15  
    16  import (
    17  	"sort"
    18  
    19  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/mysql"
    20  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/terror"
    21  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/types"
    22  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/types/json"
    23  )
    24  
    25  // CompareFunc is a function to compare the two values in Row, the two columns must have the same type.
    26  type CompareFunc = func(l Row, lCol int, r Row, rCol int) int
    27  
    28  // GetCompareFunc gets a compare function for the field type.
    29  func GetCompareFunc(tp *types.FieldType) CompareFunc {
    30  	switch tp.Tp {
    31  	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
    32  		if mysql.HasUnsignedFlag(tp.Flag) {
    33  			return cmpUint64
    34  		}
    35  		return cmpInt64
    36  	case mysql.TypeFloat:
    37  		return cmpFloat32
    38  	case mysql.TypeDouble:
    39  		return cmpFloat64
    40  	case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar,
    41  		mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
    42  		return cmpString
    43  	case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
    44  		return cmpTime
    45  	case mysql.TypeDuration:
    46  		return cmpDuration
    47  	case mysql.TypeNewDecimal:
    48  		return cmpMyDecimal
    49  	case mysql.TypeSet, mysql.TypeEnum:
    50  		return cmpNameValue
    51  	case mysql.TypeBit:
    52  		return cmpBit
    53  	case mysql.TypeJSON:
    54  		return cmpJSON
    55  	}
    56  	return nil
    57  }
    58  
    59  func cmpNull(lNull, rNull bool) int {
    60  	if lNull && rNull {
    61  		return 0
    62  	}
    63  	if lNull {
    64  		return -1
    65  	}
    66  	return 1
    67  }
    68  
    69  func cmpInt64(l Row, lCol int, r Row, rCol int) int {
    70  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
    71  	if lNull || rNull {
    72  		return cmpNull(lNull, rNull)
    73  	}
    74  	return types.CompareInt64(l.GetInt64(lCol), r.GetInt64(rCol))
    75  }
    76  
    77  func cmpUint64(l Row, lCol int, r Row, rCol int) int {
    78  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
    79  	if lNull || rNull {
    80  		return cmpNull(lNull, rNull)
    81  	}
    82  	return types.CompareUint64(l.GetUint64(lCol), r.GetUint64(rCol))
    83  }
    84  
    85  func cmpString(l Row, lCol int, r Row, rCol int) int {
    86  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
    87  	if lNull || rNull {
    88  		return cmpNull(lNull, rNull)
    89  	}
    90  	return types.CompareString(l.GetString(lCol), r.GetString(rCol))
    91  }
    92  
    93  func cmpFloat32(l Row, lCol int, r Row, rCol int) int {
    94  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
    95  	if lNull || rNull {
    96  		return cmpNull(lNull, rNull)
    97  	}
    98  	return types.CompareFloat64(float64(l.GetFloat32(lCol)), float64(r.GetFloat32(rCol)))
    99  }
   100  
   101  func cmpFloat64(l Row, lCol int, r Row, rCol int) int {
   102  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
   103  	if lNull || rNull {
   104  		return cmpNull(lNull, rNull)
   105  	}
   106  	return types.CompareFloat64(l.GetFloat64(lCol), r.GetFloat64(rCol))
   107  }
   108  
   109  func cmpMyDecimal(l Row, lCol int, r Row, rCol int) int {
   110  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
   111  	if lNull || rNull {
   112  		return cmpNull(lNull, rNull)
   113  	}
   114  	lDec, rDec := l.GetMyDecimal(lCol), r.GetMyDecimal(rCol)
   115  	return lDec.Compare(rDec)
   116  }
   117  
   118  func cmpTime(l Row, lCol int, r Row, rCol int) int {
   119  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
   120  	if lNull || rNull {
   121  		return cmpNull(lNull, rNull)
   122  	}
   123  	lTime, rTime := l.GetTime(lCol), r.GetTime(rCol)
   124  	return lTime.Compare(rTime)
   125  }
   126  
   127  func cmpDuration(l Row, lCol int, r Row, rCol int) int {
   128  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
   129  	if lNull || rNull {
   130  		return cmpNull(lNull, rNull)
   131  	}
   132  	lDur, rDur := l.GetDuration(lCol), r.GetDuration(rCol)
   133  	return types.CompareInt64(int64(lDur.Duration), int64(rDur.Duration))
   134  }
   135  
   136  func cmpNameValue(l Row, lCol int, r Row, rCol int) int {
   137  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
   138  	if lNull || rNull {
   139  		return cmpNull(lNull, rNull)
   140  	}
   141  	_, lVal := l.getNameValue(lCol)
   142  	_, rVal := r.getNameValue(rCol)
   143  	return types.CompareUint64(lVal, rVal)
   144  }
   145  
   146  func cmpBit(l Row, lCol int, r Row, rCol int) int {
   147  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
   148  	if lNull || rNull {
   149  		return cmpNull(lNull, rNull)
   150  	}
   151  	lBit := types.BinaryLiteral(l.GetBytes(lCol))
   152  	rBit := types.BinaryLiteral(r.GetBytes(rCol))
   153  	lUint, err := lBit.ToInt()
   154  	terror.Log(err)
   155  	rUint, err := rBit.ToInt()
   156  	terror.Log(err)
   157  	return types.CompareUint64(lUint, rUint)
   158  }
   159  
   160  func cmpJSON(l Row, lCol int, r Row, rCol int) int {
   161  	lNull, rNull := l.IsNull(lCol), r.IsNull(rCol)
   162  	if lNull || rNull {
   163  		return cmpNull(lNull, rNull)
   164  	}
   165  	lJ, rJ := l.GetJSON(lCol), r.GetJSON(rCol)
   166  	return json.CompareBinary(lJ, rJ)
   167  }
   168  
   169  func compare(row Row, colIdx int, ad *types.Datum) int {
   170  	switch ad.Kind() {
   171  	case types.KindNull:
   172  		if row.IsNull(colIdx) {
   173  			return 0
   174  		}
   175  		return 1
   176  	case types.KindMinNotNull:
   177  		if row.IsNull(colIdx) {
   178  			return -1
   179  		}
   180  		return 1
   181  	case types.KindMaxValue:
   182  		return -1
   183  	case types.KindInt64:
   184  		return types.CompareInt64(row.GetInt64(colIdx), ad.GetInt64())
   185  	case types.KindUint64:
   186  		return types.CompareUint64(row.GetUint64(colIdx), ad.GetUint64())
   187  	case types.KindFloat32:
   188  		return types.CompareFloat64(float64(row.GetFloat32(colIdx)), float64(ad.GetFloat32()))
   189  	case types.KindFloat64:
   190  		return types.CompareFloat64(row.GetFloat64(colIdx), ad.GetFloat64())
   191  	case types.KindString, types.KindBytes, types.KindBinaryLiteral, types.KindMysqlBit:
   192  		return types.CompareString(row.GetString(colIdx), ad.GetString())
   193  	case types.KindMysqlDecimal:
   194  		l, r := row.GetMyDecimal(colIdx), ad.GetMysqlDecimal()
   195  		return l.Compare(r)
   196  	case types.KindMysqlDuration:
   197  		l, r := row.GetDuration(colIdx), ad.GetMysqlDuration()
   198  		return types.CompareInt64(int64(l.Duration), int64(r.Duration))
   199  	case types.KindMysqlEnum:
   200  		l, r := row.GetEnum(colIdx).Value, ad.GetMysqlEnum().Value
   201  		return types.CompareUint64(l, r)
   202  	case types.KindMysqlSet:
   203  		l, r := row.GetSet(colIdx).Value, ad.GetMysqlSet().Value
   204  		return types.CompareUint64(l, r)
   205  	case types.KindMysqlJSON:
   206  		l, r := row.GetJSON(colIdx), ad.GetMysqlJSON()
   207  		return json.CompareBinary(l, r)
   208  	case types.KindMysqlTime:
   209  		l, r := row.GetTime(colIdx), ad.GetMysqlTime()
   210  		return l.Compare(r)
   211  	default:
   212  		return 0
   213  	}
   214  }
   215  
   216  // LowerBound searches on the non-decreasing column colIdx,
   217  // returns the smallest index i such that the value at row i is not less than `ad`.
   218  func (c *Chunk) LowerBound(colIdx int, ad *types.Datum) (index int, match bool) {
   219  	index = sort.Search(c.NumRows(), func(i int) bool {
   220  		cmp := compare(c.GetRow(i), colIdx, ad)
   221  		if cmp == 0 {
   222  			match = true
   223  		}
   224  		return cmp >= 0
   225  	})
   226  	return
   227  }