github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/conv.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"strconv"
    21  	"strings"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  // Conv function converts numbers between different number bases. Returns a string representation of the number N, converted from base from_base to base to_base.
    28  type Conv struct {
    29  	n        sql.Expression
    30  	fromBase sql.Expression
    31  	toBase   sql.Expression
    32  }
    33  
    34  var _ sql.FunctionExpression = (*Conv)(nil)
    35  var _ sql.CollationCoercible = (*Conv)(nil)
    36  
    37  // NewConv returns a new Conv expression.
    38  func NewConv(n, from, to sql.Expression) sql.Expression {
    39  	return &Conv{n, from, to}
    40  }
    41  
    42  // FunctionName implements sql.FunctionExpression
    43  func (c *Conv) FunctionName() string {
    44  	return "conv"
    45  }
    46  
    47  // Description implements sql.FunctionExpression
    48  func (c *Conv) Description() string {
    49  	return "returns a string representation of the number N, converted from base from_base to base to_base."
    50  }
    51  
    52  // Type implements the Expression interface.
    53  func (c *Conv) Type() sql.Type { return types.LongText }
    54  
    55  // CollationCoercibility implements the interface sql.CollationCoercible.
    56  func (*Conv) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    57  	return ctx.GetCollation(), 4
    58  }
    59  
    60  // IsNullable implements the Expression interface.
    61  func (c *Conv) IsNullable() bool {
    62  	return c.n.IsNullable() || c.fromBase.IsNullable() || c.toBase.IsNullable()
    63  }
    64  
    65  func (c *Conv) String() string {
    66  	return fmt.Sprintf("%s(%s,%s,%s)", c.FunctionName(), c.n, c.fromBase, c.toBase)
    67  }
    68  
    69  // Eval implements the Expression interface.
    70  func (c *Conv) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    71  	n, err := c.n.Eval(ctx, row)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	if n == nil {
    76  		return nil, nil
    77  	}
    78  
    79  	from, err := c.fromBase.Eval(ctx, row)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	if from == nil {
    84  		return nil, nil
    85  	}
    86  
    87  	to, err := c.toBase.Eval(ctx, row)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	if to == nil {
    92  		return nil, nil
    93  	}
    94  
    95  	n, _, err = types.LongText.Convert(n)
    96  	if err != nil {
    97  		return nil, nil
    98  	}
    99  
   100  	// valConvertedFrom is unsigned if n input is positive, signed if negative.
   101  	valConvertedFrom := convertFromBase(n.(string), from)
   102  	switch valConvertedFrom {
   103  	case nil:
   104  		return nil, nil
   105  	case 0:
   106  		return "0", nil
   107  	}
   108  
   109  	result := convertToBase(valConvertedFrom, to)
   110  	if result == "" {
   111  		return nil, nil
   112  	}
   113  
   114  	return strings.ToUpper(result), nil
   115  }
   116  
   117  // Resolved implements the Expression interface.
   118  func (c *Conv) Resolved() bool {
   119  	return c.n.Resolved() && c.fromBase.Resolved() && c.toBase.Resolved()
   120  }
   121  
   122  // Children implements the Expression interface.
   123  func (c *Conv) Children() []sql.Expression {
   124  	return []sql.Expression{c.n, c.fromBase, c.toBase}
   125  }
   126  
   127  // WithChildren implements the Expression interface.
   128  func (c *Conv) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   129  	if len(children) != 3 {
   130  		return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 3)
   131  	}
   132  	return NewConv(children[0], children[1], children[2]), nil
   133  }
   134  
   135  // convertFromBase returns nil if fromBase input is invalid, 0 if nVal input is invalid and converted result if nVal and fromBase inputs are valid.
   136  // This conversion truncates nVal as its first subpart that is convertable.
   137  // nVal is treated as unsigned except nVal is negative.
   138  func convertFromBase(nVal string, fromBase interface{}) interface{} {
   139  	fromBase, _, err := types.Int64.Convert(fromBase)
   140  	if err != nil {
   141  		return nil
   142  	}
   143  
   144  	fromVal := int(math.Abs(float64(fromBase.(int64))))
   145  	if fromVal < 2 || fromVal > 36 {
   146  		return nil
   147  	}
   148  
   149  	negative := false
   150  	var upper string
   151  	var lower string
   152  	if nVal[0] == '-' {
   153  		negative = true
   154  		nVal = nVal[1:]
   155  	} else if nVal[0] == '+' {
   156  		nVal = nVal[1:]
   157  	}
   158  
   159  	// check for upper and lower bound for given fromBase
   160  	if negative {
   161  		upper = strconv.FormatInt(math.MaxInt64, fromVal)
   162  		lower = strconv.FormatInt(math.MinInt64, fromVal)
   163  		if len(nVal) > len(lower) {
   164  			nVal = lower
   165  		} else if len(nVal) > len(upper) {
   166  			nVal = upper
   167  		}
   168  	} else {
   169  		upper = strconv.FormatUint(math.MaxUint64, fromVal)
   170  		lower = "0"
   171  		if len(nVal) < len(lower) {
   172  			nVal = lower
   173  		} else if len(nVal) > len(upper) {
   174  			nVal = upper
   175  		}
   176  	}
   177  
   178  	truncate := false
   179  	result := uint64(0)
   180  	i := 1
   181  	for !truncate && i <= len(nVal) {
   182  		val, err := strconv.ParseUint(nVal[:i], fromVal, 64)
   183  		if err != nil {
   184  			truncate = true
   185  			return result
   186  		}
   187  		result = val
   188  		i++
   189  	}
   190  
   191  	if negative {
   192  		return int64(result) * -1
   193  	}
   194  
   195  	return result
   196  }
   197  
   198  // convertToBase returns result of whole CONV function in string format, empty string if to input is invalid.
   199  // The sign of toBase decides whether result is formatted as signed or unsigned.
   200  func convertToBase(val interface{}, toBase interface{}) string {
   201  	toBase, _, err := types.Int64.Convert(toBase)
   202  	if err != nil {
   203  		return ""
   204  	}
   205  
   206  	toVal := int(math.Abs(float64(toBase.(int64))))
   207  	if toVal < 2 || toVal > 36 {
   208  		return ""
   209  	}
   210  
   211  	var result string
   212  	switch v := val.(type) {
   213  	case int64:
   214  		if toBase.(int64) < 0 {
   215  			result = strconv.FormatInt(v, toVal)
   216  			if err != nil {
   217  				return ""
   218  			}
   219  		} else {
   220  			result = strconv.FormatUint(uint64(v), toVal)
   221  			if err != nil {
   222  				return ""
   223  			}
   224  		}
   225  	case uint64:
   226  		if toBase.(int64) < 0 {
   227  			result = strconv.FormatInt(int64(v), toVal)
   228  			if err != nil {
   229  				return ""
   230  			}
   231  		} else {
   232  			result = strconv.FormatUint(v, toVal)
   233  			if err != nil {
   234  				return ""
   235  			}
   236  		}
   237  	}
   238  
   239  	return result
   240  }