github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/ceil_round_floor.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 "math" 20 21 "github.com/shopspring/decimal" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/types" 26 ) 27 28 // Ceil returns the smallest integer value not less than X. 29 type Ceil struct { 30 expression.UnaryExpression 31 } 32 33 var _ sql.FunctionExpression = (*Ceil)(nil) 34 var _ sql.CollationCoercible = (*Ceil)(nil) 35 36 // NewCeil creates a new Ceil expression. 37 func NewCeil(num sql.Expression) sql.Expression { 38 return &Ceil{expression.UnaryExpression{Child: num}} 39 } 40 41 // FunctionName implements sql.FunctionExpression 42 func (c *Ceil) FunctionName() string { 43 return "ceil" 44 } 45 46 // Description implements sql.FunctionExpression 47 func (c *Ceil) Description() string { 48 return "returns the smallest integer value that is greater than or equal to number." 49 } 50 51 // Type implements the Expression interface. 52 func (c *Ceil) Type() sql.Type { 53 childType := c.Child.Type() 54 if types.IsInteger(childType) { 55 return childType 56 } 57 return types.Int32 58 } 59 60 // CollationCoercibility implements the interface sql.CollationCoercible. 61 func (*Ceil) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 62 return sql.Collation_binary, 5 63 } 64 65 func (c *Ceil) String() string { 66 return fmt.Sprintf("%s(%s)", c.FunctionName(), c.Child) 67 } 68 69 // WithChildren implements the Expression interface. 70 func (c *Ceil) WithChildren(children ...sql.Expression) (sql.Expression, error) { 71 if len(children) != 1 { 72 return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) 73 } 74 return NewCeil(children[0]), nil 75 } 76 77 // Eval implements the Expression interface. 78 func (c *Ceil) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 79 child, err := c.Child.Eval(ctx, row) 80 81 if err != nil { 82 return nil, err 83 } 84 85 if child == nil { 86 return nil, nil 87 } 88 89 // non number type will be caught here 90 if !types.IsNumber(c.Child.Type()) { 91 child, _, err = types.Float64.Convert(child) 92 if err != nil { 93 return int32(0), nil 94 } 95 96 return int32(math.Ceil(child.(float64))), nil 97 } 98 99 // if it's number type and not float value, it does not need ceil-ing 100 switch num := child.(type) { 101 case float64: 102 return math.Ceil(num), nil 103 case float32: 104 return float32(math.Ceil(float64(num))), nil 105 case decimal.Decimal: 106 return num.Ceil(), nil 107 default: 108 return child, nil 109 } 110 } 111 112 // Floor returns the biggest integer value not less than X. 113 type Floor struct { 114 expression.UnaryExpression 115 } 116 117 var _ sql.FunctionExpression = (*Floor)(nil) 118 var _ sql.CollationCoercible = (*Floor)(nil) 119 120 // NewFloor returns a new Floor expression. 121 func NewFloor(num sql.Expression) sql.Expression { 122 return &Floor{expression.UnaryExpression{Child: num}} 123 } 124 125 // FunctionName implements sql.FunctionExpression 126 func (f *Floor) FunctionName() string { 127 return "floor" 128 } 129 130 // Description implements sql.FunctionExpression 131 func (f *Floor) Description() string { 132 return "returns the largest integer value that is less than or equal to number." 133 } 134 135 // Type implements the Expression interface. 136 func (f *Floor) Type() sql.Type { 137 childType := f.Child.Type() 138 if types.IsInteger(childType) { 139 return childType 140 } 141 return types.Int32 142 } 143 144 // CollationCoercibility implements the interface sql.CollationCoercible. 145 func (*Floor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 146 return sql.Collation_binary, 5 147 } 148 149 func (f *Floor) String() string { 150 return fmt.Sprintf("%s(%s)", f.FunctionName(), f.Child) 151 } 152 153 // WithChildren implements the Expression interface. 154 func (f *Floor) WithChildren(children ...sql.Expression) (sql.Expression, error) { 155 if len(children) != 1 { 156 return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1) 157 } 158 return NewFloor(children[0]), nil 159 } 160 161 // Eval implements the Expression interface. 162 func (f *Floor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 163 child, err := f.Child.Eval(ctx, row) 164 165 if err != nil { 166 return nil, err 167 } 168 169 if child == nil { 170 return nil, nil 171 } 172 173 // non number type will be caught here 174 if !types.IsNumber(f.Child.Type()) { 175 child, _, err = types.Float64.Convert(child) 176 if err != nil { 177 return int32(0), nil 178 } 179 180 return int32(math.Floor(child.(float64))), nil 181 } 182 183 // if it's number type and not float value, it does not need floor-ing 184 switch num := child.(type) { 185 case float64: 186 return math.Floor(num), nil 187 case float32: 188 return float32(math.Floor(float64(num))), nil 189 case decimal.Decimal: 190 return num.Floor(), nil 191 default: 192 return child, nil 193 } 194 } 195 196 // Round returns the number (x) with (d) requested decimal places. 197 // If d is negative, the number is returned with the (abs(d)) least significant 198 // digits of it's integer part set to 0. If d is not specified or nil/null 199 // it defaults to 0. 200 type Round struct { 201 expression.BinaryExpressionStub 202 } 203 204 var _ sql.FunctionExpression = (*Round)(nil) 205 var _ sql.CollationCoercible = (*Round)(nil) 206 207 // NewRound returns a new Round expression. 208 func NewRound(args ...sql.Expression) (sql.Expression, error) { 209 argLen := len(args) 210 if argLen == 0 || argLen > 2 { 211 return nil, sql.ErrInvalidArgumentNumber.New("ROUND", "1 or 2", argLen) 212 } 213 214 var right sql.Expression 215 if len(args) == 2 { 216 right = args[1] 217 } 218 219 return &Round{expression.BinaryExpressionStub{LeftChild: args[0], RightChild: right}}, nil 220 } 221 222 // FunctionName implements sql.FunctionExpression 223 func (r *Round) FunctionName() string { 224 return "round" 225 } 226 227 // Description implements sql.FunctionExpression 228 func (r *Round) Description() string { 229 return "rounds the number to decimals decimal places." 230 } 231 232 // Children implements the Expression interface. 233 func (r *Round) Children() []sql.Expression { 234 if r.RightChild == nil { 235 return []sql.Expression{r.LeftChild} 236 } 237 238 return r.BinaryExpressionStub.Children() 239 } 240 241 // Eval implements the Expression interface. 242 func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 243 val, err := r.LeftChild.Eval(ctx, row) 244 if err != nil { 245 return nil, err 246 } 247 248 if val == nil { 249 return nil, nil 250 } 251 252 decType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale) 253 val, _, err = decType.Convert(val) 254 if err != nil { 255 // TODO: truncate 256 return nil, err 257 } 258 259 prec := int32(0) 260 if r.RightChild != nil { 261 var tmp interface{} 262 tmp, err = r.RightChild.Eval(ctx, row) 263 if err != nil { 264 return nil, err 265 } 266 267 if tmp == nil { 268 return nil, nil 269 } 270 271 if tmp != nil { 272 tmp, _, err = types.Int32.Convert(tmp) 273 if err != nil { 274 // TODO: truncate 275 return nil, err 276 } 277 prec = tmp.(int32) 278 // MySQL cuts off at 30 for larger values 279 // TODO: these limits are fine only because we can't handle decimals larger than this 280 if prec > types.DecimalTypeMaxPrecision { 281 prec = types.DecimalTypeMaxPrecision 282 } 283 if prec < -types.DecimalTypeMaxScale { 284 prec = -types.DecimalTypeMaxScale 285 } 286 } 287 } 288 289 var res interface{} 290 tmp := val.(decimal.Decimal).Round(prec) 291 if types.IsSigned(r.LeftChild.Type()) { 292 res, _, err = types.Int64.Convert(tmp) 293 } else if types.IsUnsigned(r.LeftChild.Type()) { 294 res, _, err = types.Uint64.Convert(tmp) 295 } else if types.IsFloat(r.LeftChild.Type()) { 296 res, _, err = types.Float64.Convert(tmp) 297 } else if types.IsDecimal(r.LeftChild.Type()) { 298 res = tmp 299 } else if types.IsTextBlob(r.LeftChild.Type()) { 300 res, _, err = types.Float64.Convert(tmp) 301 } 302 303 return res, err 304 } 305 306 // IsNullable implements the Expression interface. 307 func (r *Round) IsNullable() bool { 308 return r.LeftChild.IsNullable() 309 } 310 311 func (r *Round) String() string { 312 if r.RightChild == nil { 313 return fmt.Sprintf("%s(%s,0)", r.FunctionName(), r.LeftChild.String()) 314 } 315 316 return fmt.Sprintf("%s(%s,%s)", r.FunctionName(), r.LeftChild.String(), r.RightChild.String()) 317 } 318 319 // Resolved implements the Expression interface. 320 func (r *Round) Resolved() bool { 321 return r.LeftChild.Resolved() && (r.RightChild == nil || r.RightChild.Resolved()) 322 } 323 324 // Type implements the Expression interface. 325 func (r *Round) Type() sql.Type { 326 leftChildType := r.LeftChild.Type() 327 if types.IsNumber(leftChildType) { 328 return leftChildType 329 } 330 return types.Int32 331 } 332 333 // CollationCoercibility implements the interface sql.CollationCoercible. 334 func (*Round) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 335 return sql.Collation_binary, 5 336 } 337 338 // WithChildren implements the Expression interface. 339 func (r *Round) WithChildren(children ...sql.Expression) (sql.Expression, error) { 340 return NewRound(children...) 341 }