github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/cmd/slicer/cogroup.go (about)

     1  // Copyright 2019 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"flag"
    11  	"fmt"
    12  	"math/rand"
    13  	"os"
    14  	"sort"
    15  	"strconv"
    16  
    17  	"github.com/grailbio/base/log"
    18  	"github.com/grailbio/bigslice"
    19  	"github.com/grailbio/bigslice/exec"
    20  	"github.com/grailbio/bigslice/sliceio"
    21  )
    22  
    23  func randomReader(nshard, nkey int) (slice bigslice.Slice) {
    24  	return bigslice.ReaderFunc(nshard, func(shard int, order *[]int, keys []string, values [][]int) (n int, err error) {
    25  		if *order == nil {
    26  			r := rand.New(rand.NewSource(rand.Int63()))
    27  			*order = r.Perm(nkey)
    28  		}
    29  		var i int
    30  		for i < len(*order) && i < len(keys) {
    31  			keys[i] = fmt.Sprint((*order)[i])
    32  			values[i] = []int{shard<<24 | (*order)[i]}
    33  			i++
    34  		}
    35  		*order = (*order)[i:]
    36  		if len(*order) == 0 {
    37  			log.Printf("shard %d complete", shard)
    38  			return i, sliceio.EOF
    39  		}
    40  		return i, nil
    41  	})
    42  }
    43  
    44  var cogroupTest = bigslice.Func(func(nshard, nkey int) (slice bigslice.Slice) {
    45  	log.Printf("cogroupTest(%d, %d)", nshard, nkey)
    46  	// Each shard produces a (shuffled) set of values for each key.
    47  
    48  	slice = randomReader(nshard, nkey)
    49  	slice = bigslice.Cogroup(slice)
    50  	return
    51  })
    52  
    53  func cogroup(sess *exec.Session, args []string) error {
    54  	var (
    55  		flags  = flag.NewFlagSet("cogroup", flag.ExitOnError)
    56  		nshard = flags.Int("nshard", 64, "number of shards")
    57  		nkey   = flags.Int("nkey", 1e6, "number of keys per shard")
    58  	)
    59  	flags.Usage = func() {
    60  		fmt.Fprintln(os.Stderr, `usage: slicer cogroup [-nshard N] [-nkey N]`)
    61  		flags.PrintDefaults()
    62  		os.Exit(2)
    63  	}
    64  	if err := flags.Parse(args); err != nil {
    65  		log.Fatal(err)
    66  	}
    67  
    68  	ctx := context.Background()
    69  	r, err := sess.Run(ctx, cogroupTest, *nshard, *nkey)
    70  	if err != nil {
    71  		return err
    72  	}
    73  	seen := make([]bool, *nkey)
    74  	scan := r.Scanner()
    75  	defer scan.Close()
    76  	ok := true
    77  	errorf := func(format string, v ...interface{}) {
    78  		log.Error.Printf(format, v...)
    79  		ok = false
    80  	}
    81  	var (
    82  		keystr string
    83  		values [][]int
    84  	)
    85  	for scan.Scan(ctx, &keystr, &values) {
    86  		key, err := strconv.Atoi(keystr)
    87  		if err != nil {
    88  			panic(err)
    89  		}
    90  		if seen[key] {
    91  			errorf("saw key %v multiple times", key)
    92  		}
    93  		seen[key] = true
    94  		if got, want := len(values), *nshard; got != want {
    95  			errorf("wrong number of values for key %d: got %v, want %v", key, got, want)
    96  		} else {
    97  			flat := make([]int, len(values))
    98  			for i, v := range values {
    99  				if got, want := len(v), 1; got != want {
   100  					errorf("wrong number of values for key %d: got %v, want %v", key, got, want)
   101  				}
   102  				flat[i] = v[0]
   103  			}
   104  			sort.Ints(flat)
   105  			for i, v := range flat {
   106  				if got, want := v, i<<24|key; got != want {
   107  					errorf("wrong value for key %d: got %v, want %v", key, got, want)
   108  				}
   109  			}
   110  		}
   111  	}
   112  	if err := scan.Err(); err != nil {
   113  		return err
   114  	}
   115  	for key, saw := range seen {
   116  		if !saw {
   117  			errorf("did not see key %v", key)
   118  		}
   119  	}
   120  	if !ok {
   121  		return errors.New("test errors")
   122  	}
   123  	fmt.Println("ok")
   124  	return nil
   125  }