github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/sort.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  )
    23  
    24  // Sort is the sort node.
    25  type Sort struct {
    26  	UnaryNode
    27  	SortFields sql.SortFields
    28  }
    29  
    30  // NewSort creates a new Sort node.
    31  func NewSort(sortFields []sql.SortField, child sql.Node) *Sort {
    32  	return &Sort{
    33  		UnaryNode:  UnaryNode{child},
    34  		SortFields: sortFields,
    35  	}
    36  }
    37  
    38  var _ sql.Expressioner = (*Sort)(nil)
    39  var _ sql.Node = (*Sort)(nil)
    40  var _ sql.CollationCoercible = (*Sort)(nil)
    41  
    42  // Resolved implements the Resolvable interface.
    43  func (s *Sort) Resolved() bool {
    44  	for _, f := range s.SortFields {
    45  		if !f.Column.Resolved() {
    46  			return false
    47  		}
    48  	}
    49  	return s.Child.Resolved()
    50  }
    51  
    52  func (s *Sort) IsReadOnly() bool {
    53  	return s.Child.IsReadOnly()
    54  }
    55  
    56  func (s *Sort) String() string {
    57  	pr := sql.NewTreePrinter()
    58  	var fields = make([]string, len(s.SortFields))
    59  	for i, f := range s.SortFields {
    60  		fields[i] = fmt.Sprintf("%s %s", f.Column, f.Order)
    61  	}
    62  	_ = pr.WriteNode("Sort(%s)", strings.Join(fields, ", "))
    63  	_ = pr.WriteChildren(s.Child.String())
    64  	return pr.String()
    65  }
    66  
    67  func (s *Sort) DebugString() string {
    68  	pr := sql.NewTreePrinter()
    69  	var fields = make([]string, len(s.SortFields))
    70  	for i, f := range s.SortFields {
    71  		fields[i] = sql.DebugString(f)
    72  	}
    73  	_ = pr.WriteNode("Sort(%s)", strings.Join(fields, ", "))
    74  	_ = pr.WriteChildren(sql.DebugString(s.Child))
    75  	return pr.String()
    76  }
    77  
    78  // Expressions implements the Expressioner interface.
    79  func (s *Sort) Expressions() []sql.Expression {
    80  	// TODO: use shared method
    81  	var exprs = make([]sql.Expression, len(s.SortFields))
    82  	for i, f := range s.SortFields {
    83  		exprs[i] = f.Column
    84  	}
    85  	return exprs
    86  }
    87  
    88  // WithChildren implements the Node interface.
    89  func (s *Sort) WithChildren(children ...sql.Node) (sql.Node, error) {
    90  	if len(children) != 1 {
    91  		return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1)
    92  	}
    93  
    94  	return NewSort(s.SortFields, children[0]), nil
    95  }
    96  
    97  // CheckPrivileges implements the interface sql.Node.
    98  func (s *Sort) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
    99  	return s.Child.CheckPrivileges(ctx, opChecker)
   100  }
   101  
   102  // CollationCoercibility implements the interface sql.CollationCoercible.
   103  func (s *Sort) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   104  	return sql.GetCoercibility(ctx, s.Child)
   105  }
   106  
   107  // WithExpressions implements the Expressioner interface.
   108  func (s *Sort) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   109  	if len(exprs) != len(s.SortFields) {
   110  		return nil, sql.ErrInvalidChildrenNumber.New(s, len(exprs), len(s.SortFields))
   111  	}
   112  
   113  	fields := s.SortFields.FromExpressions(exprs...)
   114  	return NewSort(fields, s.Child), nil
   115  }
   116  
   117  // TopN was a sort node that has a limit. It doesn't need to buffer everything,
   118  // but can calculate the top n on the fly.
   119  type TopN struct {
   120  	UnaryNode
   121  	Limit         sql.Expression
   122  	Fields        sql.SortFields
   123  	CalcFoundRows bool
   124  }
   125  
   126  // NewTopN creates a new TopN node.
   127  func NewTopN(fields sql.SortFields, limit sql.Expression, child sql.Node) *TopN {
   128  	return &TopN{
   129  		UnaryNode: UnaryNode{child},
   130  		Limit:     limit,
   131  		Fields:    fields,
   132  	}
   133  }
   134  
   135  var _ sql.Node = (*TopN)(nil)
   136  var _ sql.Expressioner = (*TopN)(nil)
   137  var _ sql.CollationCoercible = (*TopN)(nil)
   138  
   139  // Resolved implements the Resolvable interface.
   140  func (n *TopN) Resolved() bool {
   141  	for _, f := range n.Fields {
   142  		if !f.Column.Resolved() {
   143  			return false
   144  		}
   145  	}
   146  	return n.Child.Resolved()
   147  }
   148  
   149  func (n TopN) WithCalcFoundRows(v bool) *TopN {
   150  	n.CalcFoundRows = v
   151  	return &n
   152  }
   153  
   154  func (n *TopN) IsReadOnly() bool {
   155  	return n.Child.IsReadOnly()
   156  }
   157  
   158  func (n *TopN) String() string {
   159  	pr := sql.NewTreePrinter()
   160  	var fields = make([]string, len(n.Fields))
   161  	for i, f := range n.Fields {
   162  		fields[i] = fmt.Sprintf("%s %s", f.Column, f.Order)
   163  	}
   164  	_ = pr.WriteNode("TopN(Limit: [%s]; %s)", n.Limit.String(), strings.Join(fields, ", "))
   165  	_ = pr.WriteChildren(n.Child.String())
   166  	return pr.String()
   167  }
   168  
   169  func (n *TopN) DebugString() string {
   170  	pr := sql.NewTreePrinter()
   171  	var fields = make([]string, len(n.Fields))
   172  	for i, f := range n.Fields {
   173  		fields[i] = sql.DebugString(f)
   174  	}
   175  	_ = pr.WriteNode("TopN(Limit: [%s]; %s)", sql.DebugString(n.Limit), strings.Join(fields, ", "))
   176  	_ = pr.WriteChildren(sql.DebugString(n.Child))
   177  	return pr.String()
   178  }
   179  
   180  // Expressions implements the Expressioner interface.
   181  func (n *TopN) Expressions() []sql.Expression {
   182  	exprs := []sql.Expression{n.Limit}
   183  	exprs = append(exprs, n.Fields.ToExpressions()...)
   184  	return exprs
   185  }
   186  
   187  // WithChildren implements the Node interface.
   188  func (n *TopN) WithChildren(children ...sql.Node) (sql.Node, error) {
   189  	if len(children) != 1 {
   190  		return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 1)
   191  	}
   192  
   193  	topn := NewTopN(n.Fields, n.Limit, children[0])
   194  	topn.CalcFoundRows = n.CalcFoundRows
   195  	return topn, nil
   196  }
   197  
   198  // CheckPrivileges implements the interface sql.Node.
   199  func (n *TopN) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   200  	return n.Child.CheckPrivileges(ctx, opChecker)
   201  }
   202  
   203  // CollationCoercibility implements the interface sql.CollationCoercible.
   204  func (n *TopN) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   205  	return sql.GetCoercibility(ctx, n.Child)
   206  }
   207  
   208  // WithExpressions implements the Expressioner interface.
   209  func (n *TopN) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   210  	if len(exprs) != len(n.Fields)+1 {
   211  		return nil, sql.ErrInvalidChildrenNumber.New(n, len(exprs), len(n.Fields)+1)
   212  	}
   213  
   214  	var limit = exprs[0]
   215  	var fields = n.Fields.FromExpressions(exprs[1:]...)
   216  
   217  	topn := NewTopN(fields, limit, n.Child)
   218  	topn.CalcFoundRows = n.CalcFoundRows
   219  	return topn, nil
   220  }