github.com/dolthub/go-mysql-server@v0.18.0/sql/types/bit.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"encoding/binary"
    19  	"fmt"
    20  	"reflect"
    21  
    22  	"github.com/dolthub/vitess/go/sqltypes"
    23  	"github.com/dolthub/vitess/go/vt/proto/query"
    24  	"github.com/shopspring/decimal"
    25  	"gopkg.in/src-d/go-errors.v1"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  )
    29  
    30  const (
    31  	// BitTypeMinBits returns the minimum number of bits for Bit.
    32  	BitTypeMinBits = 1
    33  	// BitTypeMaxBits returns the maximum number of bits for Bit.
    34  	BitTypeMaxBits = 64
    35  )
    36  
    37  var (
    38  	promotedBitType = MustCreateBitType(BitTypeMaxBits)
    39  	errBeyondMaxBit = errors.NewKind("%v is beyond the maximum value that can be held by %v bits")
    40  	bitValueType    = reflect.TypeOf(uint64(0))
    41  )
    42  
    43  // BitType represents the BIT type.
    44  // https://dev.mysql.com/doc/refman/8.0/en/bit-type.html
    45  // The type of the returned value is uint64.
    46  type BitType interface {
    47  	sql.Type
    48  	NumberOfBits() uint8
    49  }
    50  
    51  type BitType_ struct {
    52  	numOfBits uint8
    53  }
    54  
    55  // CreateBitType creates a BitType.
    56  func CreateBitType(numOfBits uint8) (BitType, error) {
    57  	if numOfBits < BitTypeMinBits || numOfBits > BitTypeMaxBits {
    58  		return nil, fmt.Errorf("%v is an invalid number of bits", numOfBits)
    59  	}
    60  	return BitType_{
    61  		numOfBits: numOfBits,
    62  	}, nil
    63  }
    64  
    65  // MustCreateBitType is the same as CreateBitType except it panics on errors.
    66  func MustCreateBitType(numOfBits uint8) BitType {
    67  	bt, err := CreateBitType(numOfBits)
    68  	if err != nil {
    69  		panic(err)
    70  	}
    71  	return bt
    72  }
    73  
    74  // MaxTextResponseByteLength implements Type interface
    75  func (t BitType_) MaxTextResponseByteLength(_ *sql.Context) uint32 {
    76  	// Because this is a text serialization format, each bit requires one byte in the text response format
    77  	return uint32(t.numOfBits)
    78  }
    79  
    80  // Compare implements Type interface.
    81  func (t BitType_) Compare(a interface{}, b interface{}) (int, error) {
    82  	if hasNulls, res := CompareNulls(a, b); hasNulls {
    83  		return res, nil
    84  	}
    85  
    86  	ac, _, err := t.Convert(a)
    87  	if err != nil {
    88  		return 0, err
    89  	}
    90  	bc, _, err := t.Convert(b)
    91  	if err != nil {
    92  		return 0, err
    93  	}
    94  
    95  	ai := ac.(uint64)
    96  	bi := bc.(uint64)
    97  	if ai < bi {
    98  		return -1, nil
    99  	} else if ai > bi {
   100  		return 1, nil
   101  	}
   102  	return 0, nil
   103  }
   104  
   105  // Convert implements Type interface.
   106  func (t BitType_) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) {
   107  	if v == nil {
   108  		return nil, sql.InRange, nil
   109  	}
   110  
   111  	value := uint64(0)
   112  	switch val := v.(type) {
   113  	case bool:
   114  		if val {
   115  			value = 1
   116  		} else {
   117  			value = 0
   118  		}
   119  	case int:
   120  		value = uint64(val)
   121  	case uint:
   122  		value = uint64(val)
   123  	case int8:
   124  		value = uint64(val)
   125  	case uint8:
   126  		value = uint64(val)
   127  	case int16:
   128  		value = uint64(val)
   129  	case uint16:
   130  		value = uint64(val)
   131  	case int32:
   132  		value = uint64(val)
   133  	case uint32:
   134  		value = uint64(val)
   135  	case int64:
   136  		value = uint64(val)
   137  	case uint64:
   138  		value = val
   139  	case float32:
   140  		return t.Convert(float64(val))
   141  	case float64:
   142  		if val < 0 {
   143  			return nil, sql.InRange, fmt.Errorf(`negative floats cannot become bit values`)
   144  		}
   145  		value = uint64(val)
   146  	case decimal.NullDecimal:
   147  		if !val.Valid {
   148  			return nil, sql.InRange, nil
   149  		}
   150  		return t.Convert(val.Decimal)
   151  	case decimal.Decimal:
   152  		val = val.Round(0)
   153  		if val.GreaterThan(dec_uint64_max) {
   154  			return nil, sql.OutOfRange, errBeyondMaxBit.New(val.String(), t.numOfBits)
   155  		}
   156  		if val.LessThan(dec_int64_min) {
   157  			return nil, sql.OutOfRange, errBeyondMaxBit.New(val.String(), t.numOfBits)
   158  		}
   159  		value = uint64(val.IntPart())
   160  	case string:
   161  		return t.Convert([]byte(val))
   162  	case []byte:
   163  		if len(val) > 8 {
   164  			return nil, sql.OutOfRange, errBeyondMaxBit.New(value, t.numOfBits)
   165  		}
   166  		value = binary.BigEndian.Uint64(append(make([]byte, 8-len(val)), val...))
   167  	default:
   168  		return nil, sql.OutOfRange, sql.ErrInvalidType.New(t)
   169  	}
   170  
   171  	if value > uint64(1<<t.numOfBits-1) {
   172  		return nil, sql.OutOfRange, errBeyondMaxBit.New(value, t.numOfBits)
   173  	}
   174  	return value, sql.InRange, nil
   175  }
   176  
   177  // MustConvert implements the Type interface.
   178  func (t BitType_) MustConvert(v interface{}) interface{} {
   179  	value, _, err := t.Convert(v)
   180  	if err != nil {
   181  		panic(err)
   182  	}
   183  	return value
   184  }
   185  
   186  // Equals implements the Type interface.
   187  func (t BitType_) Equals(otherType sql.Type) bool {
   188  	if ot, ok := otherType.(BitType_); ok {
   189  		return t.numOfBits == ot.numOfBits
   190  	}
   191  	return false
   192  }
   193  
   194  // Promote implements the Type interface.
   195  func (t BitType_) Promote() sql.Type {
   196  	return promotedBitType
   197  }
   198  
   199  // SQL implements Type interface.
   200  func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) {
   201  	if v == nil {
   202  		return sqltypes.NULL, nil
   203  	}
   204  	value, _, err := t.Convert(v)
   205  	if err != nil {
   206  		return sqltypes.Value{}, err
   207  	}
   208  	bitVal := value.(uint64)
   209  
   210  	var data []byte
   211  	for i := uint64(0); i < uint64(t.numOfBits); i += 8 {
   212  		data = append(data, byte(bitVal>>i))
   213  	}
   214  	for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 {
   215  		data[i], data[j] = data[j], data[i]
   216  	}
   217  	val := AppendAndSliceBytes(dest, data)
   218  
   219  	return sqltypes.MakeTrusted(sqltypes.Bit, val), nil
   220  }
   221  
   222  // String implements Type interface.
   223  func (t BitType_) String() string {
   224  	return fmt.Sprintf("bit(%v)", t.numOfBits)
   225  }
   226  
   227  // Type implements Type interface.
   228  func (t BitType_) Type() query.Type {
   229  	return sqltypes.Bit
   230  }
   231  
   232  // ValueType implements Type interface.
   233  func (t BitType_) ValueType() reflect.Type {
   234  	return bitValueType
   235  }
   236  
   237  // CollationCoercibility implements sql.CollationCoercible interface.
   238  func (BitType_) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   239  	return sql.Collation_binary, 5
   240  }
   241  
   242  // Zero implements Type interface. Returns a uint64 value.
   243  func (t BitType_) Zero() interface{} {
   244  	return uint64(0)
   245  }
   246  
   247  // NumberOfBits returns the number of bits that this type may contain.
   248  func (t BitType_) NumberOfBits() uint8 {
   249  	return t.numOfBits
   250  }