github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/aggregatemerger/merger.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggregatemerger
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"sync"
    22  	_ "unsafe"
    23  
    24  	"github.com/ecodeclub/eorm/internal/rows"
    25  
    26  	"github.com/ecodeclub/ekit/sqlx"
    27  
    28  	"github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator"
    29  	"github.com/ecodeclub/eorm/internal/merger/internal/errs"
    30  	"go.uber.org/multierr"
    31  )
    32  
    33  // Merger 该实现不支持group by操作,并且聚合函数查询应该只返回一行数据。
    34  type Merger struct {
    35  	aggregators []aggregator.Aggregator
    36  	colNames    []string
    37  }
    38  
    39  func NewMerger(aggregators ...aggregator.Aggregator) *Merger {
    40  	cols := make([]string, 0, len(aggregators))
    41  	for _, agg := range aggregators {
    42  		cols = append(cols, agg.ColumnName())
    43  	}
    44  	return &Merger{
    45  		aggregators: aggregators,
    46  		colNames:    cols,
    47  	}
    48  }
    49  
    50  func (m *Merger) Merge(ctx context.Context, results []rows.Rows) (rows.Rows, error) {
    51  	if ctx.Err() != nil {
    52  		return nil, ctx.Err()
    53  	}
    54  	if len(results) == 0 {
    55  		return nil, errs.ErrMergerEmptyRows
    56  	}
    57  	for _, res := range results {
    58  		if res == nil {
    59  			return nil, errs.ErrMergerRowsIsNull
    60  		}
    61  	}
    62  
    63  	return &Rows{
    64  		rowsList:    results,
    65  		aggregators: m.aggregators,
    66  		mu:          &sync.RWMutex{},
    67  		//聚合函数AVG传递到各个sql.Rows时会被转化为SUM和COUNT,这是一个对外不可见的转化。
    68  		//所以merger.Rows的列名及顺序是由上方aggregator出现的顺序及ColumnName()的返回值决定的而不是sql.Rows。
    69  		columns: m.colNames,
    70  	}, nil
    71  
    72  }
    73  
    74  type Rows struct {
    75  	rowsList    []rows.Rows
    76  	aggregators []aggregator.Aggregator
    77  	closed      bool
    78  	mu          *sync.RWMutex
    79  	lastErr     error
    80  	cur         []any
    81  	columns     []string
    82  	nextCalled  bool
    83  }
    84  
    85  func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) {
    86  	return r.rowsList[0].ColumnTypes()
    87  }
    88  
    89  func (*Rows) NextResultSet() bool {
    90  	return false
    91  }
    92  
    93  func (r *Rows) Next() bool {
    94  	r.mu.Lock()
    95  	if r.closed || r.lastErr != nil {
    96  		r.mu.Unlock()
    97  		return false
    98  	}
    99  	if r.nextCalled {
   100  		r.mu.Unlock()
   101  		_ = r.Close()
   102  		return false
   103  	}
   104  
   105  	rowsData, err := r.getSqlRowsData()
   106  	r.nextCalled = true
   107  	if err != nil {
   108  		r.lastErr = err
   109  		r.mu.Unlock()
   110  		_ = r.Close()
   111  		return false
   112  	}
   113  	// 进行聚合函数计算
   114  	res, err := r.executeAggregateCalculation(rowsData)
   115  	if err != nil {
   116  		r.lastErr = err
   117  		r.mu.Unlock()
   118  		_ = r.Close()
   119  		return false
   120  	}
   121  	r.cur = res
   122  	r.mu.Unlock()
   123  	return true
   124  
   125  }
   126  
   127  // getAggregateInfo 进行aggregate运算
   128  func (r *Rows) executeAggregateCalculation(rowsData [][]any) ([]any, error) {
   129  	res := make([]any, 0, len(r.aggregators))
   130  	for _, agg := range r.aggregators {
   131  		val, err := agg.Aggregate(rowsData)
   132  		if err != nil {
   133  			return nil, err
   134  		}
   135  
   136  		res = append(res, val)
   137  	}
   138  	return res, nil
   139  }
   140  
   141  // getSqlRowData 从sqlRows里面获取数据
   142  func (r *Rows) getSqlRowsData() ([][]any, error) {
   143  	// 所有sql.Rows的数据
   144  	rowsData := make([][]any, 0, len(r.rowsList))
   145  	for _, row := range r.rowsList {
   146  		colData, err := r.getSqlRowData(row)
   147  		if err != nil {
   148  			return nil, err
   149  		}
   150  		rowsData = append(rowsData, colData)
   151  	}
   152  	return rowsData, nil
   153  }
   154  func (*Rows) getSqlRowData(row rows.Rows) ([]any, error) {
   155  	var colsData []any
   156  	var err error
   157  	scanner, err := sqlx.NewSQLRowsScanner(row)
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  	colsData, err = scanner.Scan()
   162  	if errors.Is(err, sqlx.ErrNoMoreRows) {
   163  		return nil, errs.ErrMergerAggregateHasEmptyRows
   164  	}
   165  	return colsData, err
   166  }
   167  
   168  func (r *Rows) Scan(dest ...any) error {
   169  	r.mu.Lock()
   170  	defer r.mu.Unlock()
   171  	if r.lastErr != nil {
   172  		return r.lastErr
   173  	}
   174  	if r.closed {
   175  		return errs.ErrMergerRowsClosed
   176  	}
   177  
   178  	if len(r.cur) == 0 {
   179  		return errs.ErrMergerScanNotNext
   180  	}
   181  	for i := 0; i < len(dest); i++ {
   182  		err := rows.ConvertAssign(dest[i], r.cur[i])
   183  		if err != nil {
   184  			return err
   185  		}
   186  	}
   187  	return nil
   188  }
   189  
   190  func (r *Rows) Close() error {
   191  	r.mu.Lock()
   192  	defer r.mu.Unlock()
   193  	r.closed = true
   194  	errorList := make([]error, 0, len(r.rowsList))
   195  	for i := 0; i < len(r.rowsList); i++ {
   196  		row := r.rowsList[i]
   197  		err := row.Close()
   198  		if err != nil {
   199  			errorList = append(errorList, err)
   200  		}
   201  	}
   202  	return multierr.Combine(errorList...)
   203  }
   204  
   205  func (r *Rows) Columns() ([]string, error) {
   206  	r.mu.RLock()
   207  	defer r.mu.RUnlock()
   208  	if r.closed {
   209  		return nil, errs.ErrMergerRowsClosed
   210  	}
   211  	return r.columns, nil
   212  }
   213  
   214  func (r *Rows) Err() error {
   215  	r.mu.RLock()
   216  	defer r.mu.RUnlock()
   217  	return r.lastErr
   218  }