github.com/dolthub/go-mysql-server@v0.18.0/sql/types/system_int.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package types 16 17 import ( 18 "reflect" 19 "strconv" 20 21 "github.com/dolthub/vitess/go/sqltypes" 22 "github.com/dolthub/vitess/go/vt/proto/query" 23 "github.com/shopspring/decimal" 24 25 "github.com/dolthub/go-mysql-server/sql" 26 ) 27 28 var systemIntValueType = reflect.TypeOf(int64(0)) 29 30 // systemIntType is an internal integer type ONLY for system variables. 31 type systemIntType struct { 32 varName string 33 lowerbound int64 34 upperbound int64 35 negativeOne bool 36 } 37 38 var _ sql.SystemVariableType = systemIntType{} 39 var _ sql.CollationCoercible = systemIntType{} 40 41 // NewSystemIntType returns a new systemIntType. 42 func NewSystemIntType(varName string, lowerbound, upperbound int64, negativeOne bool) sql.SystemVariableType { 43 return systemIntType{varName, lowerbound, upperbound, negativeOne} 44 } 45 46 // Compare implements Type interface. 47 func (t systemIntType) Compare(a interface{}, b interface{}) (int, error) { 48 as, _, err := t.Convert(a) 49 if err != nil { 50 return 0, err 51 } 52 bs, _, err := t.Convert(b) 53 if err != nil { 54 return 0, err 55 } 56 ai := as.(int64) 57 bi := bs.(int64) 58 59 if ai == bi { 60 return 0, nil 61 } 62 if ai < bi { 63 return -1, nil 64 } 65 return 1, nil 66 } 67 68 // Convert implements Type interface. 69 func (t systemIntType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) { 70 // String nor nil values are accepted 71 switch value := v.(type) { 72 case int: 73 return t.Convert(int64(value)) 74 case uint: 75 return t.Convert(int64(value)) 76 case int8: 77 return t.Convert(int64(value)) 78 case uint8: 79 return t.Convert(int64(value)) 80 case int16: 81 return t.Convert(int64(value)) 82 case uint16: 83 return t.Convert(int64(value)) 84 case int32: 85 return t.Convert(int64(value)) 86 case uint32: 87 return t.Convert(int64(value)) 88 case int64: 89 if value >= t.lowerbound && value <= t.upperbound { 90 return value, sql.InRange, nil 91 } 92 if t.negativeOne && value == -1 { 93 return value, sql.InRange, nil 94 } 95 case uint64: 96 return t.Convert(int64(value)) 97 case float32: 98 return t.Convert(float64(value)) 99 case float64: 100 // Float values aren't truly accepted, but the engine will give them when it should give ints. 101 // Therefore, if the float doesn't have a fractional portion, we treat it as an int. 102 if value == float64(int64(value)) { 103 return t.Convert(int64(value)) 104 } 105 case decimal.Decimal: 106 f, _ := value.Float64() 107 return t.Convert(f) 108 case decimal.NullDecimal: 109 if value.Valid { 110 f, _ := value.Decimal.Float64() 111 return t.Convert(f) 112 } 113 case string: 114 // try getting int out of string value 115 i, err := strconv.ParseInt(value, 10, 64) 116 if err != nil { 117 return nil, sql.OutOfRange, sql.ErrInvalidSystemVariableValue.New(t.varName, v) 118 } 119 return t.Convert(i) 120 } 121 122 return nil, sql.OutOfRange, sql.ErrInvalidSystemVariableValue.New(t.varName, v) 123 } 124 125 // MustConvert implements the Type interface. 126 func (t systemIntType) MustConvert(v interface{}) interface{} { 127 value, _, err := t.Convert(v) 128 if err != nil { 129 panic(err) 130 } 131 return value 132 } 133 134 // Equals implements the Type interface. 135 func (t systemIntType) Equals(otherType sql.Type) bool { 136 if ot, ok := otherType.(systemIntType); ok { 137 return t.varName == ot.varName && t.lowerbound == ot.lowerbound && t.upperbound == ot.upperbound && t.negativeOne == ot.negativeOne 138 } 139 return false 140 } 141 142 // MaxTextResponseByteLength implements the Type interface 143 func (t systemIntType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { 144 return t.UnderlyingType().MaxTextResponseByteLength(ctx) 145 } 146 147 // Promote implements the Type interface. 148 func (t systemIntType) Promote() sql.Type { 149 return t 150 } 151 152 // SQL implements Type interface. 153 func (t systemIntType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) { 154 if v == nil { 155 return sqltypes.NULL, nil 156 } 157 158 v, _, err := t.Convert(v) 159 if err != nil { 160 return sqltypes.Value{}, err 161 } 162 163 stop := len(dest) 164 dest = strconv.AppendInt(dest, v.(int64), 10) 165 val := dest[stop:] 166 167 return sqltypes.MakeTrusted(t.Type(), val), nil 168 } 169 170 // String implements Type interface. 171 func (t systemIntType) String() string { 172 return "system_int" 173 } 174 175 // Type implements Type interface. 176 func (t systemIntType) Type() query.Type { 177 return sqltypes.Int64 178 } 179 180 // ValueType implements Type interface. 181 func (t systemIntType) ValueType() reflect.Type { 182 return systemIntValueType 183 } 184 185 // Zero implements Type interface. 186 func (t systemIntType) Zero() interface{} { 187 return int64(0) 188 } 189 190 // CollationCoercibility implements sql.CollationCoercible interface. 191 func (systemIntType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 192 return sql.Collation_binary, 5 193 } 194 195 // EncodeValue implements SystemVariableType interface. 196 func (t systemIntType) EncodeValue(val interface{}) (string, error) { 197 expectedVal, ok := val.(int64) 198 if !ok { 199 return "", sql.ErrSystemVariableCodeFail.New(val, t.String()) 200 } 201 return strconv.FormatInt(expectedVal, 10), nil 202 } 203 204 // DecodeValue implements SystemVariableType interface. 205 func (t systemIntType) DecodeValue(val string) (interface{}, error) { 206 parsedVal, err := strconv.ParseInt(val, 10, 64) 207 if err != nil { 208 return nil, err 209 } 210 if parsedVal >= t.lowerbound && parsedVal <= t.upperbound { 211 return parsedVal, nil 212 } 213 return nil, sql.ErrSystemVariableCodeFail.New(val, t.String()) 214 } 215 216 func (t systemIntType) UnderlyingType() sql.Type { 217 return Int64 218 }