github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/slicetest/print.go (about)

     1  package slicetest
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"reflect"
     8  	"sort"
     9  	"strings"
    10  
    11  	"github.com/grailbio/bigslice"
    12  	"github.com/grailbio/bigslice/exec"
    13  )
    14  
    15  // Print prints slice to stdout in a deterministic order. This is useful for use
    16  // in slice function examples, as we can rely on the deterministic order in our
    17  // expected output. Print uses local evaluation, so all user functions are
    18  // executed within the same process. This makes it safe and convenient to use
    19  // shared memory in slice operations.
    20  func Print(slice bigslice.Slice) {
    21  	fn := bigslice.Func(func() bigslice.Slice { return slice })
    22  	sess := exec.Start(exec.Local)
    23  	ctx := context.Background()
    24  	res, err := sess.Run(ctx, fn)
    25  	if err != nil {
    26  		log.Panicf("unhandled error running session: %v", err)
    27  	}
    28  	var rows [][]reflect.Value
    29  	vs := make([]reflect.Value, slice.NumOut())
    30  	for i := range vs {
    31  		vs[i] = reflect.New(slice.Out(i))
    32  	}
    33  	args := make([]interface{}, slice.NumOut())
    34  	for i := range args {
    35  		args[i] = vs[i].Interface()
    36  	}
    37  	scanner := res.Scanner()
    38  	for scanner.Scan(ctx, args...) {
    39  		row := make([]reflect.Value, len(args))
    40  		for i := range row {
    41  			row[i] = reflect.ValueOf(reflect.Indirect(vs[i]).Interface())
    42  		}
    43  		rows = append(rows, row)
    44  	}
    45  	if scanner.Err() != nil {
    46  		log.Panicf("unhandled error scanning: %v", err)
    47  	}
    48  	canonicalize(rows)
    49  	for _, row := range rows {
    50  		strs := make([]string, len(row))
    51  		for j := range strs {
    52  			strs[j] = fmt.Sprintf("%v", row[j])
    53  		}
    54  		fmt.Println(strings.Join(strs, " "))
    55  	}
    56  }
    57  
    58  // canonicalize deep sorts rows to make the order deterministic.
    59  func canonicalize(rows [][]reflect.Value) {
    60  	for _, row := range rows {
    61  		for _, v := range row {
    62  			valueCanonicalize(v)
    63  		}
    64  	}
    65  	sort.Sort(canonical(rows))
    66  }
    67  
    68  // canonical provides a canonical sort of rows for printing. This used for
    69  // printing slices in deterministic order.
    70  type canonical [][]reflect.Value
    71  
    72  func (c canonical) Len() int      { return len(c) }
    73  func (c canonical) Swap(i, j int) { c[i], c[j] = c[j], c[i] }
    74  func (c canonical) Less(i, j int) bool {
    75  	for col := range c[i] {
    76  		if valueLess(c[i][col], c[j][col]) {
    77  			return true
    78  		}
    79  		if valueLess(c[j][col], c[i][col]) {
    80  			return false
    81  		}
    82  	}
    83  	return false
    84  }
    85  
    86  func valueLess(lhs, rhs reflect.Value) bool {
    87  	switch lhs.Kind() {
    88  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    89  		return lhs.Int() < rhs.Int()
    90  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
    91  		return lhs.Uint() < rhs.Uint()
    92  	case reflect.String:
    93  		return lhs.String() < rhs.String()
    94  	case reflect.Slice:
    95  		for i := 0; i < rhs.Len(); i++ {
    96  			if lhs.Len() < i {
    97  				return true
    98  			}
    99  			if valueLess(lhs.Index(i), rhs.Index(i)) {
   100  				return true
   101  			}
   102  			if valueLess(lhs.Index(i), rhs.Index(i)) {
   103  				return false
   104  			}
   105  		}
   106  		return false
   107  	}
   108  	log.Panicf("cannot compare %v and %v", lhs.Kind(), rhs.Kind())
   109  	return false
   110  }
   111  
   112  func valueCanonicalize(v reflect.Value) {
   113  	switch v.Kind() {
   114  	case reflect.Slice:
   115  		for i := 0; i < v.Len(); i++ {
   116  			valueCanonicalize(v.Index(i))
   117  		}
   118  		sort.Slice(v.Interface(), func(i, j int) bool {
   119  			return valueLess(v.Index(i), v.Index(j))
   120  		})
   121  	}
   122  }