github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/date_format.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"strconv"
    20  	"time"
    21  
    22  	"github.com/lestrrat-go/strftime"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  func panicIfErr(err error) {
    30  	if err != nil {
    31  		panic(err)
    32  	}
    33  }
    34  
    35  func monthNum(t time.Time) string {
    36  	return strconv.FormatInt(int64(t.Month()), 10)
    37  }
    38  
    39  func dayWithSuffix(t time.Time) string {
    40  	suffix := "th"
    41  	day := int64(t.Day())
    42  	if day < 4 || day > 20 {
    43  		switch day % 10 {
    44  		case 1:
    45  			suffix = "st"
    46  		case 2:
    47  			suffix = "nd"
    48  		case 3:
    49  			suffix = "rd"
    50  		}
    51  	}
    52  
    53  	return strconv.FormatInt(day, 10) + suffix
    54  }
    55  
    56  func dayOfMonth(t time.Time) string {
    57  	return strconv.FormatInt(int64(t.Day()), 10)
    58  }
    59  
    60  func microsecondsStr(t time.Time) string {
    61  	micros := t.Nanosecond() / int(time.Microsecond)
    62  	return fmt.Sprintf("%06d", micros)
    63  }
    64  
    65  func minutesStr(t time.Time) string {
    66  	return fmt.Sprintf("%02d", t.Minute())
    67  }
    68  
    69  func twelveHour(t time.Time) (int, string) {
    70  	ampm := "AM"
    71  	if t.Hour() >= 12 {
    72  		ampm = "PM"
    73  	}
    74  
    75  	hour := t.Hour() % 12
    76  	if hour == 0 {
    77  		hour = 12
    78  	}
    79  
    80  	return hour, ampm
    81  }
    82  
    83  func twelveHourPadded(t time.Time) string {
    84  	hour, _ := twelveHour(t)
    85  	return fmt.Sprintf("%02d", hour)
    86  }
    87  
    88  func twelveHourNoPadding(t time.Time) string {
    89  	hour, _ := twelveHour(t)
    90  	return fmt.Sprintf("%d", hour)
    91  }
    92  
    93  func twentyFourHourNoPadding(t time.Time) string {
    94  	return fmt.Sprintf("%d", t.Hour())
    95  }
    96  
    97  func fullMonthName(t time.Time) string {
    98  	return t.Month().String()
    99  }
   100  
   101  func ampmClockStr(t time.Time) string {
   102  	hour, ampm := twelveHour(t)
   103  	return fmt.Sprintf("%02d:%02d:%02d %s", hour, t.Minute(), t.Second(), ampm)
   104  }
   105  
   106  func secondsStr(t time.Time) string {
   107  	return fmt.Sprintf("%02d", t.Second())
   108  }
   109  
   110  func yearWeek(mode int32, t time.Time) (int32, int32) {
   111  	yw := YearWeek{expression.NewLiteral(t, types.Datetime), expression.NewLiteral(mode, types.Int32)}
   112  	res, _ := yw.Eval(nil, nil)
   113  	yr := res.(int32) / 100
   114  	wk := res.(int32) % 100
   115  
   116  	return yr, wk
   117  }
   118  
   119  func weekMode0(t time.Time) string {
   120  	yr, wk := yearWeek(0, t)
   121  
   122  	if yr < int32(t.Year()) {
   123  		wk = 0
   124  	} else if yr > int32(t.Year()) {
   125  		wk = 53
   126  	}
   127  
   128  	return fmt.Sprintf("%02d", wk)
   129  }
   130  
   131  func weekMode1(t time.Time) string {
   132  	yr, wk := yearWeek(1, t)
   133  
   134  	if yr < int32(t.Year()) {
   135  		wk = 0
   136  	} else if yr > int32(t.Year()) {
   137  		wk = 53
   138  	}
   139  
   140  	return fmt.Sprintf("%02d", wk)
   141  }
   142  
   143  func weekMode2(t time.Time) string {
   144  	_, wk := yearWeek(2, t)
   145  	return fmt.Sprintf("%02d", wk)
   146  }
   147  
   148  func weekMode3(t time.Time) string {
   149  	_, wk := yearWeek(3, t)
   150  	return fmt.Sprintf("%02d", wk)
   151  }
   152  
   153  func yearMode0(t time.Time) string {
   154  	yr, _ := yearWeek(0, t)
   155  	return strconv.FormatInt(int64(yr), 10)
   156  }
   157  
   158  func yearMode1(t time.Time) string {
   159  	yr, _ := yearWeek(1, t)
   160  	return strconv.FormatInt(int64(yr), 10)
   161  }
   162  
   163  func dayName(t time.Time) string {
   164  	return t.Weekday().String()
   165  }
   166  
   167  func yearTwoDigit(t time.Time) string {
   168  	return strconv.FormatInt(int64(t.Year())%100, 10)
   169  }
   170  
   171  type AppendFuncWrapper struct {
   172  	fn func(time.Time) string
   173  }
   174  
   175  func wrap(fn func(time.Time) string) strftime.Appender {
   176  	return AppendFuncWrapper{fn}
   177  }
   178  
   179  func (af AppendFuncWrapper) Append(bytes []byte, t time.Time) []byte {
   180  	s := af.fn(t)
   181  	return append(bytes, []byte(s)...)
   182  }
   183  
   184  var mysqlDateFormatSpec = strftime.NewSpecificationSet()
   185  var dateFormatSpecifierToFunc = map[byte]func(time.Time) string{
   186  	'a': nil,
   187  	'b': nil,
   188  	'c': monthNum,
   189  	'D': dayWithSuffix,
   190  	'd': nil,
   191  	'e': dayOfMonth,
   192  	'f': microsecondsStr,
   193  	'H': nil,
   194  	'h': twelveHourPadded,
   195  	'I': twelveHourPadded,
   196  	'i': minutesStr,
   197  	'j': nil,
   198  	'k': twentyFourHourNoPadding,
   199  	'l': twelveHourNoPadding,
   200  	'M': fullMonthName,
   201  	'm': nil,
   202  	'p': nil,
   203  	'r': ampmClockStr,
   204  	'S': nil,
   205  	's': secondsStr,
   206  	'T': nil,
   207  	'U': weekMode0,
   208  	'u': weekMode1,
   209  	'V': weekMode2,
   210  	'v': weekMode3,
   211  	'W': dayName,
   212  	'w': nil,
   213  	'X': yearMode0,
   214  	'x': yearMode1,
   215  	'Y': nil,
   216  	'y': yearTwoDigit,
   217  }
   218  
   219  func init() {
   220  	for specifier, fn := range dateFormatSpecifierToFunc {
   221  		if fn != nil {
   222  			panicIfErr(mysqlDateFormatSpec.Set(specifier, wrap(fn)))
   223  		}
   224  	}
   225  
   226  	// replace any strftime specifiers that aren't supported
   227  	fn := func(b byte) {
   228  		if _, ok := dateFormatSpecifierToFunc[b]; !ok {
   229  			panicIfErr(mysqlDateFormatSpec.Set(b, wrap(func(time.Time) string {
   230  				return string(b)
   231  			})))
   232  		}
   233  	}
   234  
   235  	capToLower := byte('a' - 'A')
   236  	for i := byte('A'); i <= 'Z'; i++ {
   237  		fn(i)
   238  		fn(i + capToLower)
   239  	}
   240  }
   241  
   242  func formatDate(format string, t time.Time) (string, error) {
   243  	formatter, err := strftime.New(format, strftime.WithSpecificationSet(mysqlDateFormatSpec))
   244  
   245  	if err != nil {
   246  		return "", err
   247  	}
   248  
   249  	return formatter.FormatString(t), nil
   250  }
   251  
   252  // DateFormat function returns a string representation of the date specified in the format specified
   253  type DateFormat struct {
   254  	expression.BinaryExpressionStub
   255  }
   256  
   257  var _ sql.FunctionExpression = (*DateFormat)(nil)
   258  var _ sql.CollationCoercible = (*DateFormat)(nil)
   259  
   260  // FunctionName implements sql.FunctionExpression
   261  func (f *DateFormat) FunctionName() string {
   262  	return "date_format"
   263  }
   264  
   265  // Description implements sql.FunctionExpression
   266  func (f *DateFormat) Description() string {
   267  	return "format date as specified."
   268  }
   269  
   270  // NewDateFormat returns a new DateFormat UDF
   271  func NewDateFormat(ex, value sql.Expression) sql.Expression {
   272  	return &DateFormat{
   273  		expression.BinaryExpressionStub{
   274  			LeftChild:  ex,
   275  			RightChild: value,
   276  		},
   277  	}
   278  }
   279  
   280  // Eval implements the Expression interface.
   281  func (f *DateFormat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   282  	if f.LeftChild == nil || f.RightChild == nil {
   283  		return nil, nil
   284  	}
   285  
   286  	left, err := f.LeftChild.Eval(ctx, row)
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  
   291  	if left == nil {
   292  		return nil, nil
   293  	}
   294  
   295  	timeVal, _, err := types.DatetimeMaxPrecision.Convert(left)
   296  
   297  	if err != nil {
   298  		return nil, err
   299  	}
   300  
   301  	t := timeVal.(time.Time)
   302  
   303  	right, err := f.RightChild.Eval(ctx, row)
   304  	if err != nil {
   305  		return nil, err
   306  	}
   307  
   308  	if right == nil {
   309  		return nil, nil
   310  	}
   311  
   312  	formatStr, ok := right.(string)
   313  
   314  	if !ok {
   315  		return nil, sql.ErrInvalidArgumentDetails.New("DATE_FORMAT", "format must be a string")
   316  	}
   317  
   318  	return formatDate(formatStr, t)
   319  }
   320  
   321  // Type implements the Expression interface.
   322  func (f *DateFormat) Type() sql.Type {
   323  	return types.Text
   324  }
   325  
   326  // CollationCoercibility implements the interface sql.CollationCoercible.
   327  func (*DateFormat) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   328  	return ctx.GetCollation(), 4
   329  }
   330  
   331  // IsNullable implements the Expression interface.
   332  func (f *DateFormat) IsNullable() bool {
   333  	if types.IsNull(f.LeftChild) {
   334  		if types.IsNull(f.RightChild) {
   335  			return true
   336  		}
   337  		return f.RightChild.IsNullable()
   338  	}
   339  	return f.LeftChild.IsNullable()
   340  }
   341  
   342  func (f *DateFormat) String() string {
   343  	return fmt.Sprintf("date_format(%s, %s)", f.LeftChild, f.RightChild)
   344  }
   345  
   346  // WithChildren implements the Expression interface.
   347  func (f *DateFormat) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   348  	if len(children) != 2 {
   349  		return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2)
   350  	}
   351  	return NewDateFormat(children[0], children[1]), nil
   352  }