github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/merger/aggregatemerger/aggregator/sum.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggregator
    16  
    17  import (
    18  	"reflect"
    19  
    20  	"github.com/ecodeclub/eorm/internal/merger"
    21  
    22  	"github.com/ecodeclub/eorm/internal/merger/internal/errs"
    23  )
    24  
    25  type Sum struct {
    26  	sumColumnInfo merger.ColumnInfo
    27  }
    28  
    29  func (s *Sum) Aggregate(cols [][]any) (any, error) {
    30  	sumFunc, err := s.findSumFunc(cols[0])
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  	return sumFunc(cols, s.sumColumnInfo.Index)
    35  }
    36  
    37  func (s *Sum) findSumFunc(col []any) (func([][]any, int) (any, error), error) {
    38  	var kind reflect.Kind
    39  	sumIndex := s.sumColumnInfo.Index
    40  	if sumIndex < 0 || sumIndex >= len(col) {
    41  		return nil, errs.ErrMergerInvalidAggregateColumnIndex
    42  	}
    43  	kind = reflect.TypeOf(col[sumIndex]).Kind()
    44  	sumFunc, ok := sumAggregateFuncMapping[kind]
    45  	if !ok {
    46  		return nil, errs.ErrMergerAggregateFuncNotFound
    47  	}
    48  	return sumFunc, nil
    49  }
    50  
    51  func (s *Sum) ColumnName() string {
    52  	return s.sumColumnInfo.Name
    53  }
    54  
    55  func NewSum(info merger.ColumnInfo) *Sum {
    56  	return &Sum{
    57  		sumColumnInfo: info,
    58  	}
    59  }
    60  
    61  func sumAggregate[T AggregateElement](cols [][]any, sumIndex int) (any, error) {
    62  	var sum T
    63  	for _, col := range cols {
    64  		sum += col[sumIndex].(T)
    65  	}
    66  	return sum, nil
    67  }
    68  
    69  var sumAggregateFuncMapping = map[reflect.Kind]func([][]any, int) (any, error){
    70  	reflect.Int:     sumAggregate[int],
    71  	reflect.Int8:    sumAggregate[int8],
    72  	reflect.Int16:   sumAggregate[int16],
    73  	reflect.Int32:   sumAggregate[int32],
    74  	reflect.Int64:   sumAggregate[int64],
    75  	reflect.Uint8:   sumAggregate[uint8],
    76  	reflect.Uint16:  sumAggregate[uint16],
    77  	reflect.Uint32:  sumAggregate[uint32],
    78  	reflect.Uint64:  sumAggregate[uint64],
    79  	reflect.Float32: sumAggregate[float32],
    80  	reflect.Float64: sumAggregate[float64],
    81  	reflect.Uint:    sumAggregate[uint],
    82  }