github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/batchmerger/merger.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package batchmerger
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"sync"
    21  
    22  	"github.com/ecodeclub/eorm/internal/rows"
    23  
    24  	"go.uber.org/multierr"
    25  
    26  	"github.com/ecodeclub/eorm/internal/merger/internal/errs"
    27  )
    28  
    29  type Merger struct {
    30  	cols []string
    31  }
    32  
    33  func NewMerger() *Merger {
    34  	return &Merger{}
    35  }
    36  
    37  func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) {
    38  	if ctx.Err() != nil {
    39  		return nil, ctx.Err()
    40  	}
    41  	if len(results) == 0 {
    42  		return nil, errs.ErrMergerEmptyRows
    43  	}
    44  	for i := 0; i < len(results); i++ {
    45  		err := m.checkColumns(results[i])
    46  		if err != nil {
    47  			return nil, err
    48  		}
    49  	}
    50  	return &Rows{
    51  		rowsList: results,
    52  		mu:       &sync.RWMutex{},
    53  		columns:  m.cols,
    54  	}, nil
    55  }
    56  
    57  // checkColumns 检查sql.Rows列表中sql.Rows的列集是否相同,并且sql.Rows不能为nil
    58  func (m *Merger) checkColumns(rows rows.Rows) error {
    59  	if rows == nil {
    60  		return errs.ErrMergerRowsIsNull
    61  	}
    62  	cols, err := rows.Columns()
    63  	if err != nil {
    64  		return err
    65  	}
    66  	if len(m.cols) == 0 {
    67  		m.cols = cols
    68  	}
    69  	if len(m.cols) != len(cols) {
    70  		return errs.ErrMergerRowsDiff
    71  	}
    72  	for idx, colName := range cols {
    73  		if m.cols[idx] != colName {
    74  			return errs.ErrMergerRowsDiff
    75  		}
    76  	}
    77  	return nil
    78  
    79  }
    80  
    81  type Rows struct {
    82  	rowsList []rows.Rows
    83  	cnt      int
    84  	mu       *sync.RWMutex
    85  	columns  []string
    86  	closed   bool
    87  	lastErr  error
    88  }
    89  
    90  func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) {
    91  	return r.rowsList[0].ColumnTypes()
    92  }
    93  
    94  func (*Rows) NextResultSet() bool {
    95  	return false
    96  }
    97  
    98  func (r *Rows) Next() bool {
    99  	r.mu.Lock()
   100  	if r.closed {
   101  		r.mu.Unlock()
   102  		return false
   103  	}
   104  	if r.cnt >= len(r.rowsList) || r.lastErr != nil {
   105  		r.mu.Unlock()
   106  		_ = r.Close()
   107  		return false
   108  	}
   109  	canNext, err := r.nextRows()
   110  	if err != nil {
   111  		r.lastErr = err
   112  		r.mu.Unlock()
   113  		_ = r.Close()
   114  		return false
   115  	}
   116  	r.mu.Unlock()
   117  	return canNext
   118  
   119  }
   120  
   121  func (r *Rows) nextRows() (bool, error) {
   122  	row := r.rowsList[r.cnt]
   123  
   124  	if row.Next() {
   125  		return true, nil
   126  	}
   127  
   128  	for row.NextResultSet() {
   129  		if row.Next() {
   130  			return true, nil
   131  		}
   132  	}
   133  
   134  	if row.Err() != nil {
   135  		return false, row.Err()
   136  	}
   137  
   138  	for {
   139  		r.cnt++
   140  		if r.cnt >= len(r.rowsList) {
   141  			break
   142  		}
   143  		row = r.rowsList[r.cnt]
   144  
   145  		if row.Next() {
   146  			return true, nil
   147  		} else if row.Err() != nil {
   148  			return false, row.Err()
   149  		}
   150  
   151  		for row.NextResultSet() {
   152  			if row.Next() {
   153  				return true, nil
   154  			}
   155  		}
   156  	}
   157  	return false, nil
   158  }
   159  
   160  func (r *Rows) Scan(dest ...any) error {
   161  	r.mu.RLock()
   162  	defer r.mu.RUnlock()
   163  	if r.lastErr != nil {
   164  		return r.lastErr
   165  	}
   166  	if r.closed {
   167  		return errs.ErrMergerRowsClosed
   168  	}
   169  	return r.rowsList[r.cnt].Scan(dest...)
   170  
   171  }
   172  
   173  func (r *Rows) Close() error {
   174  	r.mu.Lock()
   175  	defer r.mu.Unlock()
   176  	r.closed = true
   177  	errorList := make([]error, 0, len(r.rowsList))
   178  	for i := 0; i < len(r.rowsList); i++ {
   179  		row := r.rowsList[i]
   180  		err := row.Close()
   181  		if err != nil {
   182  			errorList = append(errorList, err)
   183  		}
   184  	}
   185  	return multierr.Combine(errorList...)
   186  }
   187  
   188  func (r *Rows) Columns() ([]string, error) {
   189  	r.mu.RLock()
   190  	defer r.mu.RUnlock()
   191  	if r.closed {
   192  		return nil, errs.ErrMergerRowsClosed
   193  	}
   194  	return r.columns, nil
   195  }
   196  
   197  func (r *Rows) Err() error {
   198  	r.mu.RLock()
   199  	defer r.mu.RUnlock()
   200  	return r.lastErr
   201  }