github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/logic.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/types"
    22  )
    23  
    24  // And checks whether two expressions are true.
    25  type And struct {
    26  	BinaryExpressionStub
    27  }
    28  
    29  var _ sql.Expression = (*And)(nil)
    30  var _ sql.CollationCoercible = (*And)(nil)
    31  
    32  // NewAnd creates a new And expression.
    33  func NewAnd(left, right sql.Expression) sql.Expression {
    34  	return &And{BinaryExpressionStub{LeftChild: left, RightChild: right}}
    35  }
    36  
    37  // JoinAnd joins several expressions with And.
    38  func JoinAnd(exprs ...sql.Expression) sql.Expression {
    39  	switch len(exprs) {
    40  	case 0:
    41  		return nil
    42  	case 1:
    43  		return exprs[0]
    44  	default:
    45  		if exprs[0] == nil {
    46  			return JoinAnd(exprs[1:]...)
    47  		}
    48  		result := NewAnd(exprs[0], exprs[1])
    49  		for _, e := range exprs[2:] {
    50  			if e != nil {
    51  				result = NewAnd(result, e)
    52  			}
    53  		}
    54  		return result
    55  	}
    56  }
    57  
    58  // SplitConjunction breaks AND expressions into their left and right parts, recursively
    59  func SplitConjunction(expr sql.Expression) []sql.Expression {
    60  	if expr == nil {
    61  		return nil
    62  	}
    63  	and, ok := expr.(*And)
    64  	if !ok {
    65  		return []sql.Expression{expr}
    66  	}
    67  
    68  	return append(
    69  		SplitConjunction(and.LeftChild),
    70  		SplitConjunction(and.RightChild)...,
    71  	)
    72  }
    73  
    74  // SplitDisjunction breaks OR expressions into their left and right parts, recursively
    75  func SplitDisjunction(expr sql.Expression) []sql.Expression {
    76  	if expr == nil {
    77  		return nil
    78  	}
    79  	and, ok := expr.(*Or)
    80  	if !ok {
    81  		return []sql.Expression{expr}
    82  	}
    83  
    84  	return append(
    85  		SplitDisjunction(and.LeftChild),
    86  		SplitDisjunction(and.RightChild)...,
    87  	)
    88  }
    89  
    90  func (a *And) String() string {
    91  	return fmt.Sprintf("(%s AND %s)", a.LeftChild, a.RightChild)
    92  }
    93  
    94  func (a *And) DebugString() string {
    95  	pr := sql.NewTreePrinter()
    96  	_ = pr.WriteNode("AND")
    97  	children := []string{sql.DebugString(a.LeftChild), sql.DebugString(a.RightChild)}
    98  	_ = pr.WriteChildren(children...)
    99  	return pr.String()
   100  }
   101  
   102  // Type implements the Expression interface.
   103  func (*And) Type() sql.Type {
   104  	return types.Boolean
   105  }
   106  
   107  // CollationCoercibility implements the interface sql.CollationCoercible.
   108  func (*And) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   109  	return sql.Collation_binary, 5
   110  }
   111  
   112  // Eval implements the Expression interface.
   113  func (a *And) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   114  	lval, err := a.LeftChild.Eval(ctx, row)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	if lval != nil {
   119  		lvalBool, err := sql.ConvertToBool(ctx, lval)
   120  		if err == nil && lvalBool == false {
   121  			return false, nil
   122  		}
   123  	}
   124  
   125  	rval, err := a.RightChild.Eval(ctx, row)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	if rval != nil {
   130  		rvalBool, err := sql.ConvertToBool(ctx, rval)
   131  		if err == nil && rvalBool == false {
   132  			return false, nil
   133  		}
   134  	}
   135  
   136  	if lval == nil || rval == nil {
   137  		return nil, nil
   138  	}
   139  
   140  	return true, nil
   141  }
   142  
   143  // WithChildren implements the Expression interface.
   144  func (a *And) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   145  	if len(children) != 2 {
   146  		return nil, sql.ErrInvalidChildrenNumber.New(a, len(children), 2)
   147  	}
   148  	return NewAnd(children[0], children[1]), nil
   149  }
   150  
   151  // Or checks whether one of the two given expressions is true.
   152  type Or struct {
   153  	BinaryExpressionStub
   154  }
   155  
   156  var _ sql.Expression = (*Or)(nil)
   157  var _ sql.CollationCoercible = (*Or)(nil)
   158  
   159  // NewOr creates a new Or expression.
   160  func NewOr(left, right sql.Expression) sql.Expression {
   161  	return &Or{BinaryExpressionStub{LeftChild: left, RightChild: right}}
   162  }
   163  
   164  // JoinOr joins several expressions with Or.
   165  func JoinOr(exprs ...sql.Expression) sql.Expression {
   166  	switch len(exprs) {
   167  	case 0:
   168  		return nil
   169  	case 1:
   170  		return exprs[0]
   171  	default:
   172  		if exprs[0] == nil {
   173  			return JoinOr(exprs[1:]...)
   174  		}
   175  		result := NewOr(exprs[0], exprs[1])
   176  		for _, e := range exprs[2:] {
   177  			if e != nil {
   178  				result = NewOr(result, e)
   179  			}
   180  		}
   181  		return result
   182  	}
   183  }
   184  
   185  func (o *Or) String() string {
   186  	return fmt.Sprintf("(%s OR %s)", o.LeftChild, o.RightChild)
   187  }
   188  
   189  func (o *Or) DebugString() string {
   190  	pr := sql.NewTreePrinter()
   191  	_ = pr.WriteNode("Or")
   192  	children := []string{sql.DebugString(o.LeftChild), sql.DebugString(o.RightChild)}
   193  	_ = pr.WriteChildren(children...)
   194  	return pr.String()
   195  }
   196  
   197  // Type implements the Expression interface.
   198  func (*Or) Type() sql.Type {
   199  	return types.Boolean
   200  }
   201  
   202  // CollationCoercibility implements the interface sql.CollationCoercible.
   203  func (*Or) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   204  	return sql.Collation_binary, 5
   205  }
   206  
   207  // Eval implements the Expression interface.
   208  func (o *Or) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   209  	lval, err := o.LeftChild.Eval(ctx, row)
   210  	if err != nil {
   211  		return nil, err
   212  	}
   213  	if lval != nil {
   214  		lval, err = sql.ConvertToBool(ctx, lval)
   215  		if err == nil && lval.(bool) {
   216  			return true, nil
   217  		}
   218  	}
   219  
   220  	rval, err := o.RightChild.Eval(ctx, row)
   221  	if err != nil {
   222  		return nil, err
   223  	}
   224  	if rval != nil {
   225  		rval, err = sql.ConvertToBool(ctx, rval)
   226  		if err == nil && rval.(bool) {
   227  			return true, nil
   228  		}
   229  	}
   230  
   231  	// Can also be triggered by lval and rval not being bool types.
   232  	if lval == false && rval == false {
   233  		return false, nil
   234  	}
   235  
   236  	// (lval == nil && rval == nil) || (lval == false && rval == nil) || (lval == nil && rval == false)
   237  	return nil, nil
   238  }
   239  
   240  // WithChildren implements the Expression interface.
   241  func (o *Or) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   242  	if len(children) != 2 {
   243  		return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 2)
   244  	}
   245  	return NewOr(children[0], children[1]), nil
   246  }
   247  
   248  // Xor checks whether only one of the two given expressions is true.
   249  type Xor struct {
   250  	BinaryExpressionStub
   251  }
   252  
   253  var _ sql.Expression = (*Xor)(nil)
   254  var _ sql.CollationCoercible = (*Xor)(nil)
   255  
   256  // NewXor creates a new Xor expression.
   257  func NewXor(left, right sql.Expression) sql.Expression {
   258  	return &Xor{BinaryExpressionStub{LeftChild: left, RightChild: right}}
   259  }
   260  
   261  func (x *Xor) String() string {
   262  	return fmt.Sprintf("(%s XOR %s)", x.LeftChild, x.RightChild)
   263  }
   264  
   265  func (x *Xor) DebugString() string {
   266  	return fmt.Sprintf("%s XOR %s", sql.DebugString(x.LeftChild), sql.DebugString(x.RightChild))
   267  }
   268  
   269  // Type implements the Expression interface.
   270  func (*Xor) Type() sql.Type {
   271  	return types.Boolean
   272  }
   273  
   274  // CollationCoercibility implements the interface sql.CollationCoercible.
   275  func (*Xor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   276  	return sql.Collation_binary, 5
   277  }
   278  
   279  // Eval implements the Expression interface.
   280  func (x *Xor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   281  	lval, err := x.LeftChild.Eval(ctx, row)
   282  	if err != nil {
   283  		return nil, err
   284  	}
   285  	if lval == nil {
   286  		return nil, nil
   287  	}
   288  	lvalue, err := sql.ConvertToBool(ctx, lval)
   289  	if err != nil {
   290  		return nil, err
   291  	}
   292  
   293  	rval, err := x.RightChild.Eval(ctx, row)
   294  	if err != nil {
   295  		return nil, err
   296  	}
   297  	if rval == nil {
   298  		return nil, nil
   299  	}
   300  	rvalue, err := sql.ConvertToBool(ctx, rval)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  
   305  	// a XOR b == (a AND (NOT b)) OR ((NOT a) and b)
   306  	if (rvalue && !lvalue) || (!rvalue && lvalue) {
   307  		return true, nil
   308  	}
   309  	return false, nil
   310  }
   311  
   312  // WithChildren implements the Expression interface.
   313  func (x *Xor) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   314  	if len(children) != 2 {
   315  		return nil, sql.ErrInvalidChildrenNumber.New(x, len(children), 2)
   316  	}
   317  	return NewXor(children[0], children[1]), nil
   318  }