github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/expreval/literal_helpers.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expreval
    16  
    17  import (
    18  	"strconv"
    19  	"time"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/expression"
    23  
    24  	"github.com/dolthub/dolt/go/store/types"
    25  )
    26  
    27  func literalAsInt64(literal *expression.Literal) (int64, error) {
    28  	v := literal.Value()
    29  	switch typedVal := v.(type) {
    30  	case bool:
    31  		if typedVal {
    32  			return 1, nil
    33  		} else {
    34  			return 0, nil
    35  		}
    36  	case int:
    37  		return int64(typedVal), nil
    38  	case int8:
    39  		return int64(typedVal), nil
    40  	case int16:
    41  		return int64(typedVal), nil
    42  	case int32:
    43  		return int64(typedVal), nil
    44  	case int64:
    45  		return typedVal, nil
    46  	case uint:
    47  		return int64(typedVal), nil
    48  	case uint8:
    49  		return int64(typedVal), nil
    50  	case uint16:
    51  		return int64(typedVal), nil
    52  	case uint32:
    53  		return int64(typedVal), nil
    54  	case uint64:
    55  		if typedVal&0x8000000000000000 != 0 {
    56  			return 0, errInvalidConversion.New(literal.String(), "uint64", "int64")
    57  		}
    58  
    59  		return int64(typedVal), nil
    60  	case float64:
    61  		i64 := int64(typedVal)
    62  		if i64 == int64(typedVal+0.9999) {
    63  			return i64, nil
    64  		} else {
    65  			return 0, errInvalidConversion.New(literal.String(), "float64", "int64")
    66  		}
    67  	case float32:
    68  		i64 := int64(typedVal)
    69  		if i64 == int64(typedVal+0.9999) {
    70  			return i64, nil
    71  		} else {
    72  			return 0, errInvalidConversion.New(literal.String(), "float32", "int64")
    73  		}
    74  	case string:
    75  		return strconv.ParseInt(typedVal, 10, 64)
    76  	}
    77  
    78  	return 0, errInvalidConversion.New(literal.String(), literal.Type().String(), "int64")
    79  }
    80  
    81  func literalAsUint64(literal *expression.Literal) (uint64, error) {
    82  	v := literal.Value()
    83  	switch typedVal := v.(type) {
    84  	case bool:
    85  		if typedVal {
    86  			return 1, nil
    87  		} else {
    88  			return 0, nil
    89  		}
    90  	case int:
    91  		if typedVal < 0 {
    92  			return 0, errInvalidConversion.New(literal.String(), "int", "uint64")
    93  		}
    94  
    95  		return uint64(typedVal), nil
    96  	case int8:
    97  		if typedVal < 0 {
    98  			return 0, errInvalidConversion.New(literal.String(), "int8", "uint64")
    99  		}
   100  
   101  		return uint64(typedVal), nil
   102  	case int16:
   103  		if typedVal < 0 {
   104  			return 0, errInvalidConversion.New(literal.String(), "int16", "uint64")
   105  		}
   106  
   107  		return uint64(typedVal), nil
   108  	case int32:
   109  		if typedVal < 0 {
   110  			return 0, errInvalidConversion.New(literal.String(), "int32", "uint64")
   111  		}
   112  
   113  		return uint64(typedVal), nil
   114  	case int64:
   115  		if typedVal < 0 {
   116  			return 0, errInvalidConversion.New(literal.String(), "int64", "uint64")
   117  		}
   118  
   119  		return uint64(typedVal), nil
   120  	case uint:
   121  		return uint64(typedVal), nil
   122  	case uint8:
   123  		return uint64(typedVal), nil
   124  	case uint16:
   125  		return uint64(typedVal), nil
   126  	case uint32:
   127  		return uint64(typedVal), nil
   128  	case uint64:
   129  		return typedVal, nil
   130  	case float64:
   131  		if typedVal < 0 {
   132  			return 0, errInvalidConversion.New(literal.String(), "float64", "uint64")
   133  		}
   134  
   135  		u64 := uint64(typedVal)
   136  		if u64 == uint64(typedVal+0.9999) {
   137  			return u64, nil
   138  		} else {
   139  			return 0, errInvalidConversion.New(literal.String(), "float64", "uint64")
   140  		}
   141  	case float32:
   142  		u64 := uint64(typedVal)
   143  		if u64 == uint64(typedVal+0.9999) {
   144  			return u64, nil
   145  		} else {
   146  			return 0, errInvalidConversion.New(literal.String(), "float32", "uint64")
   147  		}
   148  	case string:
   149  		return strconv.ParseUint(typedVal, 10, 64)
   150  	}
   151  
   152  	return 0, errInvalidConversion.New(literal.String(), literal.Type().String(), "int64")
   153  }
   154  
   155  func literalAsFloat64(literal *expression.Literal) (float64, error) {
   156  	v := literal.Value()
   157  	switch typedVal := v.(type) {
   158  	case int:
   159  		return float64(typedVal), nil
   160  	case int8:
   161  		return float64(typedVal), nil
   162  	case int16:
   163  		return float64(typedVal), nil
   164  	case int32:
   165  		return float64(typedVal), nil
   166  	case int64:
   167  		return float64(typedVal), nil
   168  	case uint:
   169  		return float64(typedVal), nil
   170  	case uint8:
   171  		return float64(typedVal), nil
   172  	case uint16:
   173  		return float64(typedVal), nil
   174  	case uint32:
   175  		return float64(typedVal), nil
   176  	case uint64:
   177  		return float64(typedVal), nil
   178  	case float64:
   179  		return typedVal, nil
   180  	case float32:
   181  		return float64(typedVal), nil
   182  	case string:
   183  		return strconv.ParseFloat(typedVal, 64)
   184  	}
   185  
   186  	return 0, errInvalidConversion.New(literal.String(), literal.Type().String(), "float64")
   187  }
   188  
   189  func literalAsBool(literal *expression.Literal) (bool, error) {
   190  	v := literal.Value()
   191  	switch typedVal := v.(type) {
   192  	case bool:
   193  		return typedVal, nil
   194  	case string:
   195  		b, err := strconv.ParseBool(typedVal)
   196  
   197  		if err == nil {
   198  			return b, nil
   199  		}
   200  
   201  		return false, errInvalidConversion.New(literal.String(), literal.Type().String(), "bool")
   202  	case int:
   203  		return typedVal != 0, nil
   204  	case int8:
   205  		return typedVal != 0, nil
   206  	case int16:
   207  		return typedVal != 0, nil
   208  	case int32:
   209  		return typedVal != 0, nil
   210  	case int64:
   211  		return typedVal != 0, nil
   212  	case uint:
   213  		return typedVal != 0, nil
   214  	case uint8:
   215  		return typedVal != 0, nil
   216  	case uint16:
   217  		return typedVal != 0, nil
   218  	case uint32:
   219  		return typedVal != 0, nil
   220  	case uint64:
   221  		return typedVal != 0, nil
   222  	}
   223  
   224  	return false, errInvalidConversion.New(literal.String(), literal.Type().String(), "bool")
   225  }
   226  
   227  func literalAsString(literal *expression.Literal) (string, error) {
   228  	v := literal.Value()
   229  	switch typedVal := v.(type) {
   230  	case string:
   231  		return typedVal, nil
   232  	case int, int8, int16, int32, int64:
   233  		i64, _ := literalAsInt64(literal)
   234  		return strconv.FormatInt(i64, 10), nil
   235  	case uint, uint8, uint16, uint32, uint64:
   236  		u64, _ := literalAsUint64(literal)
   237  		return strconv.FormatUint(u64, 10), nil
   238  	case float32, float64:
   239  		f64, _ := literalAsFloat64(literal)
   240  		return strconv.FormatFloat(f64, 'f', -1, 64), nil
   241  	case bool:
   242  		return strconv.FormatBool(typedVal), nil
   243  	}
   244  
   245  	return "", errInvalidConversion.New(literal.String(), literal.Type().String(), "bool")
   246  }
   247  
   248  func parseDate(s string) (time.Time, error) {
   249  	for _, layout := range sql.TimestampDatetimeLayouts {
   250  		res, err := time.Parse(layout, s)
   251  
   252  		if err == nil {
   253  			return res, nil
   254  		}
   255  	}
   256  
   257  	return time.Time{}, sql.ErrConvertingToTime.New(s)
   258  }
   259  
   260  func literalAsTimestamp(literal *expression.Literal) (time.Time, error) {
   261  	v := literal.Value()
   262  	switch typedVal := v.(type) {
   263  	case string:
   264  		ts, err := parseDate(typedVal)
   265  
   266  		if err != nil {
   267  			return time.Time{}, err
   268  		}
   269  
   270  		return ts, nil
   271  	}
   272  
   273  	return time.Time{}, errInvalidConversion.New(literal.String(), literal.Type().String(), "datetime")
   274  }
   275  
   276  // LiteralToNomsValue converts a go-mysql-servel Literal into a noms value.
   277  func LiteralToNomsValue(kind types.NomsKind, literal *expression.Literal) (types.Value, error) {
   278  	switch kind {
   279  	case types.IntKind:
   280  		i64, err := literalAsInt64(literal)
   281  
   282  		if err != nil {
   283  			return nil, err
   284  		}
   285  
   286  		return types.Int(i64), nil
   287  
   288  	case types.UintKind:
   289  		u64, err := literalAsUint64(literal)
   290  
   291  		if err != nil {
   292  			return nil, err
   293  		}
   294  
   295  		return types.Uint(u64), nil
   296  
   297  	case types.FloatKind:
   298  		f64, err := literalAsFloat64(literal)
   299  
   300  		if err != nil {
   301  			return nil, err
   302  		}
   303  
   304  		return types.Float(f64), nil
   305  
   306  	case types.BoolKind:
   307  		b, err := literalAsBool(literal)
   308  
   309  		if err != nil {
   310  			return nil, err
   311  		}
   312  
   313  		return types.Bool(b), err
   314  
   315  	case types.StringKind:
   316  		s, err := literalAsString(literal)
   317  
   318  		if err != nil {
   319  			return nil, err
   320  		}
   321  
   322  		return types.String(s), nil
   323  
   324  	case types.TimestampKind:
   325  		ts, err := literalAsTimestamp(literal)
   326  
   327  		if err != nil {
   328  			return nil, err
   329  		}
   330  
   331  		return types.Timestamp(ts), nil
   332  	}
   333  
   334  	return nil, errInvalidConversion.New(literal.String(), literal.Type().String(), kind.String())
   335  }