github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/interval.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "fmt" 19 "regexp" 20 "strconv" 21 "strings" 22 "time" 23 24 errors "gopkg.in/src-d/go-errors.v1" 25 26 "github.com/dolthub/go-mysql-server/sql" 27 "github.com/dolthub/go-mysql-server/sql/types" 28 ) 29 30 // Interval defines a time duration. 31 type Interval struct { 32 UnaryExpression 33 Unit string 34 } 35 36 var _ sql.Expression = (*Interval)(nil) 37 var _ sql.CollationCoercible = (*Interval)(nil) 38 39 // NewInterval creates a new interval expression. 40 func NewInterval(child sql.Expression, unit string) *Interval { 41 return &Interval{UnaryExpression{Child: child}, strings.ToUpper(unit)} 42 } 43 44 // Type implements the sql.Expression interface. 45 func (i *Interval) Type() sql.Type { return types.Uint64 } 46 47 // CollationCoercibility implements the interface sql.CollationCoercible. 48 func (*Interval) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 49 return sql.Collation_binary, 5 50 } 51 52 // IsNullable implements the sql.Expression interface. 53 func (i *Interval) IsNullable() bool { return i.Child.IsNullable() } 54 55 // Eval implements the sql.Expression interface. 56 func (i *Interval) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 57 panic("Interval.Eval is just a placeholder method and should not be called directly") 58 } 59 60 var ( 61 errInvalidIntervalUnit = errors.NewKind("invalid interval unit: %s") 62 errInvalidIntervalFormat = errors.NewKind("invalid interval format for %q: %s") 63 ) 64 65 // EvalDelta evaluates the expression returning a TimeDelta. This method should 66 // be used instead of Eval, as this expression returns a TimeDelta, which is not 67 // a valid value that can be returned in Eval. 68 func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error) { 69 val, err := i.Child.Eval(ctx, row) 70 if err != nil { 71 return nil, err 72 } 73 74 if val == nil { 75 return nil, nil 76 } 77 78 var td TimeDelta 79 80 if r, ok := unitTextFormats[i.Unit]; ok { 81 val, _, err = types.LongText.Convert(val) 82 if err != nil { 83 return nil, err 84 } 85 86 text := val.(string) 87 if !r.MatchString(text) { 88 return nil, errInvalidIntervalFormat.New(i.Unit, text) 89 } 90 91 parts := textFormatParts(text, r) 92 93 switch i.Unit { 94 case "DAY_HOUR": 95 td.Days = parts[0] 96 td.Hours = parts[1] 97 case "DAY_MICROSECOND": 98 td.Days = parts[0] 99 td.Hours = parts[1] 100 td.Minutes = parts[2] 101 td.Seconds = parts[3] 102 td.Microseconds = parts[4] 103 case "DAY_MINUTE": 104 td.Days = parts[0] 105 td.Hours = parts[1] 106 td.Minutes = parts[2] 107 case "DAY_SECOND": 108 td.Days = parts[0] 109 td.Hours = parts[1] 110 td.Minutes = parts[2] 111 td.Seconds = parts[3] 112 case "HOUR_MICROSECOND": 113 td.Hours = parts[0] 114 td.Minutes = parts[1] 115 td.Seconds = parts[2] 116 td.Microseconds = parts[3] 117 case "HOUR_SECOND": 118 td.Hours = parts[0] 119 td.Minutes = parts[1] 120 td.Seconds = parts[2] 121 case "HOUR_MINUTE": 122 td.Hours = parts[0] 123 td.Minutes = parts[1] 124 case "MINUTE_MICROSECOND": 125 td.Minutes = parts[0] 126 td.Seconds = parts[1] 127 td.Microseconds = parts[2] 128 case "MINUTE_SECOND": 129 td.Minutes = parts[0] 130 td.Seconds = parts[1] 131 case "SECOND_MICROSECOND": 132 td.Seconds = parts[0] 133 td.Microseconds = parts[1] 134 case "YEAR_MONTH": 135 td.Years = parts[0] 136 td.Months = parts[1] 137 default: 138 return nil, errInvalidIntervalUnit.New(i.Unit) 139 } 140 } else { 141 val, _, err = types.Int64.Convert(val) 142 if err != nil { 143 return nil, err 144 } 145 146 num := val.(int64) 147 148 switch i.Unit { 149 case "DAY": 150 td.Days = num 151 case "HOUR": 152 td.Hours = num 153 case "MINUTE": 154 td.Minutes = num 155 case "SECOND": 156 td.Seconds = num 157 case "MICROSECOND": 158 td.Microseconds = num 159 case "QUARTER": 160 td.Months = num * 3 161 case "MONTH": 162 td.Months = num 163 case "WEEK": 164 td.Days = num * 7 165 case "YEAR": 166 td.Years = num 167 default: 168 return nil, errInvalidIntervalUnit.New(i.Unit) 169 } 170 } 171 172 return &td, nil 173 } 174 175 // WithChildren implements the Expression interface. 176 func (i *Interval) WithChildren(children ...sql.Expression) (sql.Expression, error) { 177 if len(children) != 1 { 178 return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 1) 179 } 180 return NewInterval(children[0], i.Unit), nil 181 } 182 183 func (i *Interval) String() string { 184 return fmt.Sprintf("INTERVAL %s %s", i.Child, i.Unit) 185 } 186 187 var unitTextFormats = map[string]*regexp.Regexp{ 188 "DAY_HOUR": regexp.MustCompile(`^(\d+)\s+(\d+)$`), 189 "DAY_MICROSECOND": regexp.MustCompile(`^(\d+)\s+(\d+):(\d+):(\d+).(\d+)$`), 190 "DAY_MINUTE": regexp.MustCompile(`^(\d+)\s+(\d+):(\d+)$`), 191 "DAY_SECOND": regexp.MustCompile(`^(\d+)\s+(\d+):(\d+):(\d+)$`), 192 "HOUR_MICROSECOND": regexp.MustCompile(`^(\d+):(\d+):(\d+).(\d+)$`), 193 "HOUR_SECOND": regexp.MustCompile(`^(\d+):(\d+):(\d+)$`), 194 "HOUR_MINUTE": regexp.MustCompile(`^(\d+):(\d+)$`), 195 "MINUTE_MICROSECOND": regexp.MustCompile(`^(\d+):(\d+).(\d+)$`), 196 "MINUTE_SECOND": regexp.MustCompile(`^(\d+):(\d+)$`), 197 "SECOND_MICROSECOND": regexp.MustCompile(`^(\d+).(\d+)$`), 198 "YEAR_MONTH": regexp.MustCompile(`^(\d+)-(\d+)$`), 199 } 200 201 func textFormatParts(text string, r *regexp.Regexp) []int64 { 202 parts := r.FindStringSubmatch(text) 203 var result []int64 204 for _, p := range parts[1:] { 205 // It is safe to ignore the error here, because at this point we know 206 // the string matches the regexp, and that means it can't be an 207 // invalid number. 208 n, _ := strconv.ParseInt(p, 10, 64) 209 result = append(result, n) 210 } 211 return result 212 } 213 214 // TimeDelta is the difference between a time and another time. 215 type TimeDelta struct { 216 Years int64 217 Months int64 218 Days int64 219 Hours int64 220 Minutes int64 221 Seconds int64 222 Microseconds int64 223 } 224 225 // Add returns the given time plus the time delta. 226 func (td TimeDelta) Add(t time.Time) time.Time { 227 return td.apply(t, 1) 228 } 229 230 // Sub returns the given time minus the time delta. 231 func (td TimeDelta) Sub(t time.Time) time.Time { 232 return td.apply(t, -1) 233 } 234 235 const ( 236 day = 24 * time.Hour 237 week = 7 * day 238 ) 239 240 func (td TimeDelta) apply(t time.Time, sign int64) time.Time { 241 y := int64(t.Year()) 242 mo := int64(t.Month()) 243 d := t.Day() 244 h := t.Hour() 245 min := t.Minute() 246 s := t.Second() 247 ns := t.Nanosecond() 248 249 if td.Years != 0 { 250 y += td.Years * sign 251 } 252 253 if td.Months != 0 { 254 m := mo + td.Months*sign 255 if m < 1 { 256 mo = 12 + (m % 12) 257 y += m/12 - 1 258 } else if m > 12 { 259 mo = m % 12 260 y += m / 12 261 } else { 262 mo = m 263 } 264 265 // Due to the operations done before, month may be zero, which means it's 266 // december. 267 if mo == 0 { 268 mo = 12 269 } 270 } 271 272 if days := daysInMonth(time.Month(mo), int(y)); days < d { 273 d = days 274 } 275 276 date := time.Date(int(y), time.Month(mo), d, h, min, s, ns, t.Location()) 277 278 if td.Days != 0 { 279 date = date.Add(time.Duration(td.Days) * day * time.Duration(sign)) 280 } 281 282 if td.Hours != 0 { 283 date = date.Add(time.Duration(td.Hours) * time.Hour * time.Duration(sign)) 284 } 285 286 if td.Minutes != 0 { 287 date = date.Add(time.Duration(td.Minutes) * time.Minute * time.Duration(sign)) 288 } 289 290 if td.Seconds != 0 { 291 date = date.Add(time.Duration(td.Seconds) * time.Second * time.Duration(sign)) 292 } 293 294 if td.Microseconds != 0 { 295 date = date.Add(time.Duration(td.Microseconds) * time.Microsecond * time.Duration(sign)) 296 } 297 298 return date 299 } 300 301 func daysInMonth(month time.Month, year int) int { 302 if month == time.December { 303 return 31 304 } 305 306 date := time.Date(year, month+time.Month(1), 1, 0, 0, 0, 0, time.Local) 307 return date.Add(-1 * day).Day() 308 }