github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/expreval/literal_helpers.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expreval
    16  
    17  import (
    18  	"strconv"
    19  	"time"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql/expression"
    22  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    23  
    24  	"github.com/dolthub/dolt/go/store/types"
    25  )
    26  
    27  func literalAsInt64(literal *expression.Literal) (int64, error) {
    28  	v := literal.Value()
    29  	switch typedVal := v.(type) {
    30  	case bool:
    31  		if typedVal {
    32  			return 1, nil
    33  		} else {
    34  			return 0, nil
    35  		}
    36  	case int:
    37  		return int64(typedVal), nil
    38  	case int8:
    39  		return int64(typedVal), nil
    40  	case int16:
    41  		return int64(typedVal), nil
    42  	case int32:
    43  		return int64(typedVal), nil
    44  	case int64:
    45  		return typedVal, nil
    46  	case uint:
    47  		return int64(typedVal), nil
    48  	case uint8:
    49  		return int64(typedVal), nil
    50  	case uint16:
    51  		return int64(typedVal), nil
    52  	case uint32:
    53  		return int64(typedVal), nil
    54  	case uint64:
    55  		if typedVal&0x8000000000000000 != 0 {
    56  			return 0, errInvalidConversion.New(literal.String(), "uint64", "int64")
    57  		}
    58  
    59  		return int64(typedVal), nil
    60  	case float64:
    61  		i64 := int64(typedVal)
    62  		if i64 == int64(typedVal+0.9999) {
    63  			return i64, nil
    64  		} else {
    65  			return 0, errInvalidConversion.New(literal.String(), "float64", "int64")
    66  		}
    67  	case float32:
    68  		i64 := int64(typedVal)
    69  		if i64 == int64(typedVal+0.9999) {
    70  			return i64, nil
    71  		} else {
    72  			return 0, errInvalidConversion.New(literal.String(), "float32", "int64")
    73  		}
    74  	case string:
    75  		return strconv.ParseInt(typedVal, 10, 64)
    76  	}
    77  
    78  	return 0, errInvalidConversion.New(literal.String(), literal.Type().String(), "int64")
    79  }
    80  
    81  func literalAsUint64(literal *expression.Literal) (uint64, error) {
    82  	v := literal.Value()
    83  	switch typedVal := v.(type) {
    84  	case bool:
    85  		if typedVal {
    86  			return 1, nil
    87  		} else {
    88  			return 0, nil
    89  		}
    90  	case int:
    91  		if typedVal < 0 {
    92  			return 0, errInvalidConversion.New(literal.String(), "int", "uint64")
    93  		}
    94  
    95  		return uint64(typedVal), nil
    96  	case int8:
    97  		if typedVal < 0 {
    98  			return 0, errInvalidConversion.New(literal.String(), "int8", "uint64")
    99  		}
   100  
   101  		return uint64(typedVal), nil
   102  	case int16:
   103  		if typedVal < 0 {
   104  			return 0, errInvalidConversion.New(literal.String(), "int16", "uint64")
   105  		}
   106  
   107  		return uint64(typedVal), nil
   108  	case int32:
   109  		if typedVal < 0 {
   110  			return 0, errInvalidConversion.New(literal.String(), "int32", "uint64")
   111  		}
   112  
   113  		return uint64(typedVal), nil
   114  	case int64:
   115  		if typedVal < 0 {
   116  			return 0, errInvalidConversion.New(literal.String(), "int64", "uint64")
   117  		}
   118  
   119  		return uint64(typedVal), nil
   120  	case uint:
   121  		return uint64(typedVal), nil
   122  	case uint8:
   123  		return uint64(typedVal), nil
   124  	case uint16:
   125  		return uint64(typedVal), nil
   126  	case uint32:
   127  		return uint64(typedVal), nil
   128  	case uint64:
   129  		return typedVal, nil
   130  	case float64:
   131  		if typedVal < 0 {
   132  			return 0, errInvalidConversion.New(literal.String(), "float64", "uint64")
   133  		}
   134  
   135  		u64 := uint64(typedVal)
   136  		if u64 == uint64(typedVal+0.9999) {
   137  			return u64, nil
   138  		} else {
   139  			return 0, errInvalidConversion.New(literal.String(), "float64", "uint64")
   140  		}
   141  	case float32:
   142  		u64 := uint64(typedVal)
   143  		if u64 == uint64(typedVal+0.9999) {
   144  			return u64, nil
   145  		} else {
   146  			return 0, errInvalidConversion.New(literal.String(), "float32", "uint64")
   147  		}
   148  	case string:
   149  		return strconv.ParseUint(typedVal, 10, 64)
   150  	}
   151  
   152  	return 0, errInvalidConversion.New(literal.String(), literal.Type().String(), "int64")
   153  }
   154  
   155  func literalAsFloat64(literal *expression.Literal) (float64, error) {
   156  	v := literal.Value()
   157  	switch typedVal := v.(type) {
   158  	case int:
   159  		return float64(typedVal), nil
   160  	case int8:
   161  		return float64(typedVal), nil
   162  	case int16:
   163  		return float64(typedVal), nil
   164  	case int32:
   165  		return float64(typedVal), nil
   166  	case int64:
   167  		return float64(typedVal), nil
   168  	case uint:
   169  		return float64(typedVal), nil
   170  	case uint8:
   171  		return float64(typedVal), nil
   172  	case uint16:
   173  		return float64(typedVal), nil
   174  	case uint32:
   175  		return float64(typedVal), nil
   176  	case uint64:
   177  		return float64(typedVal), nil
   178  	case float64:
   179  		return typedVal, nil
   180  	case float32:
   181  		return float64(typedVal), nil
   182  	case string:
   183  		return strconv.ParseFloat(typedVal, 64)
   184  	}
   185  
   186  	return 0, errInvalidConversion.New(literal.String(), literal.Type().String(), "float64")
   187  }
   188  
   189  func literalAsBool(literal *expression.Literal) (bool, error) {
   190  	v := literal.Value()
   191  	switch typedVal := v.(type) {
   192  	case bool:
   193  		return typedVal, nil
   194  	case string:
   195  		b, err := strconv.ParseBool(typedVal)
   196  
   197  		if err == nil {
   198  			return b, nil
   199  		}
   200  
   201  		return false, errInvalidConversion.New(literal.String(), literal.Type().String(), "bool")
   202  	case int:
   203  		return typedVal != 0, nil
   204  	case int8:
   205  		return typedVal != 0, nil
   206  	case int16:
   207  		return typedVal != 0, nil
   208  	case int32:
   209  		return typedVal != 0, nil
   210  	case int64:
   211  		return typedVal != 0, nil
   212  	case uint:
   213  		return typedVal != 0, nil
   214  	case uint8:
   215  		return typedVal != 0, nil
   216  	case uint16:
   217  		return typedVal != 0, nil
   218  	case uint32:
   219  		return typedVal != 0, nil
   220  	case uint64:
   221  		return typedVal != 0, nil
   222  	}
   223  
   224  	return false, errInvalidConversion.New(literal.String(), literal.Type().String(), "bool")
   225  }
   226  
   227  func literalAsString(literal *expression.Literal) (string, error) {
   228  	v := literal.Value()
   229  	switch typedVal := v.(type) {
   230  	case string:
   231  		return typedVal, nil
   232  	case int, int8, int16, int32, int64:
   233  		i64, _ := literalAsInt64(literal)
   234  		return strconv.FormatInt(i64, 10), nil
   235  	case uint, uint8, uint16, uint32, uint64:
   236  		u64, _ := literalAsUint64(literal)
   237  		return strconv.FormatUint(u64, 10), nil
   238  	case float32, float64:
   239  		f64, _ := literalAsFloat64(literal)
   240  		return strconv.FormatFloat(f64, 'f', -1, 64), nil
   241  	case bool:
   242  		return strconv.FormatBool(typedVal), nil
   243  	}
   244  
   245  	return "", errInvalidConversion.New(literal.String(), literal.Type().String(), "bool")
   246  }
   247  
   248  func parseDate(s string) (time.Time, error) {
   249  	for _, layout := range gmstypes.TimestampDatetimeLayouts {
   250  		res, err := time.Parse(layout, s)
   251  
   252  		if err == nil {
   253  			return res, nil
   254  		}
   255  	}
   256  
   257  	return time.Time{}, gmstypes.ErrConvertingToTime.New(s)
   258  }
   259  
   260  func literalAsTimestamp(literal *expression.Literal) (time.Time, error) {
   261  	v := literal.Value()
   262  	switch typedVal := v.(type) {
   263  	case time.Time:
   264  		return typedVal, nil
   265  	case string:
   266  		ts, err := parseDate(typedVal)
   267  
   268  		if err != nil {
   269  			return time.Time{}, err
   270  		}
   271  
   272  		return ts, nil
   273  	}
   274  
   275  	return time.Time{}, errInvalidConversion.New(literal.String(), literal.Type().String(), "datetime")
   276  }
   277  
   278  // LiteralToNomsValue converts a go-mysql-servel Literal into a noms value.
   279  func LiteralToNomsValue(kind types.NomsKind, literal *expression.Literal) (types.Value, error) {
   280  	if literal.Value() == nil {
   281  		return types.NullValue, nil
   282  	}
   283  
   284  	switch kind {
   285  	case types.IntKind:
   286  		i64, err := literalAsInt64(literal)
   287  
   288  		if err != nil {
   289  			return nil, err
   290  		}
   291  
   292  		return types.Int(i64), nil
   293  
   294  	case types.UintKind:
   295  		u64, err := literalAsUint64(literal)
   296  
   297  		if err != nil {
   298  			return nil, err
   299  		}
   300  
   301  		return types.Uint(u64), nil
   302  
   303  	case types.FloatKind:
   304  		f64, err := literalAsFloat64(literal)
   305  
   306  		if err != nil {
   307  			return nil, err
   308  		}
   309  
   310  		return types.Float(f64), nil
   311  
   312  	case types.BoolKind:
   313  		b, err := literalAsBool(literal)
   314  
   315  		if err != nil {
   316  			return nil, err
   317  		}
   318  
   319  		return types.Bool(b), err
   320  
   321  	case types.StringKind:
   322  		s, err := literalAsString(literal)
   323  
   324  		if err != nil {
   325  			return nil, err
   326  		}
   327  
   328  		return types.String(s), nil
   329  
   330  	case types.TimestampKind:
   331  		ts, err := literalAsTimestamp(literal)
   332  
   333  		if err != nil {
   334  			return nil, err
   335  		}
   336  
   337  		return types.Timestamp(ts), nil
   338  	}
   339  
   340  	return nil, errInvalidConversion.New(literal.String(), literal.Type().String(), kind.String())
   341  }