github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/timediff.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"strings"
    21  	"time"
    22  
    23  	"gopkg.in/src-d/go-errors.v1"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/go-mysql-server/sql/expression"
    27  	"github.com/dolthub/go-mysql-server/sql/types"
    28  )
    29  
    30  // TimeDiff subtracts the second argument from the first expressed as a time value.
    31  type TimeDiff struct {
    32  	expression.BinaryExpressionStub
    33  }
    34  
    35  var _ sql.FunctionExpression = (*TimeDiff)(nil)
    36  var _ sql.CollationCoercible = (*TimeDiff)(nil)
    37  
    38  // NewTimeDiff creates a new NewTimeDiff expression.
    39  func NewTimeDiff(e1, e2 sql.Expression) sql.Expression {
    40  	return &TimeDiff{
    41  		expression.BinaryExpressionStub{
    42  			LeftChild:  e1,
    43  			RightChild: e2,
    44  		},
    45  	}
    46  }
    47  
    48  // FunctionName implements sql.FunctionExpression
    49  func (td *TimeDiff) FunctionName() string {
    50  	return "timediff"
    51  }
    52  
    53  // Description implements sql.FunctionExpression
    54  func (td *TimeDiff) Description() string {
    55  	return "returns expr1 − expr2 expressed as a time value. expr1 and expr2 are time or date-and-time expressions, but both must be of the same type."
    56  }
    57  
    58  // Type implements the Expression interface.
    59  func (td *TimeDiff) Type() sql.Type { return types.Time }
    60  
    61  // CollationCoercibility implements the interface sql.CollationCoercible.
    62  func (*TimeDiff) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    63  	return sql.Collation_binary, 5
    64  }
    65  
    66  func (td *TimeDiff) String() string {
    67  	return fmt.Sprintf("%s(%s,%s)", td.FunctionName(), td.LeftChild, td.RightChild)
    68  }
    69  
    70  // WithChildren implements the Expression interface.
    71  func (td *TimeDiff) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    72  	if len(children) != 2 {
    73  		return nil, sql.ErrInvalidChildrenNumber.New(td, len(children), 2)
    74  	}
    75  	return NewTimeDiff(children[0], children[1]), nil
    76  }
    77  
    78  func convToDateOrTime(val interface{}) (interface{}, error) {
    79  	date, _, err := types.DatetimeMaxPrecision.Convert(val)
    80  	if err == nil {
    81  		return date, nil
    82  	}
    83  	tim, _, err := types.Time.Convert(val)
    84  	if err == nil {
    85  		return tim, err
    86  	}
    87  	return nil, err
    88  }
    89  
    90  // Eval implements the Expression interface.
    91  func (td *TimeDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    92  	if td.LeftChild == nil || td.RightChild == nil {
    93  		return nil, nil
    94  	}
    95  
    96  	left, err := td.LeftChild.Eval(ctx, row)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	right, err := td.RightChild.Eval(ctx, row)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	if left == nil || right == nil {
   107  		return nil, nil
   108  	}
   109  
   110  	// always convert string types
   111  	if _, ok := left.(string); ok {
   112  		left, err = convToDateOrTime(left)
   113  		if err != nil {
   114  			ctx.Warn(1292, err.Error())
   115  			return nil, nil
   116  		}
   117  	}
   118  	if _, ok := right.(string); ok {
   119  		right, err = convToDateOrTime(right)
   120  		if err != nil {
   121  			ctx.Warn(1292, err.Error())
   122  			return nil, nil
   123  		}
   124  	}
   125  
   126  	// handle as date
   127  	if leftDatetime, ok := left.(time.Time); ok {
   128  		rightDatetime, ok := right.(time.Time)
   129  		if !ok {
   130  			return nil, nil
   131  		}
   132  		if leftDatetime.Location() != rightDatetime.Location() {
   133  			rightDatetime = rightDatetime.In(leftDatetime.Location())
   134  		}
   135  		ret, _, err := types.Time.Convert(leftDatetime.Sub(rightDatetime))
   136  		return ret, err
   137  	}
   138  
   139  	// handle as time
   140  	if leftTime, ok := left.(types.Timespan); ok {
   141  		rightTime, ok := right.(types.Timespan)
   142  		if !ok {
   143  			return nil, nil
   144  		}
   145  		return leftTime.Subtract(rightTime), nil
   146  	}
   147  	return nil, sql.ErrInvalidArgumentType.New("timediff")
   148  }
   149  
   150  // DateDiff returns expr1 − expr2 expressed as a value in days from one date to the other.
   151  type DateDiff struct {
   152  	expression.BinaryExpressionStub
   153  }
   154  
   155  var _ sql.FunctionExpression = (*DateDiff)(nil)
   156  var _ sql.CollationCoercible = (*DateDiff)(nil)
   157  
   158  // NewDateDiff creates a new DATEDIFF() function.
   159  func NewDateDiff(expr1, expr2 sql.Expression) sql.Expression {
   160  	return &DateDiff{
   161  		expression.BinaryExpressionStub{
   162  			LeftChild:  expr1,
   163  			RightChild: expr2,
   164  		},
   165  	}
   166  }
   167  
   168  // FunctionName implements sql.FunctionExpression
   169  func (d *DateDiff) FunctionName() string {
   170  	return "datediff"
   171  }
   172  
   173  // Description implements sql.FunctionExpression
   174  func (d *DateDiff) Description() string {
   175  	return "gets difference between two dates in result of days."
   176  }
   177  
   178  // Type implements the sql.Expression interface.
   179  func (d *DateDiff) Type() sql.Type { return types.Int64 }
   180  
   181  // CollationCoercibility implements the interface sql.CollationCoercible.
   182  func (*DateDiff) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   183  	return sql.Collation_binary, 5
   184  }
   185  
   186  // WithChildren implements the Expression interface.
   187  func (d *DateDiff) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   188  	if len(children) != 2 {
   189  		return nil, sql.ErrInvalidChildrenNumber.New(d, len(children), 2)
   190  	}
   191  	return NewDateDiff(children[0], children[1]), nil
   192  }
   193  
   194  // Eval implements the sql.Expression interface.
   195  func (d *DateDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   196  	if d.LeftChild == nil || d.RightChild == nil {
   197  		return nil, nil
   198  	}
   199  
   200  	expr1, err := d.LeftChild.Eval(ctx, row)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	if expr1 == nil {
   205  		return nil, nil
   206  	}
   207  
   208  	expr1, _, err = types.DatetimeMaxPrecision.Convert(expr1)
   209  	if err != nil {
   210  		return nil, err
   211  	}
   212  
   213  	expr1str := expr1.(time.Time).String()[:10]
   214  	expr1, _, _ = types.DatetimeMaxPrecision.Convert(expr1str)
   215  
   216  	expr2, err := d.RightChild.Eval(ctx, row)
   217  	if err != nil {
   218  		return nil, err
   219  	}
   220  	if expr2 == nil {
   221  		return nil, nil
   222  	}
   223  
   224  	expr2, _, err = types.DatetimeMaxPrecision.Convert(expr2)
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  
   229  	expr2str := expr2.(time.Time).String()[:10]
   230  	expr2, _, _ = types.DatetimeMaxPrecision.Convert(expr2str)
   231  
   232  	date1 := expr1.(time.Time)
   233  	date2 := expr2.(time.Time)
   234  
   235  	diff := int64(math.Round(date1.Sub(date2).Hours() / 24))
   236  
   237  	return diff, nil
   238  }
   239  
   240  func (d *DateDiff) String() string {
   241  	return fmt.Sprintf("DATEDIFF(%s, %s)", d.LeftChild, d.RightChild)
   242  }
   243  
   244  // TimestampDiff returns expr1 − expr2 expressed as a value in unit specified.
   245  type TimestampDiff struct {
   246  	unit  sql.Expression
   247  	expr1 sql.Expression
   248  	expr2 sql.Expression
   249  }
   250  
   251  var _ sql.FunctionExpression = (*TimestampDiff)(nil)
   252  var _ sql.CollationCoercible = (*TimestampDiff)(nil)
   253  
   254  // NewTimestampDiff creates a new TIMESTAMPDIFF() function.
   255  func NewTimestampDiff(u, e1, e2 sql.Expression) sql.Expression {
   256  	return &TimestampDiff{u, e1, e2}
   257  }
   258  
   259  // FunctionName implements sql.FunctionExpression
   260  func (t *TimestampDiff) FunctionName() string {
   261  	return "timestampdiff"
   262  }
   263  
   264  // Description implements sql.FunctionExpression
   265  func (t *TimestampDiff) Description() string {
   266  	return "gets difference between two dates in result of units specified."
   267  }
   268  
   269  // Children implements the sql.Expression interface.
   270  func (t *TimestampDiff) Children() []sql.Expression {
   271  	return []sql.Expression{t.unit, t.expr1, t.expr2}
   272  }
   273  
   274  // Resolved implements the sql.Expression interface.
   275  func (t *TimestampDiff) Resolved() bool {
   276  	return t.unit.Resolved() && t.expr1.Resolved() && t.expr2.Resolved()
   277  }
   278  
   279  // IsNullable implements the sql.Expression interface.
   280  func (t *TimestampDiff) IsNullable() bool {
   281  	return t.unit.IsNullable() && t.expr1.IsNullable() && t.expr2.IsNullable()
   282  }
   283  
   284  // Type implements the sql.Expression interface.
   285  func (t *TimestampDiff) Type() sql.Type { return types.Int64 }
   286  
   287  // CollationCoercibility implements the interface sql.CollationCoercible.
   288  func (*TimestampDiff) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   289  	return sql.Collation_binary, 5
   290  }
   291  
   292  // WithChildren implements the Expression interface.
   293  func (t *TimestampDiff) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   294  	if len(children) != 3 {
   295  		return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 3)
   296  	}
   297  	return NewTimestampDiff(children[0], children[1], children[2]), nil
   298  }
   299  
   300  // Eval implements the sql.Expression interface.
   301  func (t *TimestampDiff) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   302  	if t.unit == nil {
   303  		return nil, errors.NewKind("unit cannot be null").New(t.unit)
   304  	}
   305  	if t.expr1 == nil || t.expr2 == nil {
   306  		return nil, nil
   307  	}
   308  
   309  	expr1, err := t.expr1.Eval(ctx, row)
   310  	if err != nil {
   311  		return nil, err
   312  	}
   313  	if expr1 == nil {
   314  		return nil, nil
   315  	}
   316  
   317  	expr2, err := t.expr2.Eval(ctx, row)
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  	if expr2 == nil {
   322  		return nil, nil
   323  	}
   324  
   325  	expr1, _, err = types.DatetimeMaxPrecision.Convert(expr1)
   326  	if err != nil {
   327  		return nil, err
   328  	}
   329  
   330  	expr2, _, err = types.DatetimeMaxPrecision.Convert(expr2)
   331  	if err != nil {
   332  		return nil, err
   333  	}
   334  
   335  	unit, err := t.unit.Eval(ctx, row)
   336  	if err != nil {
   337  		return nil, err
   338  	}
   339  	if unit == nil {
   340  		return nil, errors.NewKind("unit cannot be null").New(unit)
   341  	}
   342  
   343  	unit = strings.TrimPrefix(strings.ToLower(unit.(string)), "sql_tsi_")
   344  
   345  	date1 := expr1.(time.Time)
   346  	date2 := expr2.(time.Time)
   347  
   348  	diff := date2.Sub(date1)
   349  
   350  	var res int64
   351  	switch unit {
   352  	case "microsecond":
   353  		res = diff.Microseconds()
   354  	case "second":
   355  		res = int64(diff.Seconds())
   356  	case "minute":
   357  		res = int64(diff.Minutes())
   358  	case "hour":
   359  		res = int64(diff.Hours())
   360  	case "day":
   361  		res = int64(diff.Hours() / 24)
   362  	case "week":
   363  		res = int64(diff.Hours() / (24 * 7))
   364  	case "month":
   365  		res = int64(diff.Hours() / (24 * 30))
   366  		if res > 0 {
   367  			if date2.Day()-date1.Day() < 0 {
   368  				res -= 1
   369  			} else if date2.Hour()-date1.Hour() < 0 {
   370  				res -= 1
   371  			} else if date2.Minute()-date1.Minute() < 0 {
   372  				res -= 1
   373  			} else if date2.Second()-date1.Second() < 0 {
   374  				res -= 1
   375  			}
   376  		}
   377  	case "quarter":
   378  		monthRes := int64(diff.Hours() / (24 * 30))
   379  		if monthRes > 0 {
   380  			if date2.Day()-date1.Day() < 0 {
   381  				monthRes -= 1
   382  			} else if date2.Hour()-date1.Hour() < 0 {
   383  				monthRes -= 1
   384  			} else if date2.Minute()-date1.Minute() < 0 {
   385  				monthRes -= 1
   386  			} else if date2.Second()-date1.Second() < 0 {
   387  				monthRes -= 1
   388  			}
   389  		}
   390  		res = monthRes / 3
   391  	case "year":
   392  		yearRes := int64(diff.Hours() / (24 * 365))
   393  		if yearRes > 0 {
   394  			monthRes := int64(diff.Hours() / (24 * 30))
   395  			if monthRes > 0 {
   396  				if date2.Day()-date1.Day() < 0 {
   397  					monthRes -= 1
   398  				} else if date2.Hour()-date1.Hour() < 0 {
   399  					monthRes -= 1
   400  				} else if date2.Minute()-date1.Minute() < 0 {
   401  					monthRes -= 1
   402  				} else if date2.Second()-date1.Second() < 0 {
   403  					monthRes -= 1
   404  				}
   405  			}
   406  			res = monthRes / 12
   407  		} else {
   408  			res = yearRes
   409  		}
   410  
   411  	default:
   412  		return nil, errors.NewKind("invalid interval unit: %s").New(unit)
   413  	}
   414  
   415  	return res, nil
   416  }
   417  
   418  func (t *TimestampDiff) String() string {
   419  	return fmt.Sprintf("TIMESTAMPDIFF(%s, %s, %s)", t.unit, t.expr1, t.expr2)
   420  }