github.com/dolthub/go-mysql-server@v0.18.0/sql/types/set.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package types 16 17 import ( 18 "fmt" 19 "math" 20 "math/bits" 21 "reflect" 22 "strconv" 23 "strings" 24 "unicode/utf8" 25 26 "github.com/dolthub/vitess/go/sqltypes" 27 "github.com/dolthub/vitess/go/vt/proto/query" 28 "github.com/shopspring/decimal" 29 30 "github.com/dolthub/go-mysql-server/sql" 31 "github.com/dolthub/go-mysql-server/sql/encodings" 32 ) 33 34 const ( 35 // SetTypeMaxElements returns the maximum number of elements for the Set type. 36 SetTypeMaxElements = 64 37 ) 38 39 var ( 40 setValueType = reflect.TypeOf(uint64(0)) 41 ) 42 43 type SetType struct { 44 collation sql.CollationID 45 hashedValToBit map[uint64]uint64 46 bitToVal map[uint64]string 47 maxResponseByteLength uint32 48 } 49 50 var _ sql.SetType = SetType{} 51 var _ sql.TypeWithCollation = SetType{} 52 var _ sql.CollationCoercible = SetType{} 53 54 // CreateSetType creates a SetType. 55 func CreateSetType(values []string, collation sql.CollationID) (sql.SetType, error) { 56 if len(values) == 0 { 57 return nil, fmt.Errorf("number of values may not be zero") 58 } 59 // A SET column can have a maximum of 64 distinct members. 60 if len(values) > SetTypeMaxElements { 61 return nil, fmt.Errorf("number of values is too large") 62 } 63 64 hashedValToBit := make(map[uint64]uint64) 65 bitToVal := make(map[uint64]string) 66 var maxByteLength uint32 67 maxCharLength := collation.Collation().CharacterSet.MaxLength() 68 for i, value := range values { 69 // SET member values should not themselves contain commas. 70 if strings.Contains(value, ",") { 71 return nil, fmt.Errorf("values cannot contain a comma") 72 } 73 if collation != sql.Collation_binary { 74 // Trailing spaces are automatically deleted from SET member values in the table definition when a table is created. 75 value = strings.TrimRight(value, " ") 76 } 77 78 hashedVal, err := collation.HashToUint(value) 79 if err != nil { 80 return nil, err 81 } 82 if _, ok := hashedValToBit[hashedVal]; ok { 83 return nil, sql.ErrDuplicateEntrySet.New(value) 84 } 85 bit := uint64(1 << uint64(i)) 86 hashedValToBit[hashedVal] = bit 87 bitToVal[bit] = value 88 maxByteLength = maxByteLength + uint32(utf8.RuneCountInString(value)*int(maxCharLength)) 89 if i != 0 { 90 maxByteLength = maxByteLength + uint32(maxCharLength) 91 } 92 } 93 return SetType{ 94 collation: collation, 95 hashedValToBit: hashedValToBit, 96 bitToVal: bitToVal, 97 maxResponseByteLength: maxByteLength, 98 }, nil 99 } 100 101 // MustCreateSetType is the same as CreateSetType except it panics on errors. 102 func MustCreateSetType(values []string, collation sql.CollationID) sql.SetType { 103 et, err := CreateSetType(values, collation) 104 if err != nil { 105 panic(err) 106 } 107 return et 108 } 109 110 // Compare implements Type interface. 111 func (t SetType) Compare(a interface{}, b interface{}) (int, error) { 112 if hasNulls, res := CompareNulls(a, b); hasNulls { 113 return res, nil 114 } 115 116 ai, _, err := t.Convert(a) 117 if err != nil { 118 return 0, err 119 } 120 bi, _, err := t.Convert(b) 121 if err != nil { 122 return 0, err 123 } 124 au := ai.(uint64) 125 bu := bi.(uint64) 126 127 if au < bu { 128 return -1, nil 129 } else if au > bu { 130 return 1, nil 131 } 132 return 0, nil 133 } 134 135 // Convert implements Type interface. 136 // Returns the string representing the given value if applicable. 137 func (t SetType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) { 138 if v == nil { 139 return nil, sql.InRange, nil 140 } 141 142 switch value := v.(type) { 143 case int: 144 return t.Convert(uint64(value)) 145 case uint: 146 return t.Convert(uint64(value)) 147 case int8: 148 return t.Convert(uint64(value)) 149 case uint8: 150 return t.Convert(uint64(value)) 151 case int16: 152 return t.Convert(uint64(value)) 153 case uint16: 154 return t.Convert(uint64(value)) 155 case int32: 156 return t.Convert(uint64(value)) 157 case uint32: 158 return t.Convert(uint64(value)) 159 case int64: 160 return t.Convert(uint64(value)) 161 case uint64: 162 if value <= t.allValuesBitField() { 163 return value, sql.InRange, nil 164 } 165 case float32: 166 return t.Convert(uint64(value)) 167 case float64: 168 return t.Convert(uint64(value)) 169 case decimal.Decimal: 170 return t.Convert(value.BigInt().Uint64()) 171 case decimal.NullDecimal: 172 if !value.Valid { 173 return nil, sql.InRange, nil 174 } 175 return t.Convert(value.Decimal.BigInt().Uint64()) 176 case string: 177 ret, err := t.convertStringToBitField(value) 178 return ret, sql.InRange, err 179 case []byte: 180 return t.Convert(string(value)) 181 } 182 183 return uint64(0), sql.OutOfRange, sql.ErrConvertingToSet.New(v) 184 } 185 186 // MaxTextResponseByteLength implements the Type interface 187 func (t SetType) MaxTextResponseByteLength(_ *sql.Context) uint32 { 188 return t.maxResponseByteLength 189 } 190 191 // MustConvert implements the Type interface. 192 func (t SetType) MustConvert(v interface{}) interface{} { 193 value, _, err := t.Convert(v) 194 if err != nil { 195 panic(err) 196 } 197 return value 198 } 199 200 // Equals implements the Type interface. 201 func (t SetType) Equals(otherType sql.Type) bool { 202 if ot, ok := otherType.(SetType); ok && t.collation.Equals(ot.collation) && len(t.bitToVal) == len(ot.bitToVal) { 203 for bit, val := range t.bitToVal { 204 if ot.bitToVal[bit] != val { 205 return false 206 } 207 } 208 return true 209 } 210 return false 211 } 212 213 // Promote implements the Type interface. 214 func (t SetType) Promote() sql.Type { 215 return t 216 } 217 218 // SQL implements Type interface. 219 func (t SetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) { 220 if v == nil { 221 return sqltypes.NULL, nil 222 } 223 convertedValue, _, err := t.Convert(v) 224 if err != nil { 225 return sqltypes.Value{}, err 226 } 227 value, err := t.BitsToString(convertedValue.(uint64)) 228 if err != nil { 229 return sqltypes.Value{}, err 230 } 231 232 resultCharset := ctx.GetCharacterSetResults() 233 if resultCharset == sql.CharacterSet_Unspecified || resultCharset == sql.CharacterSet_binary { 234 resultCharset = t.collation.CharacterSet() 235 } 236 encodedBytes, ok := resultCharset.Encoder().Encode(encodings.StringToBytes(value)) 237 if !ok { 238 snippet := value 239 if len(snippet) > 50 { 240 snippet = snippet[:50] 241 } 242 snippet = strings.ToValidUTF8(snippet, string(utf8.RuneError)) 243 return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), snippet) 244 } 245 val := AppendAndSliceBytes(dest, encodedBytes) 246 247 return sqltypes.MakeTrusted(sqltypes.Set, val), nil 248 } 249 250 // String implements Type interface. 251 func (t SetType) String() string { 252 return t.StringWithTableCollation(sql.Collation_Default) 253 } 254 255 // Type implements Type interface. 256 func (t SetType) Type() query.Type { 257 return sqltypes.Set 258 } 259 260 // ValueType implements Type interface. 261 func (t SetType) ValueType() reflect.Type { 262 return setValueType 263 } 264 265 // Zero implements Type interface. 266 func (t SetType) Zero() interface{} { 267 return uint64(0) 268 } 269 270 // CollationCoercibility implements sql.CollationCoercible interface. 271 func (t SetType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 272 return t.collation, 4 273 } 274 275 // CharacterSet implements SetType interface. 276 func (t SetType) CharacterSet() sql.CharacterSetID { 277 return t.collation.CharacterSet() 278 } 279 280 // Collation implements SetType interface. 281 func (t SetType) Collation() sql.CollationID { 282 return t.collation 283 } 284 285 // NumberOfElements implements SetType interface. 286 func (t SetType) NumberOfElements() uint16 { 287 return uint16(len(t.hashedValToBit)) 288 } 289 290 // BitsToString implements SetType interface. 291 func (t SetType) BitsToString(v uint64) (string, error) { 292 return t.convertBitFieldToString(v) 293 } 294 295 // Values implements SetType interface. 296 func (t SetType) Values() []string { 297 bitEdge := 64 - bits.LeadingZeros64(t.allValuesBitField()) 298 valArray := make([]string, bitEdge) 299 for i := 0; i < bitEdge; i++ { 300 bit := uint64(1 << uint64(i)) 301 valArray[i] = t.bitToVal[bit] 302 } 303 return valArray 304 } 305 306 // WithNewCollation implements sql.TypeWithCollation interface. 307 func (t SetType) WithNewCollation(collation sql.CollationID) (sql.Type, error) { 308 return CreateSetType(t.Values(), collation) 309 } 310 311 // StringWithTableCollation implements sql.TypeWithCollation interface. 312 func (t SetType) StringWithTableCollation(tableCollation sql.CollationID) string { 313 s := fmt.Sprintf("set('%v')", strings.Join(t.Values(), `','`)) 314 if t.CharacterSet() != tableCollation.CharacterSet() { 315 s += " CHARACTER SET " + t.CharacterSet().String() 316 } 317 if t.collation != tableCollation { 318 s += " COLLATE " + t.collation.String() 319 } 320 return s 321 } 322 323 // allValuesBitField returns a bit field that references every value that the set contains. 324 func (t SetType) allValuesBitField() uint64 { 325 valCount := uint64(len(t.hashedValToBit)) 326 if valCount == 64 { 327 return math.MaxUint64 328 } 329 // A set with 3 values will have an upper bound of 8, or 0b1000. 330 // 8 - 1 == 7, and 7 is 0b0111, which would map to every value in the set. 331 return uint64(1<<valCount) - 1 332 } 333 334 // convertBitFieldToString converts the given bit field into the equivalent comma-delimited string. 335 func (t SetType) convertBitFieldToString(bitField uint64) (string, error) { 336 strBuilder := strings.Builder{} 337 bitEdge := 64 - bits.LeadingZeros64(bitField) 338 writeCommas := false 339 if bitEdge > len(t.bitToVal) { 340 return "", sql.ErrTooLargeForSet.New(bitField) 341 } 342 for i := 0; i < bitEdge; i++ { 343 bit := uint64(1 << uint64(i)) 344 if bit&bitField != 0 { 345 val, ok := t.bitToVal[bit] 346 if !ok { 347 return "", sql.ErrInvalidSetValue.New(bitField) 348 } 349 if len(val) == 0 { 350 continue 351 } 352 if writeCommas { 353 strBuilder.WriteByte(',') 354 } else { 355 writeCommas = true 356 } 357 strBuilder.WriteString(val) 358 } 359 } 360 return strBuilder.String(), nil 361 } 362 363 // convertStringToBitField converts the given string into a bit field. 364 func (t SetType) convertStringToBitField(str string) (uint64, error) { 365 if str == "" { 366 return 0, nil 367 } 368 var bitField uint64 369 vals := strings.Split(str, ",") 370 for _, val := range vals { 371 compareVal := val 372 if t.collation != sql.Collation_binary { 373 compareVal = strings.TrimRight(compareVal, " ") 374 } 375 hashedVal, err := t.collation.HashToUint(compareVal) 376 if err == nil { 377 if bit, ok := t.hashedValToBit[hashedVal]; ok { 378 bitField |= bit 379 continue 380 } 381 } 382 383 asUint, err := strconv.ParseUint(val, 10, 64) 384 if err == nil { 385 if asUint == 0 { 386 continue 387 } 388 if _, ok := t.bitToVal[asUint]; ok { 389 bitField |= asUint 390 continue 391 } 392 } 393 return 0, sql.ErrInvalidSetValue.New(val) 394 } 395 return bitField, nil 396 }