github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/accum.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package bigslice 6 7 import ( 8 "context" 9 "reflect" 10 11 "github.com/grailbio/bigslice/frame" 12 "github.com/grailbio/bigslice/slicefunc" 13 "github.com/grailbio/bigslice/sliceio" 14 ) 15 16 // An Accumulator represents a stateful accumulation of values of 17 // a certain type. Accumulators maintain their state in memory. 18 // 19 // Accumulators should be read only after accumulation is complete. 20 type Accumulator interface { 21 // Accumulate the provided columns of length n. 22 Accumulate(in frame.Frame, n int) 23 // Read a batch of accumulated values into keys and values. These 24 // are slices of the key type and accumulator type respectively. 25 Read(keys, values reflect.Value) (int, error) 26 } 27 28 func canMakeAccumulatorForKey(keyType reflect.Type) bool { 29 switch keyType.Kind() { 30 case reflect.String, reflect.Int, reflect.Int64: 31 return true 32 default: 33 return false 34 } 35 } 36 37 func makeAccumulator(keyType, accType reflect.Type, fn slicefunc.Func) Accumulator { 38 switch keyType.Kind() { 39 case reflect.String: 40 return &stringAccumulator{ 41 accType: accType, 42 fn: fn, 43 state: make(map[string]reflect.Value), 44 } 45 case reflect.Int: 46 return &intAccumulator{ 47 accType: accType, 48 fn: fn, 49 state: make(map[int]reflect.Value), 50 } 51 case reflect.Int64: 52 return &int64Accumulator{ 53 accType: accType, 54 fn: fn, 55 state: make(map[int64]reflect.Value), 56 } 57 default: 58 return nil 59 } 60 } 61 62 // StringAccumulator accumulates values by string keys. 63 type stringAccumulator struct { 64 accType reflect.Type 65 fn slicefunc.Func 66 state map[string]reflect.Value 67 } 68 69 func (s *stringAccumulator) Accumulate(in frame.Frame, n int) { 70 ctx := context.Background() 71 keys := in.Interface(0).([]string) 72 args := make([]reflect.Value, in.NumOut()) 73 for i := 0; i < n; i++ { 74 key := keys[i] 75 val, ok := s.state[key] 76 if !ok { 77 val = reflect.Zero(s.accType) 78 } 79 args[0] = val 80 for j := 1; j < in.NumOut(); j++ { 81 args[j] = in.Index(j, i) 82 } 83 s.state[key] = s.fn.Call(ctx, args)[0] 84 } 85 } 86 87 func (s *stringAccumulator) Read(keys, values reflect.Value) (n int, err error) { 88 max := keys.Len() 89 for key, val := range s.state { 90 if n >= max { 91 break 92 } 93 keys.Index(n).Set(reflect.ValueOf(key)) 94 values.Index(n).Set(val) 95 delete(s.state, key) 96 n++ 97 } 98 if len(s.state) == 0 { 99 return n, sliceio.EOF 100 } 101 return n, nil 102 } 103 104 // IntAccumulator accumulates values by integer keys. 105 type intAccumulator struct { 106 accType reflect.Type 107 fn slicefunc.Func 108 state map[int]reflect.Value 109 } 110 111 func (s *intAccumulator) Accumulate(in frame.Frame, n int) { 112 ctx := context.Background() 113 keys := in.Interface(0).([]int) 114 args := make([]reflect.Value, in.NumOut()) 115 for i := 0; i < n; i++ { 116 key := keys[i] 117 val, ok := s.state[key] 118 if !ok { 119 val = reflect.Zero(s.accType) 120 } 121 args[0] = val 122 for j := 1; j < in.NumOut(); j++ { 123 args[j] = in.Index(j, i) 124 } 125 s.state[key] = s.fn.Call(ctx, args)[0] 126 } 127 } 128 129 func (s *intAccumulator) Read(keys, values reflect.Value) (n int, err error) { 130 max := keys.Len() 131 for key, val := range s.state { 132 if n >= max { 133 break 134 } 135 keys.Index(n).Set(reflect.ValueOf(key)) 136 values.Index(n).Set(val) 137 delete(s.state, key) 138 n++ 139 } 140 if len(s.state) == 0 { 141 return n, sliceio.EOF 142 } 143 return n, nil 144 } 145 146 // Int64Accumulator accumulates values by integer keys. 147 type int64Accumulator struct { 148 accType reflect.Type 149 fn slicefunc.Func 150 state map[int64]reflect.Value 151 } 152 153 func (s *int64Accumulator) Accumulate(in frame.Frame, n int) { 154 ctx := context.Background() 155 keys := in.Interface(0).([]int64) 156 args := make([]reflect.Value, in.NumOut()) 157 for i := 0; i < n; i++ { 158 key := keys[i] 159 val, ok := s.state[key] 160 if !ok { 161 val = reflect.Zero(s.accType) 162 } 163 args[0] = val 164 for j := 1; j < in.NumOut(); j++ { 165 args[j] = in.Index(j, i) 166 } 167 s.state[key] = s.fn.Call(ctx, args)[0] 168 } 169 } 170 171 func (s *int64Accumulator) Read(keys, values reflect.Value) (n int, err error) { 172 max := keys.Len() 173 for key, val := range s.state { 174 if n >= max { 175 break 176 } 177 keys.Index(n).Set(reflect.ValueOf(key)) 178 values.Index(n).Set(val) 179 delete(s.state, key) 180 n++ 181 } 182 if len(s.state) == 0 { 183 return n, sliceio.EOF 184 } 185 return n, nil 186 }