github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/aggregatemerger/merger.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregatemerger 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "sync" 22 _ "unsafe" 23 24 "github.com/ecodeclub/eorm/internal/rows" 25 26 "github.com/ecodeclub/ekit/sqlx" 27 28 "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" 29 "github.com/ecodeclub/eorm/internal/merger/internal/errs" 30 "go.uber.org/multierr" 31 ) 32 33 // Merger 该实现不支持group by操作,并且聚合函数查询应该只返回一行数据。 34 type Merger struct { 35 aggregators []aggregator.Aggregator 36 colNames []string 37 } 38 39 func NewMerger(aggregators ...aggregator.Aggregator) *Merger { 40 cols := make([]string, 0, len(aggregators)) 41 for _, agg := range aggregators { 42 cols = append(cols, agg.ColumnName()) 43 } 44 return &Merger{ 45 aggregators: aggregators, 46 colNames: cols, 47 } 48 } 49 50 func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) { 51 if ctx.Err() != nil { 52 return nil, ctx.Err() 53 } 54 if len(results) == 0 { 55 return nil, errs.ErrMergerEmptyRows 56 } 57 for _, res := range results { 58 if res == nil { 59 return nil, errs.ErrMergerRowsIsNull 60 } 61 } 62 63 return &Rows{ 64 rowsList: results, 65 aggregators: m.aggregators, 66 mu: &sync.RWMutex{}, 67 //聚合函数AVG传递到各个sql.Rows时会被转化为SUM和COUNT,这是一个对外不可见的转化。 68 //所以merger.Rows的列名及顺序是由上方aggregator出现的顺序及ColumnName()的返回值决定的而不是sql.Rows。 69 columns: m.colNames, 70 }, nil 71 72 } 73 74 type Rows struct { 75 rowsList []rows.Rows 76 aggregators []aggregator.Aggregator 77 closed bool 78 mu *sync.RWMutex 79 lastErr error 80 cur []any 81 columns []string 82 nextCalled bool 83 } 84 85 func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { 86 return r.rowsList[0].ColumnTypes() 87 } 88 89 func (*Rows) NextResultSet() bool { 90 return false 91 } 92 93 func (r *Rows) Next() bool { 94 r.mu.Lock() 95 if r.closed || r.lastErr != nil { 96 r.mu.Unlock() 97 return false 98 } 99 if r.nextCalled { 100 r.mu.Unlock() 101 _ = r.Close() 102 return false 103 } 104 105 rowsData, err := r.getSqlRowsData() 106 r.nextCalled = true 107 if err != nil { 108 r.lastErr = err 109 r.mu.Unlock() 110 _ = r.Close() 111 return false 112 } 113 // 进行聚合函数计算 114 res, err := r.executeAggregateCalculation(rowsData) 115 if err != nil { 116 r.lastErr = err 117 r.mu.Unlock() 118 _ = r.Close() 119 return false 120 } 121 r.cur = res 122 r.mu.Unlock() 123 return true 124 125 } 126 127 // getAggregateInfo 进行aggregate运算 128 func (r *Rows) executeAggregateCalculation(rowsData [][]any) ([]any, error) { 129 res := make([]any, 0, len(r.aggregators)) 130 for _, agg := range r.aggregators { 131 val, err := agg.Aggregate(rowsData) 132 if err != nil { 133 return nil, err 134 } 135 136 res = append(res, val) 137 } 138 return res, nil 139 } 140 141 // getSqlRowData 从sqlRows里面获取数据 142 func (r *Rows) getSqlRowsData() ([][]any, error) { 143 // 所有sql.Rows的数据 144 rowsData := make([][]any, 0, len(r.rowsList)) 145 for _, row := range r.rowsList { 146 colData, err := r.getSqlRowData(row) 147 if err != nil { 148 return nil, err 149 } 150 rowsData = append(rowsData, colData) 151 } 152 return rowsData, nil 153 } 154 func (*Rows) getSqlRowData(row rows.Rows) ([]any, error) { 155 var colsData []any 156 var err error 157 scanner, err := sqlx.NewSQLRowsScanner(row) 158 if err != nil { 159 return nil, err 160 } 161 colsData, err = scanner.Scan() 162 if errors.Is(err, sqlx.ErrNoMoreRows) { 163 return nil, errs.ErrMergerAggregateHasEmptyRows 164 } 165 return colsData, err 166 } 167 168 func (r *Rows) Scan(dest ...any) error { 169 r.mu.Lock() 170 defer r.mu.Unlock() 171 if r.lastErr != nil { 172 return r.lastErr 173 } 174 if r.closed { 175 return errs.ErrMergerRowsClosed 176 } 177 178 if len(r.cur) == 0 { 179 return errs.ErrMergerScanNotNext 180 } 181 for i := 0; i < len(dest); i++ { 182 err := rows.ConvertAssign(dest[i], r.cur[i]) 183 if err != nil { 184 return err 185 } 186 } 187 return nil 188 } 189 190 func (r *Rows) Close() error { 191 r.mu.Lock() 192 defer r.mu.Unlock() 193 r.closed = true 194 errorList := make([]error, 0, len(r.rowsList)) 195 for i := 0; i < len(r.rowsList); i++ { 196 row := r.rowsList[i] 197 err := row.Close() 198 if err != nil { 199 errorList = append(errorList, err) 200 } 201 } 202 return multierr.Combine(errorList...) 203 } 204 205 func (r *Rows) Columns() ([]string, error) { 206 r.mu.RLock() 207 defer r.mu.RUnlock() 208 if r.closed { 209 return nil, errs.ErrMergerRowsClosed 210 } 211 return r.columns, nil 212 } 213 214 func (r *Rows) Err() error { 215 r.mu.RLock() 216 defer r.mu.RUnlock() 217 return r.lastErr 218 }