github.com/dolthub/go-mysql-server@v0.18.0/sql/types/system_enum.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package types 16 17 import ( 18 "reflect" 19 "strings" 20 21 "github.com/dolthub/vitess/go/sqltypes" 22 "github.com/dolthub/vitess/go/vt/proto/query" 23 "github.com/shopspring/decimal" 24 25 "github.com/dolthub/go-mysql-server/sql" 26 ) 27 28 var systemEnumValueType = reflect.TypeOf(string("")) 29 30 // systemEnumType is an internal enum type ONLY for system variables. 31 type systemEnumType struct { 32 varName string 33 valToIndex map[string]int 34 indexToVal []string 35 } 36 37 var _ sql.SystemVariableType = systemEnumType{} 38 var _ sql.CollationCoercible = systemEnumType{} 39 40 // NewSystemEnumType returns a new systemEnumType. 41 func NewSystemEnumType(varName string, values ...string) sql.SystemVariableType { 42 if len(values) > 65535 { // system variables should NEVER hit this 43 panic(varName + " somehow has more than 65535 values") 44 } 45 valToIndex := make(map[string]int) 46 for i, value := range values { 47 valToIndex[strings.ToLower(value)] = i 48 } 49 return systemEnumType{varName, valToIndex, values} 50 } 51 52 // Compare implements Type interface. 53 func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) { 54 as, _, err := t.Convert(a) 55 if err != nil { 56 return 0, err 57 } 58 bs, _, err := t.Convert(b) 59 if err != nil { 60 return 0, err 61 } 62 ai := as.(string) 63 bi := bs.(string) 64 65 if ai == bi { 66 return 0, nil 67 } 68 if ai < bi { 69 return -1, nil 70 } 71 return 1, nil 72 } 73 74 // Convert implements Type interface. 75 func (t systemEnumType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) { 76 // Nil values are not accepted 77 switch value := v.(type) { 78 case int: 79 if value >= 0 && value < len(t.indexToVal) { 80 return t.indexToVal[value], sql.InRange, nil 81 } 82 case uint: 83 return t.Convert(int(value)) 84 case int8: 85 return t.Convert(int(value)) 86 case uint8: 87 return t.Convert(int(value)) 88 case int16: 89 return t.Convert(int(value)) 90 case uint16: 91 return t.Convert(int(value)) 92 case int32: 93 return t.Convert(int(value)) 94 case uint32: 95 return t.Convert(int(value)) 96 case int64: 97 return t.Convert(int(value)) 98 case uint64: 99 return t.Convert(int(value)) 100 case float32: 101 return t.Convert(float64(value)) 102 case float64: 103 // Float values aren't truly accepted, but the engine will give them when it should give ints. 104 // Therefore, if the float doesn't have a fractional portion, we treat it as an int. 105 if value == float64(int(value)) { 106 return t.Convert(int(value)) 107 } 108 case decimal.Decimal: 109 f, _ := value.Float64() 110 return t.Convert(f) 111 case decimal.NullDecimal: 112 if value.Valid { 113 f, _ := value.Decimal.Float64() 114 return t.Convert(f) 115 } 116 case string: 117 if idx, ok := t.valToIndex[strings.ToLower(value)]; ok { 118 return t.indexToVal[idx], sql.InRange, nil 119 } 120 } 121 122 return nil, sql.OutOfRange, sql.ErrInvalidSystemVariableValue.New(t.varName, v) 123 } 124 125 // MustConvert implements the Type interface. 126 func (t systemEnumType) MustConvert(v interface{}) interface{} { 127 value, _, err := t.Convert(v) 128 if err != nil { 129 panic(err) 130 } 131 return value 132 } 133 134 // Equals implements the Type interface. 135 func (t systemEnumType) Equals(otherType sql.Type) bool { 136 if ot, ok := otherType.(systemEnumType); ok && t.varName == ot.varName && len(t.indexToVal) == len(ot.indexToVal) { 137 for i, val := range t.indexToVal { 138 if ot.indexToVal[i] != val { 139 return false 140 } 141 } 142 return true 143 } 144 return false 145 } 146 147 // MaxTextResponseByteLength implements the Type interface 148 func (t systemEnumType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { 149 return t.UnderlyingType().MaxTextResponseByteLength(ctx) 150 } 151 152 // Promote implements the Type interface. 153 func (t systemEnumType) Promote() sql.Type { 154 return t 155 } 156 157 // SQL implements Type interface. 158 func (t systemEnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) { 159 if v == nil { 160 return sqltypes.NULL, nil 161 } 162 163 v, _, err := t.Convert(v) 164 if err != nil { 165 return sqltypes.Value{}, err 166 } 167 168 val := AppendAndSliceString(dest, v.(string)) 169 170 return sqltypes.MakeTrusted(t.Type(), val), nil 171 } 172 173 // String implements Type interface. 174 func (t systemEnumType) String() string { 175 return "system_enum" 176 } 177 178 // Type implements Type interface. 179 func (t systemEnumType) Type() query.Type { 180 return sqltypes.VarChar 181 } 182 183 // ValueType implements Type interface. 184 func (t systemEnumType) ValueType() reflect.Type { 185 return systemEnumValueType 186 } 187 188 // Zero implements Type interface. 189 func (t systemEnumType) Zero() interface{} { 190 return "" 191 } 192 193 // CollationCoercibility implements sql.CollationCoercible interface. 194 func (systemEnumType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 195 return sql.Collation_utf8mb3_general_ci, 3 196 } 197 198 // EncodeValue implements SystemVariableType interface. 199 func (t systemEnumType) EncodeValue(val interface{}) (string, error) { 200 expectedVal, ok := val.(string) 201 if !ok { 202 return "", sql.ErrSystemVariableCodeFail.New(val, t.String()) 203 } 204 return expectedVal, nil 205 } 206 207 // DecodeValue implements SystemVariableType interface. 208 func (t systemEnumType) DecodeValue(val string) (interface{}, error) { 209 outVal, _, err := t.Convert(val) 210 if err != nil { 211 return nil, sql.ErrSystemVariableCodeFail.New(val, t.String()) 212 } 213 return outVal, nil 214 } 215 216 func (t systemEnumType) UnderlyingType() sql.Type { 217 return EnumType{} 218 }