github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/set_op.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  )
    23  
    24  const (
    25  	UnionType = iota
    26  	IntersectType
    27  	ExceptType
    28  )
    29  
    30  // SetOp is a node that returns everything in Left and then everything in Right
    31  type SetOp struct {
    32  	BinaryNode
    33  	SetOpType  int
    34  	Distinct   bool
    35  	Limit      sql.Expression
    36  	Offset     sql.Expression
    37  	SortFields sql.SortFields
    38  	dispose    []sql.DisposeFunc
    39  	id         sql.TableId
    40  	cols       sql.ColSet
    41  }
    42  
    43  var _ sql.Node = (*SetOp)(nil)
    44  var _ sql.Expressioner = (*SetOp)(nil)
    45  var _ sql.CollationCoercible = (*SetOp)(nil)
    46  
    47  // var _ sql.NameableNode = (*SetOp)(nil)
    48  var _ TableIdNode = (*SetOp)(nil)
    49  
    50  // NewSetOp creates a new SetOp node with the given children.
    51  func NewSetOp(setOpType int, left, right sql.Node, distinct bool, limit, offset sql.Expression, sortFields sql.SortFields) *SetOp {
    52  	return &SetOp{
    53  		BinaryNode: BinaryNode{left: left, right: right},
    54  		Distinct:   distinct,
    55  		Limit:      limit,
    56  		Offset:     offset,
    57  		SortFields: sortFields,
    58  		SetOpType:  setOpType,
    59  	}
    60  }
    61  
    62  func (s *SetOp) Name() string {
    63  	// TODO union should have its own name, table id, cols, etc
    64  	return ""
    65  }
    66  
    67  // WithId implements sql.TableIdNode
    68  func (s *SetOp) WithId(id sql.TableId) TableIdNode {
    69  	ret := *s
    70  	ret.id = id
    71  	return &ret
    72  }
    73  
    74  // Id implements sql.TableIdNode
    75  func (s *SetOp) Id() sql.TableId {
    76  	return s.id
    77  }
    78  
    79  // WithColumns implements sql.TableIdNode
    80  func (s *SetOp) WithColumns(set sql.ColSet) TableIdNode {
    81  	ret := *s
    82  	ret.cols = set
    83  	return &ret
    84  }
    85  
    86  // Columns implements sql.TableIdNode
    87  func (s *SetOp) Columns() sql.ColSet {
    88  	return s.cols
    89  }
    90  
    91  func (s *SetOp) AddDispose(f sql.DisposeFunc) {
    92  	s.dispose = append(s.dispose, f)
    93  }
    94  
    95  func (s *SetOp) Schema() sql.Schema {
    96  	ls := s.left.Schema()
    97  	rs := s.right.Schema()
    98  	ret := make([]*sql.Column, len(ls))
    99  	for i := range ls {
   100  		c := *ls[i]
   101  		if i < len(rs) {
   102  			c.Nullable = ls[i].Nullable || rs[i].Nullable
   103  		}
   104  		ret[i] = &c
   105  	}
   106  	return ret
   107  }
   108  
   109  // Opaque implements the sql.OpaqueNode interface.
   110  // Like SubqueryAlias, the selects in a SetOp must be evaluated in isolation.
   111  func (s *SetOp) Opaque() bool {
   112  	return true
   113  }
   114  
   115  func (s *SetOp) Resolved() bool {
   116  	res := s.Left().Resolved() && s.Right().Resolved()
   117  	if s.Limit != nil {
   118  		res = res && s.Limit.Resolved()
   119  	}
   120  	if s.Offset != nil {
   121  		res = res && s.Offset.Resolved()
   122  	}
   123  	for _, sf := range s.SortFields {
   124  		res = res && sf.Column.Resolved()
   125  	}
   126  	return res
   127  }
   128  
   129  func (s *SetOp) WithDistinct(b bool) *SetOp {
   130  	ret := *s
   131  	ret.Distinct = b
   132  	return &ret
   133  }
   134  
   135  func (s *SetOp) WithLimit(e sql.Expression) *SetOp {
   136  	ret := *s
   137  	ret.Limit = e
   138  	return &ret
   139  }
   140  
   141  func (s *SetOp) WithOffset(e sql.Expression) *SetOp {
   142  	ret := *s
   143  	ret.Offset = e
   144  	return &ret
   145  }
   146  
   147  func (s *SetOp) Expressions() []sql.Expression {
   148  	var exprs []sql.Expression
   149  	if s.Limit != nil {
   150  		exprs = append(exprs, s.Limit)
   151  	}
   152  	if s.Offset != nil {
   153  		exprs = append(exprs, s.Offset)
   154  	}
   155  	if len(s.SortFields) > 0 {
   156  		exprs = append(exprs, s.SortFields.ToExpressions()...)
   157  	}
   158  	return exprs
   159  }
   160  
   161  func (s *SetOp) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
   162  	var expLim, expOff, expSort int
   163  	if s.Limit != nil {
   164  		expLim = 1
   165  	}
   166  	if s.Offset != nil {
   167  		expOff = 1
   168  	}
   169  	expSort = len(s.SortFields)
   170  
   171  	if len(exprs) != expLim+expOff+expSort {
   172  		return nil, fmt.Errorf("expected %d limit and %d sort fields", expLim, expSort)
   173  	} else if len(exprs) == 0 {
   174  		return s, nil
   175  	}
   176  
   177  	ret := *s
   178  	if expLim == 1 {
   179  		ret.Limit = exprs[0]
   180  		exprs = exprs[1:]
   181  	}
   182  	if expOff == 1 {
   183  		ret.Offset = exprs[0]
   184  		exprs = exprs[1:]
   185  	}
   186  	ret.SortFields = s.SortFields.FromExpressions(exprs...)
   187  	return &ret, nil
   188  }
   189  
   190  // WithChildren implements the Node interface.
   191  func (s *SetOp) WithChildren(children ...sql.Node) (sql.Node, error) {
   192  	if len(children) != 2 {
   193  		return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 2)
   194  	}
   195  	ret := *s
   196  	ret.left = children[0]
   197  	ret.right = children[1]
   198  	return &ret, nil
   199  }
   200  
   201  // CheckPrivileges implements the interface sql.Node.
   202  func (s *SetOp) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   203  	return s.left.CheckPrivileges(ctx, opChecker) && s.right.CheckPrivileges(ctx, opChecker)
   204  }
   205  
   206  // CollationCoercibility implements the interface sql.CollationCoercible.
   207  func (*SetOp) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   208  	// Unions are able to return differing values, therefore they cannot be used to determine coercibility
   209  	return sql.Collation_binary, 7
   210  }
   211  
   212  func (s *SetOp) Dispose() {
   213  	for _, f := range s.dispose {
   214  		f()
   215  	}
   216  }
   217  
   218  func (s *SetOp) String() string {
   219  	pr := sql.NewTreePrinter()
   220  	var distinct string
   221  	if s.Distinct {
   222  		distinct = "distinct"
   223  	} else {
   224  		distinct = "all"
   225  	}
   226  	switch s.SetOpType {
   227  	case UnionType:
   228  		_ = pr.WriteNode(fmt.Sprintf("Union %s", distinct))
   229  	case IntersectType:
   230  		_ = pr.WriteNode(fmt.Sprintf("Intersect %s", distinct))
   231  	case ExceptType:
   232  		_ = pr.WriteNode(fmt.Sprintf("Except %s", distinct))
   233  	}
   234  	var children []string
   235  	if len(s.SortFields) > 0 {
   236  		children = append(children, fmt.Sprintf("sortFields: %s", s.SortFields.ToExpressions()))
   237  	}
   238  	if s.Limit != nil {
   239  		children = append(children, fmt.Sprintf("limit: %s", s.Limit))
   240  	}
   241  	if s.Offset != nil {
   242  		children = append(children, fmt.Sprintf("offset: %s", s.Offset))
   243  	}
   244  	children = append(children, s.left.String(), s.right.String())
   245  	_ = pr.WriteChildren(children...)
   246  	return pr.String()
   247  }
   248  
   249  func (s *SetOp) IsReadOnly() bool {
   250  	return s.left.IsReadOnly() && s.right.IsReadOnly()
   251  }
   252  
   253  func (s *SetOp) DebugString() string {
   254  	pr := sql.NewTreePrinter()
   255  	var distinct string
   256  	if s.Distinct {
   257  		distinct = "distinct"
   258  	} else {
   259  		distinct = "all"
   260  	}
   261  	switch s.SetOpType {
   262  	case UnionType:
   263  		_ = pr.WriteNode(fmt.Sprintf("Union %s", distinct))
   264  	case IntersectType:
   265  		_ = pr.WriteNode(fmt.Sprintf("Intersect %s", distinct))
   266  	case ExceptType:
   267  		_ = pr.WriteNode(fmt.Sprintf("Except %s", distinct))
   268  	}
   269  	var children []string
   270  	if len(s.SortFields) > 0 {
   271  		sFields := make([]string, len(s.SortFields))
   272  		for i, e := range s.SortFields.ToExpressions() {
   273  			sFields[i] = sql.DebugString(e)
   274  		}
   275  		children = append(children, fmt.Sprintf("sortFields: %s", strings.Join(sFields, ", ")))
   276  	}
   277  	if s.Limit != nil {
   278  		children = append(children, fmt.Sprintf("limit: %s", s.Limit))
   279  	}
   280  	if s.Offset != nil {
   281  		children = append(children, fmt.Sprintf("offset: %s", s.Offset))
   282  	}
   283  	children = append(children, sql.DebugString(s.left), sql.DebugString(s.right))
   284  	_ = pr.WriteChildren(children...)
   285  	return pr.String()
   286  }