github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/ceil_round_floor.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  
    21  	"github.com/shopspring/decimal"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/types"
    26  )
    27  
    28  // Ceil returns the smallest integer value not less than X.
    29  type Ceil struct {
    30  	expression.UnaryExpression
    31  }
    32  
    33  var _ sql.FunctionExpression = (*Ceil)(nil)
    34  var _ sql.CollationCoercible = (*Ceil)(nil)
    35  
    36  // NewCeil creates a new Ceil expression.
    37  func NewCeil(num sql.Expression) sql.Expression {
    38  	return &Ceil{expression.UnaryExpression{Child: num}}
    39  }
    40  
    41  // FunctionName implements sql.FunctionExpression
    42  func (c *Ceil) FunctionName() string {
    43  	return "ceil"
    44  }
    45  
    46  // Description implements sql.FunctionExpression
    47  func (c *Ceil) Description() string {
    48  	return "returns the smallest integer value that is greater than or equal to number."
    49  }
    50  
    51  // Type implements the Expression interface.
    52  func (c *Ceil) Type() sql.Type {
    53  	childType := c.Child.Type()
    54  	if types.IsInteger(childType) {
    55  		return childType
    56  	}
    57  	return types.Int32
    58  }
    59  
    60  // CollationCoercibility implements the interface sql.CollationCoercible.
    61  func (*Ceil) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    62  	return sql.Collation_binary, 5
    63  }
    64  
    65  func (c *Ceil) String() string {
    66  	return fmt.Sprintf("%s(%s)", c.FunctionName(), c.Child)
    67  }
    68  
    69  // WithChildren implements the Expression interface.
    70  func (c *Ceil) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    71  	if len(children) != 1 {
    72  		return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1)
    73  	}
    74  	return NewCeil(children[0]), nil
    75  }
    76  
    77  // Eval implements the Expression interface.
    78  func (c *Ceil) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    79  	child, err := c.Child.Eval(ctx, row)
    80  
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	if child == nil {
    86  		return nil, nil
    87  	}
    88  
    89  	// non number type will be caught here
    90  	if !types.IsNumber(c.Child.Type()) {
    91  		child, _, err = types.Float64.Convert(child)
    92  		if err != nil {
    93  			return int32(0), nil
    94  		}
    95  
    96  		return int32(math.Ceil(child.(float64))), nil
    97  	}
    98  
    99  	// if it's number type and not float value, it does not need ceil-ing
   100  	switch num := child.(type) {
   101  	case float64:
   102  		return math.Ceil(num), nil
   103  	case float32:
   104  		return float32(math.Ceil(float64(num))), nil
   105  	case decimal.Decimal:
   106  		return num.Ceil(), nil
   107  	default:
   108  		return child, nil
   109  	}
   110  }
   111  
   112  // Floor returns the biggest integer value not less than X.
   113  type Floor struct {
   114  	expression.UnaryExpression
   115  }
   116  
   117  var _ sql.FunctionExpression = (*Floor)(nil)
   118  var _ sql.CollationCoercible = (*Floor)(nil)
   119  
   120  // NewFloor returns a new Floor expression.
   121  func NewFloor(num sql.Expression) sql.Expression {
   122  	return &Floor{expression.UnaryExpression{Child: num}}
   123  }
   124  
   125  // FunctionName implements sql.FunctionExpression
   126  func (f *Floor) FunctionName() string {
   127  	return "floor"
   128  }
   129  
   130  // Description implements sql.FunctionExpression
   131  func (f *Floor) Description() string {
   132  	return "returns the largest integer value that is less than or equal to number."
   133  }
   134  
   135  // Type implements the Expression interface.
   136  func (f *Floor) Type() sql.Type {
   137  	childType := f.Child.Type()
   138  	if types.IsInteger(childType) {
   139  		return childType
   140  	}
   141  	return types.Int32
   142  }
   143  
   144  // CollationCoercibility implements the interface sql.CollationCoercible.
   145  func (*Floor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   146  	return sql.Collation_binary, 5
   147  }
   148  
   149  func (f *Floor) String() string {
   150  	return fmt.Sprintf("%s(%s)", f.FunctionName(), f.Child)
   151  }
   152  
   153  // WithChildren implements the Expression interface.
   154  func (f *Floor) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   155  	if len(children) != 1 {
   156  		return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1)
   157  	}
   158  	return NewFloor(children[0]), nil
   159  }
   160  
   161  // Eval implements the Expression interface.
   162  func (f *Floor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   163  	child, err := f.Child.Eval(ctx, row)
   164  
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  
   169  	if child == nil {
   170  		return nil, nil
   171  	}
   172  
   173  	// non number type will be caught here
   174  	if !types.IsNumber(f.Child.Type()) {
   175  		child, _, err = types.Float64.Convert(child)
   176  		if err != nil {
   177  			return int32(0), nil
   178  		}
   179  
   180  		return int32(math.Floor(child.(float64))), nil
   181  	}
   182  
   183  	// if it's number type and not float value, it does not need floor-ing
   184  	switch num := child.(type) {
   185  	case float64:
   186  		return math.Floor(num), nil
   187  	case float32:
   188  		return float32(math.Floor(float64(num))), nil
   189  	case decimal.Decimal:
   190  		return num.Floor(), nil
   191  	default:
   192  		return child, nil
   193  	}
   194  }
   195  
   196  // Round returns the number (x) with (d) requested decimal places.
   197  // If d is negative, the number is returned with the (abs(d)) least significant
   198  // digits of it's integer part set to 0. If d is not specified or nil/null
   199  // it defaults to 0.
   200  type Round struct {
   201  	expression.BinaryExpressionStub
   202  }
   203  
   204  var _ sql.FunctionExpression = (*Round)(nil)
   205  var _ sql.CollationCoercible = (*Round)(nil)
   206  
   207  // NewRound returns a new Round expression.
   208  func NewRound(args ...sql.Expression) (sql.Expression, error) {
   209  	argLen := len(args)
   210  	if argLen == 0 || argLen > 2 {
   211  		return nil, sql.ErrInvalidArgumentNumber.New("ROUND", "1 or 2", argLen)
   212  	}
   213  
   214  	var right sql.Expression
   215  	if len(args) == 2 {
   216  		right = args[1]
   217  	}
   218  
   219  	return &Round{expression.BinaryExpressionStub{LeftChild: args[0], RightChild: right}}, nil
   220  }
   221  
   222  // FunctionName implements sql.FunctionExpression
   223  func (r *Round) FunctionName() string {
   224  	return "round"
   225  }
   226  
   227  // Description implements sql.FunctionExpression
   228  func (r *Round) Description() string {
   229  	return "rounds the number to decimals decimal places."
   230  }
   231  
   232  // Children implements the Expression interface.
   233  func (r *Round) Children() []sql.Expression {
   234  	if r.RightChild == nil {
   235  		return []sql.Expression{r.LeftChild}
   236  	}
   237  
   238  	return r.BinaryExpressionStub.Children()
   239  }
   240  
   241  // Eval implements the Expression interface.
   242  func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   243  	val, err := r.LeftChild.Eval(ctx, row)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	if val == nil {
   249  		return nil, nil
   250  	}
   251  
   252  	decType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)
   253  	val, _, err = decType.Convert(val)
   254  	if err != nil {
   255  		// TODO: truncate
   256  		return nil, err
   257  	}
   258  
   259  	prec := int32(0)
   260  	if r.RightChild != nil {
   261  		var tmp interface{}
   262  		tmp, err = r.RightChild.Eval(ctx, row)
   263  		if err != nil {
   264  			return nil, err
   265  		}
   266  
   267  		if tmp == nil {
   268  			return nil, nil
   269  		}
   270  
   271  		if tmp != nil {
   272  			tmp, _, err = types.Int32.Convert(tmp)
   273  			if err != nil {
   274  				// TODO: truncate
   275  				return nil, err
   276  			}
   277  			prec = tmp.(int32)
   278  			// MySQL cuts off at 30 for larger values
   279  			// TODO: these limits are fine only because we can't handle decimals larger than this
   280  			if prec > types.DecimalTypeMaxPrecision {
   281  				prec = types.DecimalTypeMaxPrecision
   282  			}
   283  			if prec < -types.DecimalTypeMaxScale {
   284  				prec = -types.DecimalTypeMaxScale
   285  			}
   286  		}
   287  	}
   288  
   289  	var res interface{}
   290  	tmp := val.(decimal.Decimal).Round(prec)
   291  	if types.IsSigned(r.LeftChild.Type()) {
   292  		res, _, err = types.Int64.Convert(tmp)
   293  	} else if types.IsUnsigned(r.LeftChild.Type()) {
   294  		res, _, err = types.Uint64.Convert(tmp)
   295  	} else if types.IsFloat(r.LeftChild.Type()) {
   296  		res, _, err = types.Float64.Convert(tmp)
   297  	} else if types.IsDecimal(r.LeftChild.Type()) {
   298  		res = tmp
   299  	} else if types.IsTextBlob(r.LeftChild.Type()) {
   300  		res, _, err = types.Float64.Convert(tmp)
   301  	}
   302  
   303  	return res, err
   304  }
   305  
   306  // IsNullable implements the Expression interface.
   307  func (r *Round) IsNullable() bool {
   308  	return r.LeftChild.IsNullable()
   309  }
   310  
   311  func (r *Round) String() string {
   312  	if r.RightChild == nil {
   313  		return fmt.Sprintf("%s(%s,0)", r.FunctionName(), r.LeftChild.String())
   314  	}
   315  
   316  	return fmt.Sprintf("%s(%s,%s)", r.FunctionName(), r.LeftChild.String(), r.RightChild.String())
   317  }
   318  
   319  // Resolved implements the Expression interface.
   320  func (r *Round) Resolved() bool {
   321  	return r.LeftChild.Resolved() && (r.RightChild == nil || r.RightChild.Resolved())
   322  }
   323  
   324  // Type implements the Expression interface.
   325  func (r *Round) Type() sql.Type {
   326  	leftChildType := r.LeftChild.Type()
   327  	if types.IsNumber(leftChildType) {
   328  		return leftChildType
   329  	}
   330  	return types.Int32
   331  }
   332  
   333  // CollationCoercibility implements the interface sql.CollationCoercible.
   334  func (*Round) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   335  	return sql.Collation_binary, 5
   336  }
   337  
   338  // WithChildren implements the Expression interface.
   339  func (r *Round) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   340  	return NewRound(children...)
   341  }