github.com/dolthub/go-mysql-server@v0.18.0/sql/types/system_enum.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"reflect"
    19  	"strings"
    20  
    21  	"github.com/dolthub/vitess/go/sqltypes"
    22  	"github.com/dolthub/vitess/go/vt/proto/query"
    23  	"github.com/shopspring/decimal"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  )
    27  
    28  var systemEnumValueType = reflect.TypeOf(string(""))
    29  
    30  // systemEnumType is an internal enum type ONLY for system variables.
    31  type systemEnumType struct {
    32  	varName    string
    33  	valToIndex map[string]int
    34  	indexToVal []string
    35  }
    36  
    37  var _ sql.SystemVariableType = systemEnumType{}
    38  var _ sql.CollationCoercible = systemEnumType{}
    39  
    40  // NewSystemEnumType returns a new systemEnumType.
    41  func NewSystemEnumType(varName string, values ...string) sql.SystemVariableType {
    42  	if len(values) > 65535 { // system variables should NEVER hit this
    43  		panic(varName + " somehow has more than 65535 values")
    44  	}
    45  	valToIndex := make(map[string]int)
    46  	for i, value := range values {
    47  		valToIndex[strings.ToLower(value)] = i
    48  	}
    49  	return systemEnumType{varName, valToIndex, values}
    50  }
    51  
    52  // Compare implements Type interface.
    53  func (t systemEnumType) Compare(a interface{}, b interface{}) (int, error) {
    54  	as, _, err := t.Convert(a)
    55  	if err != nil {
    56  		return 0, err
    57  	}
    58  	bs, _, err := t.Convert(b)
    59  	if err != nil {
    60  		return 0, err
    61  	}
    62  	ai := as.(string)
    63  	bi := bs.(string)
    64  
    65  	if ai == bi {
    66  		return 0, nil
    67  	}
    68  	if ai < bi {
    69  		return -1, nil
    70  	}
    71  	return 1, nil
    72  }
    73  
    74  // Convert implements Type interface.
    75  func (t systemEnumType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) {
    76  	// Nil values are not accepted
    77  	switch value := v.(type) {
    78  	case int:
    79  		if value >= 0 && value < len(t.indexToVal) {
    80  			return t.indexToVal[value], sql.InRange, nil
    81  		}
    82  	case uint:
    83  		return t.Convert(int(value))
    84  	case int8:
    85  		return t.Convert(int(value))
    86  	case uint8:
    87  		return t.Convert(int(value))
    88  	case int16:
    89  		return t.Convert(int(value))
    90  	case uint16:
    91  		return t.Convert(int(value))
    92  	case int32:
    93  		return t.Convert(int(value))
    94  	case uint32:
    95  		return t.Convert(int(value))
    96  	case int64:
    97  		return t.Convert(int(value))
    98  	case uint64:
    99  		return t.Convert(int(value))
   100  	case float32:
   101  		return t.Convert(float64(value))
   102  	case float64:
   103  		// Float values aren't truly accepted, but the engine will give them when it should give ints.
   104  		// Therefore, if the float doesn't have a fractional portion, we treat it as an int.
   105  		if value == float64(int(value)) {
   106  			return t.Convert(int(value))
   107  		}
   108  	case decimal.Decimal:
   109  		f, _ := value.Float64()
   110  		return t.Convert(f)
   111  	case decimal.NullDecimal:
   112  		if value.Valid {
   113  			f, _ := value.Decimal.Float64()
   114  			return t.Convert(f)
   115  		}
   116  	case string:
   117  		if idx, ok := t.valToIndex[strings.ToLower(value)]; ok {
   118  			return t.indexToVal[idx], sql.InRange, nil
   119  		}
   120  	}
   121  
   122  	return nil, sql.OutOfRange, sql.ErrInvalidSystemVariableValue.New(t.varName, v)
   123  }
   124  
   125  // MustConvert implements the Type interface.
   126  func (t systemEnumType) MustConvert(v interface{}) interface{} {
   127  	value, _, err := t.Convert(v)
   128  	if err != nil {
   129  		panic(err)
   130  	}
   131  	return value
   132  }
   133  
   134  // Equals implements the Type interface.
   135  func (t systemEnumType) Equals(otherType sql.Type) bool {
   136  	if ot, ok := otherType.(systemEnumType); ok && t.varName == ot.varName && len(t.indexToVal) == len(ot.indexToVal) {
   137  		for i, val := range t.indexToVal {
   138  			if ot.indexToVal[i] != val {
   139  				return false
   140  			}
   141  		}
   142  		return true
   143  	}
   144  	return false
   145  }
   146  
   147  // MaxTextResponseByteLength implements the Type interface
   148  func (t systemEnumType) MaxTextResponseByteLength(ctx *sql.Context) uint32 {
   149  	return t.UnderlyingType().MaxTextResponseByteLength(ctx)
   150  }
   151  
   152  // Promote implements the Type interface.
   153  func (t systemEnumType) Promote() sql.Type {
   154  	return t
   155  }
   156  
   157  // SQL implements Type interface.
   158  func (t systemEnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) {
   159  	if v == nil {
   160  		return sqltypes.NULL, nil
   161  	}
   162  
   163  	v, _, err := t.Convert(v)
   164  	if err != nil {
   165  		return sqltypes.Value{}, err
   166  	}
   167  
   168  	val := AppendAndSliceString(dest, v.(string))
   169  
   170  	return sqltypes.MakeTrusted(t.Type(), val), nil
   171  }
   172  
   173  // String implements Type interface.
   174  func (t systemEnumType) String() string {
   175  	return "system_enum"
   176  }
   177  
   178  // Type implements Type interface.
   179  func (t systemEnumType) Type() query.Type {
   180  	return sqltypes.VarChar
   181  }
   182  
   183  // ValueType implements Type interface.
   184  func (t systemEnumType) ValueType() reflect.Type {
   185  	return systemEnumValueType
   186  }
   187  
   188  // Zero implements Type interface.
   189  func (t systemEnumType) Zero() interface{} {
   190  	return ""
   191  }
   192  
   193  // CollationCoercibility implements sql.CollationCoercible interface.
   194  func (systemEnumType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   195  	return sql.Collation_utf8mb3_general_ci, 3
   196  }
   197  
   198  // EncodeValue implements SystemVariableType interface.
   199  func (t systemEnumType) EncodeValue(val interface{}) (string, error) {
   200  	expectedVal, ok := val.(string)
   201  	if !ok {
   202  		return "", sql.ErrSystemVariableCodeFail.New(val, t.String())
   203  	}
   204  	return expectedVal, nil
   205  }
   206  
   207  // DecodeValue implements SystemVariableType interface.
   208  func (t systemEnumType) DecodeValue(val string) (interface{}, error) {
   209  	outVal, _, err := t.Convert(val)
   210  	if err != nil {
   211  		return nil, sql.ErrSystemVariableCodeFail.New(val, t.String())
   212  	}
   213  	return outVal, nil
   214  }
   215  
   216  func (t systemEnumType) UnderlyingType() sql.Type {
   217  	return EnumType{}
   218  }