github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/format.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 "math" 20 "strings" 21 22 "github.com/shopspring/decimal" 23 "golang.org/x/text/language" 24 "golang.org/x/text/message" 25 "golang.org/x/text/number" 26 27 "github.com/dolthub/go-mysql-server/sql" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 // Format function returns a result of NumValue rounded to NumDecimalPlaces as a string. 32 type Format struct { 33 NumValue sql.Expression 34 NumDecimalPlaces sql.Expression 35 Locale sql.Expression 36 } 37 38 var _ sql.FunctionExpression = (*Format)(nil) 39 var _ sql.CollationCoercible = (*Format)(nil) 40 41 // NewFormat returns a new Format expression. 42 func NewFormat(args ...sql.Expression) (sql.Expression, error) { 43 var numValue, numDecimalPlaces, locale sql.Expression 44 switch len(args) { 45 case 2: 46 numValue = args[0] 47 numDecimalPlaces = args[1] 48 locale = nil 49 case 3: 50 numValue = args[0] 51 numDecimalPlaces = args[1] 52 locale = args[2] 53 default: 54 return nil, sql.ErrInvalidArgumentNumber.New("FORMAT", "2 or 3", len(args)) 55 } 56 return &Format{numValue, numDecimalPlaces, locale}, nil 57 } 58 59 // FunctionName implements sql.FunctionExpression 60 func (f *Format) FunctionName() string { 61 return "format" 62 } 63 64 // Description implements sql.FunctionExpression 65 func (f *Format) Description() string { 66 return "returns a number formatted to specified number of decimal places." 67 } 68 69 // Type implements the Expression interface. 70 func (f *Format) Type() sql.Type { return types.LongText } 71 72 // CollationCoercibility implements the interface sql.CollationCoercible. 73 func (*Format) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 74 return ctx.GetCollation(), 4 75 } 76 77 // IsNullable implements the Expression interface. 78 func (f *Format) IsNullable() bool { 79 return f.NumValue.IsNullable() || f.NumDecimalPlaces.IsNullable() || (f.Locale != nil && f.Locale.IsNullable()) 80 } 81 82 func (f *Format) String() string { 83 return fmt.Sprintf("%s(%s,%s,%s)", f.FunctionName(), f.NumValue, f.NumDecimalPlaces, f.Locale) 84 } 85 86 // Eval implements the Expression interface. 87 func (f *Format) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 88 numVal, err := f.NumValue.Eval(ctx, row) 89 if err != nil { 90 return nil, err 91 } 92 if numVal == nil { 93 return nil, nil 94 } 95 96 numDP, err := f.NumDecimalPlaces.Eval(ctx, row) 97 if err != nil { 98 return nil, err 99 } 100 if numDP == nil { 101 return nil, nil 102 } 103 104 locale := language.English 105 if f.Locale != nil { 106 loc, lErr := f.Locale.Eval(ctx, row) 107 if lErr != nil { 108 return nil, lErr 109 } 110 if loc != nil { 111 locale, err = language.Parse(loc.(string)) 112 if err != nil { 113 locale = language.English 114 } 115 } 116 } 117 118 numVal, _, err = types.Float64.Convert(numVal) 119 if err != nil { 120 return nil, nil 121 } 122 numValue := numVal.(float64) 123 124 numDP, _, err = types.Float64.Convert(numDP) 125 if err != nil { 126 return nil, nil 127 } 128 numDecimalPlaces := numDP.(float64) 129 numDecimalPlaces = math.Round(numDecimalPlaces) 130 131 if numDecimalPlaces < 0 { 132 numDecimalPlaces = 0 133 } else if numDecimalPlaces > 30 { // MySQL cuts off at 30 for larger values 134 numDecimalPlaces = 30 135 } 136 137 // One way to round to a decimal place is to shift the number up by the desired decimal position, round to the 138 // nearest integer, and then shift back down. 139 // For example, we have 5.855 and want to round to 2 decimal places. 140 // In this case, numValue = 5.855 and numDecimalPlaces = 2 141 // round(numValue * 10^numDecimalPlaces) / 10^numDecimalPlaces 142 // round(5.855 * 10^2) / 10^2 143 // round(5.855 * 100) / 100 144 // round(585.5) / 100 145 // 586 / 100 146 // 5.86 147 //TODO: this can introduce rounding errors that don't show up in MySQL when the decimal places are larger than the input due to precision errors 148 roundedValue := math.Round(numValue*math.Pow(10.0, numDecimalPlaces)) / math.Pow(10.0, numDecimalPlaces) 149 150 // FORMAT(-5.932887e-08, 2); ==> -0.00 151 // FORMAT(-0.00000005932887, 2); ==> 0.00 152 // will return 0.00 for both cases 153 var whole int64 154 var fractionStr string 155 var negative string 156 if roundedValue != 0 { 157 res := decimal.NewFromFloat(roundedValue) 158 whole = res.IntPart() 159 if whole == 0 && res.IsNegative() { 160 negative = "-" 161 } 162 163 str := res.String() 164 dotIdx := strings.Index(str, ".") 165 if dotIdx == -1 { 166 fractionStr = "" 167 } else { 168 fractionStr = str[dotIdx+1:] 169 } 170 } 171 172 p := message.NewPrinter(locale) 173 formattedWhole := p.Sprintf("%v", number.Decimal(whole)) 174 if numDecimalPlaces == 0 { 175 return fmt.Sprintf("%s%s", negative, formattedWhole), nil 176 } 177 178 decimalChar := p.Sprintf("%v", number.Decimal(1.5)) 179 if len(fractionStr) < int(numDecimalPlaces) { 180 rp := int(numDecimalPlaces) - len(fractionStr) 181 fractionStr += strings.Repeat("0", rp) 182 } 183 184 result := fmt.Sprintf("%s%s%s%s", negative, formattedWhole, decimalChar[1:2], fractionStr) 185 return result, nil 186 } 187 188 // Resolved implements the Expression interface. 189 func (f *Format) Resolved() bool { 190 if f.Locale == nil { 191 return f.NumValue.Resolved() && f.NumDecimalPlaces.Resolved() 192 } 193 return f.NumValue.Resolved() && f.NumDecimalPlaces.Resolved() && f.Locale.Resolved() 194 } 195 196 // Children implements the Expression interface. 197 func (f *Format) Children() []sql.Expression { 198 if f.Locale == nil { 199 return []sql.Expression{f.NumValue, f.NumDecimalPlaces} 200 } 201 return []sql.Expression{f.NumValue, f.NumDecimalPlaces, f.Locale} 202 } 203 204 // WithChildren implements the Expression interface. 205 func (f *Format) WithChildren(children ...sql.Expression) (sql.Expression, error) { 206 if (len(children) == 2 && f.Locale == nil) || (len(children) == 3 && f.Locale != nil) { 207 return NewFormat(children...) 208 } 209 return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) 210 }