github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/batchmerger/merger.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package batchmerger 16 17 import ( 18 "context" 19 "database/sql" 20 "sync" 21 22 "github.com/ecodeclub/eorm/internal/rows" 23 24 "go.uber.org/multierr" 25 26 "github.com/ecodeclub/eorm/internal/merger/internal/errs" 27 ) 28 29 type Merger struct { 30 cols []string 31 } 32 33 func NewMerger() *Merger { 34 return &Merger{} 35 } 36 37 func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) { 38 if ctx.Err() != nil { 39 return nil, ctx.Err() 40 } 41 if len(results) == 0 { 42 return nil, errs.ErrMergerEmptyRows 43 } 44 for i := 0; i < len(results); i++ { 45 err := m.checkColumns(results[i]) 46 if err != nil { 47 return nil, err 48 } 49 } 50 return &Rows{ 51 rowsList: results, 52 mu: &sync.RWMutex{}, 53 columns: m.cols, 54 }, nil 55 } 56 57 // checkColumns 检查sql.Rows列表中sql.Rows的列集是否相同,并且sql.Rows不能为nil 58 func (m *Merger) checkColumns(rows rows.Rows) error { 59 if rows == nil { 60 return errs.ErrMergerRowsIsNull 61 } 62 cols, err := rows.Columns() 63 if err != nil { 64 return err 65 } 66 if len(m.cols) == 0 { 67 m.cols = cols 68 } 69 if len(m.cols) != len(cols) { 70 return errs.ErrMergerRowsDiff 71 } 72 for idx, colName := range cols { 73 if m.cols[idx] != colName { 74 return errs.ErrMergerRowsDiff 75 } 76 } 77 return nil 78 79 } 80 81 type Rows struct { 82 rowsList []rows.Rows 83 cnt int 84 mu *sync.RWMutex 85 columns []string 86 closed bool 87 lastErr error 88 } 89 90 func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) { 91 return r.rowsList[0].ColumnTypes() 92 } 93 94 func (*Rows) NextResultSet() bool { 95 return false 96 } 97 98 func (r *Rows) Next() bool { 99 r.mu.Lock() 100 if r.closed { 101 r.mu.Unlock() 102 return false 103 } 104 if r.cnt >= len(r.rowsList) || r.lastErr != nil { 105 r.mu.Unlock() 106 _ = r.Close() 107 return false 108 } 109 canNext, err := r.nextRows() 110 if err != nil { 111 r.lastErr = err 112 r.mu.Unlock() 113 _ = r.Close() 114 return false 115 } 116 r.mu.Unlock() 117 return canNext 118 119 } 120 121 func (r *Rows) nextRows() (bool, error) { 122 row := r.rowsList[r.cnt] 123 124 if row.Next() { 125 return true, nil 126 } 127 128 for row.NextResultSet() { 129 if row.Next() { 130 return true, nil 131 } 132 } 133 134 if row.Err() != nil { 135 return false, row.Err() 136 } 137 138 for { 139 r.cnt++ 140 if r.cnt >= len(r.rowsList) { 141 break 142 } 143 row = r.rowsList[r.cnt] 144 145 if row.Next() { 146 return true, nil 147 } else if row.Err() != nil { 148 return false, row.Err() 149 } 150 151 for row.NextResultSet() { 152 if row.Next() { 153 return true, nil 154 } 155 } 156 } 157 return false, nil 158 } 159 160 func (r *Rows) Scan(dest ...any) error { 161 r.mu.RLock() 162 defer r.mu.RUnlock() 163 if r.lastErr != nil { 164 return r.lastErr 165 } 166 if r.closed { 167 return errs.ErrMergerRowsClosed 168 } 169 return r.rowsList[r.cnt].Scan(dest...) 170 171 } 172 173 func (r *Rows) Close() error { 174 r.mu.Lock() 175 defer r.mu.Unlock() 176 r.closed = true 177 errorList := make([]error, 0, len(r.rowsList)) 178 for i := 0; i < len(r.rowsList); i++ { 179 row := r.rowsList[i] 180 err := row.Close() 181 if err != nil { 182 errorList = append(errorList, err) 183 } 184 } 185 return multierr.Combine(errorList...) 186 } 187 188 func (r *Rows) Columns() ([]string, error) { 189 r.mu.RLock() 190 defer r.mu.RUnlock() 191 if r.closed { 192 return nil, errs.ErrMergerRowsClosed 193 } 194 return r.columns, nil 195 } 196 197 func (r *Rows) Err() error { 198 r.mu.RLock() 199 defer r.mu.RUnlock() 200 return r.lastErr 201 }