github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/case.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "bytes" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 "github.com/dolthub/go-mysql-server/sql/types" 22 ) 23 24 // CaseBranch is a single branch of a case expression. 25 type CaseBranch struct { 26 Cond sql.Expression 27 Value sql.Expression 28 } 29 30 // Case is an expression that returns the value of one of its branches when a 31 // condition is met. 32 type Case struct { 33 Expr sql.Expression 34 Branches []CaseBranch 35 Else sql.Expression 36 } 37 38 var _ sql.Expression = (*Case)(nil) 39 var _ sql.CollationCoercible = (*Case)(nil) 40 41 // NewCase returns an new Case expression. 42 func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression) *Case { 43 return &Case{expr, branches, elseExpr} 44 } 45 46 // From the description of operator typing here: 47 // https://dev.mysql.com/doc/refman/8.0/en/flow-control-functions.html#operator_case 48 func combinedCaseBranchType(left, right sql.Type) sql.Type { 49 if left == types.Null { 50 return right 51 } 52 if right == types.Null { 53 return left 54 } 55 if types.IsTextOnly(left) && types.IsTextOnly(right) { 56 return types.LongText 57 } 58 if types.IsTextBlob(left) && types.IsTextBlob(right) { 59 return types.LongBlob 60 } 61 if types.IsTime(left) && types.IsTime(right) { 62 if left == right { 63 return left 64 } 65 return types.DatetimeMaxPrecision 66 } 67 if types.IsNumber(left) && types.IsNumber(right) { 68 if left == types.Float64 || right == types.Float64 { 69 return types.Float64 70 } 71 if left == types.Float32 || right == types.Float32 { 72 return types.Float32 73 } 74 if types.IsDecimal(left) || types.IsDecimal(right) { 75 return types.MustCreateDecimalType(65, 10) 76 } 77 if left == types.Uint64 && types.IsSigned(right) || 78 right == types.Uint64 && types.IsSigned(left) { 79 return types.MustCreateDecimalType(65, 10) 80 } 81 if !types.IsSigned(left) && !types.IsSigned(right) { 82 return types.Uint64 83 } else { 84 return types.Int64 85 } 86 } 87 if types.IsJSON(left) && types.IsJSON(right) { 88 return types.JSON 89 } 90 return types.LongText 91 } 92 93 // Type implements the sql.Expression interface. 94 func (c *Case) Type() sql.Type { 95 curr := types.Null 96 for _, b := range c.Branches { 97 curr = combinedCaseBranchType(curr, b.Value.Type()) 98 } 99 if c.Else != nil { 100 curr = combinedCaseBranchType(curr, c.Else.Type()) 101 } 102 return curr 103 } 104 105 // CollationCoercibility implements the interface sql.CollationCoercible. 106 func (c *Case) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 107 // This should be calculated during the expression's evaluation, but that's not possible with the 108 // current abstraction 109 return c.Type().CollationCoercibility(ctx) 110 } 111 112 // IsNullable implements the sql.Expression interface. 113 func (c *Case) IsNullable() bool { 114 for _, b := range c.Branches { 115 if b.Value.IsNullable() { 116 return true 117 } 118 } 119 120 return c.Else == nil || c.Else.IsNullable() 121 } 122 123 // Resolved implements the sql.Expression interface. 124 func (c *Case) Resolved() bool { 125 if (c.Expr != nil && !c.Expr.Resolved()) || 126 (c.Else != nil && !c.Else.Resolved()) { 127 return false 128 } 129 130 for _, b := range c.Branches { 131 if !b.Cond.Resolved() || !b.Value.Resolved() { 132 return false 133 } 134 } 135 136 return true 137 } 138 139 // Children implements the sql.Expression interface. 140 func (c *Case) Children() []sql.Expression { 141 var children []sql.Expression 142 143 if c.Expr != nil { 144 children = append(children, c.Expr) 145 } 146 147 for _, b := range c.Branches { 148 children = append(children, b.Cond, b.Value) 149 } 150 151 if c.Else != nil { 152 children = append(children, c.Else) 153 } 154 155 return children 156 } 157 158 // Eval implements the sql.Expression interface. 159 func (c *Case) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 160 span, ctx := ctx.Span("expression.Case") 161 defer span.End() 162 163 t := c.Type() 164 165 for _, b := range c.Branches { 166 var cond sql.Expression 167 if c.Expr != nil { 168 cond = NewEquals(c.Expr, b.Cond) 169 } else { 170 cond = b.Cond 171 } 172 173 res, err := sql.EvaluateCondition(ctx, cond, row) 174 if err != nil { 175 return nil, err 176 } 177 178 if sql.IsTrue(res) { 179 bval, err := b.Value.Eval(ctx, row) 180 if err != nil { 181 return nil, err 182 } 183 // When unable to convert to the type of the case, return the original value 184 // A common error here is "Out of bounds value for decimal type" 185 if ret, _, err := t.Convert(bval); err == nil { 186 return ret, nil 187 } 188 return bval, nil 189 } 190 } 191 192 if c.Else != nil { 193 val, err := c.Else.Eval(ctx, row) 194 if err != nil { 195 return nil, err 196 } 197 // When unable to convert to the type of the case, return the original value 198 // A common error here is "Out of bounds value for decimal type" 199 if ret, _, err := t.Convert(val); err == nil { 200 return ret, nil 201 } 202 return val, nil 203 204 } 205 206 return nil, nil 207 } 208 209 // WithChildren implements the Expression interface. 210 func (c *Case) WithChildren(children ...sql.Expression) (sql.Expression, error) { 211 var expected = len(c.Branches) * 2 212 if c.Expr != nil { 213 expected++ 214 } 215 216 if c.Else != nil { 217 expected++ 218 } 219 220 if len(children) != expected { 221 return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), expected) 222 } 223 224 var expr, elseExpr sql.Expression 225 if c.Expr != nil { 226 expr = children[0] 227 children = children[1:] 228 } 229 230 if c.Else != nil { 231 elseExpr = children[len(children)-1] 232 children = children[:len(children)-1] 233 } 234 235 var branches []CaseBranch 236 for i := 0; i < len(children); i += 2 { 237 branches = append(branches, CaseBranch{ 238 Cond: children[i], 239 Value: children[i+1], 240 }) 241 } 242 243 return NewCase(expr, branches, elseExpr), nil 244 } 245 246 func (c *Case) String() string { 247 var buf bytes.Buffer 248 249 buf.WriteString("CASE ") 250 if c.Expr != nil { 251 buf.WriteString(c.Expr.String()) 252 } 253 254 for _, b := range c.Branches { 255 buf.WriteString(" WHEN ") 256 buf.WriteString(b.Cond.String()) 257 buf.WriteString(" THEN ") 258 buf.WriteString(b.Value.String()) 259 } 260 261 if c.Else != nil { 262 buf.WriteString(" ELSE ") 263 buf.WriteString(c.Else.String()) 264 } 265 266 buf.WriteString(" END") 267 return buf.String() 268 } 269 270 func (c *Case) DebugString() string { 271 var buf bytes.Buffer 272 273 buf.WriteString("CASE ") 274 if c.Expr != nil { 275 buf.WriteString(sql.DebugString(c.Expr)) 276 } 277 278 for _, b := range c.Branches { 279 buf.WriteString(" WHEN ") 280 buf.WriteString(sql.DebugString(b.Cond)) 281 buf.WriteString(" THEN ") 282 buf.WriteString(sql.DebugString(b.Value)) 283 } 284 285 if c.Else != nil { 286 buf.WriteString(" ELSE ") 287 buf.WriteString(sql.DebugString(c.Else)) 288 } 289 290 buf.WriteString(" END") 291 return buf.String() 292 }