github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/accum.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package bigslice
     6  
     7  import (
     8  	"context"
     9  	"reflect"
    10  
    11  	"github.com/grailbio/bigslice/frame"
    12  	"github.com/grailbio/bigslice/slicefunc"
    13  	"github.com/grailbio/bigslice/sliceio"
    14  )
    15  
    16  // An Accumulator represents a stateful accumulation of values of
    17  // a certain type. Accumulators maintain their state in memory.
    18  //
    19  // Accumulators should be read only after accumulation is complete.
    20  type Accumulator interface {
    21  	// Accumulate the provided columns of length n.
    22  	Accumulate(in frame.Frame, n int)
    23  	// Read a batch of accumulated values into keys and values. These
    24  	// are slices of the key type and accumulator type respectively.
    25  	Read(keys, values reflect.Value) (int, error)
    26  }
    27  
    28  func canMakeAccumulatorForKey(keyType reflect.Type) bool {
    29  	switch keyType.Kind() {
    30  	case reflect.String, reflect.Int, reflect.Int64:
    31  		return true
    32  	default:
    33  		return false
    34  	}
    35  }
    36  
    37  func makeAccumulator(keyType, accType reflect.Type, fn slicefunc.Func) Accumulator {
    38  	switch keyType.Kind() {
    39  	case reflect.String:
    40  		return &stringAccumulator{
    41  			accType: accType,
    42  			fn:      fn,
    43  			state:   make(map[string]reflect.Value),
    44  		}
    45  	case reflect.Int:
    46  		return &intAccumulator{
    47  			accType: accType,
    48  			fn:      fn,
    49  			state:   make(map[int]reflect.Value),
    50  		}
    51  	case reflect.Int64:
    52  		return &int64Accumulator{
    53  			accType: accType,
    54  			fn:      fn,
    55  			state:   make(map[int64]reflect.Value),
    56  		}
    57  	default:
    58  		return nil
    59  	}
    60  }
    61  
    62  // StringAccumulator accumulates values by string keys.
    63  type stringAccumulator struct {
    64  	accType reflect.Type
    65  	fn      slicefunc.Func
    66  	state   map[string]reflect.Value
    67  }
    68  
    69  func (s *stringAccumulator) Accumulate(in frame.Frame, n int) {
    70  	ctx := context.Background()
    71  	keys := in.Interface(0).([]string)
    72  	args := make([]reflect.Value, in.NumOut())
    73  	for i := 0; i < n; i++ {
    74  		key := keys[i]
    75  		val, ok := s.state[key]
    76  		if !ok {
    77  			val = reflect.Zero(s.accType)
    78  		}
    79  		args[0] = val
    80  		for j := 1; j < in.NumOut(); j++ {
    81  			args[j] = in.Index(j, i)
    82  		}
    83  		s.state[key] = s.fn.Call(ctx, args)[0]
    84  	}
    85  }
    86  
    87  func (s *stringAccumulator) Read(keys, values reflect.Value) (n int, err error) {
    88  	max := keys.Len()
    89  	for key, val := range s.state {
    90  		if n >= max {
    91  			break
    92  		}
    93  		keys.Index(n).Set(reflect.ValueOf(key))
    94  		values.Index(n).Set(val)
    95  		delete(s.state, key)
    96  		n++
    97  	}
    98  	if len(s.state) == 0 {
    99  		return n, sliceio.EOF
   100  	}
   101  	return n, nil
   102  }
   103  
   104  // IntAccumulator accumulates values by integer keys.
   105  type intAccumulator struct {
   106  	accType reflect.Type
   107  	fn      slicefunc.Func
   108  	state   map[int]reflect.Value
   109  }
   110  
   111  func (s *intAccumulator) Accumulate(in frame.Frame, n int) {
   112  	ctx := context.Background()
   113  	keys := in.Interface(0).([]int)
   114  	args := make([]reflect.Value, in.NumOut())
   115  	for i := 0; i < n; i++ {
   116  		key := keys[i]
   117  		val, ok := s.state[key]
   118  		if !ok {
   119  			val = reflect.Zero(s.accType)
   120  		}
   121  		args[0] = val
   122  		for j := 1; j < in.NumOut(); j++ {
   123  			args[j] = in.Index(j, i)
   124  		}
   125  		s.state[key] = s.fn.Call(ctx, args)[0]
   126  	}
   127  }
   128  
   129  func (s *intAccumulator) Read(keys, values reflect.Value) (n int, err error) {
   130  	max := keys.Len()
   131  	for key, val := range s.state {
   132  		if n >= max {
   133  			break
   134  		}
   135  		keys.Index(n).Set(reflect.ValueOf(key))
   136  		values.Index(n).Set(val)
   137  		delete(s.state, key)
   138  		n++
   139  	}
   140  	if len(s.state) == 0 {
   141  		return n, sliceio.EOF
   142  	}
   143  	return n, nil
   144  }
   145  
   146  // Int64Accumulator accumulates values by integer keys.
   147  type int64Accumulator struct {
   148  	accType reflect.Type
   149  	fn      slicefunc.Func
   150  	state   map[int64]reflect.Value
   151  }
   152  
   153  func (s *int64Accumulator) Accumulate(in frame.Frame, n int) {
   154  	ctx := context.Background()
   155  	keys := in.Interface(0).([]int64)
   156  	args := make([]reflect.Value, in.NumOut())
   157  	for i := 0; i < n; i++ {
   158  		key := keys[i]
   159  		val, ok := s.state[key]
   160  		if !ok {
   161  			val = reflect.Zero(s.accType)
   162  		}
   163  		args[0] = val
   164  		for j := 1; j < in.NumOut(); j++ {
   165  			args[j] = in.Index(j, i)
   166  		}
   167  		s.state[key] = s.fn.Call(ctx, args)[0]
   168  	}
   169  }
   170  
   171  func (s *int64Accumulator) Read(keys, values reflect.Value) (n int, err error) {
   172  	max := keys.Len()
   173  	for key, val := range s.state {
   174  		if n >= max {
   175  			break
   176  		}
   177  		keys.Index(n).Set(reflect.ValueOf(key))
   178  		values.Index(n).Set(val)
   179  		delete(s.state, key)
   180  		n++
   181  	}
   182  	if len(s.state) == 0 {
   183  		return n, sliceio.EOF
   184  	}
   185  	return n, nil
   186  }