github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/sort.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"container/heap"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  )
    22  
    23  // Sorter is a sorter implementation for Row slices using SortFields for the comparison
    24  type Sorter struct {
    25  	SortFields []sql.SortField
    26  	Rows       []sql.Row
    27  	LastError  error
    28  	Ctx        *sql.Context
    29  }
    30  
    31  func (s *Sorter) Len() int {
    32  	return len(s.Rows)
    33  }
    34  
    35  func (s *Sorter) Swap(i, j int) {
    36  	s.Rows[i], s.Rows[j] = s.Rows[j], s.Rows[i]
    37  }
    38  
    39  func (s *Sorter) Less(i, j int) bool {
    40  	if s.LastError != nil {
    41  		return false
    42  	}
    43  
    44  	a := s.Rows[i]
    45  	b := s.Rows[j]
    46  	for _, sf := range s.SortFields {
    47  		typ := sf.Column.Type()
    48  		av, err := sf.Column.Eval(s.Ctx, a)
    49  		if err != nil {
    50  			s.LastError = sql.ErrUnableSort.Wrap(err)
    51  			return false
    52  		}
    53  
    54  		bv, err := sf.Column.Eval(s.Ctx, b)
    55  		if err != nil {
    56  			s.LastError = sql.ErrUnableSort.Wrap(err)
    57  			return false
    58  		}
    59  
    60  		if sf.Order == sql.Descending {
    61  			av, bv = bv, av
    62  		}
    63  
    64  		if av == nil && bv == nil {
    65  			continue
    66  		} else if av == nil {
    67  			return sf.NullOrdering == sql.NullsFirst
    68  		} else if bv == nil {
    69  			return sf.NullOrdering != sql.NullsFirst
    70  		}
    71  
    72  		cmp, err := typ.Compare(av, bv)
    73  		if err != nil {
    74  			s.LastError = err
    75  			return false
    76  		}
    77  
    78  		switch cmp {
    79  		case -1:
    80  			return true
    81  		case 1:
    82  			return false
    83  		}
    84  	}
    85  
    86  	return false
    87  }
    88  
    89  // Sorter2 is a version of Sorter that operates on Row2
    90  type Sorter2 struct {
    91  	SortFields []sql.SortField
    92  	Rows       []sql.Row2
    93  	LastError  error
    94  	Ctx        *sql.Context
    95  }
    96  
    97  func (s *Sorter2) Len() int {
    98  	return len(s.Rows)
    99  }
   100  
   101  func (s *Sorter2) Swap(i, j int) {
   102  	s.Rows[i], s.Rows[j] = s.Rows[j], s.Rows[i]
   103  }
   104  
   105  func (s *Sorter2) Less(i, j int) bool {
   106  	if s.LastError != nil {
   107  		return false
   108  	}
   109  
   110  	a := s.Rows[i]
   111  	b := s.Rows[j]
   112  	for _, sf := range s.SortFields {
   113  		typ := sf.Column2.Type2()
   114  		av, err := sf.Column2.Eval2(s.Ctx, a)
   115  		if err != nil {
   116  			s.LastError = sql.ErrUnableSort.Wrap(err)
   117  			return false
   118  		}
   119  
   120  		bv, err := sf.Column2.Eval2(s.Ctx, b)
   121  		if err != nil {
   122  			s.LastError = sql.ErrUnableSort.Wrap(err)
   123  			return false
   124  		}
   125  
   126  		if sf.Order == sql.Descending {
   127  			av, bv = bv, av
   128  		}
   129  
   130  		if av.IsNull() && bv.IsNull() {
   131  			continue
   132  		} else if av.IsNull() {
   133  			return sf.NullOrdering == sql.NullsFirst
   134  		} else if bv.IsNull() {
   135  			return sf.NullOrdering != sql.NullsFirst
   136  		}
   137  
   138  		cmp, err := typ.Compare2(av, bv)
   139  		if err != nil {
   140  			s.LastError = err
   141  			return false
   142  		}
   143  
   144  		switch cmp {
   145  		case -1:
   146  			return true
   147  		case 1:
   148  			return false
   149  		}
   150  	}
   151  
   152  	return false
   153  }
   154  
   155  // TopRowsHeap implements heap.Interface based on Sorter. It inverts the Less()
   156  // function so that it can be used to implement TopN. heap.Push() rows into it,
   157  // and if Len() > MAX; heap.Pop() the current min row. Then, at the end of
   158  // seeing all the rows, call Rows(). Rows() will return the rows which come
   159  // back from heap.Pop() in reverse order, correctly restoring the order for the
   160  // TopN elements.
   161  type TopRowsHeap struct {
   162  	Sorter
   163  }
   164  
   165  func (h *TopRowsHeap) Less(i, j int) bool {
   166  	return !h.Sorter.Less(i, j)
   167  }
   168  
   169  func (h *TopRowsHeap) Push(x interface{}) {
   170  	h.Sorter.Rows = append(h.Sorter.Rows, x.(sql.Row))
   171  }
   172  
   173  func (h *TopRowsHeap) Pop() interface{} {
   174  	old := h.Sorter.Rows
   175  	n := len(old)
   176  	res := old[n-1]
   177  	h.Sorter.Rows = old[0 : n-1]
   178  	return res
   179  }
   180  
   181  func (h *TopRowsHeap) Rows() ([]sql.Row, error) {
   182  	l := h.Len()
   183  	res := make([]sql.Row, l)
   184  	for i := l - 1; i >= 0; i-- {
   185  		res[i] = heap.Pop(h).(sql.Row)
   186  	}
   187  	return res, h.LastError
   188  }