github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/sortmerger/merger.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sortmerger 16 17 import ( 18 "container/heap" 19 "context" 20 "database/sql" 21 "reflect" 22 "sync" 23 24 "github.com/ecodeclub/eorm/internal/rows" 25 26 "go.uber.org/multierr" 27 28 "github.com/ecodeclub/eorm/internal/merger/internal/errs" 29 ) 30 31 type Order bool 32 33 const ( 34 // ASC 升序排序 35 ASC Order = true 36 // DESC 降序排序 37 DESC Order = false 38 ) 39 40 type Ordered interface { 41 ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 | ~string 42 } 43 44 type SortColumn struct { 45 name string 46 order Order 47 } 48 49 func NewSortColumn(colName string, order Order) SortColumn { 50 return SortColumn{ 51 name: colName, 52 order: order, 53 } 54 } 55 56 type sortColumns struct { 57 columns []SortColumn 58 colMap map[string]int 59 } 60 61 func (s sortColumns) Has(name string) bool { 62 _, ok := s.colMap[name] 63 return ok 64 } 65 66 func (s sortColumns) Find(name string) int { 67 return s.colMap[name] 68 } 69 70 func (s sortColumns) Get(index int) SortColumn { 71 return s.columns[index] 72 } 73 74 func (s sortColumns) Len() int { 75 return len(s.columns) 76 } 77 78 // Merger 如果有GroupBy子句,会导致排序是给每个分组排的,那么该实现无法运作正常 79 type Merger struct { 80 sortColumns 81 cols []string 82 } 83 84 func NewMerger(sortCols ...SortColumn) (*Merger, error) { 85 scs, err := newSortColumns(sortCols...) 86 if err != nil { 87 return nil, err 88 } 89 return &Merger{ 90 sortColumns: scs, 91 }, nil 92 } 93 94 func newSortColumns(sortCols ...SortColumn) (sortColumns, error) { 95 if len(sortCols) == 0 { 96 return sortColumns{}, errs.ErrEmptySortColumns 97 } 98 sortMap := make(map[string]int, len(sortCols)) 99 for idx, sortCol := range sortCols { 100 if _, ok := sortMap[sortCol.name]; ok { 101 return sortColumns{}, errs.NewRepeatSortColumn(sortCol.name) 102 } 103 sortMap[sortCol.name] = idx 104 } 105 scs := sortColumns{ 106 columns: sortCols, 107 colMap: sortMap, 108 } 109 return scs, nil 110 } 111 112 func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) { 113 // 检测results是否符合条件 114 if ctx.Err() != nil { 115 return nil, ctx.Err() 116 } 117 if len(results) == 0 { 118 return nil, errs.ErrMergerEmptyRows 119 } 120 for i := 0; i < len(results); i++ { 121 if err := m.checkColumns(results[i]); err != nil { 122 return nil, err 123 } 124 if ctx.Err() != nil { 125 return nil, ctx.Err() 126 } 127 } 128 return m.initRows(results) 129 } 130 131 func (m *Merger) initRows(results []rows.Rows) (*Rows, error) { 132 rs := &Rows{ 133 rowsList: results, 134 sortColumns: m.sortColumns, 135 mu: &sync.RWMutex{}, 136 columns: m.cols, 137 } 138 h := &Heap{ 139 h: make([]*node, 0, len(rs.rowsList)), 140 sortColumns: rs.sortColumns, 141 } 142 rs.hp = h 143 for i := 0; i < len(rs.rowsList); i++ { 144 err := rs.nextRows(rs.rowsList[i], i) 145 if err != nil { 146 _ = rs.Close() 147 return nil, err 148 } 149 } 150 return rs, nil 151 } 152 153 func (m *Merger) checkColumns(rows rows.Rows) error { 154 if rows == nil { 155 return errs.ErrMergerRowsIsNull 156 } 157 cols, err := rows.Columns() 158 if err != nil { 159 return err 160 } 161 colMap := make(map[string]struct{}, len(cols)) 162 if len(m.cols) == 0 { 163 m.cols = cols 164 } 165 if len(m.cols) != len(cols) { 166 return errs.ErrMergerRowsDiff 167 } 168 for idx, colName := range cols { 169 if m.cols[idx] != colName { 170 return errs.ErrMergerRowsDiff 171 } 172 colMap[colName] = struct{}{} 173 } 174 175 for _, sortColumn := range m.sortColumns.columns { 176 _, ok := colMap[sortColumn.name] 177 if !ok { 178 return errs.NewInvalidSortColumn(sortColumn.name) 179 } 180 } 181 return nil 182 } 183 184 func newNode(row rows.Rows, sortCols sortColumns, index int) (*node, error) { 185 colsInfo, err := row.ColumnTypes() 186 if err != nil { 187 return nil, err 188 } 189 columns := make([]any, 0, len(colsInfo)) 190 sortColumns := make([]any, sortCols.Len()) 191 for _, colInfo := range colsInfo { 192 colName := colInfo.Name() 193 colType := colInfo.ScanType() 194 for colType.Kind() == reflect.Ptr { 195 colType = colType.Elem() 196 } 197 column := reflect.New(colType).Interface() 198 if sortCols.Has(colName) { 199 sortIndex := sortCols.Find(colName) 200 sortColumns[sortIndex] = column 201 } 202 columns = append(columns, column) 203 } 204 err = row.Scan(columns...) 205 if err != nil { 206 return nil, err 207 } 208 for i := 0; i < len(sortColumns); i++ { 209 sortColumns[i] = reflect.ValueOf(sortColumns[i]).Elem().Interface() 210 } 211 for i := 0; i < len(columns); i++ { 212 columns[i] = reflect.ValueOf(columns[i]).Elem().Interface() 213 } 214 return &node{ 215 sortCols: sortColumns, 216 columns: columns, 217 index: index, 218 }, nil 219 } 220 221 type Rows struct { 222 rowsList []rows.Rows 223 sortColumns sortColumns 224 hp *Heap 225 cur *node 226 mu *sync.RWMutex 227 lastErr error 228 closed bool 229 columns []string 230 } 231 232 func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { 233 return r.rowsList[0].ColumnTypes() 234 } 235 236 func (*Rows) NextResultSet() bool { 237 return false 238 } 239 240 func (r *Rows) Next() bool { 241 r.mu.Lock() 242 if r.closed { 243 r.mu.Unlock() 244 return false 245 } 246 if r.hp.Len() == 0 || r.lastErr != nil { 247 r.mu.Unlock() 248 _ = r.Close() 249 return false 250 } 251 r.cur = heap.Pop(r.hp).(*node) 252 row := r.rowsList[r.cur.index] 253 err := r.nextRows(row, r.cur.index) 254 if err != nil { 255 r.lastErr = err 256 r.mu.Unlock() 257 _ = r.Close() 258 return false 259 } 260 r.mu.Unlock() 261 return true 262 } 263 264 func (r *Rows) nextRows(row rows.Rows, index int) error { 265 if row.Next() { 266 n, err := newNode(row, r.sortColumns, index) 267 if err != nil { 268 return err 269 } 270 heap.Push(r.hp, n) 271 } else if row.Err() != nil { 272 return row.Err() 273 } 274 return nil 275 } 276 277 func (r *Rows) Scan(dest ...any) error { 278 r.mu.Lock() 279 defer r.mu.Unlock() 280 if r.lastErr != nil { 281 return r.lastErr 282 } 283 if r.closed { 284 return errs.ErrMergerRowsClosed 285 } 286 if r.cur == nil { 287 return errs.ErrMergerScanNotNext 288 } 289 var err error 290 for i := 0; i < len(dest); i++ { 291 err = rows.ConvertAssign(dest[i], r.cur.columns[i]) 292 if err != nil { 293 return err 294 } 295 } 296 return nil 297 } 298 299 func (r *Rows) Close() error { 300 r.mu.Lock() 301 defer r.mu.Unlock() 302 r.closed = true 303 errorList := make([]error, 0, len(r.rowsList)) 304 for i := 0; i < len(r.rowsList); i++ { 305 row := r.rowsList[i] 306 err := row.Close() 307 if err != nil { 308 errorList = append(errorList, err) 309 } 310 } 311 return multierr.Combine(errorList...) 312 } 313 314 func (r *Rows) Err() error { 315 r.mu.RLock() 316 defer r.mu.RUnlock() 317 return r.lastErr 318 } 319 320 func (r *Rows) Columns() ([]string, error) { 321 r.mu.RLock() 322 defer r.mu.RUnlock() 323 if r.closed { 324 return nil, errs.ErrMergerRowsClosed 325 } 326 return r.columns, nil 327 }