github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/format.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"strings"
    21  
    22  	"github.com/shopspring/decimal"
    23  	"golang.org/x/text/language"
    24  	"golang.org/x/text/message"
    25  	"golang.org/x/text/number"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  // Format function returns a result of NumValue rounded to NumDecimalPlaces as a string.
    32  type Format struct {
    33  	NumValue         sql.Expression
    34  	NumDecimalPlaces sql.Expression
    35  	Locale           sql.Expression
    36  }
    37  
    38  var _ sql.FunctionExpression = (*Format)(nil)
    39  var _ sql.CollationCoercible = (*Format)(nil)
    40  
    41  // NewFormat returns a new Format expression.
    42  func NewFormat(args ...sql.Expression) (sql.Expression, error) {
    43  	var numValue, numDecimalPlaces, locale sql.Expression
    44  	switch len(args) {
    45  	case 2:
    46  		numValue = args[0]
    47  		numDecimalPlaces = args[1]
    48  		locale = nil
    49  	case 3:
    50  		numValue = args[0]
    51  		numDecimalPlaces = args[1]
    52  		locale = args[2]
    53  	default:
    54  		return nil, sql.ErrInvalidArgumentNumber.New("FORMAT", "2 or 3", len(args))
    55  	}
    56  	return &Format{numValue, numDecimalPlaces, locale}, nil
    57  }
    58  
    59  // FunctionName implements sql.FunctionExpression
    60  func (f *Format) FunctionName() string {
    61  	return "format"
    62  }
    63  
    64  // Description implements sql.FunctionExpression
    65  func (f *Format) Description() string {
    66  	return "returns a number formatted to specified number of decimal places."
    67  }
    68  
    69  // Type implements the Expression interface.
    70  func (f *Format) Type() sql.Type { return types.LongText }
    71  
    72  // CollationCoercibility implements the interface sql.CollationCoercible.
    73  func (*Format) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    74  	return ctx.GetCollation(), 4
    75  }
    76  
    77  // IsNullable implements the Expression interface.
    78  func (f *Format) IsNullable() bool {
    79  	return f.NumValue.IsNullable() || f.NumDecimalPlaces.IsNullable() || (f.Locale != nil && f.Locale.IsNullable())
    80  }
    81  
    82  func (f *Format) String() string {
    83  	return fmt.Sprintf("%s(%s,%s,%s)", f.FunctionName(), f.NumValue, f.NumDecimalPlaces, f.Locale)
    84  }
    85  
    86  // Eval implements the Expression interface.
    87  func (f *Format) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    88  	numVal, err := f.NumValue.Eval(ctx, row)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	if numVal == nil {
    93  		return nil, nil
    94  	}
    95  
    96  	numDP, err := f.NumDecimalPlaces.Eval(ctx, row)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	if numDP == nil {
   101  		return nil, nil
   102  	}
   103  
   104  	locale := language.English
   105  	if f.Locale != nil {
   106  		loc, lErr := f.Locale.Eval(ctx, row)
   107  		if lErr != nil {
   108  			return nil, lErr
   109  		}
   110  		if loc != nil {
   111  			locale, err = language.Parse(loc.(string))
   112  			if err != nil {
   113  				locale = language.English
   114  			}
   115  		}
   116  	}
   117  
   118  	numVal, _, err = types.Float64.Convert(numVal)
   119  	if err != nil {
   120  		return nil, nil
   121  	}
   122  	numValue := numVal.(float64)
   123  
   124  	numDP, _, err = types.Float64.Convert(numDP)
   125  	if err != nil {
   126  		return nil, nil
   127  	}
   128  	numDecimalPlaces := numDP.(float64)
   129  	numDecimalPlaces = math.Round(numDecimalPlaces)
   130  
   131  	if numDecimalPlaces < 0 {
   132  		numDecimalPlaces = 0
   133  	} else if numDecimalPlaces > 30 { // MySQL cuts off at 30 for larger values
   134  		numDecimalPlaces = 30
   135  	}
   136  
   137  	// One way to round to a decimal place is to shift the number up by the desired decimal position, round to the
   138  	// nearest integer, and then shift back down.
   139  	// For example, we have 5.855 and want to round to 2 decimal places.
   140  	// In this case, numValue = 5.855 and numDecimalPlaces = 2
   141  	// round(numValue * 10^numDecimalPlaces) / 10^numDecimalPlaces
   142  	// round(5.855 * 10^2) / 10^2
   143  	// round(5.855 * 100) / 100
   144  	// round(585.5) / 100
   145  	// 586 / 100
   146  	// 5.86
   147  	//TODO: this can introduce rounding errors that don't show up in MySQL when the decimal places are larger than the input due to precision errors
   148  	roundedValue := math.Round(numValue*math.Pow(10.0, numDecimalPlaces)) / math.Pow(10.0, numDecimalPlaces)
   149  
   150  	// FORMAT(-5.932887e-08, 2);     		==> -0.00
   151  	// FORMAT(-0.00000005932887, 2); 		==> 0.00
   152  	// will return 0.00 for both cases
   153  	var whole int64
   154  	var fractionStr string
   155  	var negative string
   156  	if roundedValue != 0 {
   157  		res := decimal.NewFromFloat(roundedValue)
   158  		whole = res.IntPart()
   159  		if whole == 0 && res.IsNegative() {
   160  			negative = "-"
   161  		}
   162  
   163  		str := res.String()
   164  		dotIdx := strings.Index(str, ".")
   165  		if dotIdx == -1 {
   166  			fractionStr = ""
   167  		} else {
   168  			fractionStr = str[dotIdx+1:]
   169  		}
   170  	}
   171  
   172  	p := message.NewPrinter(locale)
   173  	formattedWhole := p.Sprintf("%v", number.Decimal(whole))
   174  	if numDecimalPlaces == 0 {
   175  		return fmt.Sprintf("%s%s", negative, formattedWhole), nil
   176  	}
   177  
   178  	decimalChar := p.Sprintf("%v", number.Decimal(1.5))
   179  	if len(fractionStr) < int(numDecimalPlaces) {
   180  		rp := int(numDecimalPlaces) - len(fractionStr)
   181  		fractionStr += strings.Repeat("0", rp)
   182  	}
   183  
   184  	result := fmt.Sprintf("%s%s%s%s", negative, formattedWhole, decimalChar[1:2], fractionStr)
   185  	return result, nil
   186  }
   187  
   188  // Resolved implements the Expression interface.
   189  func (f *Format) Resolved() bool {
   190  	if f.Locale == nil {
   191  		return f.NumValue.Resolved() && f.NumDecimalPlaces.Resolved()
   192  	}
   193  	return f.NumValue.Resolved() && f.NumDecimalPlaces.Resolved() && f.Locale.Resolved()
   194  }
   195  
   196  // Children implements the Expression interface.
   197  func (f *Format) Children() []sql.Expression {
   198  	if f.Locale == nil {
   199  		return []sql.Expression{f.NumValue, f.NumDecimalPlaces}
   200  	}
   201  	return []sql.Expression{f.NumValue, f.NumDecimalPlaces, f.Locale}
   202  }
   203  
   204  // WithChildren implements the Expression interface.
   205  func (f *Format) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   206  	if (len(children) == 2 && f.Locale == nil) || (len(children) == 3 && f.Locale != nil) {
   207  		return NewFormat(children...)
   208  	}
   209  	return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2)
   210  }