github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/bit_count.go (about) 1 // Copyright 2024 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 "github.com/dolthub/go-mysql-server/sql/types" 22 ) 23 24 // BitCount returns the smallest integer value not less than X. 25 type BitCount struct { 26 *UnaryFunc 27 } 28 29 var _ sql.FunctionExpression = (*BitCount)(nil) 30 var _ sql.CollationCoercible = (*BitCount)(nil) 31 32 // NewBitCount creates a new Ceil expression. 33 func NewBitCount(arg sql.Expression) sql.Expression { 34 return &BitCount{NewUnaryFunc(arg, "BIT_COUNT", types.Int32)} 35 } 36 37 // FunctionName implements sql.FunctionExpression 38 func (b *BitCount) FunctionName() string { 39 return "bit_count" 40 } 41 42 // Description implements sql.FunctionExpression 43 func (b *BitCount) Description() string { 44 return "returns the number of bits that are set." 45 } 46 47 // Type implements the Expression interface. 48 func (b *BitCount) Type() sql.Type { 49 return types.Int32 50 } 51 52 // CollationCoercibility implements the interface sql.CollationCoercible. 53 func (b *BitCount) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 54 return sql.Collation_binary, 5 55 } 56 57 func (b *BitCount) String() string { 58 return fmt.Sprintf("%s(%s)", b.FunctionName(), b.Child) 59 } 60 61 // WithChildren implements the Expression interface. 62 func (b *BitCount) WithChildren(children ...sql.Expression) (sql.Expression, error) { 63 if len(children) != 1 { 64 return nil, sql.ErrInvalidChildrenNumber.New(b, len(children), 1) 65 } 66 return NewBitCount(children[0]), nil 67 } 68 69 func countBits(n uint64) int32 { 70 var res int32 71 for n != 0 { 72 res++ 73 n &= n - 1 74 } 75 return res 76 } 77 78 // Eval implements the Expression interface. 79 func (b *BitCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 80 if b.Child == nil { 81 return nil, nil 82 } 83 84 child, err := b.Child.Eval(ctx, row) 85 if err != nil { 86 return nil, err 87 } 88 89 if child == nil { 90 return nil, nil 91 } 92 93 var res int32 94 switch val := child.(type) { 95 case []byte: 96 for _, v := range val { 97 res += countBits(uint64(v)) 98 } 99 default: 100 num, _, err := types.Int64.Convert(child) 101 if err != nil { 102 ctx.Warn(1292, "Truncated incorrect INTEGER value: '%v'", child) 103 num = int64(0) 104 } 105 106 // Must convert to unsigned because shifting a negative signed value fills with 1s 107 res = countBits(uint64(num.(int64))) 108 } 109 110 return res, nil 111 }