github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/interval.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"fmt"
    19  	"regexp"
    20  	"strconv"
    21  	"strings"
    22  	"time"
    23  
    24  	errors "gopkg.in/src-d/go-errors.v1"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	"github.com/dolthub/go-mysql-server/sql/types"
    28  )
    29  
    30  // Interval defines a time duration.
    31  type Interval struct {
    32  	UnaryExpression
    33  	Unit string
    34  }
    35  
    36  var _ sql.Expression = (*Interval)(nil)
    37  var _ sql.CollationCoercible = (*Interval)(nil)
    38  
    39  // NewInterval creates a new interval expression.
    40  func NewInterval(child sql.Expression, unit string) *Interval {
    41  	return &Interval{UnaryExpression{Child: child}, strings.ToUpper(unit)}
    42  }
    43  
    44  // Type implements the sql.Expression interface.
    45  func (i *Interval) Type() sql.Type { return types.Uint64 }
    46  
    47  // CollationCoercibility implements the interface sql.CollationCoercible.
    48  func (*Interval) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    49  	return sql.Collation_binary, 5
    50  }
    51  
    52  // IsNullable implements the sql.Expression interface.
    53  func (i *Interval) IsNullable() bool { return i.Child.IsNullable() }
    54  
    55  // Eval implements the sql.Expression interface.
    56  func (i *Interval) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    57  	panic("Interval.Eval is just a placeholder method and should not be called directly")
    58  }
    59  
    60  var (
    61  	errInvalidIntervalUnit   = errors.NewKind("invalid interval unit: %s")
    62  	errInvalidIntervalFormat = errors.NewKind("invalid interval format for %q: %s")
    63  )
    64  
    65  // EvalDelta evaluates the expression returning a TimeDelta. This method should
    66  // be used instead of Eval, as this expression returns a TimeDelta, which is not
    67  // a valid value that can be returned in Eval.
    68  func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error) {
    69  	val, err := i.Child.Eval(ctx, row)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	if val == nil {
    75  		return nil, nil
    76  	}
    77  
    78  	var td TimeDelta
    79  
    80  	if r, ok := unitTextFormats[i.Unit]; ok {
    81  		val, _, err = types.LongText.Convert(val)
    82  		if err != nil {
    83  			return nil, err
    84  		}
    85  
    86  		text := val.(string)
    87  		if !r.MatchString(text) {
    88  			return nil, errInvalidIntervalFormat.New(i.Unit, text)
    89  		}
    90  
    91  		parts := textFormatParts(text, r)
    92  
    93  		switch i.Unit {
    94  		case "DAY_HOUR":
    95  			td.Days = parts[0]
    96  			td.Hours = parts[1]
    97  		case "DAY_MICROSECOND":
    98  			td.Days = parts[0]
    99  			td.Hours = parts[1]
   100  			td.Minutes = parts[2]
   101  			td.Seconds = parts[3]
   102  			td.Microseconds = parts[4]
   103  		case "DAY_MINUTE":
   104  			td.Days = parts[0]
   105  			td.Hours = parts[1]
   106  			td.Minutes = parts[2]
   107  		case "DAY_SECOND":
   108  			td.Days = parts[0]
   109  			td.Hours = parts[1]
   110  			td.Minutes = parts[2]
   111  			td.Seconds = parts[3]
   112  		case "HOUR_MICROSECOND":
   113  			td.Hours = parts[0]
   114  			td.Minutes = parts[1]
   115  			td.Seconds = parts[2]
   116  			td.Microseconds = parts[3]
   117  		case "HOUR_SECOND":
   118  			td.Hours = parts[0]
   119  			td.Minutes = parts[1]
   120  			td.Seconds = parts[2]
   121  		case "HOUR_MINUTE":
   122  			td.Hours = parts[0]
   123  			td.Minutes = parts[1]
   124  		case "MINUTE_MICROSECOND":
   125  			td.Minutes = parts[0]
   126  			td.Seconds = parts[1]
   127  			td.Microseconds = parts[2]
   128  		case "MINUTE_SECOND":
   129  			td.Minutes = parts[0]
   130  			td.Seconds = parts[1]
   131  		case "SECOND_MICROSECOND":
   132  			td.Seconds = parts[0]
   133  			td.Microseconds = parts[1]
   134  		case "YEAR_MONTH":
   135  			td.Years = parts[0]
   136  			td.Months = parts[1]
   137  		default:
   138  			return nil, errInvalidIntervalUnit.New(i.Unit)
   139  		}
   140  	} else {
   141  		val, _, err = types.Int64.Convert(val)
   142  		if err != nil {
   143  			return nil, err
   144  		}
   145  
   146  		num := val.(int64)
   147  
   148  		switch i.Unit {
   149  		case "DAY":
   150  			td.Days = num
   151  		case "HOUR":
   152  			td.Hours = num
   153  		case "MINUTE":
   154  			td.Minutes = num
   155  		case "SECOND":
   156  			td.Seconds = num
   157  		case "MICROSECOND":
   158  			td.Microseconds = num
   159  		case "QUARTER":
   160  			td.Months = num * 3
   161  		case "MONTH":
   162  			td.Months = num
   163  		case "WEEK":
   164  			td.Days = num * 7
   165  		case "YEAR":
   166  			td.Years = num
   167  		default:
   168  			return nil, errInvalidIntervalUnit.New(i.Unit)
   169  		}
   170  	}
   171  
   172  	return &td, nil
   173  }
   174  
   175  // WithChildren implements the Expression interface.
   176  func (i *Interval) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   177  	if len(children) != 1 {
   178  		return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1)
   179  	}
   180  	return NewInterval(children[0], i.Unit), nil
   181  }
   182  
   183  func (i *Interval) String() string {
   184  	return fmt.Sprintf("INTERVAL %s %s", i.Child, i.Unit)
   185  }
   186  
   187  var unitTextFormats = map[string]*regexp.Regexp{
   188  	"DAY_HOUR":           regexp.MustCompile(`^(\d+)\s+(\d+)$`),
   189  	"DAY_MICROSECOND":    regexp.MustCompile(`^(\d+)\s+(\d+):(\d+):(\d+).(\d+)$`),
   190  	"DAY_MINUTE":         regexp.MustCompile(`^(\d+)\s+(\d+):(\d+)$`),
   191  	"DAY_SECOND":         regexp.MustCompile(`^(\d+)\s+(\d+):(\d+):(\d+)$`),
   192  	"HOUR_MICROSECOND":   regexp.MustCompile(`^(\d+):(\d+):(\d+).(\d+)$`),
   193  	"HOUR_SECOND":        regexp.MustCompile(`^(\d+):(\d+):(\d+)$`),
   194  	"HOUR_MINUTE":        regexp.MustCompile(`^(\d+):(\d+)$`),
   195  	"MINUTE_MICROSECOND": regexp.MustCompile(`^(\d+):(\d+).(\d+)$`),
   196  	"MINUTE_SECOND":      regexp.MustCompile(`^(\d+):(\d+)$`),
   197  	"SECOND_MICROSECOND": regexp.MustCompile(`^(\d+).(\d+)$`),
   198  	"YEAR_MONTH":         regexp.MustCompile(`^(\d+)-(\d+)$`),
   199  }
   200  
   201  func textFormatParts(text string, r *regexp.Regexp) []int64 {
   202  	parts := r.FindStringSubmatch(text)
   203  	var result []int64
   204  	for _, p := range parts[1:] {
   205  		// It is safe to ignore the error here, because at this point we know
   206  		// the string matches the regexp, and that means it can't be an
   207  		// invalid number.
   208  		n, _ := strconv.ParseInt(p, 10, 64)
   209  		result = append(result, n)
   210  	}
   211  	return result
   212  }
   213  
   214  // TimeDelta is the difference between a time and another time.
   215  type TimeDelta struct {
   216  	Years        int64
   217  	Months       int64
   218  	Days         int64
   219  	Hours        int64
   220  	Minutes      int64
   221  	Seconds      int64
   222  	Microseconds int64
   223  }
   224  
   225  // Add returns the given time plus the time delta.
   226  func (td TimeDelta) Add(t time.Time) time.Time {
   227  	return td.apply(t, 1)
   228  }
   229  
   230  // Sub returns the given time minus the time delta.
   231  func (td TimeDelta) Sub(t time.Time) time.Time {
   232  	return td.apply(t, -1)
   233  }
   234  
   235  const (
   236  	day  = 24 * time.Hour
   237  	week = 7 * day
   238  )
   239  
   240  func (td TimeDelta) apply(t time.Time, sign int64) time.Time {
   241  	y := int64(t.Year())
   242  	mo := int64(t.Month())
   243  	d := t.Day()
   244  	h := t.Hour()
   245  	min := t.Minute()
   246  	s := t.Second()
   247  	ns := t.Nanosecond()
   248  
   249  	if td.Years != 0 {
   250  		y += td.Years * sign
   251  	}
   252  
   253  	if td.Months != 0 {
   254  		m := mo + td.Months*sign
   255  		if m < 1 {
   256  			mo = 12 + (m % 12)
   257  			y += m/12 - 1
   258  		} else if m > 12 {
   259  			mo = m % 12
   260  			y += m / 12
   261  		} else {
   262  			mo = m
   263  		}
   264  
   265  		// Due to the operations done before, month may be zero, which means it's
   266  		// december.
   267  		if mo == 0 {
   268  			mo = 12
   269  		}
   270  	}
   271  
   272  	if days := daysInMonth(time.Month(mo), int(y)); days < d {
   273  		d = days
   274  	}
   275  
   276  	date := time.Date(int(y), time.Month(mo), d, h, min, s, ns, t.Location())
   277  
   278  	if td.Days != 0 {
   279  		date = date.Add(time.Duration(td.Days) * day * time.Duration(sign))
   280  	}
   281  
   282  	if td.Hours != 0 {
   283  		date = date.Add(time.Duration(td.Hours) * time.Hour * time.Duration(sign))
   284  	}
   285  
   286  	if td.Minutes != 0 {
   287  		date = date.Add(time.Duration(td.Minutes) * time.Minute * time.Duration(sign))
   288  	}
   289  
   290  	if td.Seconds != 0 {
   291  		date = date.Add(time.Duration(td.Seconds) * time.Second * time.Duration(sign))
   292  	}
   293  
   294  	if td.Microseconds != 0 {
   295  		date = date.Add(time.Duration(td.Microseconds) * time.Microsecond * time.Duration(sign))
   296  	}
   297  
   298  	return date
   299  }
   300  
   301  func daysInMonth(month time.Month, year int) int {
   302  	if month == time.December {
   303  		return 31
   304  	}
   305  
   306  	date := time.Date(year, month+time.Month(1), 1, 0, 0, 0, 0, time.Local)
   307  	return date.Add(-1 * day).Day()
   308  }