github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/sortmerger/heap.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sortmerger 16 17 import ( 18 "database/sql/driver" 19 "reflect" 20 "time" 21 ) 22 23 var compareFuncMapping = map[reflect.Kind]func(any, any, Order) int{ 24 reflect.Int: compare[int], 25 reflect.Int8: compare[int8], 26 reflect.Int16: compare[int16], 27 reflect.Int32: compare[int32], 28 reflect.Int64: compare[int64], 29 reflect.Uint8: compare[uint8], 30 reflect.Uint16: compare[uint16], 31 reflect.Uint32: compare[uint32], 32 reflect.Uint64: compare[uint64], 33 reflect.Float32: compare[float32], 34 reflect.Float64: compare[float64], 35 reflect.String: compare[string], 36 reflect.Uint: compare[uint], 37 } 38 39 type Heap struct { 40 h []*node 41 sortColumns sortColumns 42 } 43 44 func (h *Heap) Len() int { 45 return len(h.h) 46 } 47 48 func (h *Heap) Less(i, j int) bool { 49 for k := 0; k < h.sortColumns.Len(); k++ { 50 valueI := h.h[i].sortCols[k] 51 valueJ := h.h[j].sortCols[k] 52 _, ok := valueJ.(driver.Valuer) 53 var cp func(any, any, Order) int 54 if ok { 55 cp = compareNullable 56 } else { 57 kind := reflect.TypeOf(valueI).Kind() 58 cp = compareFuncMapping[kind] 59 } 60 res := cp(valueI, valueJ, h.sortColumns.Get(k).order) 61 if res == 0 { 62 continue 63 } 64 if res == -1 { 65 return true 66 } 67 return false 68 } 69 return false 70 } 71 72 func (h *Heap) Swap(i, j int) { 73 h.h[i], h.h[j] = h.h[j], h.h[i] 74 } 75 76 func (h *Heap) Push(x any) { 77 h.h = append(h.h, x.(*node)) 78 } 79 80 func (h *Heap) Pop() any { 81 v := h.h[len(h.h)-1] 82 h.h = h.h[:len(h.h)-1] 83 return v 84 } 85 86 type node struct { 87 index int 88 sortCols []any 89 columns []any 90 } 91 92 // 升序时, -1 表示 i < j, 1 表示i > j ,0 表示两者相同 93 // 降序时,-1 表示 i > j, 1 表示 i < j ,0 表示两者相同 94 95 func compare[T Ordered](ii any, jj any, order Order) int { 96 i, j := ii.(T), jj.(T) 97 if i < j && order == ASC || i > j && order == DESC { 98 return -1 99 } else if i > j && order == ASC || i < j && order == DESC { 100 return 1 101 } else { 102 return 0 103 } 104 } 105 106 func compareNullable(ii, jj any, order Order) int { 107 i := ii.(driver.Valuer) 108 j := jj.(driver.Valuer) 109 iVal, _ := i.Value() 110 jVal, _ := j.Value() 111 // 如果i,j都为空返回0 112 // 如果val返回为空永远是最小值 113 if iVal == nil && jVal == nil { 114 return 0 115 } else if iVal == nil && order == ASC || jVal == nil && order == DESC { 116 return -1 117 } else if iVal == nil && order == DESC || jVal == nil && order == ASC { 118 return 1 119 } 120 121 vali, ok := iVal.(time.Time) 122 if ok { 123 valj := jVal.(time.Time) 124 return compare[int64](vali.UnixMilli(), valj.UnixMilli(), order) 125 } 126 kind := reflect.TypeOf(iVal).Kind() 127 return compareFuncMapping[kind](iVal, jVal, order) 128 }