github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/cmd/slicer/cogroup.go (about) 1 // Copyright 2019 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package main 6 7 import ( 8 "context" 9 "errors" 10 "flag" 11 "fmt" 12 "math/rand" 13 "os" 14 "sort" 15 "strconv" 16 17 "github.com/grailbio/base/log" 18 "github.com/grailbio/bigslice" 19 "github.com/grailbio/bigslice/exec" 20 "github.com/grailbio/bigslice/sliceio" 21 ) 22 23 func randomReader(nshard, nkey int) (slice bigslice.Slice) { 24 return bigslice.ReaderFunc(nshard, func(shard int, order *[]int, keys []string, values [][]int) (n int, err error) { 25 if *order == nil { 26 r := rand.New(rand.NewSource(rand.Int63())) 27 *order = r.Perm(nkey) 28 } 29 var i int 30 for i < len(*order) && i < len(keys) { 31 keys[i] = fmt.Sprint((*order)[i]) 32 values[i] = []int{shard<<24 | (*order)[i]} 33 i++ 34 } 35 *order = (*order)[i:] 36 if len(*order) == 0 { 37 log.Printf("shard %d complete", shard) 38 return i, sliceio.EOF 39 } 40 return i, nil 41 }) 42 } 43 44 var cogroupTest = bigslice.Func(func(nshard, nkey int) (slice bigslice.Slice) { 45 log.Printf("cogroupTest(%d, %d)", nshard, nkey) 46 // Each shard produces a (shuffled) set of values for each key. 47 48 slice = randomReader(nshard, nkey) 49 slice = bigslice.Cogroup(slice) 50 return 51 }) 52 53 func cogroup(sess *exec.Session, args []string) error { 54 var ( 55 flags = flag.NewFlagSet("cogroup", flag.ExitOnError) 56 nshard = flags.Int("nshard", 64, "number of shards") 57 nkey = flags.Int("nkey", 1e6, "number of keys per shard") 58 ) 59 flags.Usage = func() { 60 fmt.Fprintln(os.Stderr, `usage: slicer cogroup [-nshard N] [-nkey N]`) 61 flags.PrintDefaults() 62 os.Exit(2) 63 } 64 if err := flags.Parse(args); err != nil { 65 log.Fatal(err) 66 } 67 68 ctx := context.Background() 69 r, err := sess.Run(ctx, cogroupTest, *nshard, *nkey) 70 if err != nil { 71 return err 72 } 73 seen := make([]bool, *nkey) 74 scan := r.Scanner() 75 defer scan.Close() 76 ok := true 77 errorf := func(format string, v ...interface{}) { 78 log.Error.Printf(format, v...) 79 ok = false 80 } 81 var ( 82 keystr string 83 values [][]int 84 ) 85 for scan.Scan(ctx, &keystr, &values) { 86 key, err := strconv.Atoi(keystr) 87 if err != nil { 88 panic(err) 89 } 90 if seen[key] { 91 errorf("saw key %v multiple times", key) 92 } 93 seen[key] = true 94 if got, want := len(values), *nshard; got != want { 95 errorf("wrong number of values for key %d: got %v, want %v", key, got, want) 96 } else { 97 flat := make([]int, len(values)) 98 for i, v := range values { 99 if got, want := len(v), 1; got != want { 100 errorf("wrong number of values for key %d: got %v, want %v", key, got, want) 101 } 102 flat[i] = v[0] 103 } 104 sort.Ints(flat) 105 for i, v := range flat { 106 if got, want := v, i<<24|key; got != want { 107 errorf("wrong value for key %d: got %v, want %v", key, got, want) 108 } 109 } 110 } 111 } 112 if err := scan.Err(); err != nil { 113 return err 114 } 115 for key, saw := range seen { 116 if !saw { 117 errorf("did not see key %v", key) 118 } 119 } 120 if !ok { 121 return errors.New("test errors") 122 } 123 fmt.Println("ok") 124 return nil 125 }