github.com/dolthub/go-mysql-server@v0.18.0/sql/types/system_set.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"reflect"
    19  
    20  	"github.com/dolthub/vitess/go/sqltypes"
    21  	"github.com/dolthub/vitess/go/vt/proto/query"
    22  	"github.com/shopspring/decimal"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  )
    26  
    27  // systemSetType is an internal set type ONLY for system variables.
    28  type systemSetType struct {
    29  	sql.SetType
    30  	varName string
    31  }
    32  
    33  var _ sql.SystemVariableType = systemSetType{}
    34  var _ sql.CollationCoercible = systemSetType{}
    35  
    36  // NewSystemSetType returns a new systemSetType.
    37  func NewSystemSetType(varName string, values ...string) sql.SystemVariableType {
    38  	return systemSetType{MustCreateSetType(values, sql.Collation_Default), varName}
    39  }
    40  
    41  // Compare implements Type interface.
    42  func (t systemSetType) Compare(a interface{}, b interface{}) (int, error) {
    43  	if a == nil || b == nil {
    44  		return 0, sql.ErrInvalidSystemVariableValue.New(t.varName, nil)
    45  	}
    46  	ai, _, err := t.Convert(a)
    47  	if err != nil {
    48  		return 0, err
    49  	}
    50  	bi, _, err := t.Convert(b)
    51  	if err != nil {
    52  		return 0, err
    53  	}
    54  	au := ai.(uint64)
    55  	bu := bi.(uint64)
    56  
    57  	if au == bu {
    58  		return 0, nil
    59  	}
    60  	if au < bu {
    61  		return -1, nil
    62  	}
    63  	return 1, nil
    64  }
    65  
    66  // Convert implements Type interface.
    67  func (t systemSetType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) {
    68  	// Nil values are not accepted
    69  	switch value := v.(type) {
    70  	case int:
    71  		return t.SetType.Convert(value)
    72  	case uint:
    73  		return t.SetType.Convert(value)
    74  	case int8:
    75  		return t.SetType.Convert(value)
    76  	case uint8:
    77  		return t.SetType.Convert(value)
    78  	case int16:
    79  		return t.SetType.Convert(value)
    80  	case uint16:
    81  		return t.SetType.Convert(value)
    82  	case int32:
    83  		return t.SetType.Convert(value)
    84  	case uint32:
    85  		return t.SetType.Convert(value)
    86  	case int64:
    87  		return t.SetType.Convert(value)
    88  	case uint64:
    89  		return t.SetType.Convert(value)
    90  	case float32:
    91  		return t.Convert(float64(value))
    92  	case float64:
    93  		// Float values aren't truly accepted, but the engine will give them when it should give ints.
    94  		// Therefore, if the float doesn't have a fractional portion, we treat it as an int.
    95  		if value == float64(int64(value)) {
    96  			return t.SetType.Convert(int64(value))
    97  		}
    98  	case decimal.Decimal:
    99  		f, _ := value.Float64()
   100  		return t.Convert(f)
   101  	case decimal.NullDecimal:
   102  		if value.Valid {
   103  			f, _ := value.Decimal.Float64()
   104  			return t.Convert(f)
   105  		}
   106  	case string:
   107  		return t.SetType.Convert(value)
   108  	}
   109  
   110  	return nil, sql.OutOfRange, sql.ErrInvalidSystemVariableValue.New(t.varName, v)
   111  }
   112  
   113  // MustConvert implements the Type interface.
   114  func (t systemSetType) MustConvert(v interface{}) interface{} {
   115  	value, _, err := t.Convert(v)
   116  	if err != nil {
   117  		panic(err)
   118  	}
   119  	return value
   120  }
   121  
   122  // Equals implements the Type interface.
   123  func (t systemSetType) Equals(otherType sql.Type) bool {
   124  	if ot, ok := otherType.(systemSetType); ok {
   125  		return t.varName == ot.varName && t.SetType.Equals(ot.SetType)
   126  	}
   127  	return false
   128  }
   129  
   130  // MaxTextResponseByteLength implements the Type interface
   131  func (t systemSetType) MaxTextResponseByteLength(ctx *sql.Context) uint32 {
   132  	return t.UnderlyingType().MaxTextResponseByteLength(ctx)
   133  }
   134  
   135  // Promote implements the Type interface.
   136  func (t systemSetType) Promote() sql.Type {
   137  	return t
   138  }
   139  
   140  // SQL implements Type interface.
   141  func (t systemSetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) {
   142  	if v == nil {
   143  		return sqltypes.NULL, nil
   144  	}
   145  	convertedValue, _, err := t.Convert(v)
   146  	if err != nil {
   147  		return sqltypes.Value{}, err
   148  	}
   149  	value, err := t.BitsToString(convertedValue.(uint64))
   150  	if err != nil {
   151  		return sqltypes.Value{}, err
   152  	}
   153  
   154  	val := AppendAndSliceString(dest, value)
   155  
   156  	return sqltypes.MakeTrusted(t.Type(), val), nil
   157  }
   158  
   159  // String implements Type interface.
   160  func (t systemSetType) String() string {
   161  	return "system_set"
   162  }
   163  
   164  // Type implements Type interface.
   165  func (t systemSetType) Type() query.Type {
   166  	return sqltypes.VarChar
   167  }
   168  
   169  // ValueType implements Type interface.
   170  func (t systemSetType) ValueType() reflect.Type {
   171  	return t.SetType.ValueType()
   172  }
   173  
   174  // Zero implements Type interface.
   175  func (t systemSetType) Zero() interface{} {
   176  	return ""
   177  }
   178  
   179  // CollationCoercibility implements sql.CollationCoercible interface.
   180  func (systemSetType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   181  	return sql.Collation_utf8mb3_general_ci, 3
   182  }
   183  
   184  // EncodeValue implements SystemVariableType interface.
   185  func (t systemSetType) EncodeValue(val interface{}) (string, error) {
   186  	expectedVal, ok := val.(uint64)
   187  	if !ok {
   188  		return "", sql.ErrSystemVariableCodeFail.New(val, t.String())
   189  	}
   190  	return t.BitsToString(expectedVal)
   191  }
   192  
   193  // DecodeValue implements SystemVariableType interface.
   194  func (t systemSetType) DecodeValue(val string) (interface{}, error) {
   195  	outVal, _, err := t.Convert(val)
   196  	if err != nil {
   197  		return nil, sql.ErrSystemVariableCodeFail.New(val, t.String())
   198  	}
   199  	return outVal, nil
   200  }
   201  
   202  func (t systemSetType) UnderlyingType() sql.Type {
   203  	return t.SetType
   204  }