
     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package sql
    17  import (
    18  	"fmt"
    19  	"regexp"
    20  	"strconv"
    21  	"strings"
    22  	"time"
    24  	""
    26  	gmstime ""
    27  )
    29  const EventDateSpaceTimeFormat = "2006-01-02 15:04:05"
    31  // EventSchedulerStatement represents a SQL statement that requires a EventScheduler
    32  // (e.g. CREATE / ALTER / DROP EVENT and DROP DATABASE).
    33  type EventSchedulerStatement interface {
    34  	Node
    35  	// WithEventScheduler returns a new instance of this EventSchedulerStatement,
    36  	// with the event scheduler notifier configured.
    37  	WithEventScheduler(controller EventScheduler) Node
    38  }
    40  // EventScheduler is an interface used for notifying the EventSchedulerStatus
    41  // for querying any events related statements. This allows plan Nodes to communicate
    42  // to the EventSchedulerStatus.
    43  type EventScheduler interface {
    44  	// AddEvent is called when there is an event created at runtime.
    45  	AddEvent(ctx *Context, edb EventDatabase, event EventDefinition)
    46  	// UpdateEvent is called when there is an event altered at runtime.
    47  	UpdateEvent(ctx *Context, edb EventDatabase, orgEventName string, event EventDefinition)
    48  	// RemoveEvent is called when there is an event dropped at runtime. This function
    49  	// removes the given event if it exists in the enabled events list of the EventSchedulerStatus.
    50  	RemoveEvent(dbName, eventName string)
    51  	// RemoveSchemaEvents is called when there is a database dropped at runtime. This function
    52  	// removes all events of given database that exist in the enabled events list of the EventSchedulerStatus.
    53  	RemoveSchemaEvents(dbName string)
    54  }
    56  // EventDefinition describes a scheduled event.
    57  type EventDefinition struct {
    58  	// The name of this event. Event names in a database are unique.
    59  	Name string
    60  	// The SQL statements to be executed when this event is executed.
    61  	EventBody string
    62  	// The timezone offset the event was created or last altered at.
    63  	TimezoneOffset string
    64  	// The enabled or disabled status of this event.
    65  	Status string
    66  	// The user or account who created this scheduled event.
    67  	Definer string
    68  	// The SQL_MODE in effect when this event was created.
    69  	SqlMode string
    70  	// The time at which the event was created.
    71  	CreatedAt time.Time
    72  	// The time at which the event was last altered.
    73  	LastAltered time.Time
    74  	// The time at which the event was last executed.
    75  	LastExecuted time.Time
    77  	/* Fields parsed from the CREATE EVENT statement */
    78  	Comment              string
    79  	OnCompletionPreserve bool
    80  	HasExecuteAt         bool
    81  	ExecuteAt            time.Time
    82  	ExecuteEvery         string
    83  	Starts               time.Time // STARTS is always defined when EVERY is defined.
    84  	HasEnds              bool
    85  	Ends                 time.Time
    86  }
    88  // ConvertTimesFromUTCToTz returns a new EventDefinition with all its time values converted
    89  // from UTC TZ to the given TZ. This function should only be used when needing to display
    90  // data that includes the time values in string format for such as SHOW EVENTS or
    91  // SHOW CREATE EVENT statements.
    92  func (e *EventDefinition) ConvertTimesFromUTCToTz(tz string) *EventDefinition {
    93  	ne := *e
    94  	if ne.HasExecuteAt {
    95  		t, ok := gmstime.ConvertTimeZone(e.ExecuteAt, "+00:00", tz)
    96  		if ok {
    97  			ne.ExecuteAt = t
    98  		}
    99  	} else {
   100  		t, ok := gmstime.ConvertTimeZone(e.Starts, "+00:00", tz)
   101  		if ok {
   102  			ne.Starts = t
   103  		}
   104  		if ne.HasEnds {
   105  			t, ok = gmstime.ConvertTimeZone(e.Ends, "+00:00", tz)
   106  			if ok {
   107  				ne.Ends = t
   108  			}
   109  		}
   110  	}
   112  	t, ok := gmstime.ConvertTimeZone(e.CreatedAt, "+00:00", tz)
   113  	if ok {
   114  		ne.CreatedAt = t
   115  	}
   116  	t, ok = gmstime.ConvertTimeZone(e.LastAltered, "+00:00", tz)
   117  	if ok {
   118  		ne.LastAltered = t
   119  	}
   120  	t, ok = gmstime.ConvertTimeZone(e.LastExecuted, "+00:00", tz)
   121  	if ok {
   122  		ne.LastExecuted = t
   123  	}
   124  	return &ne
   125  }
   127  // GetNextExecutionTime returns the next execution time for the event, which depends on AT
   128  // or EVERY field of EventDefinition. It also returns whether the event is expired.
   129  func (e *EventDefinition) GetNextExecutionTime(curTime time.Time) (time.Time, bool, error) {
   130  	if e.HasExecuteAt {
   131  		return e.ExecuteAt, e.ExecuteAt.Sub(curTime).Seconds() <= -1, nil
   132  	} else {
   133  		timeDur, err := getTimeDurationFromEveryInterval(e.ExecuteEvery)
   134  		if err != nil {
   135  			return time.Time{}, true, err
   136  		}
   137  		// check for last executed, if not set, get the next time by incrementing the start time by interval
   138  		// use 'last executed' time if the event was executed before; otherwise, use 'starts' time
   139  		startTime := e.Starts
   140  		if !e.LastExecuted.IsZero() && e.LastExecuted.Sub(e.Starts).Seconds() > 0 {
   141  			startTime = e.LastExecuted
   142  		}
   144  		// if startTime > curTime, then event hasn't executed yet, so execute at startTime
   145  		if startTime.Sub(curTime).Seconds() > 0 {
   146  			return startTime, false, nil
   147  		}
   148  		// if endTime is defined and endTime < curTime, then event is ended
   149  		if e.HasEnds && e.Ends.Sub(curTime).Seconds() < 0 {
   150  			return time.Time{}, true, nil
   151  		}
   153  		diffToNext := (int64(curTime.Sub(startTime).Seconds()/timeDur.Seconds()) + 1) * int64(timeDur.Seconds())
   154  		nextTime := startTime.Add(time.Duration(diffToNext) * time.Second)
   155  		// sanity check
   156  		for nextTime.Sub(curTime).Seconds() < 0 {
   157  			nextTime = nextTime.Add(timeDur)
   158  		}
   159  		// if the next execution time is past the endTime, then the event is expired.
   160  		if e.HasEnds && e.Ends.Sub(nextTime).Seconds() < 0 {
   161  			return time.Time{}, true, nil
   162  		}
   163  		return nextTime, false, nil
   164  	}
   165  }
   167  // CreateEventStatement returns a CREATE EVENT statement for this event.
   168  func (e *EventDefinition) CreateEventStatement() string {
   169  	stmt := "CREATE"
   170  	if e.Definer != "" {
   171  		stmt = fmt.Sprintf("%s DEFINER = %s", stmt, e.Definer)
   172  	}
   173  	stmt = fmt.Sprintf("%s EVENT `%s`", stmt, e.Name)
   175  	if e.HasExecuteAt {
   176  		stmt = fmt.Sprintf("%s ON SCHEDULE AT '%s'", stmt, e.ExecuteAt.Format(EventDateSpaceTimeFormat))
   177  	} else {
   178  		// STARTS should be NOT null regardless of user definition
   179  		stmt = fmt.Sprintf("%s ON SCHEDULE EVERY %s STARTS '%s'", stmt, e.ExecuteEvery, e.Starts.Format(EventDateSpaceTimeFormat))
   180  		if e.HasEnds {
   181  			stmt = fmt.Sprintf("%s ENDS '%s'", stmt, e.Ends.Format(EventDateSpaceTimeFormat))
   182  		}
   183  	}
   185  	if e.OnCompletionPreserve {
   186  		stmt = fmt.Sprintf("%s ON COMPLETION PRESERVE", stmt)
   187  	} else {
   188  		stmt = fmt.Sprintf("%s ON COMPLETION NOT PRESERVE", stmt)
   189  	}
   191  	stmt = fmt.Sprintf("%s %s", stmt, e.Status)
   193  	if e.Comment != "" {
   194  		stmt = fmt.Sprintf("%s COMMENT '%s'", stmt, e.Comment)
   195  	}
   197  	return fmt.Sprintf("%s DO %s", stmt, e.EventBody)
   198  }
   200  // getTimeDurationFromEveryInterval returns time.Duration converting the given EVERY interval.
   201  func getTimeDurationFromEveryInterval(every string) (time.Duration, error) {
   202  	everyInterval, err := EventOnScheduleEveryIntervalFromString(every)
   203  	if err != nil {
   204  		return 0, err
   205  	}
   206  	hours := everyInterval.Years*8766 + everyInterval.Months*730 + everyInterval.Days*24 + everyInterval.Hours
   207  	timeDur := time.Duration(hours)*time.Hour + time.Duration(everyInterval.Minutes)*time.Minute + time.Duration(everyInterval.Seconds)*time.Second
   209  	return timeDur, nil
   210  }
   212  // EventStatus represents an event status that is defined for an event.
   213  type EventStatus byte
   215  const (
   216  	EventStatus_Enable EventStatus = iota
   217  	EventStatus_Disable
   218  	EventStatus_DisableOnSlave
   219  )
   221  // String returns the original SQL representation.
   222  func (e EventStatus) String() string {
   223  	switch e {
   224  	case EventStatus_Enable:
   225  		return "ENABLE"
   226  	case EventStatus_Disable:
   227  		return "DISABLE"
   228  	case EventStatus_DisableOnSlave:
   229  		return "DISABLE ON SLAVE"
   230  	default:
   231  		panic(fmt.Errorf("invalid event status value `%d`", byte(e)))
   232  	}
   233  }
   235  // EventStatusFromString returns EventStatus based on the given string value.
   236  // This function is used in Dolt to get EventStatus value for the EventDefinition.
   237  func EventStatusFromString(status string) (EventStatus, error) {
   238  	switch strings.ToLower(status) {
   239  	case "enable":
   240  		return EventStatus_Enable, nil
   241  	case "disable":
   242  		return EventStatus_Disable, nil
   243  	case "disable on slave":
   244  		return EventStatus_DisableOnSlave, nil
   245  	default:
   246  		// use disable as default to be safe
   247  		return EventStatus_Disable, fmt.Errorf("invalid event status value: `%s`", status)
   248  	}
   249  }
   251  // EventOnScheduleEveryInterval is used to store ON SCHEDULE EVERY clause's interval definition.
   252  // It is equivalent of expression.TimeDelta without microseconds field.
   253  type EventOnScheduleEveryInterval struct {
   254  	Years   int64
   255  	Months  int64
   256  	Days    int64
   257  	Hours   int64
   258  	Minutes int64
   259  	Seconds int64
   260  }
   262  func NewEveryInterval(y, mo, d, h, mi, s int64) *EventOnScheduleEveryInterval {
   263  	return &EventOnScheduleEveryInterval{
   264  		Years:   y,
   265  		Months:  mo,
   266  		Days:    d,
   267  		Hours:   h,
   268  		Minutes: mi,
   269  		Seconds: s,
   270  	}
   271  }
   273  // GetIntervalValAndField returns ON SCHEDULE EVERY clause's interval value and field type in string format
   274  // (e.g. returns "'1:2'" and "MONTH_DAY" for 1 month and 2 day or returns "4" and "HOUR" for 4 hour intervals).
   275  func (e *EventOnScheduleEveryInterval) GetIntervalValAndField() (string, string) {
   276  	if e == nil {
   277  		return "", ""
   278  	}
   280  	var val, field []string
   281  	if e.Years != 0 {
   282  		val = append(val, fmt.Sprintf("%v", e.Years))
   283  		field = append(field, "YEAR")
   284  	}
   285  	if e.Months != 0 {
   286  		val = append(val, fmt.Sprintf("%v", e.Months))
   287  		field = append(field, "MONTH")
   288  	}
   289  	if e.Days != 0 {
   290  		val = append(val, fmt.Sprintf("%v", e.Days))
   291  		field = append(field, "DAY")
   292  	}
   293  	if e.Hours != 0 {
   294  		val = append(val, fmt.Sprintf("%v", e.Hours))
   295  		field = append(field, "HOUR")
   296  	}
   297  	if e.Minutes != 0 {
   298  		val = append(val, fmt.Sprintf("%v", e.Minutes))
   299  		field = append(field, "MINUTE")
   300  	}
   301  	if e.Seconds != 0 {
   302  		val = append(val, fmt.Sprintf("%v", e.Seconds))
   303  		field = append(field, "SECOND")
   304  	}
   306  	if len(val) == 0 {
   307  		return "", ""
   308  	} else if len(val) == 1 {
   309  		return val[0], field[0]
   310  	}
   312  	return fmt.Sprintf("'%s'", strings.Join(val, ":")), strings.Join(field, "_")
   313  }
   315  // EventOnScheduleEveryIntervalFromString returns *EventOnScheduleEveryInterval parsing given interval string
   316  // such as `2 DAY` or `'1:2' MONTH_DAY`. This function is used in Dolt to construct EventOnScheduleEveryInterval value
   317  // for the EventDefinition.
   318  func EventOnScheduleEveryIntervalFromString(every string) (*EventOnScheduleEveryInterval, error) {
   319  	errCannotParseEveryInterval := fmt.Errorf("cannot parse ON SCHEDULE EVERY interval: `%s`", every)
   320  	strs := strings.Split(every, " ")
   321  	if len(strs) != 2 {
   322  		return nil, errCannotParseEveryInterval
   323  	}
   324  	intervalVal := strs[0]
   325  	intervalField := strs[1]
   327  	intervalVal = strings.TrimSuffix(strings.TrimPrefix(intervalVal, "'"), "'")
   328  	iVals := strings.Split(intervalVal, ":")
   329  	iFields := strings.Split(intervalField, "_")
   331  	if len(iVals) != len(iFields) {
   332  		return nil, errCannotParseEveryInterval
   333  	}
   335  	var interval = &EventOnScheduleEveryInterval{}
   336  	for i, val := range iVals {
   337  		n, err := strconv.ParseInt(val, 10, 64)
   338  		if err != nil {
   339  			return nil, errCannotParseEveryInterval
   340  		}
   341  		switch iFields[i] {
   342  		case "YEAR":
   343  			interval.Years = n
   344  		case "MONTH":
   345  			interval.Months = n
   346  		case "DAY":
   347  			interval.Days = n
   348  		case "HOUR":
   349  			interval.Hours = n
   350  		case "MINUTE":
   351  			interval.Minutes = n
   352  		case "SECOND":
   353  			interval.Seconds = n
   354  		default:
   355  			return nil, errCannotParseEveryInterval
   356  		}
   357  	}
   359  	return interval, nil
   360  }
   362  // -------------------------
   363  //  Events datetime parsing
   364  // -------------------------
   366  var ErrIncorrectValue = errors.NewKind("Incorrect %s value: '%s'")
   367  var dateRegex = regexp.MustCompile(`(?m)^(\d{1,4})-(\d{1,2})-(\d{1,2})(.*)$`)
   368  var timeRegex = regexp.MustCompile(`(?m)^([ T])?(\d{1,2})?(:)?(\d{1,2})?(:)?(\d{1,2})?(\.)?(\d{1,6})?(.*)$`)
   369  var tzRegex = regexp.MustCompile(`(?m)^([+\-])(\d{2}):(\d{2})$`)
   371  // GetTimeValueFromStringInput returns time.Time in system timezone (SYSTEM = time.Now().Location()).
   372  // evaluating valid MySQL datetime and timestamp formats.
   373  func GetTimeValueFromStringInput(field, t string) (time.Time, error) {
   374  	// TODO: the time value should be in session timezone rather than system timezone.
   375  	sessTz := gmstime.SystemTimezoneOffset()
   377  	// For MySQL datetime format, it accepts any valid date format
   378  	// and tries parsing time part first and timezone part if time part is valid.
   379  	// Otherwise, any invalid time or timezone part is truncated and gives warning.
   380  	// TODO: It seems like we should be able to reuse the timestamp parsing logic from Datetime.Convert.
   381  	//       Do we need to reimplement this here?
   382  	dt := strings.Split(t, "-")
   383  	if len(dt) > 1 {
   384  		var year, month, day, hour, minute, second int
   385  		var timePart, tzPart string
   386  		var ok bool
   387  		var inputTz = sessTz
   388  		// FIRST try to get date part
   389  		year, month, day, timePart, ok = getDatePart(t)
   390  		if !ok {
   391  			return time.Time{}, ErrIncorrectValue.New(field, t)
   392  		}
   393  		// Then time part
   394  		if timePart != "" {
   395  			hour, minute, second, tzPart, ok = getTimePart(timePart)
   396  			if !ok {
   397  				return time.Time{}, ErrIncorrectValue.New(field, t)
   398  			}
   399  		}
   400  		// Then timezone part
   401  		if tzPart != "" {
   402  			if tzPart[0] != '+' && tzPart[0] != '-' {
   403  				// TODO: warning: Truncated incorrect datetime value: '...'
   404  			} else {
   405  				inputTz, ok = getTimezonePart(tzPart)
   406  				if !ok {
   407  					return time.Time{}, ErrIncorrectValue.New(field, t)
   408  				}
   409  			}
   410  		}
   412  		datetimeVal := fmt.Sprintf("%4d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
   413  		tVal, err := time.Parse(EventDateSpaceTimeFormat, datetimeVal)
   414  		if err != nil {
   415  			return time.Time{}, fmt.Errorf("invalid time zone: %s", sessTz)
   416  		}
   418  		// convert the time value to the session timezone for display and storage
   419  		tVal, ok = gmstime.ConvertTimeZone(tVal, inputTz, sessTz)
   420  		if !ok {
   421  			return time.Time{}, fmt.Errorf("invalid time zone: %s", sessTz)
   422  		}
   423  		return tVal, nil
   424  	} else {
   425  		// TODO: support timestamp input parsing (e.g. 2023526...)
   426  		return time.Time{}, fmt.Errorf("timestamp input parsing not supported yet")
   427  	}
   428  }
   430  func getDatePart(s string) (int, int, int, string, bool) {
   431  	matches := dateRegex.FindStringSubmatch(s)
   432  	if matches == nil || len(matches) != 5 {
   433  		return 0, 0, 0, "", false
   434  	}
   436  	year, ok := validateYear(getInt(matches[1]))
   437  	return year, getInt(matches[2]), getInt(matches[3]), matches[4], ok
   438  }
   440  func getTimePart(t string) (int, int, int, string, bool) {
   441  	var hour, minute, second int
   442  	matches := timeRegex.FindStringSubmatch(t)
   443  	if matches == nil || len(matches) != 10 {
   444  		return 0, 0, 0, "", false
   445  	}
   446  	hour = getInt(matches[2])
   447  	if matches[3] == "" {
   448  		return hour, minute, second, "", true
   449  	} else if matches[3] != ":" {
   450  		return 0, 0, 0, "", false
   451  	}
   452  	minute = getInt(matches[4])
   453  	if matches[5] == "" {
   454  		return hour, minute, second, "", true
   455  	} else if matches[5] != ":" {
   456  		return 0, 0, 0, "", false
   457  	}
   458  	second = getInt(matches[6])
   459  	// microsecond with dot in front of it is not needed for now
   460  	//if matches[7] != "." {
   461  	//	return 0, 0, 0, "", false
   462  	//}
   463  	//microsecond := matches[8]
   464  	return hour, minute, second, matches[9], true
   465  }
   467  func getTimezonePart(tz string) (string, bool) {
   468  	matches := tzRegex.FindStringSubmatch(tz)
   469  	if len(matches) == 4 {
   470  		symbol := matches[1]
   471  		hours := matches[2]
   472  		mins := matches[3]
   473  		return fmt.Sprintf("%s%s:%s", symbol, hours, mins), true
   474  	} else {
   475  		return "", false
   476  	}
   477  }
   479  func getInt(s string) int {
   480  	i, err := strconv.Atoi(s)
   481  	if err != nil {
   482  		return 0
   483  	}
   484  	return i
   485  }
   487  func validateYear(i int) (int, bool) {
   488  	if i >= 0 && i <= 69 {
   489  		return i + 2000, true
   490  	} else if i >= 70 && i <= 99 {
   491  		return i + 1900, true
   492  	} else if i >= 1901 && i < 2155 {
   493  		return i, true
   494  	}
   495  	return 0, false
   496  }