github.com/whtcorpsinc/MilevaDB-Prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/milevadb-server/statistics/scalar.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package statistics
    15  
    16  import (
    17  	"encoding/binary"
    18  	"math"
    19  	"time"
    20  
    21  	"github.com/cznic/mathutil"
    22  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    23  	"github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx"
    24  	"github.com/whtcorpsinc/milevadb/types"
    25  )
    26  
    27  // calcFraction is used to calculate the fraction of the interval [lower, upper] that lies within the [lower, value]
    28  // using the continuous-value assumption.
    29  func calcFraction(lower, upper, value float64) float64 {
    30  	if upper <= lower {
    31  		return 0.5
    32  	}
    33  	if value <= lower {
    34  		return 0
    35  	}
    36  	if value >= upper {
    37  		return 1
    38  	}
    39  	frac := (value - lower) / (upper - lower)
    40  	if math.IsNaN(frac) || math.IsInf(frac, 0) || frac < 0 || frac > 1 {
    41  		return 0.5
    42  	}
    43  	return frac
    44  }
    45  
    46  func convertCausetToScalar(value *types.Causet, commonPfxLen int) float64 {
    47  	switch value.HoTT() {
    48  	case types.HoTTMysqlDecimal:
    49  		scalar, err := value.GetMysqlDecimal().ToFloat64()
    50  		if err != nil {
    51  			return 0
    52  		}
    53  		return scalar
    54  	case types.HoTTMysqlTime:
    55  		valueTime := value.GetMysqlTime()
    56  		var minTime types.Time
    57  		switch valueTime.Type() {
    58  		case allegrosql.TypeDate:
    59  			minTime = types.NewTime(types.MinDatetime, allegrosql.TypeDate, types.DefaultFsp)
    60  		case allegrosql.TypeDatetime:
    61  			minTime = types.NewTime(types.MinDatetime, allegrosql.TypeDatetime, types.DefaultFsp)
    62  		case allegrosql.TypeTimestamp:
    63  			minTime = types.MinTimestamp
    64  		}
    65  		sc := &stmtctx.StatementContext{TimeZone: types.BoundTimezone}
    66  		return float64(valueTime.Sub(sc, &minTime).Duration)
    67  	case types.HoTTString, types.HoTTBytes:
    68  		bytes := value.GetBytes()
    69  		if len(bytes) <= commonPfxLen {
    70  			return 0
    71  		}
    72  		return convertBytesToScalar(bytes[commonPfxLen:])
    73  	default:
    74  		// do not know how to convert
    75  		return 0
    76  	}
    77  }
    78  
    79  // PreCalculateScalar converts the lower and upper to scalar. When the causet type is HoTTString or HoTTBytes, we also
    80  // calculate their common prefix length, because when a value falls between lower and upper, the common prefix
    81  // of lower and upper equals to the common prefix of the lower, upper and the value. For some simple types like `Int64`,
    82  // we do not convert it because we can directly infer the scalar value.
    83  func (hg *Histogram) PreCalculateScalar() {
    84  	len := hg.Len()
    85  	if len == 0 {
    86  		return
    87  	}
    88  	switch hg.GetLower(0).HoTT() {
    89  	case types.HoTTMysqlDecimal, types.HoTTMysqlTime:
    90  		hg.scalars = make([]scalar, len)
    91  		for i := 0; i < len; i++ {
    92  			hg.scalars[i] = scalar{
    93  				lower: convertCausetToScalar(hg.GetLower(i), 0),
    94  				upper: convertCausetToScalar(hg.GetUpper(i), 0),
    95  			}
    96  		}
    97  	case types.HoTTBytes, types.HoTTString:
    98  		hg.scalars = make([]scalar, len)
    99  		for i := 0; i < len; i++ {
   100  			lower, upper := hg.GetLower(i), hg.GetUpper(i)
   101  			common := commonPrefixLength(lower.GetBytes(), upper.GetBytes())
   102  			hg.scalars[i] = scalar{
   103  				commonPfxLen: common,
   104  				lower:        convertCausetToScalar(lower, common),
   105  				upper:        convertCausetToScalar(upper, common),
   106  			}
   107  		}
   108  	}
   109  }
   110  
   111  func (hg *Histogram) calcFraction(index int, value *types.Causet) float64 {
   112  	lower, upper := hg.Bounds.GetRow(2*index), hg.Bounds.GetRow(2*index+1)
   113  	switch value.HoTT() {
   114  	case types.HoTTFloat32:
   115  		return calcFraction(float64(lower.GetFloat32(0)), float64(upper.GetFloat32(0)), float64(value.GetFloat32()))
   116  	case types.HoTTFloat64:
   117  		return calcFraction(lower.GetFloat64(0), upper.GetFloat64(0), value.GetFloat64())
   118  	case types.HoTTInt64:
   119  		return calcFraction(float64(lower.GetInt64(0)), float64(upper.GetInt64(0)), float64(value.GetInt64()))
   120  	case types.HoTTUint64:
   121  		return calcFraction(float64(lower.GetUint64(0)), float64(upper.GetUint64(0)), float64(value.GetUint64()))
   122  	case types.HoTTMysqlDuration:
   123  		return calcFraction(float64(lower.GetDuration(0, 0).Duration), float64(upper.GetDuration(0, 0).Duration), float64(value.GetMysqlDuration().Duration))
   124  	case types.HoTTMysqlDecimal, types.HoTTMysqlTime:
   125  		return calcFraction(hg.scalars[index].lower, hg.scalars[index].upper, convertCausetToScalar(value, 0))
   126  	case types.HoTTBytes, types.HoTTString:
   127  		return calcFraction(hg.scalars[index].lower, hg.scalars[index].upper, convertCausetToScalar(value, hg.scalars[index].commonPfxLen))
   128  	}
   129  	return 0.5
   130  }
   131  
   132  func commonPrefixLength(lower, upper []byte) int {
   133  	minLen := len(lower)
   134  	if minLen > len(upper) {
   135  		minLen = len(upper)
   136  	}
   137  	for i := 0; i < minLen; i++ {
   138  		if lower[i] != upper[i] {
   139  			return i
   140  		}
   141  	}
   142  	return minLen
   143  }
   144  
   145  func convertBytesToScalar(value []byte) float64 {
   146  	// Bytes type is viewed as a base-256 value, so we only consider at most 8 bytes.
   147  	var buf [8]byte
   148  	copy(buf[:], value)
   149  	return float64(binary.BigEndian.Uint64(buf[:]))
   150  }
   151  
   152  func calcFraction4Causets(lower, upper, value *types.Causet) float64 {
   153  	switch value.HoTT() {
   154  	case types.HoTTFloat32:
   155  		return calcFraction(float64(lower.GetFloat32()), float64(upper.GetFloat32()), float64(value.GetFloat32()))
   156  	case types.HoTTFloat64:
   157  		return calcFraction(lower.GetFloat64(), upper.GetFloat64(), value.GetFloat64())
   158  	case types.HoTTInt64:
   159  		return calcFraction(float64(lower.GetInt64()), float64(upper.GetInt64()), float64(value.GetInt64()))
   160  	case types.HoTTUint64:
   161  		return calcFraction(float64(lower.GetUint64()), float64(upper.GetUint64()), float64(value.GetUint64()))
   162  	case types.HoTTMysqlDuration:
   163  		return calcFraction(float64(lower.GetMysqlDuration().Duration), float64(upper.GetMysqlDuration().Duration), float64(value.GetMysqlDuration().Duration))
   164  	case types.HoTTMysqlDecimal, types.HoTTMysqlTime:
   165  		return calcFraction(convertCausetToScalar(lower, 0), convertCausetToScalar(upper, 0), convertCausetToScalar(value, 0))
   166  	case types.HoTTBytes, types.HoTTString:
   167  		commonPfxLen := commonPrefixLength(lower.GetBytes(), upper.GetBytes())
   168  		return calcFraction(convertCausetToScalar(lower, commonPfxLen), convertCausetToScalar(upper, commonPfxLen), convertCausetToScalar(value, commonPfxLen))
   169  	}
   170  	return 0.5
   171  }
   172  
   173  const maxNumStep = 10
   174  
   175  func enumRangeValues(low, high types.Causet, lowExclude, highExclude bool) []types.Causet {
   176  	if low.HoTT() != high.HoTT() {
   177  		return nil
   178  	}
   179  	exclude := 0
   180  	if lowExclude {
   181  		exclude++
   182  	}
   183  	if highExclude {
   184  		exclude++
   185  	}
   186  	switch low.HoTT() {
   187  	case types.HoTTInt64:
   188  		// Overflow check.
   189  		lowVal, highVal := low.GetInt64(), high.GetInt64()
   190  		if lowVal <= 0 && highVal >= 0 {
   191  			if lowVal < -maxNumStep || highVal > maxNumStep {
   192  				return nil
   193  			}
   194  		}
   195  		remaining := highVal - lowVal
   196  		if remaining >= maxNumStep+1 {
   197  			return nil
   198  		}
   199  		remaining = remaining + 1 - int64(exclude)
   200  		if remaining >= maxNumStep {
   201  			return nil
   202  		}
   203  		values := make([]types.Causet, 0, remaining)
   204  		startValue := lowVal
   205  		if lowExclude {
   206  			startValue++
   207  		}
   208  		for i := int64(0); i < remaining; i++ {
   209  			values = append(values, types.NewIntCauset(startValue+i))
   210  		}
   211  		return values
   212  	case types.HoTTUint64:
   213  		remaining := high.GetUint64() - low.GetUint64()
   214  		if remaining >= maxNumStep+1 {
   215  			return nil
   216  		}
   217  		remaining = remaining + 1 - uint64(exclude)
   218  		if remaining >= maxNumStep {
   219  			return nil
   220  		}
   221  		values := make([]types.Causet, 0, remaining)
   222  		startValue := low.GetUint64()
   223  		if lowExclude {
   224  			startValue++
   225  		}
   226  		for i := uint64(0); i < remaining; i++ {
   227  			values = append(values, types.NewUintCauset(startValue+i))
   228  		}
   229  		return values
   230  	case types.HoTTMysqlDuration:
   231  		lowDur, highDur := low.GetMysqlDuration(), high.GetMysqlDuration()
   232  		fsp := mathutil.MaxInt8(lowDur.Fsp, highDur.Fsp)
   233  		stepSize := int64(math.Pow10(int(types.MaxFsp-fsp))) * int64(time.Microsecond)
   234  		lowDur.Duration = lowDur.Duration.Round(time.Duration(stepSize))
   235  		remaining := int64(highDur.Duration-lowDur.Duration)/stepSize + 1 - int64(exclude)
   236  		if remaining >= maxNumStep {
   237  			return nil
   238  		}
   239  		startValue := int64(lowDur.Duration)
   240  		if lowExclude {
   241  			startValue += stepSize
   242  		}
   243  		values := make([]types.Causet, 0, remaining)
   244  		for i := int64(0); i < remaining; i++ {
   245  			values = append(values, types.NewDurationCauset(types.Duration{Duration: time.Duration(startValue + i*stepSize), Fsp: fsp}))
   246  		}
   247  		return values
   248  	case types.HoTTMysqlTime:
   249  		lowTime, highTime := low.GetMysqlTime(), high.GetMysqlTime()
   250  		if lowTime.Type() != highTime.Type() {
   251  			return nil
   252  		}
   253  		fsp := mathutil.MaxInt8(lowTime.Fsp(), highTime.Fsp())
   254  		var stepSize int64
   255  		sc := &stmtctx.StatementContext{TimeZone: time.UTC}
   256  		if lowTime.Type() == allegrosql.TypeDate {
   257  			stepSize = 24 * int64(time.Hour)
   258  			lowTime.SetCoreTime(types.FromDate(lowTime.Year(), lowTime.Month(), lowTime.Day(), 0, 0, 0, 0))
   259  		} else {
   260  			var err error
   261  			lowTime, err = lowTime.RoundFrac(sc, fsp)
   262  			if err != nil {
   263  				return nil
   264  			}
   265  			stepSize = int64(math.Pow10(int(types.MaxFsp-fsp))) * int64(time.Microsecond)
   266  		}
   267  		remaining := int64(highTime.Sub(sc, &lowTime).Duration)/stepSize + 1 - int64(exclude)
   268  		if remaining >= maxNumStep {
   269  			return nil
   270  		}
   271  		startValue := lowTime
   272  		var err error
   273  		if lowExclude {
   274  			startValue, err = lowTime.Add(sc, types.Duration{Duration: time.Duration(stepSize), Fsp: fsp})
   275  			if err != nil {
   276  				return nil
   277  			}
   278  		}
   279  		values := make([]types.Causet, 0, remaining)
   280  		for i := int64(0); i < remaining; i++ {
   281  			value, err := startValue.Add(sc, types.Duration{Duration: time.Duration(i * stepSize), Fsp: fsp})
   282  			if err != nil {
   283  				return nil
   284  			}
   285  			values = append(values, types.NewTimeCauset(value))
   286  		}
   287  		return values
   288  	}
   289  	return nil
   290  }