
     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package sortmerger
    17  import (
    18  	"database/sql/driver"
    19  	"reflect"
    20  	"time"
    21  )
    23  var compareFuncMapping = map[reflect.Kind]func(any, any, Order) int{
    24  	reflect.Int:     compare[int],
    25  	reflect.Int8:    compare[int8],
    26  	reflect.Int16:   compare[int16],
    27  	reflect.Int32:   compare[int32],
    28  	reflect.Int64:   compare[int64],
    29  	reflect.Uint8:   compare[uint8],
    30  	reflect.Uint16:  compare[uint16],
    31  	reflect.Uint32:  compare[uint32],
    32  	reflect.Uint64:  compare[uint64],
    33  	reflect.Float32: compare[float32],
    34  	reflect.Float64: compare[float64],
    35  	reflect.String:  compare[string],
    36  	reflect.Uint:    compare[uint],
    37  }
    39  type Heap struct {
    40  	h           []*node
    41  	sortColumns sortColumns
    42  }
    44  func (h *Heap) Len() int {
    45  	return len(h.h)
    46  }
    48  func (h *Heap) Less(i, j int) bool {
    49  	for k := 0; k < h.sortColumns.Len(); k++ {
    50  		valueI := h.h[i].sortCols[k]
    51  		valueJ := h.h[j].sortCols[k]
    52  		_, ok := valueJ.(driver.Valuer)
    53  		var cp func(any, any, Order) int
    54  		if ok {
    55  			cp = compareNullable
    56  		} else {
    57  			kind := reflect.TypeOf(valueI).Kind()
    58  			cp = compareFuncMapping[kind]
    59  		}
    60  		res := cp(valueI, valueJ, h.sortColumns.Get(k).order)
    61  		if res == 0 {
    62  			continue
    63  		}
    64  		if res == -1 {
    65  			return true
    66  		}
    67  		return false
    68  	}
    69  	return false
    70  }
    72  func (h *Heap) Swap(i, j int) {
    73  	h.h[i], h.h[j] = h.h[j], h.h[i]
    74  }
    76  func (h *Heap) Push(x any) {
    77  	h.h = append(h.h, x.(*node))
    78  }
    80  func (h *Heap) Pop() any {
    81  	v := h.h[len(h.h)-1]
    82  	h.h = h.h[:len(h.h)-1]
    83  	return v
    84  }
    86  type node struct {
    87  	index    int
    88  	sortCols []any
    89  	columns  []any
    90  }
    92  // 升序时, -1 表示 i < j, 1 表示i > j ,0 表示两者相同
    93  // 降序时,-1 表示 i > j, 1 表示 i < j ,0 表示两者相同
    95  func compare[T Ordered](ii any, jj any, order Order) int {
    96  	i, j := ii.(T), jj.(T)
    97  	if i < j && order == ASC || i > j && order == DESC {
    98  		return -1
    99  	} else if i > j && order == ASC || i < j && order == DESC {
   100  		return 1
   101  	} else {
   102  		return 0
   103  	}
   104  }
   106  func compareNullable(ii, jj any, order Order) int {
   107  	i := ii.(driver.Valuer)
   108  	j := jj.(driver.Valuer)
   109  	iVal, _ := i.Value()
   110  	jVal, _ := j.Value()
   111  	// 如果i,j都为空返回0
   112  	// 如果val返回为空永远是最小值
   113  	if iVal == nil && jVal == nil {
   114  		return 0
   115  	} else if iVal == nil && order == ASC || jVal == nil && order == DESC {
   116  		return -1
   117  	} else if iVal == nil && order == DESC || jVal == nil && order == ASC {
   118  		return 1
   119  	}
   121  	vali, ok := iVal.(time.Time)
   122  	if ok {
   123  		valj := jVal.(time.Time)
   124  		return compare[int64](vali.UnixMilli(), valj.UnixMilli(), order)
   125  	}
   126  	kind := reflect.TypeOf(iVal).Kind()
   127  	return compareFuncMapping[kind](iVal, jVal, order)
   128  }