github.com/dolthub/go-mysql-server@v0.18.0/sql/types/system_uint.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"reflect"
    19  	"strconv"
    20  
    21  	"github.com/dolthub/vitess/go/sqltypes"
    22  	"github.com/dolthub/vitess/go/vt/proto/query"
    23  	"github.com/shopspring/decimal"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  )
    27  
    28  var systemUintValueType = reflect.TypeOf(uint64(0))
    29  
    30  // systemUintType is an internal unsigned integer type ONLY for system variables.
    31  type systemUintType struct {
    32  	varName    string
    33  	lowerbound uint64
    34  	upperbound uint64
    35  }
    36  
    37  var _ sql.SystemVariableType = systemUintType{}
    38  var _ sql.CollationCoercible = systemUintType{}
    39  
    40  // NewSystemUintType returns a new systemUintType.
    41  func NewSystemUintType(varName string, lowerbound, upperbound uint64) sql.SystemVariableType {
    42  	return systemUintType{varName, lowerbound, upperbound}
    43  }
    44  
    45  // Compare implements Type interface.
    46  func (t systemUintType) Compare(a interface{}, b interface{}) (int, error) {
    47  	as, _, err := t.Convert(a)
    48  	if err != nil {
    49  		return 0, err
    50  	}
    51  	bs, _, err := t.Convert(b)
    52  	if err != nil {
    53  		return 0, err
    54  	}
    55  	ai := as.(uint64)
    56  	bi := bs.(uint64)
    57  
    58  	if ai == bi {
    59  		return 0, nil
    60  	}
    61  	if ai < bi {
    62  		return -1, nil
    63  	}
    64  	return 1, nil
    65  }
    66  
    67  // Convert implements Type interface.
    68  func (t systemUintType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) {
    69  	// Float, string, nor nil values are accepted
    70  	switch value := v.(type) {
    71  	case int:
    72  		return t.Convert(uint64(value))
    73  	case uint:
    74  		return t.Convert(uint64(value))
    75  	case int8:
    76  		return t.Convert(uint64(value))
    77  	case uint8:
    78  		return t.Convert(uint64(value))
    79  	case int16:
    80  		return t.Convert(uint64(value))
    81  	case uint16:
    82  		return t.Convert(uint64(value))
    83  	case int32:
    84  		return t.Convert(uint64(value))
    85  	case uint32:
    86  		return t.Convert(uint64(value))
    87  	case int64:
    88  		return t.Convert(uint64(value))
    89  	case uint64:
    90  		if value >= t.lowerbound && value <= t.upperbound {
    91  			return value, sql.InRange, nil
    92  		}
    93  	case float32:
    94  		return t.Convert(float64(value))
    95  	case float64:
    96  		// Float values aren't truly accepted, but the engine will give them when it should give ints.
    97  		// Therefore, if the float doesn't have a fractional portion, we treat it as an int.
    98  		if value == float64(uint64(value)) {
    99  			return t.Convert(uint64(value))
   100  		}
   101  	case decimal.Decimal:
   102  		f, _ := value.Float64()
   103  		return t.Convert(f)
   104  	case decimal.NullDecimal:
   105  		if value.Valid {
   106  			f, _ := value.Decimal.Float64()
   107  			return t.Convert(f)
   108  		}
   109  	}
   110  
   111  	return nil, sql.OutOfRange, sql.ErrInvalidSystemVariableValue.New(t.varName, v)
   112  }
   113  
   114  // MustConvert implements the Type interface.
   115  func (t systemUintType) MustConvert(v interface{}) interface{} {
   116  	value, _, err := t.Convert(v)
   117  	if err != nil {
   118  		panic(err)
   119  	}
   120  	return value
   121  }
   122  
   123  // Equals implements the Type interface.
   124  func (t systemUintType) Equals(otherType sql.Type) bool {
   125  	if ot, ok := otherType.(systemUintType); ok {
   126  		return t.varName == ot.varName && t.lowerbound == ot.lowerbound && t.upperbound == ot.upperbound
   127  	}
   128  	return false
   129  }
   130  
   131  // MaxTextResponseByteLength implements the Type interface
   132  func (t systemUintType) MaxTextResponseByteLength(ctx *sql.Context) uint32 {
   133  	return t.UnderlyingType().MaxTextResponseByteLength(ctx)
   134  }
   135  
   136  // Promote implements the Type interface.
   137  func (t systemUintType) Promote() sql.Type {
   138  	return t
   139  }
   140  
   141  // SQL implements Type interface.
   142  func (t systemUintType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) {
   143  	if v == nil {
   144  		return sqltypes.NULL, nil
   145  	}
   146  
   147  	v, _, err := t.Convert(v)
   148  	if err != nil {
   149  		return sqltypes.Value{}, err
   150  	}
   151  
   152  	stop := len(dest)
   153  	dest = strconv.AppendUint(dest, v.(uint64), 10)
   154  	val := dest[stop:]
   155  
   156  	return sqltypes.MakeTrusted(t.Type(), val), nil
   157  }
   158  
   159  // String implements Type interface.
   160  func (t systemUintType) String() string {
   161  	return "system_uint"
   162  }
   163  
   164  // Type implements Type interface.
   165  func (t systemUintType) Type() query.Type {
   166  	return sqltypes.Uint64
   167  }
   168  
   169  // ValueType implements Type interface.
   170  func (t systemUintType) ValueType() reflect.Type {
   171  	return systemUintValueType
   172  }
   173  
   174  // Zero implements Type interface.
   175  func (t systemUintType) Zero() interface{} {
   176  	return uint64(0)
   177  }
   178  
   179  // CollationCoercibility implements sql.CollationCoercible interface.
   180  func (systemUintType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   181  	return sql.Collation_binary, 5
   182  }
   183  
   184  // EncodeValue implements SystemVariableType interface.
   185  func (t systemUintType) EncodeValue(val interface{}) (string, error) {
   186  	expectedVal, ok := val.(uint64)
   187  	if !ok {
   188  		return "", sql.ErrSystemVariableCodeFail.New(val, t.String())
   189  	}
   190  	return strconv.FormatUint(expectedVal, 10), nil
   191  }
   192  
   193  // DecodeValue implements SystemVariableType interface.
   194  func (t systemUintType) DecodeValue(val string) (interface{}, error) {
   195  	parsedVal, err := strconv.ParseUint(val, 10, 64)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	if parsedVal >= t.lowerbound && parsedVal <= t.upperbound {
   200  		return parsedVal, nil
   201  	}
   202  	return nil, sql.ErrSystemVariableCodeFail.New(val, t.String())
   203  }
   204  
   205  func (t systemUintType) UnderlyingType() sql.Type {
   206  	return Uint64
   207  }