github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/case.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  )
    22  
    23  // CaseStatement represents CASE statements, which are different from CASE expressions. These are intended for use in
    24  // triggers and stored procedures. Specifically, this implements CASE statements when comparing each conditional to a
    25  // value. The version of CASE that does not compare each conditional to a value is functionally equivalent to a series
    26  // of IF/ELSE statements, and therefore we simply use an IfElseBlock.
    27  type CaseStatement struct {
    28  	Expr   sql.Expression
    29  	IfElse *IfElseBlock
    30  }
    31  
    32  var _ sql.Node = (*CaseStatement)(nil)
    33  var _ sql.DebugStringer = (*CaseStatement)(nil)
    34  var _ sql.Expressioner = (*CaseStatement)(nil)
    35  var _ sql.CollationCoercible = (*CaseStatement)(nil)
    36  
    37  // NewCaseStatement creates a new *NewCaseStatement or *IfElseBlock node.
    38  func NewCaseStatement(caseExpr sql.Expression, ifConditionals []*IfConditional, elseStatement sql.Node) sql.Node {
    39  	if elseStatement == nil {
    40  		elseStatement = ElseCaseError{}
    41  	}
    42  	ifElse := &IfElseBlock{
    43  		IfConditionals: ifConditionals,
    44  		Else:           elseStatement,
    45  	}
    46  	if caseExpr != nil {
    47  		return &CaseStatement{
    48  			Expr:   caseExpr,
    49  			IfElse: ifElse,
    50  		}
    51  	}
    52  	return ifElse
    53  }
    54  
    55  // Resolved implements the interface sql.Node.
    56  func (c *CaseStatement) Resolved() bool {
    57  	return c.Expr.Resolved() && c.IfElse.Resolved()
    58  }
    59  
    60  func (c *CaseStatement) IsReadOnly() bool {
    61  	return c.IfElse.IsReadOnly()
    62  }
    63  
    64  // String implements the interface sql.Node.
    65  func (c *CaseStatement) String() string {
    66  	p := sql.NewTreePrinter()
    67  	_ = p.WriteNode("CASE %s", c.Expr.String())
    68  	_ = p.WriteChildren(c.IfElse.String())
    69  	return p.String()
    70  }
    71  
    72  // DebugString implements the sql.DebugStringer interface.
    73  func (c *CaseStatement) DebugString() string {
    74  	p := sql.NewTreePrinter()
    75  	_ = p.WriteNode("CASE %s", sql.DebugString(c.Expr))
    76  	_ = p.WriteChildren(sql.DebugString(c.IfElse))
    77  	return p.String()
    78  }
    79  
    80  // Schema implements the interface sql.Node.
    81  func (c *CaseStatement) Schema() sql.Schema {
    82  	return c.IfElse.Schema()
    83  }
    84  
    85  // Children implements the interface sql.Node.
    86  func (c *CaseStatement) Children() []sql.Node {
    87  	return c.IfElse.Children()
    88  }
    89  
    90  // WithChildren implements the interface sql.Node.
    91  func (c *CaseStatement) WithChildren(children ...sql.Node) (sql.Node, error) {
    92  	newIfElseNode, err := c.IfElse.WithChildren(children...)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	newIfElse, ok := newIfElseNode.(*IfElseBlock)
    97  	if !ok {
    98  		return nil, fmt.Errorf("%T: expected child %T but got %T", c, c.IfElse, newIfElseNode)
    99  	}
   100  
   101  	return &CaseStatement{
   102  		Expr:   c.Expr,
   103  		IfElse: newIfElse,
   104  	}, nil
   105  }
   106  
   107  // Expressions implements the interface sql.Node.
   108  func (c *CaseStatement) Expressions() []sql.Expression {
   109  	return []sql.Expression{c.Expr}
   110  }
   111  
   112  // WithExpressions implements the interface sql.Node.
   113  func (c *CaseStatement) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   114  	if len(exprs) != 1 {
   115  		return nil, sql.ErrInvalidChildrenNumber.New(c, len(exprs), 1)
   116  	}
   117  
   118  	return &CaseStatement{
   119  		Expr:   exprs[0],
   120  		IfElse: c.IfElse,
   121  	}, nil
   122  }
   123  
   124  // CheckPrivileges implements the interface sql.Node.
   125  func (c *CaseStatement) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   126  	return c.IfElse.CheckPrivileges(ctx, opChecker)
   127  }
   128  
   129  // CollationCoercibility implements the interface sql.CollationCoercible.
   130  func (c *CaseStatement) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   131  	return c.IfElse.CollationCoercibility(ctx)
   132  }
   133  
   134  type ElseCaseError struct{}
   135  
   136  var _ sql.Node = ElseCaseError{}
   137  
   138  // Resolved implements the interface sql.Node.
   139  func (e ElseCaseError) Resolved() bool {
   140  	return true
   141  }
   142  
   143  func (e ElseCaseError) IsReadOnly() bool {
   144  	return true
   145  }
   146  
   147  // String implements the interface sql.Node.
   148  func (e ElseCaseError) String() string {
   149  	return "ELSE CASE ERROR"
   150  }
   151  
   152  // Schema implements the interface sql.Node.
   153  func (e ElseCaseError) Schema() sql.Schema {
   154  	return nil
   155  }
   156  
   157  // Children implements the interface sql.Node.
   158  func (e ElseCaseError) Children() []sql.Node {
   159  	return nil
   160  }
   161  
   162  // WithChildren implements the interface sql.Node.
   163  func (e ElseCaseError) WithChildren(children ...sql.Node) (sql.Node, error) {
   164  	return NillaryWithChildren(e, children...)
   165  }
   166  
   167  // CheckPrivileges implements the interface sql.Node.
   168  func (e ElseCaseError) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   169  	return true
   170  }
   171  
   172  // CollationCoercibility implements the interface sql.CollationCoercible.
   173  func (e ElseCaseError) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   174  	return sql.Collation_binary, 7
   175  }