github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/cogroup.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package bigslice 6 7 import ( 8 "container/heap" 9 "context" 10 "reflect" 11 12 "github.com/grailbio/bigslice/frame" 13 "github.com/grailbio/bigslice/slicefunc" 14 "github.com/grailbio/bigslice/sliceio" 15 "github.com/grailbio/bigslice/slicetype" 16 "github.com/grailbio/bigslice/sortio" 17 "github.com/grailbio/bigslice/typecheck" 18 ) 19 20 type cogroupSlice struct { 21 name Name 22 slices []Slice 23 out []reflect.Type 24 prefix int 25 numShard int 26 } 27 28 // Cogroup returns a slice that, for each key in any slice, contains 29 // the group of values for that key, in each slice. Schematically: 30 // 31 // Cogroup(Slice<tk1, ..., tkp, t11, ..., t1n>, Slice<tk1, ..., tkp, t21, ..., t2n>, ..., Slice<tk1, ..., tkp, tm1, ..., tmn>) 32 // Slice<tk1, ..., tkp, []t11, ..., []t1n, []t21, ..., []tmn> 33 // 34 // It thus implements a form of generalized JOIN and GROUP. 35 // 36 // Cogroup uses the prefix columns of each slice as its key; keys must be 37 // partitionable. 38 // 39 // TODO(marius): don't require spilling to disk when the input data 40 // set is small enough. 41 // 42 // TODO(marius): consider providing a version that returns scanners 43 // in the returned slice, so that we can stream through. This would 44 // require some changes downstream, however, so that buffering and 45 // encoding functionality also know how to read scanner values. 46 func Cogroup(slices ...Slice) Slice { 47 if len(slices) == 0 { 48 typecheck.Panic(1, "cogroup: expected at least one slice") 49 } 50 var keyTypes []reflect.Type 51 for i, slice := range slices { 52 if slice.NumOut() == 0 { 53 typecheck.Panicf(1, "cogroup: slice %d has no columns", i) 54 } 55 if i == 0 { 56 keyTypes = make([]reflect.Type, slice.Prefix()) 57 for j := range keyTypes { 58 keyTypes[j] = slice.Out(j) 59 } 60 } else { 61 if got, want := slice.Prefix(), len(keyTypes); got != want { 62 typecheck.Panicf(1, "cogroup: prefix mismatch: expected %d but got %d", want, got) 63 } 64 for j := range keyTypes { 65 if got, want := slice.Out(j), keyTypes[j]; got != want { 66 typecheck.Panicf(1, "cogroup: key column type mismatch: expected %s but got %s", want, got) 67 } 68 } 69 } 70 } 71 for i := range keyTypes { 72 if !frame.CanHash(keyTypes[i]) { 73 typecheck.Panicf(1, "cogroup: key column(%d) type %s cannot be hashed", i, keyTypes[i]) 74 } 75 if !frame.CanCompare(keyTypes[i]) { 76 typecheck.Panicf(1, "cogroup: key column(%d) type %s cannot be sorted", i, keyTypes[i]) 77 } 78 } 79 out := keyTypes 80 for _, slice := range slices { 81 for i := len(keyTypes); i < slice.NumOut(); i++ { 82 out = append(out, reflect.SliceOf(slice.Out(i))) 83 } 84 } 85 86 // Pick the max of the number of parent shards, so that the input 87 // will be partitioned as widely as the user desires. 88 var numShard int 89 for _, slice := range slices { 90 if slice.NumShard() > numShard { 91 numShard = slice.NumShard() 92 } 93 } 94 95 return &cogroupSlice{ 96 name: MakeName("cogroup"), 97 numShard: numShard, 98 slices: slices, 99 out: out, 100 prefix: len(keyTypes), 101 } 102 } 103 104 func (c *cogroupSlice) Name() Name { return c.name } 105 func (c *cogroupSlice) NumShard() int { return c.numShard } 106 func (c *cogroupSlice) ShardType() ShardType { return HashShard } 107 func (c *cogroupSlice) NumOut() int { return len(c.out) } 108 func (c *cogroupSlice) Out(i int) reflect.Type { return c.out[i] } 109 func (c *cogroupSlice) Prefix() int { return c.prefix } 110 func (c *cogroupSlice) NumDep() int { return len(c.slices) } 111 func (c *cogroupSlice) Dep(i int) Dep { return Dep{c.slices[i], true, nil, false} } 112 func (*cogroupSlice) Combiner() slicefunc.Func { return slicefunc.Nil } 113 114 type cogroupReader struct { 115 err error 116 op *cogroupSlice 117 118 readers []sliceio.Reader 119 120 heap *sortio.FrameBufferHeap 121 } 122 123 func (c *cogroupReader) Read(ctx context.Context, out frame.Frame) (int, error) { 124 const ( 125 bufferSize = 128 126 spillSize = 1 << 25 127 ) 128 if c.err != nil { 129 return 0, c.err 130 } 131 if c.heap == nil { 132 c.heap = new(sortio.FrameBufferHeap) 133 c.heap.Buffers = make([]*sortio.FrameBuffer, 0, len(c.readers)) 134 // Maintain a compare buffer that's used to compare values across 135 // the heterogeneously typed buffers. 136 // TODO(marius): the extra copy and indirection here is unnecessary. 137 lessBuf := frame.Make(slicetype.New(c.op.out[:c.op.prefix]...), 2, 2).Prefixed(c.op.prefix) 138 c.heap.LessFunc = func(i, j int) bool { 139 ib, jb := c.heap.Buffers[i], c.heap.Buffers[j] 140 for i := 0; i < c.op.prefix; i++ { 141 lessBuf.Index(i, 0).Set(ib.Frame.Index(i, ib.Index)) 142 lessBuf.Index(i, 1).Set(jb.Frame.Index(i, jb.Index)) 143 } 144 return lessBuf.Less(0, 1) 145 } 146 147 // Sort each partition one-by-one. Since tasks are scheduled 148 // to map onto a single CPU, we attain parallelism through sharding 149 // at a higher level. 150 for i := range c.readers { 151 // Do the actual sort. Aim for ~30 MB spill files. 152 // TODO(marius): make spill sizes configurable, or dependent 153 // on the environment: for example, we could pass down a memory 154 // allotment to each task from the scheduler. 155 var sorted sliceio.Reader 156 sorted, c.err = sortio.SortReader(ctx, spillSize, c.op.Dep(i), c.readers[i]) 157 if c.err != nil { 158 // TODO(marius): in case this fails, we may leave open file 159 // descriptors. We should make sure we close readers that 160 // implement Discard. 161 return 0, c.err 162 } 163 buf := &sortio.FrameBuffer{ 164 Frame: frame.Make(c.op.Dep(i), bufferSize, bufferSize), 165 Reader: sorted, 166 Off: i * bufferSize, 167 } 168 switch err := buf.Fill(ctx); { 169 case err == sliceio.EOF: 170 // No data. Skip. 171 case err != nil: 172 c.err = err 173 return 0, err 174 default: 175 c.heap.Buffers = append(c.heap.Buffers, buf) 176 } 177 } 178 } 179 heap.Init(c.heap) 180 181 // Now that we're sorted, perform a merge from each dependency. 182 var ( 183 n int 184 max = out.Len() 185 lessBuf = frame.Make(slicetype.New(c.op.out[:c.op.prefix]...), 2, 2).Prefixed(c.op.prefix) 186 ) 187 if max == 0 { 188 panic("bigslice.Cogroup: max == 0") 189 } 190 // BUG: this is gnarly 191 for n < max && len(c.heap.Buffers) > 0 { 192 // First, gather all the records that have the same key. 193 row := make([]frame.Frame, len(c.readers)) 194 var ( 195 key = make([]reflect.Value, c.op.prefix) 196 last = -1 197 ) 198 // TODO(marius): the extra copy and indirection here is unnecessary. 199 less := func() bool { 200 buf := c.heap.Buffers[0] 201 for i := 0; i < c.op.prefix; i++ { 202 lessBuf.Index(i, 0).Set(row[last].Index(i, 0)) 203 lessBuf.Index(i, 1).Set(buf.Frame.Index(i, buf.Index)) 204 } 205 return lessBuf.Less(0, 1) 206 } 207 208 for last < 0 || len(c.heap.Buffers) > 0 && !less() { 209 // first key: need to pick the smallest one 210 buf := c.heap.Buffers[0] 211 idx := buf.Off / bufferSize 212 row[idx] = frame.AppendFrame(row[idx], buf.Slice(buf.Index, buf.Index+1)) 213 buf.Index++ 214 if last < 0 { 215 for i := 0; i < c.op.prefix; i++ { 216 key[i] = row[idx].Index(i, 0) 217 } 218 } 219 last = idx 220 if buf.Index == buf.Len { 221 if err := buf.Fill(ctx); err != nil && err != sliceio.EOF { 222 c.err = err 223 return n, err 224 } else if err == sliceio.EOF { 225 heap.Remove(c.heap, 0) 226 } else { 227 heap.Fix(c.heap, 0) 228 } 229 } else { 230 heap.Fix(c.heap, 0) 231 } 232 } 233 234 // Now that we've gathered all the row values for a given key, 235 // push them into our output. 236 var j int 237 for i := range key { 238 out.Index(j, n).Set(key[i]) 239 j++ 240 } 241 // Note that here we are assuming that the key column is always first; 242 // elsewhere we don't really make this assumption, even though it is 243 // enforced when constructing a cogroup. 244 for i := range row { 245 typ := c.op.Dep(i) 246 if row[i].Len() == 0 { 247 for k := len(key); k < typ.NumOut(); k++ { 248 out.Index(j, n).Set(reflect.Zero(c.op.out[j])) 249 j++ 250 } 251 } else { 252 for k := len(key); k < typ.NumOut(); k++ { 253 // TODO(marius): precompute type checks here. 254 out.Index(j, n).Set(row[i].Value(k)) 255 j++ 256 } 257 } 258 } 259 n++ 260 } 261 if n == 0 { 262 c.err = sliceio.EOF 263 } 264 return n, c.err 265 } 266 267 func (c *cogroupSlice) Reader(shard int, deps []sliceio.Reader) sliceio.Reader { 268 return &cogroupReader{ 269 op: c, 270 readers: deps, 271 } 272 }