github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/filter.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"github.com/dolthub/go-mysql-server/sql"
    19  )
    20  
    21  // Filter skips rows that don't match a certain expression.
    22  type Filter struct {
    23  	UnaryNode
    24  	Expression sql.Expression
    25  }
    26  
    27  var _ sql.Node = (*Filter)(nil)
    28  var _ sql.CollationCoercible = (*Filter)(nil)
    29  
    30  // NewFilter creates a new filter node.
    31  func NewFilter(expression sql.Expression, child sql.Node) *Filter {
    32  	return &Filter{
    33  		UnaryNode:  UnaryNode{Child: child},
    34  		Expression: expression,
    35  	}
    36  }
    37  
    38  // Resolved implements the Resolvable interface.
    39  func (f *Filter) Resolved() bool {
    40  	return f.UnaryNode.Child.Resolved() && f.Expression.Resolved()
    41  }
    42  
    43  func (f *Filter) IsReadOnly() bool {
    44  	return f.Child.IsReadOnly()
    45  }
    46  
    47  // WithChildren implements the Node interface.
    48  func (f *Filter) WithChildren(children ...sql.Node) (sql.Node, error) {
    49  	if len(children) != 1 {
    50  		return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 1)
    51  	}
    52  
    53  	return NewFilter(f.Expression, children[0]), nil
    54  }
    55  
    56  // CheckPrivileges implements the interface sql.Node.
    57  func (f *Filter) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
    58  	return f.Child.CheckPrivileges(ctx, opChecker)
    59  }
    60  
    61  // CollationCoercibility implements the interface sql.CollationCoercible.
    62  func (f *Filter) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    63  	return sql.GetCoercibility(ctx, f.UnaryNode.Child)
    64  }
    65  
    66  // WithExpressions implements the Expressioner interface.
    67  func (f *Filter) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
    68  	if len(exprs) != 1 {
    69  		return nil, sql.ErrInvalidChildrenNumber.New(f, len(exprs), 1)
    70  	}
    71  
    72  	return NewFilter(exprs[0], f.Child), nil
    73  }
    74  
    75  // Describe implements the sql.Describable interface
    76  func (f *Filter) Describe(options sql.DescribeOptions) string {
    77  	pr := sql.NewTreePrinter()
    78  	_ = pr.WriteNode("Filter")
    79  	children := []string{sql.Describe(f.Expression, options), sql.Describe(f.Child, options)}
    80  	_ = pr.WriteChildren(children...)
    81  	return pr.String()
    82  }
    83  
    84  // String implements the fmt.Stringer interface
    85  func (f *Filter) String() string {
    86  	return f.Describe(sql.DescribeOptions{
    87  		Analyze:   false,
    88  		Estimates: false,
    89  		Debug:     false,
    90  	})
    91  }
    92  
    93  // DebugString implements the sql.DebugStringer interface
    94  func (f *Filter) DebugString() string {
    95  	return f.Describe(sql.DescribeOptions{
    96  		Analyze:   false,
    97  		Estimates: false,
    98  		Debug:     true,
    99  	})
   100  }
   101  
   102  // Expressions implements the Expressioner interface.
   103  func (f *Filter) Expressions() []sql.Expression {
   104  	return []sql.Expression{f.Expression}
   105  }
   106  
   107  // FilterIter is an iterator that filters another iterator and skips rows that
   108  // don't match the given condition.
   109  type FilterIter struct {
   110  	cond      sql.Expression
   111  	childIter sql.RowIter
   112  }
   113  
   114  // NewFilterIter creates a new FilterIter.
   115  func NewFilterIter(
   116  	cond sql.Expression,
   117  	child sql.RowIter,
   118  ) *FilterIter {
   119  	return &FilterIter{cond: cond, childIter: child}
   120  }
   121  
   122  // Next implements the RowIter interface.
   123  func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) {
   124  	for {
   125  		row, err := i.childIter.Next(ctx)
   126  		if err != nil {
   127  			return nil, err
   128  		}
   129  
   130  		res, err := sql.EvaluateCondition(ctx, i.cond, row)
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  
   135  		if sql.IsTrue(res) {
   136  			return row, nil
   137  		}
   138  	}
   139  }
   140  
   141  // Close implements the RowIter interface.
   142  func (i *FilterIter) Close(ctx *sql.Context) error {
   143  	return i.childIter.Close(ctx)
   144  }