github.com/dolthub/go-mysql-server@v0.18.0/sql/types/bit.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package types 16 17 import ( 18 "encoding/binary" 19 "fmt" 20 "reflect" 21 22 "github.com/dolthub/vitess/go/sqltypes" 23 "github.com/dolthub/vitess/go/vt/proto/query" 24 "github.com/shopspring/decimal" 25 "gopkg.in/src-d/go-errors.v1" 26 27 "github.com/dolthub/go-mysql-server/sql" 28 ) 29 30 const ( 31 // BitTypeMinBits returns the minimum number of bits for Bit. 32 BitTypeMinBits = 1 33 // BitTypeMaxBits returns the maximum number of bits for Bit. 34 BitTypeMaxBits = 64 35 ) 36 37 var ( 38 promotedBitType = MustCreateBitType(BitTypeMaxBits) 39 errBeyondMaxBit = errors.NewKind("%v is beyond the maximum value that can be held by %v bits") 40 bitValueType = reflect.TypeOf(uint64(0)) 41 ) 42 43 // BitType represents the BIT type. 44 // https://dev.mysql.com/doc/refman/8.0/en/bit-type.html 45 // The type of the returned value is uint64. 46 type BitType interface { 47 sql.Type 48 NumberOfBits() uint8 49 } 50 51 type BitType_ struct { 52 numOfBits uint8 53 } 54 55 // CreateBitType creates a BitType. 56 func CreateBitType(numOfBits uint8) (BitType, error) { 57 if numOfBits < BitTypeMinBits || numOfBits > BitTypeMaxBits { 58 return nil, fmt.Errorf("%v is an invalid number of bits", numOfBits) 59 } 60 return BitType_{ 61 numOfBits: numOfBits, 62 }, nil 63 } 64 65 // MustCreateBitType is the same as CreateBitType except it panics on errors. 66 func MustCreateBitType(numOfBits uint8) BitType { 67 bt, err := CreateBitType(numOfBits) 68 if err != nil { 69 panic(err) 70 } 71 return bt 72 } 73 74 // MaxTextResponseByteLength implements Type interface 75 func (t BitType_) MaxTextResponseByteLength(_ *sql.Context) uint32 { 76 // Because this is a text serialization format, each bit requires one byte in the text response format 77 return uint32(t.numOfBits) 78 } 79 80 // Compare implements Type interface. 81 func (t BitType_) Compare(a interface{}, b interface{}) (int, error) { 82 if hasNulls, res := CompareNulls(a, b); hasNulls { 83 return res, nil 84 } 85 86 ac, _, err := t.Convert(a) 87 if err != nil { 88 return 0, err 89 } 90 bc, _, err := t.Convert(b) 91 if err != nil { 92 return 0, err 93 } 94 95 ai := ac.(uint64) 96 bi := bc.(uint64) 97 if ai < bi { 98 return -1, nil 99 } else if ai > bi { 100 return 1, nil 101 } 102 return 0, nil 103 } 104 105 // Convert implements Type interface. 106 func (t BitType_) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) { 107 if v == nil { 108 return nil, sql.InRange, nil 109 } 110 111 value := uint64(0) 112 switch val := v.(type) { 113 case bool: 114 if val { 115 value = 1 116 } else { 117 value = 0 118 } 119 case int: 120 value = uint64(val) 121 case uint: 122 value = uint64(val) 123 case int8: 124 value = uint64(val) 125 case uint8: 126 value = uint64(val) 127 case int16: 128 value = uint64(val) 129 case uint16: 130 value = uint64(val) 131 case int32: 132 value = uint64(val) 133 case uint32: 134 value = uint64(val) 135 case int64: 136 value = uint64(val) 137 case uint64: 138 value = val 139 case float32: 140 return t.Convert(float64(val)) 141 case float64: 142 if val < 0 { 143 return nil, sql.InRange, fmt.Errorf(`negative floats cannot become bit values`) 144 } 145 value = uint64(val) 146 case decimal.NullDecimal: 147 if !val.Valid { 148 return nil, sql.InRange, nil 149 } 150 return t.Convert(val.Decimal) 151 case decimal.Decimal: 152 val = val.Round(0) 153 if val.GreaterThan(dec_uint64_max) { 154 return nil, sql.OutOfRange, errBeyondMaxBit.New(val.String(), t.numOfBits) 155 } 156 if val.LessThan(dec_int64_min) { 157 return nil, sql.OutOfRange, errBeyondMaxBit.New(val.String(), t.numOfBits) 158 } 159 value = uint64(val.IntPart()) 160 case string: 161 return t.Convert([]byte(val)) 162 case []byte: 163 if len(val) > 8 { 164 return nil, sql.OutOfRange, errBeyondMaxBit.New(value, t.numOfBits) 165 } 166 value = binary.BigEndian.Uint64(append(make([]byte, 8-len(val)), val...)) 167 default: 168 return nil, sql.OutOfRange, sql.ErrInvalidType.New(t) 169 } 170 171 if value > uint64(1<<t.numOfBits-1) { 172 return nil, sql.OutOfRange, errBeyondMaxBit.New(value, t.numOfBits) 173 } 174 return value, sql.InRange, nil 175 } 176 177 // MustConvert implements the Type interface. 178 func (t BitType_) MustConvert(v interface{}) interface{} { 179 value, _, err := t.Convert(v) 180 if err != nil { 181 panic(err) 182 } 183 return value 184 } 185 186 // Equals implements the Type interface. 187 func (t BitType_) Equals(otherType sql.Type) bool { 188 if ot, ok := otherType.(BitType_); ok { 189 return t.numOfBits == ot.numOfBits 190 } 191 return false 192 } 193 194 // Promote implements the Type interface. 195 func (t BitType_) Promote() sql.Type { 196 return promotedBitType 197 } 198 199 // SQL implements Type interface. 200 func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) { 201 if v == nil { 202 return sqltypes.NULL, nil 203 } 204 value, _, err := t.Convert(v) 205 if err != nil { 206 return sqltypes.Value{}, err 207 } 208 bitVal := value.(uint64) 209 210 var data []byte 211 for i := uint64(0); i < uint64(t.numOfBits); i += 8 { 212 data = append(data, byte(bitVal>>i)) 213 } 214 for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { 215 data[i], data[j] = data[j], data[i] 216 } 217 val := AppendAndSliceBytes(dest, data) 218 219 return sqltypes.MakeTrusted(sqltypes.Bit, val), nil 220 } 221 222 // String implements Type interface. 223 func (t BitType_) String() string { 224 return fmt.Sprintf("bit(%v)", t.numOfBits) 225 } 226 227 // Type implements Type interface. 228 func (t BitType_) Type() query.Type { 229 return sqltypes.Bit 230 } 231 232 // ValueType implements Type interface. 233 func (t BitType_) ValueType() reflect.Type { 234 return bitValueType 235 } 236 237 // CollationCoercibility implements sql.CollationCoercible interface. 238 func (BitType_) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 239 return sql.Collation_binary, 5 240 } 241 242 // Zero implements Type interface. Returns a uint64 value. 243 func (t BitType_) Zero() interface{} { 244 return uint64(0) 245 } 246 247 // NumberOfBits returns the number of bits that this type may contain. 248 func (t BitType_) NumberOfBits() uint8 { 249 return t.numOfBits 250 }