github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/logarithm.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"reflect"
    21  
    22  	"gopkg.in/src-d/go-errors.v1"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  // ErrInvalidArgumentForLogarithm is returned when an invalid argument value is passed to a
    30  // logarithm function
    31  var ErrInvalidArgumentForLogarithm = errors.NewKind("invalid argument value for logarithm: %v")
    32  
    33  // NewLogBaseFunc returns LogBase creator function with a specific base.
    34  func NewLogBaseFunc(base float64) func(e sql.Expression) sql.Expression {
    35  	return func(e sql.Expression) sql.Expression {
    36  		return NewLogBase(base, e)
    37  	}
    38  }
    39  
    40  // LogBase is a function that returns the logarithm of a value with a specific base.
    41  type LogBase struct {
    42  	expression.UnaryExpression
    43  	base float64
    44  }
    45  
    46  var _ sql.FunctionExpression = (*LogBase)(nil)
    47  var _ sql.CollationCoercible = (*LogBase)(nil)
    48  
    49  // NewLogBase creates a new LogBase expression.
    50  func NewLogBase(base float64, e sql.Expression) sql.Expression {
    51  	return &LogBase{UnaryExpression: expression.UnaryExpression{Child: e}, base: base}
    52  }
    53  
    54  // FunctionName implements sql.FunctionExpression
    55  func (l *LogBase) FunctionName() string {
    56  	switch l.base {
    57  	case float64(math.E):
    58  		return "ln"
    59  	case float64(10):
    60  		return "log10"
    61  	case float64(2):
    62  		return "log2"
    63  	default:
    64  		return "log"
    65  	}
    66  }
    67  
    68  // Description implements sql.FunctionExpression
    69  func (l *LogBase) Description() string {
    70  	switch l.base {
    71  	case float64(math.E):
    72  		return "returns the natural logarithm of X."
    73  	case float64(10):
    74  		return "returns the base-10 logarithm of X."
    75  	case float64(2):
    76  		return "returns the base-2 logarithm of X."
    77  	default:
    78  		return "if called with one parameter, this function returns the natural logarithm of X. If called with two parameters, this function returns the logarithm of X to the base B. If X is less than or equal to 0, or if B is less than or equal to 1, then NULL is returned."
    79  	}
    80  }
    81  
    82  func (l *LogBase) String() string {
    83  	switch l.base {
    84  	case float64(math.E):
    85  		return fmt.Sprintf("ln(%s)", l.Child)
    86  	case float64(10):
    87  		return fmt.Sprintf("log10(%s)", l.Child)
    88  	case float64(2):
    89  		return fmt.Sprintf("log2(%s)", l.Child)
    90  	default:
    91  		return fmt.Sprintf("log(%v, %s)", l.base, l.Child)
    92  	}
    93  }
    94  
    95  // WithChildren implements the Expression interface.
    96  func (l *LogBase) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    97  	if len(children) != 1 {
    98  		return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1)
    99  	}
   100  	return NewLogBase(l.base, children[0]), nil
   101  }
   102  
   103  // Type returns the resultant type of the function.
   104  func (l *LogBase) Type() sql.Type {
   105  	return types.Float64
   106  }
   107  
   108  // CollationCoercibility implements the interface sql.CollationCoercible.
   109  func (*LogBase) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   110  	return sql.Collation_binary, 5
   111  }
   112  
   113  // IsNullable implements the sql.Expression interface.
   114  func (l *LogBase) IsNullable() bool {
   115  	return l.base == float64(1) || l.base <= float64(0) || l.Child.IsNullable()
   116  }
   117  
   118  // Eval implements the Expression interface.
   119  func (l *LogBase) Eval(
   120  	ctx *sql.Context,
   121  	row sql.Row,
   122  ) (interface{}, error) {
   123  	v, err := l.Child.Eval(ctx, row)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	if v == nil {
   129  		return nil, nil
   130  	}
   131  
   132  	val, _, err := types.Float64.Convert(v)
   133  	if err != nil {
   134  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(v))
   135  	}
   136  	return computeLog(ctx, val.(float64), l.base)
   137  }
   138  
   139  // Log is a function that returns the natural logarithm of a value.
   140  type Log struct {
   141  	expression.BinaryExpressionStub
   142  }
   143  
   144  var _ sql.FunctionExpression = (*Log)(nil)
   145  var _ sql.CollationCoercible = (*Log)(nil)
   146  
   147  // NewLog creates a new Log expression.
   148  func NewLog(args ...sql.Expression) (sql.Expression, error) {
   149  	argLen := len(args)
   150  	if argLen == 0 || argLen > 2 {
   151  		return nil, sql.ErrInvalidArgumentNumber.New("LOG", "1 or 2", argLen)
   152  	}
   153  
   154  	if argLen == 1 {
   155  		return &Log{expression.BinaryExpressionStub{LeftChild: expression.NewLiteral(math.E, types.Float64), RightChild: args[0]}}, nil
   156  	} else {
   157  		return &Log{expression.BinaryExpressionStub{LeftChild: args[0], RightChild: args[1]}}, nil
   158  	}
   159  }
   160  
   161  // FunctionName implements sql.FunctionExpression
   162  func (l *Log) FunctionName() string {
   163  	return "log"
   164  }
   165  
   166  // Description implements sql.FunctionExpression
   167  func (l *Log) Description() string {
   168  	return "if called with one parameter, this function returns the natural logarithm of X. If called with two parameters, this function returns the logarithm of X to the base B. If X is less than or equal to 0, or if B is less than or equal to 1, then NULL is returned."
   169  }
   170  
   171  func (l *Log) String() string {
   172  	return fmt.Sprintf("%s(%s,%s)", l.FunctionName(), l.LeftChild, l.RightChild)
   173  }
   174  
   175  // WithChildren implements the Expression interface.
   176  func (l *Log) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   177  	return NewLog(children...)
   178  }
   179  
   180  // Children implements the Expression interface.
   181  func (l *Log) Children() []sql.Expression {
   182  	return []sql.Expression{l.LeftChild, l.RightChild}
   183  }
   184  
   185  // Type returns the resultant type of the function.
   186  func (l *Log) Type() sql.Type {
   187  	return types.Float64
   188  }
   189  
   190  // CollationCoercibility implements the interface sql.CollationCoercible.
   191  func (*Log) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   192  	return sql.Collation_binary, 5
   193  }
   194  
   195  // IsNullable implements the Expression interface.
   196  func (l *Log) IsNullable() bool {
   197  	return l.LeftChild.IsNullable() || l.RightChild.IsNullable()
   198  }
   199  
   200  // Eval implements the Expression interface.
   201  func (l *Log) Eval(
   202  	ctx *sql.Context,
   203  	row sql.Row,
   204  ) (interface{}, error) {
   205  	left, err := l.LeftChild.Eval(ctx, row)
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	if left == nil {
   211  		return nil, nil
   212  	}
   213  
   214  	lhs, _, err := types.Float64.Convert(left)
   215  	if err != nil {
   216  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(left))
   217  	}
   218  
   219  	right, err := l.RightChild.Eval(ctx, row)
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  
   224  	if right == nil {
   225  		return nil, nil
   226  	}
   227  
   228  	rhs, _, err := types.Float64.Convert(right)
   229  	if err != nil {
   230  		return nil, sql.ErrInvalidType.New(reflect.TypeOf(right))
   231  	}
   232  
   233  	// rhs becomes value, lhs becomes base
   234  	return computeLog(ctx, rhs.(float64), lhs.(float64))
   235  }
   236  
   237  func computeLog(ctx *sql.Context, v float64, base float64) (interface{}, error) {
   238  	if v <= 0 {
   239  		ctx.Warn(3020, ErrInvalidArgumentForLogarithm.New(v).Error())
   240  		return nil, nil
   241  	}
   242  	if base == float64(1) || base <= float64(0) {
   243  		ctx.Warn(3020, ErrInvalidArgumentForLogarithm.New(base).Error())
   244  		return nil, nil
   245  	}
   246  	switch base {
   247  	case float64(2):
   248  		return math.Log2(v), nil
   249  	case float64(10):
   250  		return math.Log10(v), nil
   251  	case math.E:
   252  		return math.Log(v), nil
   253  	default:
   254  		// LOG(BASE,V) is equivalent to LOG(V) / LOG(BASE).
   255  		return float64(math.Log(v) / math.Log(base)), nil
   256  	}
   257  }