github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/sqrt_power.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/expression"
    23  	"github.com/dolthub/go-mysql-server/sql/types"
    24  )
    25  
    26  // Sqrt is a function that returns the square value of the number provided.
    27  type Sqrt struct {
    28  	expression.UnaryExpression
    29  }
    30  
    31  var _ sql.FunctionExpression = (*Sqrt)(nil)
    32  var _ sql.CollationCoercible = (*Sqrt)(nil)
    33  
    34  // NewSqrt creates a new Sqrt expression.
    35  func NewSqrt(e sql.Expression) sql.Expression {
    36  	return &Sqrt{expression.UnaryExpression{Child: e}}
    37  }
    38  
    39  // FunctionName implements sql.FunctionExpression
    40  func (s *Sqrt) FunctionName() string {
    41  	return "sqrt"
    42  }
    43  
    44  // Description implements sql.FunctionExpression
    45  func (s *Sqrt) Description() string {
    46  	return "returns the square root of a nonnegative number X."
    47  }
    48  
    49  func (s *Sqrt) String() string {
    50  	return fmt.Sprintf("sqrt(%s)", s.Child.String())
    51  }
    52  
    53  // Type implements the Expression interface.
    54  func (s *Sqrt) Type() sql.Type {
    55  	return types.Float64
    56  }
    57  
    58  // CollationCoercibility implements the interface sql.CollationCoercible.
    59  func (*Sqrt) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    60  	return sql.Collation_binary, 5
    61  }
    62  
    63  // IsNullable implements the Expression interface.
    64  func (s *Sqrt) IsNullable() bool {
    65  	return s.Child.IsNullable()
    66  }
    67  
    68  // WithChildren implements the Expression interface.
    69  func (s *Sqrt) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    70  	if len(children) != 1 {
    71  		return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1)
    72  	}
    73  	return NewSqrt(children[0]), nil
    74  }
    75  
    76  // Eval implements the Expression interface.
    77  func (s *Sqrt) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    78  	child, err := s.Child.Eval(ctx, row)
    79  
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	if child == nil {
    85  		return nil, nil
    86  	}
    87  
    88  	child, _, err = types.Float64.Convert(child)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	res := math.Sqrt(child.(float64))
    94  	if math.IsNaN(res) || math.IsInf(res, 0) {
    95  		return nil, nil
    96  	}
    97  
    98  	return res, nil
    99  }
   100  
   101  // Power is a function that returns value of X raised to the power of Y.
   102  type Power struct {
   103  	expression.BinaryExpressionStub
   104  }
   105  
   106  var _ sql.FunctionExpression = (*Power)(nil)
   107  var _ sql.CollationCoercible = (*Power)(nil)
   108  
   109  // NewPower creates a new Power expression.
   110  func NewPower(e1, e2 sql.Expression) sql.Expression {
   111  	return &Power{
   112  		expression.BinaryExpressionStub{
   113  			LeftChild:  e1,
   114  			RightChild: e2,
   115  		},
   116  	}
   117  }
   118  
   119  // FunctionName implements sql.FunctionExpression
   120  func (p *Power) FunctionName() string {
   121  	return "power"
   122  }
   123  
   124  // Description implements sql.FunctionExpression
   125  func (p *Power) Description() string {
   126  	return "returns the value of X raised to the power of Y."
   127  }
   128  
   129  // Type implements the Expression interface.
   130  func (p *Power) Type() sql.Type { return types.Float64 }
   131  
   132  // CollationCoercibility implements the interface sql.CollationCoercible.
   133  func (*Power) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   134  	return sql.Collation_binary, 5
   135  }
   136  
   137  // IsNullable implements the Expression interface.
   138  func (p *Power) IsNullable() bool { return p.LeftChild.IsNullable() || p.RightChild.IsNullable() }
   139  
   140  func (p *Power) String() string {
   141  	return fmt.Sprintf("power(%s, %s)", p.LeftChild, p.RightChild)
   142  }
   143  
   144  // WithChildren implements the Expression interface.
   145  func (p *Power) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   146  	if len(children) != 2 {
   147  		return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 2)
   148  	}
   149  	return NewPower(children[0], children[1]), nil
   150  }
   151  
   152  // Eval implements the Expression interface.
   153  func (p *Power) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   154  	left, err := p.LeftChild.Eval(ctx, row)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  
   159  	if left == nil {
   160  		return nil, nil
   161  	}
   162  
   163  	left, _, err = types.Float64.Convert(left)
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  
   168  	right, err := p.RightChild.Eval(ctx, row)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  
   173  	if right == nil {
   174  		return nil, nil
   175  	}
   176  
   177  	right, _, err = types.Float64.Convert(right)
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  
   182  	res := math.Pow(left.(float64), right.(float64))
   183  	if math.IsNaN(res) || math.IsInf(res, 0) {
   184  		return nil, nil
   185  	}
   186  
   187  	return res, nil
   188  }