github.com/dolthub/go-mysql-server@v0.18.0/sql/types/set.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"math/bits"
    21  	"reflect"
    22  	"strconv"
    23  	"strings"
    24  	"unicode/utf8"
    25  
    26  	"github.com/dolthub/vitess/go/sqltypes"
    27  	"github.com/dolthub/vitess/go/vt/proto/query"
    28  	"github.com/shopspring/decimal"
    29  
    30  	"github.com/dolthub/go-mysql-server/sql"
    31  	"github.com/dolthub/go-mysql-server/sql/encodings"
    32  )
    33  
    34  const (
    35  	// SetTypeMaxElements returns the maximum number of elements for the Set type.
    36  	SetTypeMaxElements = 64
    37  )
    38  
    39  var (
    40  	setValueType = reflect.TypeOf(uint64(0))
    41  )
    42  
    43  type SetType struct {
    44  	collation             sql.CollationID
    45  	hashedValToBit        map[uint64]uint64
    46  	bitToVal              map[uint64]string
    47  	maxResponseByteLength uint32
    48  }
    49  
    50  var _ sql.SetType = SetType{}
    51  var _ sql.TypeWithCollation = SetType{}
    52  var _ sql.CollationCoercible = SetType{}
    53  
    54  // CreateSetType creates a SetType.
    55  func CreateSetType(values []string, collation sql.CollationID) (sql.SetType, error) {
    56  	if len(values) == 0 {
    57  		return nil, fmt.Errorf("number of values may not be zero")
    58  	}
    59  	// A SET column can have a maximum of 64 distinct members.
    60  	if len(values) > SetTypeMaxElements {
    61  		return nil, fmt.Errorf("number of values is too large")
    62  	}
    63  
    64  	hashedValToBit := make(map[uint64]uint64)
    65  	bitToVal := make(map[uint64]string)
    66  	var maxByteLength uint32
    67  	maxCharLength := collation.Collation().CharacterSet.MaxLength()
    68  	for i, value := range values {
    69  		// SET member values should not themselves contain commas.
    70  		if strings.Contains(value, ",") {
    71  			return nil, fmt.Errorf("values cannot contain a comma")
    72  		}
    73  		if collation != sql.Collation_binary {
    74  			// Trailing spaces are automatically deleted from SET member values in the table definition when a table is created.
    75  			value = strings.TrimRight(value, " ")
    76  		}
    77  
    78  		hashedVal, err := collation.HashToUint(value)
    79  		if err != nil {
    80  			return nil, err
    81  		}
    82  		if _, ok := hashedValToBit[hashedVal]; ok {
    83  			return nil, sql.ErrDuplicateEntrySet.New(value)
    84  		}
    85  		bit := uint64(1 << uint64(i))
    86  		hashedValToBit[hashedVal] = bit
    87  		bitToVal[bit] = value
    88  		maxByteLength = maxByteLength + uint32(utf8.RuneCountInString(value)*int(maxCharLength))
    89  		if i != 0 {
    90  			maxByteLength = maxByteLength + uint32(maxCharLength)
    91  		}
    92  	}
    93  	return SetType{
    94  		collation:             collation,
    95  		hashedValToBit:        hashedValToBit,
    96  		bitToVal:              bitToVal,
    97  		maxResponseByteLength: maxByteLength,
    98  	}, nil
    99  }
   100  
   101  // MustCreateSetType is the same as CreateSetType except it panics on errors.
   102  func MustCreateSetType(values []string, collation sql.CollationID) sql.SetType {
   103  	et, err := CreateSetType(values, collation)
   104  	if err != nil {
   105  		panic(err)
   106  	}
   107  	return et
   108  }
   109  
   110  // Compare implements Type interface.
   111  func (t SetType) Compare(a interface{}, b interface{}) (int, error) {
   112  	if hasNulls, res := CompareNulls(a, b); hasNulls {
   113  		return res, nil
   114  	}
   115  
   116  	ai, _, err := t.Convert(a)
   117  	if err != nil {
   118  		return 0, err
   119  	}
   120  	bi, _, err := t.Convert(b)
   121  	if err != nil {
   122  		return 0, err
   123  	}
   124  	au := ai.(uint64)
   125  	bu := bi.(uint64)
   126  
   127  	if au < bu {
   128  		return -1, nil
   129  	} else if au > bu {
   130  		return 1, nil
   131  	}
   132  	return 0, nil
   133  }
   134  
   135  // Convert implements Type interface.
   136  // Returns the string representing the given value if applicable.
   137  func (t SetType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) {
   138  	if v == nil {
   139  		return nil, sql.InRange, nil
   140  	}
   141  
   142  	switch value := v.(type) {
   143  	case int:
   144  		return t.Convert(uint64(value))
   145  	case uint:
   146  		return t.Convert(uint64(value))
   147  	case int8:
   148  		return t.Convert(uint64(value))
   149  	case uint8:
   150  		return t.Convert(uint64(value))
   151  	case int16:
   152  		return t.Convert(uint64(value))
   153  	case uint16:
   154  		return t.Convert(uint64(value))
   155  	case int32:
   156  		return t.Convert(uint64(value))
   157  	case uint32:
   158  		return t.Convert(uint64(value))
   159  	case int64:
   160  		return t.Convert(uint64(value))
   161  	case uint64:
   162  		if value <= t.allValuesBitField() {
   163  			return value, sql.InRange, nil
   164  		}
   165  	case float32:
   166  		return t.Convert(uint64(value))
   167  	case float64:
   168  		return t.Convert(uint64(value))
   169  	case decimal.Decimal:
   170  		return t.Convert(value.BigInt().Uint64())
   171  	case decimal.NullDecimal:
   172  		if !value.Valid {
   173  			return nil, sql.InRange, nil
   174  		}
   175  		return t.Convert(value.Decimal.BigInt().Uint64())
   176  	case string:
   177  		ret, err := t.convertStringToBitField(value)
   178  		return ret, sql.InRange, err
   179  	case []byte:
   180  		return t.Convert(string(value))
   181  	}
   182  
   183  	return uint64(0), sql.OutOfRange, sql.ErrConvertingToSet.New(v)
   184  }
   185  
   186  // MaxTextResponseByteLength implements the Type interface
   187  func (t SetType) MaxTextResponseByteLength(_ *sql.Context) uint32 {
   188  	return t.maxResponseByteLength
   189  }
   190  
   191  // MustConvert implements the Type interface.
   192  func (t SetType) MustConvert(v interface{}) interface{} {
   193  	value, _, err := t.Convert(v)
   194  	if err != nil {
   195  		panic(err)
   196  	}
   197  	return value
   198  }
   199  
   200  // Equals implements the Type interface.
   201  func (t SetType) Equals(otherType sql.Type) bool {
   202  	if ot, ok := otherType.(SetType); ok && t.collation.Equals(ot.collation) && len(t.bitToVal) == len(ot.bitToVal) {
   203  		for bit, val := range t.bitToVal {
   204  			if ot.bitToVal[bit] != val {
   205  				return false
   206  			}
   207  		}
   208  		return true
   209  	}
   210  	return false
   211  }
   212  
   213  // Promote implements the Type interface.
   214  func (t SetType) Promote() sql.Type {
   215  	return t
   216  }
   217  
   218  // SQL implements Type interface.
   219  func (t SetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) {
   220  	if v == nil {
   221  		return sqltypes.NULL, nil
   222  	}
   223  	convertedValue, _, err := t.Convert(v)
   224  	if err != nil {
   225  		return sqltypes.Value{}, err
   226  	}
   227  	value, err := t.BitsToString(convertedValue.(uint64))
   228  	if err != nil {
   229  		return sqltypes.Value{}, err
   230  	}
   231  
   232  	resultCharset := ctx.GetCharacterSetResults()
   233  	if resultCharset == sql.CharacterSet_Unspecified || resultCharset == sql.CharacterSet_binary {
   234  		resultCharset = t.collation.CharacterSet()
   235  	}
   236  	encodedBytes, ok := resultCharset.Encoder().Encode(encodings.StringToBytes(value))
   237  	if !ok {
   238  		snippet := value
   239  		if len(snippet) > 50 {
   240  			snippet = snippet[:50]
   241  		}
   242  		snippet = strings.ToValidUTF8(snippet, string(utf8.RuneError))
   243  		return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), snippet)
   244  	}
   245  	val := AppendAndSliceBytes(dest, encodedBytes)
   246  
   247  	return sqltypes.MakeTrusted(sqltypes.Set, val), nil
   248  }
   249  
   250  // String implements Type interface.
   251  func (t SetType) String() string {
   252  	return t.StringWithTableCollation(sql.Collation_Default)
   253  }
   254  
   255  // Type implements Type interface.
   256  func (t SetType) Type() query.Type {
   257  	return sqltypes.Set
   258  }
   259  
   260  // ValueType implements Type interface.
   261  func (t SetType) ValueType() reflect.Type {
   262  	return setValueType
   263  }
   264  
   265  // Zero implements Type interface.
   266  func (t SetType) Zero() interface{} {
   267  	return uint64(0)
   268  }
   269  
   270  // CollationCoercibility implements sql.CollationCoercible interface.
   271  func (t SetType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   272  	return t.collation, 4
   273  }
   274  
   275  // CharacterSet implements SetType interface.
   276  func (t SetType) CharacterSet() sql.CharacterSetID {
   277  	return t.collation.CharacterSet()
   278  }
   279  
   280  // Collation implements SetType interface.
   281  func (t SetType) Collation() sql.CollationID {
   282  	return t.collation
   283  }
   284  
   285  // NumberOfElements implements SetType interface.
   286  func (t SetType) NumberOfElements() uint16 {
   287  	return uint16(len(t.hashedValToBit))
   288  }
   289  
   290  // BitsToString implements SetType interface.
   291  func (t SetType) BitsToString(v uint64) (string, error) {
   292  	return t.convertBitFieldToString(v)
   293  }
   294  
   295  // Values implements SetType interface.
   296  func (t SetType) Values() []string {
   297  	bitEdge := 64 - bits.LeadingZeros64(t.allValuesBitField())
   298  	valArray := make([]string, bitEdge)
   299  	for i := 0; i < bitEdge; i++ {
   300  		bit := uint64(1 << uint64(i))
   301  		valArray[i] = t.bitToVal[bit]
   302  	}
   303  	return valArray
   304  }
   305  
   306  // WithNewCollation implements sql.TypeWithCollation interface.
   307  func (t SetType) WithNewCollation(collation sql.CollationID) (sql.Type, error) {
   308  	return CreateSetType(t.Values(), collation)
   309  }
   310  
   311  // StringWithTableCollation implements sql.TypeWithCollation interface.
   312  func (t SetType) StringWithTableCollation(tableCollation sql.CollationID) string {
   313  	s := fmt.Sprintf("set('%v')", strings.Join(t.Values(), `','`))
   314  	if t.CharacterSet() != tableCollation.CharacterSet() {
   315  		s += " CHARACTER SET " + t.CharacterSet().String()
   316  	}
   317  	if t.collation != tableCollation {
   318  		s += " COLLATE " + t.collation.String()
   319  	}
   320  	return s
   321  }
   322  
   323  // allValuesBitField returns a bit field that references every value that the set contains.
   324  func (t SetType) allValuesBitField() uint64 {
   325  	valCount := uint64(len(t.hashedValToBit))
   326  	if valCount == 64 {
   327  		return math.MaxUint64
   328  	}
   329  	// A set with 3 values will have an upper bound of 8, or 0b1000.
   330  	// 8 - 1 == 7, and 7 is 0b0111, which would map to every value in the set.
   331  	return uint64(1<<valCount) - 1
   332  }
   333  
   334  // convertBitFieldToString converts the given bit field into the equivalent comma-delimited string.
   335  func (t SetType) convertBitFieldToString(bitField uint64) (string, error) {
   336  	strBuilder := strings.Builder{}
   337  	bitEdge := 64 - bits.LeadingZeros64(bitField)
   338  	writeCommas := false
   339  	if bitEdge > len(t.bitToVal) {
   340  		return "", sql.ErrTooLargeForSet.New(bitField)
   341  	}
   342  	for i := 0; i < bitEdge; i++ {
   343  		bit := uint64(1 << uint64(i))
   344  		if bit&bitField != 0 {
   345  			val, ok := t.bitToVal[bit]
   346  			if !ok {
   347  				return "", sql.ErrInvalidSetValue.New(bitField)
   348  			}
   349  			if len(val) == 0 {
   350  				continue
   351  			}
   352  			if writeCommas {
   353  				strBuilder.WriteByte(',')
   354  			} else {
   355  				writeCommas = true
   356  			}
   357  			strBuilder.WriteString(val)
   358  		}
   359  	}
   360  	return strBuilder.String(), nil
   361  }
   362  
   363  // convertStringToBitField converts the given string into a bit field.
   364  func (t SetType) convertStringToBitField(str string) (uint64, error) {
   365  	if str == "" {
   366  		return 0, nil
   367  	}
   368  	var bitField uint64
   369  	vals := strings.Split(str, ",")
   370  	for _, val := range vals {
   371  		compareVal := val
   372  		if t.collation != sql.Collation_binary {
   373  			compareVal = strings.TrimRight(compareVal, " ")
   374  		}
   375  		hashedVal, err := t.collation.HashToUint(compareVal)
   376  		if err == nil {
   377  			if bit, ok := t.hashedValToBit[hashedVal]; ok {
   378  				bitField |= bit
   379  				continue
   380  			}
   381  		}
   382  
   383  		asUint, err := strconv.ParseUint(val, 10, 64)
   384  		if err == nil {
   385  			if asUint == 0 {
   386  				continue
   387  			}
   388  			if _, ok := t.bitToVal[asUint]; ok {
   389  				bitField |= asUint
   390  				continue
   391  			}
   392  		}
   393  		return 0, sql.ErrInvalidSetValue.New(val)
   394  	}
   395  	return bitField, nil
   396  }