github.com/dolthub/go-mysql-server@v0.18.0/sql/types/system_int.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"reflect"
    19  	"strconv"
    20  
    21  	"github.com/dolthub/vitess/go/sqltypes"
    22  	"github.com/dolthub/vitess/go/vt/proto/query"
    23  	"github.com/shopspring/decimal"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  )
    27  
    28  var systemIntValueType = reflect.TypeOf(int64(0))
    29  
    30  // systemIntType is an internal integer type ONLY for system variables.
    31  type systemIntType struct {
    32  	varName     string
    33  	lowerbound  int64
    34  	upperbound  int64
    35  	negativeOne bool
    36  }
    37  
    38  var _ sql.SystemVariableType = systemIntType{}
    39  var _ sql.CollationCoercible = systemIntType{}
    40  
    41  // NewSystemIntType returns a new systemIntType.
    42  func NewSystemIntType(varName string, lowerbound, upperbound int64, negativeOne bool) sql.SystemVariableType {
    43  	return systemIntType{varName, lowerbound, upperbound, negativeOne}
    44  }
    45  
    46  // Compare implements Type interface.
    47  func (t systemIntType) Compare(a interface{}, b interface{}) (int, error) {
    48  	as, _, err := t.Convert(a)
    49  	if err != nil {
    50  		return 0, err
    51  	}
    52  	bs, _, err := t.Convert(b)
    53  	if err != nil {
    54  		return 0, err
    55  	}
    56  	ai := as.(int64)
    57  	bi := bs.(int64)
    58  
    59  	if ai == bi {
    60  		return 0, nil
    61  	}
    62  	if ai < bi {
    63  		return -1, nil
    64  	}
    65  	return 1, nil
    66  }
    67  
    68  // Convert implements Type interface.
    69  func (t systemIntType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) {
    70  	// String nor nil values are accepted
    71  	switch value := v.(type) {
    72  	case int:
    73  		return t.Convert(int64(value))
    74  	case uint:
    75  		return t.Convert(int64(value))
    76  	case int8:
    77  		return t.Convert(int64(value))
    78  	case uint8:
    79  		return t.Convert(int64(value))
    80  	case int16:
    81  		return t.Convert(int64(value))
    82  	case uint16:
    83  		return t.Convert(int64(value))
    84  	case int32:
    85  		return t.Convert(int64(value))
    86  	case uint32:
    87  		return t.Convert(int64(value))
    88  	case int64:
    89  		if value >= t.lowerbound && value <= t.upperbound {
    90  			return value, sql.InRange, nil
    91  		}
    92  		if t.negativeOne && value == -1 {
    93  			return value, sql.InRange, nil
    94  		}
    95  	case uint64:
    96  		return t.Convert(int64(value))
    97  	case float32:
    98  		return t.Convert(float64(value))
    99  	case float64:
   100  		// Float values aren't truly accepted, but the engine will give them when it should give ints.
   101  		// Therefore, if the float doesn't have a fractional portion, we treat it as an int.
   102  		if value == float64(int64(value)) {
   103  			return t.Convert(int64(value))
   104  		}
   105  	case decimal.Decimal:
   106  		f, _ := value.Float64()
   107  		return t.Convert(f)
   108  	case decimal.NullDecimal:
   109  		if value.Valid {
   110  			f, _ := value.Decimal.Float64()
   111  			return t.Convert(f)
   112  		}
   113  	case string:
   114  		// try getting int out of string value
   115  		i, err := strconv.ParseInt(value, 10, 64)
   116  		if err != nil {
   117  			return nil, sql.OutOfRange, sql.ErrInvalidSystemVariableValue.New(t.varName, v)
   118  		}
   119  		return t.Convert(i)
   120  	}
   121  
   122  	return nil, sql.OutOfRange, sql.ErrInvalidSystemVariableValue.New(t.varName, v)
   123  }
   124  
   125  // MustConvert implements the Type interface.
   126  func (t systemIntType) MustConvert(v interface{}) interface{} {
   127  	value, _, err := t.Convert(v)
   128  	if err != nil {
   129  		panic(err)
   130  	}
   131  	return value
   132  }
   133  
   134  // Equals implements the Type interface.
   135  func (t systemIntType) Equals(otherType sql.Type) bool {
   136  	if ot, ok := otherType.(systemIntType); ok {
   137  		return t.varName == ot.varName && t.lowerbound == ot.lowerbound && t.upperbound == ot.upperbound && t.negativeOne == ot.negativeOne
   138  	}
   139  	return false
   140  }
   141  
   142  // MaxTextResponseByteLength implements the Type interface
   143  func (t systemIntType) MaxTextResponseByteLength(ctx *sql.Context) uint32 {
   144  	return t.UnderlyingType().MaxTextResponseByteLength(ctx)
   145  }
   146  
   147  // Promote implements the Type interface.
   148  func (t systemIntType) Promote() sql.Type {
   149  	return t
   150  }
   151  
   152  // SQL implements Type interface.
   153  func (t systemIntType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) {
   154  	if v == nil {
   155  		return sqltypes.NULL, nil
   156  	}
   157  
   158  	v, _, err := t.Convert(v)
   159  	if err != nil {
   160  		return sqltypes.Value{}, err
   161  	}
   162  
   163  	stop := len(dest)
   164  	dest = strconv.AppendInt(dest, v.(int64), 10)
   165  	val := dest[stop:]
   166  
   167  	return sqltypes.MakeTrusted(t.Type(), val), nil
   168  }
   169  
   170  // String implements Type interface.
   171  func (t systemIntType) String() string {
   172  	return "system_int"
   173  }
   174  
   175  // Type implements Type interface.
   176  func (t systemIntType) Type() query.Type {
   177  	return sqltypes.Int64
   178  }
   179  
   180  // ValueType implements Type interface.
   181  func (t systemIntType) ValueType() reflect.Type {
   182  	return systemIntValueType
   183  }
   184  
   185  // Zero implements Type interface.
   186  func (t systemIntType) Zero() interface{} {
   187  	return int64(0)
   188  }
   189  
   190  // CollationCoercibility implements sql.CollationCoercible interface.
   191  func (systemIntType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   192  	return sql.Collation_binary, 5
   193  }
   194  
   195  // EncodeValue implements SystemVariableType interface.
   196  func (t systemIntType) EncodeValue(val interface{}) (string, error) {
   197  	expectedVal, ok := val.(int64)
   198  	if !ok {
   199  		return "", sql.ErrSystemVariableCodeFail.New(val, t.String())
   200  	}
   201  	return strconv.FormatInt(expectedVal, 10), nil
   202  }
   203  
   204  // DecodeValue implements SystemVariableType interface.
   205  func (t systemIntType) DecodeValue(val string) (interface{}, error) {
   206  	parsedVal, err := strconv.ParseInt(val, 10, 64)
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  	if parsedVal >= t.lowerbound && parsedVal <= t.upperbound {
   211  		return parsedVal, nil
   212  	}
   213  	return nil, sql.ErrSystemVariableCodeFail.New(val, t.String())
   214  }
   215  
   216  func (t systemIntType) UnderlyingType() sql.Type {
   217  	return Int64
   218  }