github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/case.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"bytes"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/types"
    22  )
    23  
    24  // CaseBranch is a single branch of a case expression.
    25  type CaseBranch struct {
    26  	Cond  sql.Expression
    27  	Value sql.Expression
    28  }
    29  
    30  // Case is an expression that returns the value of one of its branches when a
    31  // condition is met.
    32  type Case struct {
    33  	Expr     sql.Expression
    34  	Branches []CaseBranch
    35  	Else     sql.Expression
    36  }
    37  
    38  var _ sql.Expression = (*Case)(nil)
    39  var _ sql.CollationCoercible = (*Case)(nil)
    40  
    41  // NewCase returns an new Case expression.
    42  func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression) *Case {
    43  	return &Case{expr, branches, elseExpr}
    44  }
    45  
    46  // From the description of operator typing here:
    47  // https://dev.mysql.com/doc/refman/8.0/en/flow-control-functions.html#operator_case
    48  func combinedCaseBranchType(left, right sql.Type) sql.Type {
    49  	if left == types.Null {
    50  		return right
    51  	}
    52  	if right == types.Null {
    53  		return left
    54  	}
    55  	if types.IsTextOnly(left) && types.IsTextOnly(right) {
    56  		return types.LongText
    57  	}
    58  	if types.IsTextBlob(left) && types.IsTextBlob(right) {
    59  		return types.LongBlob
    60  	}
    61  	if types.IsTime(left) && types.IsTime(right) {
    62  		if left == right {
    63  			return left
    64  		}
    65  		return types.DatetimeMaxPrecision
    66  	}
    67  	if types.IsNumber(left) && types.IsNumber(right) {
    68  		if left == types.Float64 || right == types.Float64 {
    69  			return types.Float64
    70  		}
    71  		if left == types.Float32 || right == types.Float32 {
    72  			return types.Float32
    73  		}
    74  		if types.IsDecimal(left) || types.IsDecimal(right) {
    75  			return types.MustCreateDecimalType(65, 10)
    76  		}
    77  		if left == types.Uint64 && types.IsSigned(right) ||
    78  			right == types.Uint64 && types.IsSigned(left) {
    79  			return types.MustCreateDecimalType(65, 10)
    80  		}
    81  		if !types.IsSigned(left) && !types.IsSigned(right) {
    82  			return types.Uint64
    83  		} else {
    84  			return types.Int64
    85  		}
    86  	}
    87  	if types.IsJSON(left) && types.IsJSON(right) {
    88  		return types.JSON
    89  	}
    90  	return types.LongText
    91  }
    92  
    93  // Type implements the sql.Expression interface.
    94  func (c *Case) Type() sql.Type {
    95  	curr := types.Null
    96  	for _, b := range c.Branches {
    97  		curr = combinedCaseBranchType(curr, b.Value.Type())
    98  	}
    99  	if c.Else != nil {
   100  		curr = combinedCaseBranchType(curr, c.Else.Type())
   101  	}
   102  	return curr
   103  }
   104  
   105  // CollationCoercibility implements the interface sql.CollationCoercible.
   106  func (c *Case) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   107  	// This should be calculated during the expression's evaluation, but that's not possible with the
   108  	// current abstraction
   109  	return c.Type().CollationCoercibility(ctx)
   110  }
   111  
   112  // IsNullable implements the sql.Expression interface.
   113  func (c *Case) IsNullable() bool {
   114  	for _, b := range c.Branches {
   115  		if b.Value.IsNullable() {
   116  			return true
   117  		}
   118  	}
   119  
   120  	return c.Else == nil || c.Else.IsNullable()
   121  }
   122  
   123  // Resolved implements the sql.Expression interface.
   124  func (c *Case) Resolved() bool {
   125  	if (c.Expr != nil && !c.Expr.Resolved()) ||
   126  		(c.Else != nil && !c.Else.Resolved()) {
   127  		return false
   128  	}
   129  
   130  	for _, b := range c.Branches {
   131  		if !b.Cond.Resolved() || !b.Value.Resolved() {
   132  			return false
   133  		}
   134  	}
   135  
   136  	return true
   137  }
   138  
   139  // Children implements the sql.Expression interface.
   140  func (c *Case) Children() []sql.Expression {
   141  	var children []sql.Expression
   142  
   143  	if c.Expr != nil {
   144  		children = append(children, c.Expr)
   145  	}
   146  
   147  	for _, b := range c.Branches {
   148  		children = append(children, b.Cond, b.Value)
   149  	}
   150  
   151  	if c.Else != nil {
   152  		children = append(children, c.Else)
   153  	}
   154  
   155  	return children
   156  }
   157  
   158  // Eval implements the sql.Expression interface.
   159  func (c *Case) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   160  	span, ctx := ctx.Span("expression.Case")
   161  	defer span.End()
   162  
   163  	t := c.Type()
   164  
   165  	for _, b := range c.Branches {
   166  		var cond sql.Expression
   167  		if c.Expr != nil {
   168  			cond = NewEquals(c.Expr, b.Cond)
   169  		} else {
   170  			cond = b.Cond
   171  		}
   172  
   173  		res, err := sql.EvaluateCondition(ctx, cond, row)
   174  		if err != nil {
   175  			return nil, err
   176  		}
   177  
   178  		if sql.IsTrue(res) {
   179  			bval, err := b.Value.Eval(ctx, row)
   180  			if err != nil {
   181  				return nil, err
   182  			}
   183  			// When unable to convert to the type of the case, return the original value
   184  			// A common error here is "Out of bounds value for decimal type"
   185  			if ret, _, err := t.Convert(bval); err == nil {
   186  				return ret, nil
   187  			}
   188  			return bval, nil
   189  		}
   190  	}
   191  
   192  	if c.Else != nil {
   193  		val, err := c.Else.Eval(ctx, row)
   194  		if err != nil {
   195  			return nil, err
   196  		}
   197  		// When unable to convert to the type of the case, return the original value
   198  		// A common error here is "Out of bounds value for decimal type"
   199  		if ret, _, err := t.Convert(val); err == nil {
   200  			return ret, nil
   201  		}
   202  		return val, nil
   203  
   204  	}
   205  
   206  	return nil, nil
   207  }
   208  
   209  // WithChildren implements the Expression interface.
   210  func (c *Case) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   211  	var expected = len(c.Branches) * 2
   212  	if c.Expr != nil {
   213  		expected++
   214  	}
   215  
   216  	if c.Else != nil {
   217  		expected++
   218  	}
   219  
   220  	if len(children) != expected {
   221  		return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), expected)
   222  	}
   223  
   224  	var expr, elseExpr sql.Expression
   225  	if c.Expr != nil {
   226  		expr = children[0]
   227  		children = children[1:]
   228  	}
   229  
   230  	if c.Else != nil {
   231  		elseExpr = children[len(children)-1]
   232  		children = children[:len(children)-1]
   233  	}
   234  
   235  	var branches []CaseBranch
   236  	for i := 0; i < len(children); i += 2 {
   237  		branches = append(branches, CaseBranch{
   238  			Cond:  children[i],
   239  			Value: children[i+1],
   240  		})
   241  	}
   242  
   243  	return NewCase(expr, branches, elseExpr), nil
   244  }
   245  
   246  func (c *Case) String() string {
   247  	var buf bytes.Buffer
   248  
   249  	buf.WriteString("CASE ")
   250  	if c.Expr != nil {
   251  		buf.WriteString(c.Expr.String())
   252  	}
   253  
   254  	for _, b := range c.Branches {
   255  		buf.WriteString(" WHEN ")
   256  		buf.WriteString(b.Cond.String())
   257  		buf.WriteString(" THEN ")
   258  		buf.WriteString(b.Value.String())
   259  	}
   260  
   261  	if c.Else != nil {
   262  		buf.WriteString(" ELSE ")
   263  		buf.WriteString(c.Else.String())
   264  	}
   265  
   266  	buf.WriteString(" END")
   267  	return buf.String()
   268  }
   269  
   270  func (c *Case) DebugString() string {
   271  	var buf bytes.Buffer
   272  
   273  	buf.WriteString("CASE ")
   274  	if c.Expr != nil {
   275  		buf.WriteString(sql.DebugString(c.Expr))
   276  	}
   277  
   278  	for _, b := range c.Branches {
   279  		buf.WriteString(" WHEN ")
   280  		buf.WriteString(sql.DebugString(b.Cond))
   281  		buf.WriteString(" THEN ")
   282  		buf.WriteString(sql.DebugString(b.Value))
   283  	}
   284  
   285  	if c.Else != nil {
   286  		buf.WriteString(" ELSE ")
   287  		buf.WriteString(sql.DebugString(c.Else))
   288  	}
   289  
   290  	buf.WriteString(" END")
   291  	return buf.String()
   292  }